# NB02: All Constraints Training v6.0

**Constraint-Based Architectural NCA - Step B v6.0**

**Version:** 6.0 (Clean)
**Date:** December 2025
**Purpose:** Clean implementation with only working components

---

## What This Version Does

v6.0 is a **clean restart** removing all experimental porosity losses that don't work.

### Removed (Useless)
- CorridorPorosityLoss
- LocalPorosityLoss
- InteriorVoidLoss
- CorridorBlobPenalty
- MultiScaleLocalPorosityLoss
- LowZFarFromBuildingSuppressionLoss

### Kept (Working)
- LocalLegalityLoss
- CorridorCoverageLoss (A: fill corridor)
- CorridorSpillLoss (B: stay in corridor)
- GroundOpennessLoss
- ThicknessLoss
- SparsityLoss
- FacadeContactLoss
- AccessConnectivityLoss
- LoadPathLoss
- CantileverLoss
- DensityPenalty
- TotalVariation3D

### Config Reset
- corridor_width: 4 (was 1)
- vertical_envelope: 3 (was 1)
- Balanced weights

---

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/Constraint-NCA'
print(f'Project root: {PROJECT_ROOT}')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from typing import Dict, Tuple, List
import json
from datetime import datetime
from tqdm.notebook import tqdm
import os
from itertools import combinations

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [None]:
# Load base config
with open(f'{PROJECT_ROOT}/config_step_b.json', 'r') as f:
    CONFIG = json.load(f)

# v6.0 CLEAN configuration
CONFIG.update({
    'epochs': 500,
    'steps_min': 30,
    'steps_max': 50,
    'difficulty': 'easy',
    'log_every': 20,
    'viz_every': 100,
    'save_every': 100,
    # CLEAN CONFIG - sensible defaults
    'corridor_width': 4,
    'max_thickness': 2,
    'max_facade_contact': 0.15,
    'vertical_envelope': 3,
})

print('v6.0 CLEAN Configuration')
print(f"  corridor_width: {CONFIG['corridor_width']}")
print(f"  max_thickness: {CONFIG['max_thickness']}")
print(f"  vertical_envelope: {CONFIG['vertical_envelope']}")
print(f"  epochs: {CONFIG['epochs']}")

## 2. Core Components

In [None]:
class Perceive3D(nn.Module):
    """3D Sobel perception."""

    def __init__(self, n_channels: int = 8):
        super().__init__()
        self.n_channels = n_channels
        sobel_x = self._create_sobel_kernel('x')
        sobel_y = self._create_sobel_kernel('y')
        sobel_z = self._create_sobel_kernel('z')
        identity = self._create_identity_kernel()
        kernels = torch.stack([identity, sobel_x, sobel_y, sobel_z], dim=0)
        self.register_buffer('kernels', kernels)

    def _create_sobel_kernel(self, direction: str) -> torch.Tensor:
        derivative = torch.tensor([-1., 0., 1.])
        smoothing = torch.tensor([1., 2., 1.])
        if direction == 'x':
            kernel = torch.einsum('i,j,k->ijk', smoothing, smoothing, derivative)
        elif direction == 'y':
            kernel = torch.einsum('i,j,k->ijk', smoothing, derivative, smoothing)
        elif direction == 'z':
            kernel = torch.einsum('i,j,k->ijk', derivative, smoothing, smoothing)
        return kernel / 16.0

    def _create_identity_kernel(self) -> torch.Tensor:
        kernel = torch.zeros(3, 3, 3)
        kernel[1, 1, 1] = 1.0
        return kernel

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = x.shape
        x_padded = F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate')
        outputs = []
        for k in range(4):
            kernel = self.kernels[k:k+1].unsqueeze(0).expand(C, 1, 3, 3, 3)
            out = F.conv3d(x_padded, kernel, padding=0, groups=C)
            outputs.append(out)
        return torch.cat(outputs, dim=1)

In [None]:
class UrbanPavilionNCA(nn.Module):
    """Neural Cellular Automaton for urban pavilion generation."""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        n_channels = config['n_channels']
        hidden_dim = config['hidden_dim']
        perception_dim = n_channels * 4
        n_grown = config['n_grown']

        self.perceive = Perceive3D(n_channels)
        self.update_net = nn.Sequential(
            nn.Conv3d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, n_grown, 1),
        )
        self._init_weights()

    def _init_weights(self):
        gain = self.config['xavier_gain']
        for m in self.update_net:
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform_(m.weight, gain=gain)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        last_layer = self.update_net[-1]
        with torch.no_grad():
            last_layer.bias[0] = self.config['structure_bias']
            last_layer.bias[1] = self.config['surface_bias']

    def forward(self, state: torch.Tensor, steps: int = 1) -> torch.Tensor:
        for _ in range(steps):
            state = self._step(state)
        return state

    def _step(self, state: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = state.shape
        cfg = self.config
        perception = self.perceive(state)
        delta = self.update_net(perception)

        if self.training:
            fire_mask = (torch.rand(B, 1, D, H, W, device=state.device) < cfg['fire_rate']).float()
            delta = delta * fire_mask

        grown_start = cfg['n_frozen']
        grown_new = state[:, grown_start:] + cfg['update_scale'] * delta
        grown_new = torch.clamp(grown_new, 0.0, 1.0)

        existing = state[:, cfg['ch_existing']:cfg['ch_existing']+1]
        available_mask = 1.0 - existing
        struct_new = grown_new[:, 0:1] * available_mask
        grown_masked = torch.cat([struct_new, grown_new[:, 1:]], dim=1)
        new_state = torch.cat([state[:, :grown_start], grown_masked], dim=1)
        return new_state

    def grow(self, seed: torch.Tensor, steps: int = 50) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            return self.forward(seed, steps)

In [None]:
class UrbanSceneGenerator:
    """Generate urban scenes."""

    def __init__(self, config: dict):
        self.config = config
        self.G = config['grid_size']
        self.C = config['n_channels']

    def generate(self, difficulty: str = 'easy', device: str = 'cuda') -> Tuple[torch.Tensor, dict]:
        G = self.G
        cfg = self.config
        state = torch.zeros(1, self.C, G, G, G, device=device)
        state[:, cfg['ch_ground'], 0, :, :] = 1.0

        params = self._get_difficulty_params(difficulty)
        building_info = self._place_buildings(state, params)
        access_info = self._place_access_points(state, params, building_info)
        anchor_info = self._generate_anchor_zones(state, params, building_info, access_info)

        metadata = {
            'difficulty': difficulty,
            'buildings': building_info,
            'access_points': access_info,
            'anchor_zones': anchor_info,
        }
        return state, metadata

    def _get_difficulty_params(self, difficulty: str) -> dict:
        G = self.G
        if difficulty == 'easy':
            return {
                'n_buildings': 2, 'height_range': (12, 16), 'height_variance': False,
                'width_range': (8, 12), 'gap_width': random.randint(14, 18),
                'n_ground_access': 1, 'n_elevated_access': 1, 'anchor_budget': 0.10,
            }
        elif difficulty == 'medium':
            return {
                'n_buildings': 2, 'height_range': (10, 20), 'height_variance': True,
                'width_range': (6, 10), 'gap_width': random.randint(10, 14),
                'n_ground_access': random.randint(1, 2), 'n_elevated_access': random.randint(1, 2),
                'anchor_budget': 0.07,
            }
        else:
            return {
                'n_buildings': random.randint(2, 4), 'height_range': (8, 24), 'height_variance': True,
                'width_range': (5, 8), 'gap_width': random.randint(6, 10),
                'n_ground_access': random.randint(2, 3), 'n_elevated_access': random.randint(2, 3),
                'anchor_budget': 0.05,
            }

    def _place_buildings(self, state: torch.Tensor, params: dict) -> list:
        G = self.G
        ch = self.config['ch_existing']
        buildings = []
        gap_width = params['gap_width']
        gap_center = G // 2

        w1 = random.randint(*params['width_range'])
        d1 = random.randint(G//2, G-2)
        h1 = random.randint(*params['height_range'])
        x1_end = gap_center - gap_width // 2
        x1_start = max(0, x1_end - w1)
        state[:, ch, :h1, :d1, x1_start:x1_end] = 1.0
        buildings.append({'x': (x1_start, x1_end), 'y': (0, d1), 'z': (0, h1),
                         'gap_facing_x': x1_end, 'side': 'left'})

        w2 = random.randint(*params['width_range'])
        d2 = random.randint(G//2, G-2)
        h2 = h1 if not params['height_variance'] else random.randint(*params['height_range'])
        x2_start = gap_center + gap_width // 2
        x2_end = min(G, x2_start + w2)
        state[:, ch, :h2, :d2, x2_start:x2_end] = 1.0
        buildings.append({'x': (x2_start, x2_end), 'y': (0, d2), 'z': (0, h2),
                         'gap_facing_x': x2_start, 'side': 'right'})
        return buildings

    def _place_access_points(self, state: torch.Tensor, params: dict, buildings: list) -> list:
        G = self.G
        ch = self.config['ch_access']
        access_points = []
        left_buildings = [b for b in buildings if b['side'] == 'left']
        right_buildings = [b for b in buildings if b['side'] == 'right']
        gap_x_min = max(b['gap_facing_x'] for b in left_buildings) if left_buildings else 0
        gap_x_max = min(b['gap_facing_x'] for b in right_buildings) if right_buildings else G

        for i in range(params.get('n_ground_access', 1)):
            x = random.randint(gap_x_min + 1, gap_x_max - 3)
            y = random.randint(0, G - 3)
            state[:, ch, 0:2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': 0, 'type': 'ground'})

        for i in range(params.get('n_elevated_access', 1)):
            building = random.choice(buildings)
            bz_max = building['z'][1]
            is_left = building['side'] == 'left'
            z = random.randint(3, max(4, bz_max - 2))
            y = random.randint(building['y'][0], min(building['y'][1] - 2, building['y'][0] + G//3))
            x = building['x'][1] if is_left else building['x'][0] - 2
            x = max(0, min(G - 2, x))
            state[:, ch, z:z+2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': z, 'type': 'elevated'})
        return access_points

    def _generate_anchor_zones(self, state: torch.Tensor, params: dict,
                               buildings: list, access_points: list) -> dict:
        G = self.G
        ch = self.config['ch_anchors']
        street_levels = self.config['street_levels']
        existing = state[:, self.config['ch_existing'], 0, :, :]
        street_mask = 1.0 - existing
        anchors = torch.zeros(1, 1, G, G, G, device=state.device)

        for ap in access_points:
            if ap['type'] == 'ground':
                x, y = ap['x'], ap['y']
                for z in range(street_levels):
                    anchors[:, 0, z, max(0,y-2):min(G,y+4), max(0,x-2):min(G,x+4)] = 1.0

        for building in buildings:
            by_start, by_end = building['y']
            gap_x = building['gap_facing_x']
            is_left = building['side'] == 'left'
            x_start = gap_x if is_left else gap_x - 1
            x_end = gap_x + 1 if is_left else gap_x
            for z in range(street_levels):
                anchors[:, 0, z, by_start:min(by_start+4, by_end), max(0,x_start):min(G,x_end)] = 1.0

        for z in range(street_levels):
            anchors[:, 0, z, :, :] *= street_mask
        state[:, ch:ch+1, :, :, :] = anchors
        return {'total_area': (anchors > 0.5).sum().item()}

    def batch(self, difficulty: str, batch_size: int, device: str) -> torch.Tensor:
        scenes = [self.generate(difficulty, device)[0] for _ in range(batch_size)]
        return torch.cat(scenes, dim=0)

print('Core components defined')

## 3. Corridor Computation

In [None]:
def find_access_centroids(access_channel: torch.Tensor, threshold: float = 0.5) -> List[Tuple[int, int, int]]:
    """Find centroids of access point regions."""
    binary = (access_channel > threshold).float()
    if binary.sum() == 0:
        return []
    centroids = []
    positions = (binary > 0).nonzero(as_tuple=False)
    if len(positions) == 0:
        return []

    used = set()
    for idx in range(len(positions)):
        if idx in used:
            continue
        pos = positions[idx]
        cluster = [pos]
        used.add(idx)
        for idx2 in range(idx + 1, len(positions)):
            if idx2 in used:
                continue
            pos2 = positions[idx2]
            dist = (pos - pos2).abs().sum().item()
            if dist <= 4:
                cluster.append(pos2)
                used.add(idx2)
        cluster = torch.stack(cluster).float()
        centroid = cluster.mean(dim=0).long()
        centroids.append((centroid[0].item(), centroid[1].item(), centroid[2].item()))
    return centroids


def compute_distance_field_3d(start_points: List[Tuple[int, int, int]],
                               legal_mask: torch.Tensor, max_iters: int = 64) -> torch.Tensor:
    """Compute distance field from start points."""
    D, H, W = legal_mask.shape
    device = legal_mask.device
    distance = torch.full((D, H, W), float('inf'), device=device)

    for z, y, x in start_points:
        if 0 <= z < D and 0 <= y < H and 0 <= x < W:
            distance[z, y, x] = 0

    for _ in range(max_iters):
        dist_4d = distance.unsqueeze(0).unsqueeze(0)
        expanded = -F.max_pool3d(-dist_4d, 3, 1, 1).squeeze(0).squeeze(0)
        expanded = expanded + 1
        new_distance = torch.where(legal_mask > 0.5, torch.min(distance, expanded), distance)
        if torch.allclose(distance, new_distance, atol=1e-5):
            break
        distance = new_distance
    return distance


def compute_corridor_target(seed_state: torch.Tensor, config: dict,
                            corridor_width: int = 4, vertical_envelope: int = 3) -> torch.Tensor:
    """Compute 3D corridor target from SEED state."""
    cfg = config
    G = cfg['grid_size']
    device = seed_state.device
    B = seed_state.shape[0]
    corridors = torch.zeros(B, G, G, G, device=device)

    for b in range(B):
        access = seed_state[b, cfg['ch_access']]
        existing = seed_state[b, cfg['ch_existing']]
        legal_mask = 1.0 - existing
        centroids = find_access_centroids(access)

        if len(centroids) < 2:
            dilated = F.max_pool3d(access.unsqueeze(0).unsqueeze(0),
                                   2*corridor_width+1, 1, corridor_width)
            corridors[b] = dilated.squeeze() * legal_mask
            continue

        corridor_mask = torch.zeros(G, G, G, device=device)
        for i, j in combinations(range(len(centroids)), 2):
            start, end = centroids[i], centroids[j]
            dist_from_start = compute_distance_field_3d([start], legal_mask)
            dist_from_end = compute_distance_field_3d([end], legal_mask)
            total_dist = dist_from_start[end[0], end[1], end[2]]
            if total_dist == float('inf'):
                continue
            path_cost = dist_from_start + dist_from_end
            slack = corridor_width * 2
            on_path = (path_cost <= total_dist + slack).float()
            corridor_mask = torch.max(corridor_mask, on_path)

        if corridor_mask.sum() > 0:
            corridor_4d = corridor_mask.unsqueeze(0).unsqueeze(0)
            dilated = F.max_pool3d(corridor_4d, 2*corridor_width+1, 1, corridor_width)
            corridor_dilated = dilated.squeeze()

            # Vertical envelope
            if vertical_envelope > 0:
                for z in range(G):
                    z_min = max(0, z - vertical_envelope)
                    z_max = min(G, z + vertical_envelope + 1)
                    if corridor_dilated[z_min:z_max].max(dim=0)[0].any():
                        corridor_dilated[z] = torch.max(
                            corridor_dilated[z],
                            corridor_dilated[z_min:z_max].max(dim=0)[0] * 0.8
                        )
            corridors[b] = corridor_dilated * legal_mask
        else:
            dilated = F.max_pool3d(access.unsqueeze(0).unsqueeze(0),
                                   2*corridor_width+1, 1, corridor_width)
            corridors[b] = dilated.squeeze() * legal_mask
    return corridors

print('Corridor computation defined')

## 4. Loss Functions (CLEAN - 12 losses only)

In [None]:
class LocalLegalityLoss(nn.Module):
    """Per-voxel legality enforcement."""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.street_levels = config['street_levels']

    def compute_legality_field(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        G = cfg['grid_size']
        existing = state[:, cfg['ch_existing']]
        anchors = state[:, cfg['ch_anchors']]
        B = state.shape[0]
        device = state.device

        z_indices = torch.arange(G, device=device).view(1, G, 1, 1).expand(B, G, G, G)
        above_street = (z_indices >= self.street_levels).float()
        at_street = (z_indices < self.street_levels).float()
        position_legality = above_street + at_street * anchors
        legality = (1 - existing) * position_legality
        return torch.clamp(legality, 0, 1)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        structure = state[:, self.config['ch_structure']]
        legality = self.compute_legality_field(state)
        illegal_structure = structure * (1 - legality)
        return illegal_structure.sum() / (structure.sum() + 1e-8)

In [None]:
class CorridorCoverageLoss(nn.Module):
    """(A) Penalize UNFILLED corridor - model must cover the target."""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config

    def forward(self, structure: torch.Tensor, corridor_target: torch.Tensor) -> torch.Tensor:
        unfilled = corridor_target * (1 - structure)
        return unfilled.sum() / (corridor_target.sum() + 1e-8)

In [None]:
class CorridorSpillLoss(nn.Module):
    """(B) Penalize structure OUTSIDE corridor."""

    def __init__(self, config: dict, absolute_weight: float = 0.01):
        super().__init__()
        self.config = config
        self.absolute_weight = absolute_weight

    def forward(self, structure: torch.Tensor, corridor_target: torch.Tensor,
                legality_field: torch.Tensor) -> torch.Tensor:
        outside = structure * (1 - corridor_target) * legality_field
        ratio_penalty = outside.sum() / (structure.sum() + 1e-8)
        G = self.config['grid_size']
        absolute_penalty = outside.sum() / (G * G * G)
        return ratio_penalty + self.absolute_weight * absolute_penalty

In [None]:
class GroundOpennessLoss(nn.Module):
    """Keep street level open except where corridor requires."""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.street_levels = config['street_levels']

    def forward(self, structure: torch.Tensor, corridor_target: torch.Tensor,
                legality_field: torch.Tensor) -> torch.Tensor:
        sl = self.street_levels
        ground_structure = structure[:, :sl, :, :]
        ground_corridor = corridor_target[:, :sl, :, :]
        ground_legal = legality_field[:, :sl, :, :]
        unnecessary_ground = ground_structure * (1 - ground_corridor) * ground_legal
        return unnecessary_ground.sum() / (ground_structure.sum() + 1e-8)

In [None]:
class ThicknessLoss(nn.Module):
    """Penalize blob-like thick structures."""

    def __init__(self, max_thickness: int = 2):
        super().__init__()
        self.max_thickness = max_thickness

    def erode_3d(self, x: torch.Tensor) -> torch.Tensor:
        x_4d = x.unsqueeze(1) if x.dim() == 4 else x.unsqueeze(0).unsqueeze(0)
        eroded = -F.max_pool3d(-x_4d, 3, 1, 1)
        return eroded.squeeze(1) if x.dim() == 4 else eroded.squeeze(0).squeeze(0)

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        struct_soft = torch.sigmoid(10 * (structure - 0.3))
        core = struct_soft
        for _ in range(self.max_thickness):
            core = self.erode_3d(core)
        return core.sum() / (struct_soft.sum() + 1e-8)

In [None]:
class SparsityLoss(nn.Module):
    """Volume limit with squared penalty."""

    def __init__(self, max_ratio: float = 0.15, min_ratio: float = 0.05):
        super().__init__()
        self.max_ratio = max_ratio
        self.min_ratio = min_ratio

    def forward(self, structure: torch.Tensor, available: torch.Tensor) -> torch.Tensor:
        ratio = structure.sum() / (available.sum() + 1e-8)
        over = F.relu(ratio - self.max_ratio)
        over_penalty = over ** 2 * 100
        under_penalty = 3.0 * F.relu(self.min_ratio - ratio)
        return over_penalty + under_penalty

In [None]:
class FacadeContactLoss(nn.Module):
    """Limit facade contact."""

    def __init__(self, config: dict, max_contact_ratio: float = 0.15):
        super().__init__()
        self.config = config
        self.max_contact_ratio = max_contact_ratio

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        existing = state[:, cfg['ch_existing']]
        dilated = F.max_pool3d(existing.unsqueeze(1), 3, 1, 1).squeeze(1)
        facade_zone = torch.clamp(dilated - existing, 0, 1)
        contact = (structure * facade_zone).sum()
        contact_ratio = contact / (structure.sum() + 1e-8)
        return F.relu(contact_ratio - self.max_contact_ratio)

In [None]:
class AccessConnectivityLoss(nn.Module):
    """Street connectivity from access points."""

    def __init__(self, config: dict, iterations: int = 32):
        super().__init__()
        self.config = config
        self.street_levels = config['street_levels']
        self.iterations = iterations

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        sl = self.street_levels
        structure = state[:, cfg['ch_structure'], :sl, :, :]
        existing = state[:, cfg['ch_existing'], :sl, :, :]
        access = state[:, cfg['ch_access'], :sl, :, :]

        void_mask = (1 - structure) * (1 - existing)
        access_seed = F.max_pool3d(access.unsqueeze(1), 3, 1, 1).squeeze(1)
        connected = access_seed * void_mask

        for _ in range(self.iterations):
            dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * void_mask)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected

        access_locations = (access > 0.5).float()
        reachable = (connected * access_locations).sum()
        return 1 - (reachable / (access_locations.sum() + 1e-8))

In [None]:
class LoadPathLoss(nn.Module):
    """Structural load path connectivity."""

    def __init__(self, config: dict, iterations: int = 32):
        super().__init__()
        self.config = config
        self.street_levels = config['street_levels']
        self.iterations = iterations

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        sl = self.street_levels
        structure = state[:, cfg['ch_structure']]
        existing = state[:, cfg['ch_existing']]
        anchors = state[:, cfg['ch_anchors']]

        support_street = torch.max(existing[:, :sl, :, :], anchors[:, :sl, :, :])
        support_above = existing[:, sl:, :, :]
        support = torch.cat([support_street, support_above], dim=1)

        connected = support.clone()
        struct_soft = torch.sigmoid(10 * (structure - 0.3))

        for _ in range(self.iterations):
            dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * struct_soft)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected

        elevated = structure[:, sl:, :, :]
        elevated_connected = connected[:, sl:, :, :]
        unsupported = elevated * (1 - elevated_connected)
        return unsupported.sum() / (elevated.sum() + 1e-8)

In [None]:
class CantileverLoss(nn.Module):
    """Limit horizontal overhangs."""

    def __init__(self, max_overhang: int = 3):
        super().__init__()
        self.max_overhang = max_overhang

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        B, D, H, W = structure.shape
        N = self.max_overhang
        total_loss = 0.0
        count = 0
        for z in range(N, D):
            layer = structure[:, z]
            support_volume = structure[:, max(0, z-N):z]
            support_max = support_volume.max(dim=1)[0]
            support_dilated = F.max_pool2d(support_max.unsqueeze(1), 3, 1, 1).squeeze(1)
            has_support = torch.sigmoid(10 * (support_dilated - 0.3))
            unsupported = layer * (1.0 - has_support)
            total_loss += unsupported.mean()
            count += 1
        return total_loss / max(1, count)


class DensityPenalty(nn.Module):
    """Binary output incentive."""
    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        return (structure * (1.0 - structure)).mean()


class TotalVariation3D(nn.Module):
    """Smoothness."""
    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        tv_d = (structure[:, 1:, :, :] - structure[:, :-1, :, :]).abs().mean()
        tv_h = (structure[:, :, 1:, :] - structure[:, :, :-1, :]).abs().mean()
        tv_w = (structure[:, :, :, 1:] - structure[:, :, :, :-1]).abs().mean()
        return tv_d + tv_h + tv_w


print('All 12 loss functions defined (CLEAN)')

## 5. Metrics

In [None]:
def compute_legality_compliance(state, config):
    legality_loss = LocalLegalityLoss(config)
    legality_field = legality_loss.compute_legality_field(state)
    structure = state[:, config['ch_structure']]
    return (structure * legality_field).sum() / (structure.sum() + 1e-8)

def compute_corridor_coverage(structure, corridor_target):
    filled = (structure * corridor_target).sum()
    return filled / (corridor_target.sum() + 1e-8)

def compute_corridor_spill(structure, corridor_target, legality_field):
    outside = structure * (1 - corridor_target) * legality_field
    return outside.sum() / (structure.sum() + 1e-8)

def compute_ground_openness(structure, corridor_target, legality_field, config):
    sl = config['street_levels']
    ground_struct = structure[:, :sl, :, :]
    ground_corr = corridor_target[:, :sl, :, :]
    in_corridor = (ground_struct * ground_corr).sum()
    return in_corridor / (ground_struct.sum() + 1e-8)

def compute_thickness_compliance(state, config):
    structure = state[:, config['ch_structure']]
    thickness_loss = ThicknessLoss(config.get('max_thickness', 2))
    return 1.0 - thickness_loss(structure).item()

def compute_facade_contact(state, config):
    structure = state[:, config['ch_structure']]
    existing = state[:, config['ch_existing']]
    dilated = F.max_pool3d(existing.unsqueeze(1), 3, 1, 1).squeeze(1)
    facade_zone = torch.clamp(dilated - existing, 0, 1)
    return (structure * facade_zone).sum() / (structure.sum() + 1e-8)

def compute_load_path_compliance(state, config):
    loadpath = LoadPathLoss(config)
    return 1.0 - loadpath(state)

def compute_access_reachability(state, config):
    access_conn = AccessConnectivityLoss(config)
    return 1.0 - access_conn(state)

def compute_fill_ratio(state, config):
    structure = state[:, config['ch_structure']]
    existing = state[:, config['ch_existing']]
    available = 1.0 - existing
    return structure.sum() / (available.sum() + 1e-8)

print('Metrics defined')

## 6. Trainer (CLEAN)

In [None]:
class CleanTrainerV6:
    """Step B v6.0 Trainer - Clean implementation."""

    def __init__(self, model: nn.Module, config: dict, device: str):
        self.model = model
        self.config = config
        self.device = device

        # 12 losses only
        self.legality_loss = LocalLegalityLoss(config)
        self.coverage_loss = CorridorCoverageLoss(config)
        self.spill_loss = CorridorSpillLoss(config)
        self.ground_loss = GroundOpennessLoss(config)
        self.thickness_loss = ThicknessLoss(config.get('max_thickness', 2))
        self.sparsity_loss = SparsityLoss()
        self.facade_loss = FacadeContactLoss(config, config.get('max_facade_contact', 0.15))
        self.access_conn_loss = AccessConnectivityLoss(config)
        self.loadpath_loss = LoadPathLoss(config)
        self.cantilever_loss = CantileverLoss()
        self.density_loss = DensityPenalty()
        self.tv_loss = TotalVariation3D()

        # CLEAN weights - balanced
        self.weights = {
            'legality': 30.0,
            'coverage': 25.0,
            'spill': 25.0,
            'ground': 15.0,
            'thickness': 20.0,
            'sparsity': 25.0,
            'facade': 10.0,
            'access_conn': 15.0,
            'loadpath': 8.0,
            'cantilever': 5.0,
            'density': 3.0,
            'tv': 1.0,
        }

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['lr_initial'])
        self.scene_gen = UrbanSceneGenerator(config)
        self.history = []

    def train_epoch(self, epoch: int) -> dict:
        self.model.train()
        cfg = self.config

        seeds = self.scene_gen.batch(cfg['difficulty'], cfg['batch_size'], self.device)

        # Corridor from SEED
        with torch.no_grad():
            corridor_target = compute_corridor_target(
                seeds, cfg,
                corridor_width=cfg.get('corridor_width', 4),
                vertical_envelope=cfg.get('vertical_envelope', 3)
            )

        steps = random.randint(cfg['steps_min'], cfg['steps_max'])
        final = self.model(seeds, steps=steps)

        structure = final[:, cfg['ch_structure']]
        existing = final[:, cfg['ch_existing']]
        available = 1.0 - existing
        legality_field = self.legality_loss.compute_legality_field(final)

        # Compute losses
        L_legality = self.legality_loss(final)
        L_coverage = self.coverage_loss(structure, corridor_target)
        L_spill = self.spill_loss(structure, corridor_target, legality_field)
        L_ground = self.ground_loss(structure, corridor_target, legality_field)
        L_thickness = self.thickness_loss(structure)
        L_sparsity = self.sparsity_loss(structure, available)
        L_facade = self.facade_loss(final)
        L_access_conn = self.access_conn_loss(final)
        L_loadpath = self.loadpath_loss(final)
        L_cant = self.cantilever_loss(structure)
        L_density = self.density_loss(structure)
        L_tv = self.tv_loss(structure)

        total_loss = (
            self.weights['legality'] * L_legality +
            self.weights['coverage'] * L_coverage +
            self.weights['spill'] * L_spill +
            self.weights['ground'] * L_ground +
            self.weights['thickness'] * L_thickness +
            self.weights['sparsity'] * L_sparsity +
            self.weights['facade'] * L_facade +
            self.weights['access_conn'] * L_access_conn +
            self.weights['loadpath'] * L_loadpath +
            self.weights['cantilever'] * L_cant +
            self.weights['density'] * L_density +
            self.weights['tv'] * L_tv
        )

        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg['grad_clip'])
        self.optimizer.step()

        with torch.no_grad():
            metrics = {
                'epoch': epoch,
                'total_loss': total_loss.item(),
                'L_coverage': L_coverage.item(),
                'L_spill': L_spill.item(),
                'L_thickness': L_thickness.item(),
                'coverage': compute_corridor_coverage(structure, corridor_target).item(),
                'spill': compute_corridor_spill(structure, corridor_target, legality_field).item(),
                'thickness': compute_thickness_compliance(final, cfg),
                'fill_ratio': compute_fill_ratio(final, cfg).item(),
                'legality': compute_legality_compliance(final, cfg).item(),
                'loadpath': compute_load_path_compliance(final, cfg).item(),
                'access_reach': compute_access_reachability(final, cfg).item(),
            }
        self.history.append(metrics)
        return metrics

    def evaluate(self, n_samples: int = 50) -> dict:
        self.model.eval()
        cfg = self.config
        results = []

        with torch.no_grad():
            for _ in range(n_samples):
                scene, _ = self.scene_gen.generate(cfg['difficulty'], self.device)
                corridor_target = compute_corridor_target(
                    scene, cfg,
                    corridor_width=cfg.get('corridor_width', 4),
                    vertical_envelope=cfg.get('vertical_envelope', 3)
                )
                grown = self.model.grow(scene, steps=50)
                structure = grown[:, cfg['ch_structure']]
                legality_field = self.legality_loss.compute_legality_field(grown)

                results.append({
                    'legality': compute_legality_compliance(grown, cfg).item(),
                    'coverage': compute_corridor_coverage(structure, corridor_target).item(),
                    'spill': compute_corridor_spill(structure, corridor_target, legality_field).item(),
                    'thickness': compute_thickness_compliance(grown, cfg),
                    'facade': compute_facade_contact(grown, cfg).item(),
                    'loadpath': compute_load_path_compliance(grown, cfg).item(),
                    'access_reach': compute_access_reachability(grown, cfg).item(),
                    'fill_ratio': compute_fill_ratio(grown, cfg).item(),
                })

        return {f'avg_{k}': np.mean([r[k] for r in results]) for k in results[0].keys()}

    def save_checkpoint(self, path: str):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config,
            'history': self.history,
        }, path)
        print(f'Saved: {path}')

print('CleanTrainerV6 defined')

## 7. Visualization

In [None]:
def visualize_result(model, scene_gen, config, device, title='Result'):
    model.eval()
    scene, _ = scene_gen.generate(config['difficulty'], device)

    with torch.no_grad():
        corridor_target = compute_corridor_target(
            scene, config,
            corridor_width=config.get('corridor_width', 4),
            vertical_envelope=config.get('vertical_envelope', 3)
        )
        grown = model.grow(scene, steps=50)

    cfg = config
    s = grown[0].cpu().numpy()
    G = s.shape[1]

    existing = s[cfg['ch_existing']] > 0.5
    access = s[cfg['ch_access']] > 0.5
    structure = s[cfg['ch_structure']] > 0.5
    corridor = corridor_target[0].cpu().numpy() > 0.5
    legality_field = LocalLegalityLoss(config).compute_legality_field(grown)[0].cpu().numpy()

    in_corridor = structure & corridor & (legality_field >= 0.5)
    outside_corridor = structure & ~corridor & (legality_field >= 0.5)

    fig = plt.figure(figsize=(16, 4))

    ax1 = fig.add_subplot(141, projection='3d')
    if existing.any():
        ax1.voxels(existing.transpose(1,2,0), facecolors='gray', alpha=0.3)
    if access.any():
        ax1.voxels(access.transpose(1,2,0), facecolors='green', alpha=0.9)
    if outside_corridor.any():
        ax1.voxels(outside_corridor.transpose(1,2,0), facecolors='orange', alpha=0.6)
    if in_corridor.any():
        ax1.voxels(in_corridor.transpose(1,2,0), facecolors='royalblue', alpha=0.6)
    ax1.set_title(f'{title}')

    ax2 = fig.add_subplot(142)
    plan = np.zeros((G,G,3))
    plan[existing[0,:,:]] = [0.5,0.5,0.5]
    plan[corridor[0,:,:] & ~existing[0,:,:]] = [0.8,0.9,1.0]
    plan[in_corridor.max(axis=0)] = [0.2,0.4,0.8]
    plan[outside_corridor.max(axis=0)] = [1.0,0.6,0.2]
    plan[access.max(axis=0)] = [0.2,0.8,0.2]
    ax2.imshow(plan.transpose(1,0,2), origin='lower')
    ax2.set_title('Ground Level')

    ax3 = fig.add_subplot(143)
    elev = np.zeros((G,G,3))
    elev[existing.max(axis=1)] = [0.5,0.5,0.5]
    elev[in_corridor.max(axis=1)] = [0.2,0.4,0.8]
    elev[outside_corridor.max(axis=1)] = [1.0,0.6,0.2]
    elev[access.max(axis=1)] = [0.2,0.8,0.2]
    ax3.imshow(elev.transpose(1,0,2), origin='lower')
    ax3.set_title('Elevation')

    ax4 = fig.add_subplot(144)
    ax4.axis('off')
    struct_t = grown[:, config['ch_structure']]
    leg_f = LocalLegalityLoss(config).compute_legality_field(grown)

    cov = compute_corridor_coverage(struct_t, corridor_target).item()
    spl = compute_corridor_spill(struct_t, corridor_target, leg_f).item()
    thk = compute_thickness_compliance(grown, config)
    fill = compute_fill_ratio(grown, config).item()
    load = compute_load_path_compliance(grown, config).item()

    text = f"""METRICS
Coverage: {cov*100:.1f}% {'PASS' if cov>0.7 else 'FAIL'}
Spill: {spl*100:.1f}% {'PASS' if spl<0.2 else 'FAIL'}
Thickness: {thk*100:.1f}% {'PASS' if thk>0.9 else 'FAIL'}
Fill: {fill*100:.1f}% {'PASS' if 0.05<fill<0.15 else 'FAIL'}
LoadPath: {load*100:.1f}% {'PASS' if load>0.95 else 'FAIL'}
"""
    ax4.text(0.1, 0.9, text, transform=ax4.transAxes, fontsize=10, verticalalignment='top', fontfamily='monospace')

    plt.tight_layout()
    plt.show()
    return grown


def plot_training_curves(history):
    epochs = [h['epoch'] for h in history]
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))

    axes[0,0].plot(epochs, [h['total_loss'] for h in history])
    axes[0,0].set_title('Total Loss')
    axes[0,0].set_yscale('log')

    axes[0,1].plot(epochs, [h['coverage'] for h in history], 'b-')
    axes[0,1].axhline(0.70, color='k', linestyle='--')
    axes[0,1].set_title('Coverage (>70%)')
    axes[0,1].set_ylim(0, 1.1)

    axes[0,2].plot(epochs, [h['spill'] for h in history], 'orange')
    axes[0,2].axhline(0.20, color='k', linestyle='--')
    axes[0,2].set_title('Spill (<20%)')
    axes[0,2].set_ylim(0, 1.1)

    axes[1,0].plot(epochs, [h['thickness'] for h in history], 'purple')
    axes[1,0].axhline(0.90, color='k', linestyle='--')
    axes[1,0].set_title('Thickness (>90%)')
    axes[1,0].set_ylim(0, 1.1)

    axes[1,1].plot(epochs, [h['fill_ratio'] for h in history], 'g-')
    axes[1,1].axhline(0.15, color='r', linestyle='--')
    axes[1,1].axhline(0.05, color='g', linestyle='--')
    axes[1,1].set_title('Fill Ratio (5-15%)')
    axes[1,1].set_ylim(0, 0.35)

    axes[1,2].plot(epochs, [h['loadpath'] for h in history], 'm-')
    axes[1,2].axhline(0.95, color='k', linestyle='--')
    axes[1,2].set_title('LoadPath (>95%)')
    axes[1,2].set_ylim(0, 1.1)

    plt.tight_layout()
    plt.show()

print('Visualization defined')

## 8. Training

In [None]:
# Initialize
model = UrbanPavilionNCA(CONFIG).to(device)
trainer = CleanTrainerV6(model, CONFIG, device)

print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f'\nv6.0 CLEAN - 12 losses only')
print(f'Weights: {trainer.weights}')

In [None]:
# Before training
visualize_result(model, trainer.scene_gen, CONFIG, device, 'Before Training')

In [None]:
# Training
print('='*60)
print('STEP B v6.0 TRAINING - CLEAN')
print('='*60)

for epoch in tqdm(range(CONFIG['epochs']), desc='Training'):
    metrics = trainer.train_epoch(epoch)

    if epoch % CONFIG['log_every'] == 0:
        tqdm.write(
            f"E{epoch:4d} | Loss:{metrics['total_loss']:.1f} | "
            f"Cov:{metrics['coverage']*100:.0f}% | "
            f"Spl:{metrics['spill']*100:.0f}% | "
            f"Thk:{metrics['thickness']*100:.0f}% | "
            f"Fill:{metrics['fill_ratio']*100:.1f}%"
        )

    if epoch > 0 and epoch % CONFIG['viz_every'] == 0:
        visualize_result(model, trainer.scene_gen, CONFIG, device, f'Epoch {epoch}')

    if epoch > 0 and epoch % CONFIG['save_every'] == 0:
        trainer.save_checkpoint(f"{PROJECT_ROOT}/step_b/checkpoints/v6_epoch_{epoch}.pth")

print('Training complete')

## 9. Evaluation

In [None]:
plot_training_curves(trainer.history)

In [None]:
# Evaluate
print('Evaluating on 50 samples...')
eval_results = trainer.evaluate(n_samples=50)

print('\n' + '='*60)
print('STEP B v6.0 EVALUATION')
print('='*60)

targets = {
    'legality': (0.99, '>'),
    'coverage': (0.70, '>'),
    'spill': (0.20, '<'),
    'thickness': (0.90, '>'),
    'facade': (0.15, '<'),
    'loadpath': (0.95, '>'),
    'access_reach': (0.90, '>'),
    'fill_ratio': (0.15, '<'),
}

for metric, (target, op) in targets.items():
    val = eval_results[f'avg_{metric}']
    if op == '>':
        passed = val > target
        print(f"{metric:15s}: {val*100:5.1f}% {'PASS' if passed else 'FAIL'} (>{target*100:.0f}%)")
    else:
        passed = val < target
        print(f"{metric:15s}: {val*100:5.1f}% {'PASS' if passed else 'FAIL'} (<{target*100:.0f}%)")

print('='*60)

In [None]:
# Final visualizations
for i in range(3):
    visualize_result(model, trainer.scene_gen, CONFIG, device, f'Final {i+1}')

In [None]:
# Save final
trainer.save_checkpoint(f"{PROJECT_ROOT}/step_b/checkpoints/v6_final.pth")
print('Done!')