# NB02: All Constraints Training v6.5

**Version:** 6.5 (2-Phase Curriculum + Extended Training)
**Date:** December 2025

---

## Why v6.3/v6.4 Had Phase Transition Problems

| Problem | Root Cause | Result |
|---------|------------|--------|
| No growth 0-350 epochs | Competing losses created "no growth" attractor | Wasted 70% of training |
| Sudden explosion at ~360 | Breaking out of attractor had no guidance | Chaotic uncontrolled growth |
| Never learned discipline | Only 140 epochs after breakout | Grew OUTSIDE corridor |

## v6.5 Solution: 2-Phase Curriculum Learning

| Phase | Epochs | Spill Weight | Goal |
|-------|--------|--------------|------|
| Growth | 0-600 | 0 → 10 | Learn to grow WITHOUT harsh penalties |
| Sculpting | 600-1500 | 10 → 50 (sqrt) | Guide existing growth INTO corridor |

**Key Insight**: First let model learn HOW to grow, THEN teach WHERE to grow.

---

## 1. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')
PROJECT_ROOT = '/content/drive/MyDrive/Constraint-NCA'
print(f'Project root: {PROJECT_ROOT}')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from typing import Dict, Tuple, List
import json
from tqdm.notebook import tqdm
from itertools import combinations

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [None]:
# Load base config
with open(f'{PROJECT_ROOT}/config_step_b.json', 'r') as f:
    CONFIG = json.load(f)

# v6.5 CONFIGURATION - 2-PHASE CURRICULUM
CONFIG.update({
    # Extended training
    'epochs': 1500,           # 3x longer than v6.3/v6.4
    'steps_min': 30,
    'steps_max': 50,
    'difficulty': 'easy',
    'log_every': 20,
    'viz_every': 150,         # Less frequent viz for longer training
    'save_every': 300,
    
    # Corridor parameters
    'corridor_width': 3,
    'max_thickness': 2,
    'max_facade_contact': 0.15,
    'vertical_envelope': 1,
    'lr_initial': 5e-4,
    
    # v6.5 CURRICULUM PARAMETERS
    'growth_phase_end': 600,      # End of growth phase
    'sculpt_phase_end': 1500,     # End of sculpting phase (= total epochs)
    'spill_weight_min': 0.0,      # Spill weight at epoch 0
    'spill_weight_growth': 10.0,  # Spill weight at end of growth phase
    'spill_weight_max': 50.0,     # Spill weight at end of sculpting phase
})

print('='*60)
print('v6.5 Configuration: 2-PHASE CURRICULUM')
print('='*60)
print(f"  epochs: {CONFIG['epochs']} (was 500)")
print(f"  growth_phase: 0-{CONFIG['growth_phase_end']} (spill 0→10)")
print(f"  sculpt_phase: {CONFIG['growth_phase_end']}-{CONFIG['sculpt_phase_end']} (spill 10→50)")
print('  Sculpting uses sqrt-weighted distance')
print('='*60)

## 2. Core Components

In [None]:
class Perceive3D(nn.Module):
    def __init__(self, n_channels: int = 8):
        super().__init__()
        self.n_channels = n_channels
        sobel_x = self._create_sobel_kernel('x')
        sobel_y = self._create_sobel_kernel('y')
        sobel_z = self._create_sobel_kernel('z')
        identity = self._create_identity_kernel()
        kernels = torch.stack([identity, sobel_x, sobel_y, sobel_z], dim=0)
        self.register_buffer('kernels', kernels)

    def _create_sobel_kernel(self, direction: str) -> torch.Tensor:
        derivative = torch.tensor([-1., 0., 1.])
        smoothing = torch.tensor([1., 2., 1.])
        if direction == 'x':
            kernel = torch.einsum('i,j,k->ijk', smoothing, smoothing, derivative)
        elif direction == 'y':
            kernel = torch.einsum('i,j,k->ijk', smoothing, derivative, smoothing)
        elif direction == 'z':
            kernel = torch.einsum('i,j,k->ijk', derivative, smoothing, smoothing)
        return kernel / 16.0

    def _create_identity_kernel(self) -> torch.Tensor:
        kernel = torch.zeros(3, 3, 3)
        kernel[1, 1, 1] = 1.0
        return kernel

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = x.shape
        x_padded = F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate')
        outputs = []
        for k in range(4):
            kernel = self.kernels[k:k+1].unsqueeze(0).expand(C, 1, 3, 3, 3)
            out = F.conv3d(x_padded, kernel, padding=0, groups=C)
            outputs.append(out)
        return torch.cat(outputs, dim=1)

In [None]:
class UrbanPavilionNCA(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        n_channels = config['n_channels']
        hidden_dim = config['hidden_dim']
        perception_dim = n_channels * 4
        n_grown = config['n_grown']

        self.perceive = Perceive3D(n_channels)
        self.update_net = nn.Sequential(
            nn.Conv3d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, n_grown, 1),
        )
        self._init_weights()

    def _init_weights(self):
        gain = self.config['xavier_gain']
        for m in self.update_net:
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform_(m.weight, gain=gain)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        last_layer = self.update_net[-1]
        with torch.no_grad():
            last_layer.bias[0] = self.config['structure_bias']
            last_layer.bias[1] = self.config['surface_bias']

    def forward(self, state: torch.Tensor, steps: int = 1) -> torch.Tensor:
        for _ in range(steps):
            state = self._step(state)
        return state

    def _step(self, state: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = state.shape
        cfg = self.config
        perception = self.perceive(state)
        delta = self.update_net(perception)

        if self.training:
            fire_mask = (torch.rand(B, 1, D, H, W, device=state.device) < cfg['fire_rate']).float()
            delta = delta * fire_mask

        grown_start = cfg['n_frozen']
        grown_new = state[:, grown_start:] + cfg['update_scale'] * delta
        grown_new = torch.clamp(grown_new, 0.0, 1.0)

        existing = state[:, cfg['ch_existing']:cfg['ch_existing']+1]
        available_mask = 1.0 - existing
        struct_new = grown_new[:, 0:1] * available_mask
        grown_masked = torch.cat([struct_new, grown_new[:, 1:]], dim=1)
        return torch.cat([state[:, :grown_start], grown_masked], dim=1)

    def grow(self, seed: torch.Tensor, steps: int = 50) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            return self.forward(seed, steps)

In [None]:
class UrbanSceneGenerator:
    def __init__(self, config: dict):
        self.config = config
        self.G = config['grid_size']
        self.C = config['n_channels']

    def generate(self, difficulty: str = 'easy', device: str = 'cuda') -> Tuple[torch.Tensor, dict]:
        G = self.G
        cfg = self.config
        state = torch.zeros(1, self.C, G, G, G, device=device)
        state[:, cfg['ch_ground'], 0, :, :] = 1.0

        params = self._get_difficulty_params(difficulty)
        building_info = self._place_buildings(state, params)
        access_info = self._place_access_points(state, params, building_info)
        self._generate_anchor_zones(state, params, building_info, access_info)
        return state, {'difficulty': difficulty, 'buildings': building_info, 'access_points': access_info}

    def _get_difficulty_params(self, difficulty: str) -> dict:
        if difficulty == 'easy':
            return {'n_buildings': 2, 'height_range': (12, 16), 'height_variance': False,
                    'width_range': (8, 12), 'gap_width': random.randint(14, 18),
                    'n_ground_access': 1, 'n_elevated_access': 1, 'anchor_budget': 0.10}
        elif difficulty == 'medium':
            return {'n_buildings': 2, 'height_range': (10, 20), 'height_variance': True,
                    'width_range': (6, 10), 'gap_width': random.randint(10, 14),
                    'n_ground_access': random.randint(1, 2), 'n_elevated_access': random.randint(1, 2),
                    'anchor_budget': 0.07}
        else:
            return {'n_buildings': random.randint(2, 4), 'height_range': (8, 24), 'height_variance': True,
                    'width_range': (5, 8), 'gap_width': random.randint(6, 10),
                    'n_ground_access': random.randint(2, 3), 'n_elevated_access': random.randint(2, 3),
                    'anchor_budget': 0.05}

    def _place_buildings(self, state: torch.Tensor, params: dict) -> list:
        G = self.G
        ch = self.config['ch_existing']
        buildings = []
        gap_width = params['gap_width']
        gap_center = G // 2

        w1 = random.randint(*params['width_range'])
        d1 = random.randint(G//2, G-2)
        h1 = random.randint(*params['height_range'])
        x1_end = gap_center - gap_width // 2
        x1_start = max(0, x1_end - w1)
        state[:, ch, :h1, :d1, x1_start:x1_end] = 1.0
        buildings.append({'x': (x1_start, x1_end), 'y': (0, d1), 'z': (0, h1), 'gap_facing_x': x1_end, 'side': 'left'})

        w2 = random.randint(*params['width_range'])
        d2 = random.randint(G//2, G-2)
        h2 = h1 if not params['height_variance'] else random.randint(*params['height_range'])
        x2_start = gap_center + gap_width // 2
        x2_end = min(G, x2_start + w2)
        state[:, ch, :h2, :d2, x2_start:x2_end] = 1.0
        buildings.append({'x': (x2_start, x2_end), 'y': (0, d2), 'z': (0, h2), 'gap_facing_x': x2_start, 'side': 'right'})
        return buildings

    def _place_access_points(self, state: torch.Tensor, params: dict, buildings: list) -> list:
        G = self.G
        ch = self.config['ch_access']
        access_points = []
        left_buildings = [b for b in buildings if b['side'] == 'left']
        right_buildings = [b for b in buildings if b['side'] == 'right']
        gap_x_min = max(b['gap_facing_x'] for b in left_buildings) if left_buildings else 0
        gap_x_max = min(b['gap_facing_x'] for b in right_buildings) if right_buildings else G

        for _ in range(params.get('n_ground_access', 1)):
            x = random.randint(gap_x_min + 1, gap_x_max - 3)
            y = random.randint(0, G - 3)
            state[:, ch, 0:2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': 0, 'type': 'ground'})

        for _ in range(params.get('n_elevated_access', 1)):
            building = random.choice(buildings)
            bz_max = building['z'][1]
            is_left = building['side'] == 'left'
            z = random.randint(3, max(4, bz_max - 2))
            y = random.randint(building['y'][0], min(building['y'][1] - 2, building['y'][0] + G//3))
            x = building['x'][1] if is_left else building['x'][0] - 2
            x = max(0, min(G - 2, x))
            state[:, ch, z:z+2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': z, 'type': 'elevated'})
        return access_points

    def _generate_anchor_zones(self, state: torch.Tensor, params: dict, buildings: list, access_points: list):
        G = self.G
        ch = self.config['ch_anchors']
        sl = self.config['street_levels']
        existing = state[:, self.config['ch_existing'], 0, :, :]
        street_mask = 1.0 - existing
        anchors = torch.zeros(1, 1, G, G, G, device=state.device)

        for ap in access_points:
            if ap['type'] == 'ground':
                x, y = ap['x'], ap['y']
                for z in range(sl):
                    anchors[:, 0, z, max(0,y-2):min(G,y+4), max(0,x-2):min(G,x+4)] = 1.0

        for building in buildings:
            by_start, by_end = building['y']
            gap_x = building['gap_facing_x']
            is_left = building['side'] == 'left'
            x_start = gap_x if is_left else gap_x - 1
            x_end = gap_x + 1 if is_left else gap_x
            for z in range(sl):
                anchors[:, 0, z, by_start:min(by_start+4, by_end), max(0,x_start):min(G,x_end)] = 1.0

        for z in range(sl):
            anchors[:, 0, z, :, :] *= street_mask
        state[:, ch:ch+1, :, :, :] = anchors

    def batch(self, difficulty: str, batch_size: int, device: str) -> torch.Tensor:
        return torch.cat([self.generate(difficulty, device)[0] for _ in range(batch_size)], dim=0)

print('Core components defined')

## 3. Corridor + SDF Computation

In [None]:
def find_access_centroids(access_channel: torch.Tensor) -> List[Tuple[int, int, int]]:
    binary = (access_channel > 0.5).float()
    if binary.sum() == 0:
        return []
    centroids = []
    positions = (binary > 0).nonzero(as_tuple=False)
    if len(positions) == 0:
        return []

    used = set()
    for idx in range(len(positions)):
        if idx in used:
            continue
        pos = positions[idx]
        cluster = [pos]
        used.add(idx)
        for idx2 in range(idx + 1, len(positions)):
            if idx2 in used:
                continue
            if (pos - positions[idx2]).abs().sum().item() <= 4:
                cluster.append(positions[idx2])
                used.add(idx2)
        cluster = torch.stack(cluster).float()
        centroid = cluster.mean(dim=0).long()
        centroids.append((centroid[0].item(), centroid[1].item(), centroid[2].item()))
    return centroids


def compute_distance_field_3d(start_points: List[Tuple[int, int, int]],
                               legal_mask: torch.Tensor, max_iters: int = 64) -> torch.Tensor:
    D, H, W = legal_mask.shape
    device = legal_mask.device
    distance = torch.full((D, H, W), float('inf'), device=device)

    for z, y, x in start_points:
        if 0 <= z < D and 0 <= y < H and 0 <= x < W:
            distance[z, y, x] = 0

    for _ in range(max_iters):
        dist_4d = distance.unsqueeze(0).unsqueeze(0)
        expanded = -F.max_pool3d(-dist_4d, 3, 1, 1).squeeze(0).squeeze(0) + 1
        new_distance = torch.where(legal_mask > 0.5, torch.min(distance, expanded), distance)
        if torch.allclose(distance, new_distance, atol=1e-5):
            break
        distance = new_distance
    return distance

In [None]:
def compute_corridor_and_sdf(seed_state: torch.Tensor, config: dict,
                             corridor_width: int = 3, vertical_envelope: int = 1) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute corridor target AND signed distance field."""
    cfg = config
    G = cfg['grid_size']
    device = seed_state.device
    B = seed_state.shape[0]
    corridors = torch.zeros(B, G, G, G, device=device)
    sdfs = torch.zeros(B, G, G, G, device=device)

    for b in range(B):
        access = seed_state[b, cfg['ch_access']]
        existing = seed_state[b, cfg['ch_existing']]
        legal_mask = 1.0 - existing
        centroids = find_access_centroids(access)

        if len(centroids) < 2:
            dilated = F.max_pool3d(access.unsqueeze(0).unsqueeze(0),
                                   2*corridor_width+1, 1, corridor_width)
            corridor = dilated.squeeze() * legal_mask
            corridors[b] = corridor
            corridor_points = (corridor > 0.5).nonzero(as_tuple=False).tolist()
            if corridor_points:
                corridor_points = [(p[0], p[1], p[2]) for p in corridor_points]
                dist_outside = compute_distance_field_3d(corridor_points, legal_mask)
                sdfs[b] = dist_outside * (1 - corridor) - corridor
            continue

        corridor_mask = torch.zeros(G, G, G, device=device)
        for i, j in combinations(range(len(centroids)), 2):
            start, end = centroids[i], centroids[j]
            dist_from_start = compute_distance_field_3d([start], legal_mask)
            dist_from_end = compute_distance_field_3d([end], legal_mask)
            total_dist = dist_from_start[end[0], end[1], end[2]]
            if total_dist == float('inf'):
                continue
            path_cost = dist_from_start + dist_from_end
            slack = corridor_width
            on_path = (path_cost <= total_dist + slack).float()
            corridor_mask = torch.max(corridor_mask, on_path)

        if corridor_mask.sum() > 0:
            corridor_4d = corridor_mask.unsqueeze(0).unsqueeze(0)
            dilated = F.max_pool3d(corridor_4d, 2*corridor_width+1, 1, corridor_width)
            corridor_dilated = dilated.squeeze()

            if vertical_envelope > 0:
                for z in range(G):
                    z_min = max(0, z - vertical_envelope)
                    z_max = min(G, z + vertical_envelope + 1)
                    local_max = corridor_dilated[z_min:z_max].max(dim=0)[0]
                    if local_max.any():
                        corridor_dilated[z] = torch.max(corridor_dilated[z], local_max * 0.5)

            corridor = corridor_dilated * legal_mask
            corridors[b] = corridor

            corridor_points = (corridor > 0.5).nonzero(as_tuple=False).tolist()
            corridor_points = [(p[0], p[1], p[2]) for p in corridor_points]
            
            if corridor_points:
                dist_to_corridor = compute_distance_field_3d(corridor_points, legal_mask)
                dist_to_corridor = torch.clamp(dist_to_corridor, 0, G)
                sdf = dist_to_corridor * (1 - corridor) - corridor * 1.0
                sdfs[b] = sdf
        else:
            dilated = F.max_pool3d(access.unsqueeze(0).unsqueeze(0),
                                   2*corridor_width+1, 1, corridor_width)
            corridors[b] = dilated.squeeze() * legal_mask

    return corridors, sdfs

print('Corridor + SDF computation defined')

## 4. Loss Functions (v6.5 with Progressive Spill)

In [None]:
class LocalLegalityLoss(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.street_levels = config['street_levels']

    def compute_legality_field(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        G = cfg['grid_size']
        existing = state[:, cfg['ch_existing']]
        anchors = state[:, cfg['ch_anchors']]
        B = state.shape[0]
        device = state.device

        z_indices = torch.arange(G, device=device).view(1, G, 1, 1).expand(B, G, G, G)
        above_street = (z_indices >= self.street_levels).float()
        at_street = (z_indices < self.street_levels).float()
        position_legality = above_street + at_street * anchors
        return torch.clamp((1 - existing) * position_legality, 0, 1)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        structure = state[:, self.config['ch_structure']]
        legality = self.compute_legality_field(state)
        illegal = structure * (1 - legality)
        return illegal.sum() / (structure.sum() + 1e-8)

In [None]:
class CorridorCoverageLoss(nn.Module):
    """(A) Fill the corridor - POSITIVE incentive to grow."""
    def __init__(self, config: dict):
        super().__init__()
        self.config = config

    def forward(self, structure: torch.Tensor, corridor_target: torch.Tensor) -> torch.Tensor:
        unfilled = corridor_target * (1 - structure)
        return unfilled.sum() / (corridor_target.sum() + 1e-8)

In [None]:
class ProgressiveSpillLoss(nn.Module):
    """
    v6.5: Progressive SDF-based spill loss.
    
    - Growth phase: use_sqrt=False, linear distance / 5
    - Sculpting phase: use_sqrt=True, sqrt(distance) for harsher penalty
    
    This allows growth to happen first, then guides it.
    """
    def __init__(self, config: dict):
        super().__init__()
        self.config = config

    def forward(self, structure: torch.Tensor, sdf: torch.Tensor,
                legality_field: torch.Tensor, use_sqrt: bool = False) -> torch.Tensor:
        # SDF > 0 means outside corridor
        distance_outside = F.relu(sdf)
        
        if use_sqrt:
            # Sculpting phase: sqrt weighting (harsher on far voxels)
            weighted_distance = torch.sqrt(distance_outside.clamp(min=1.0) / 5.0)
        else:
            # Growth phase: gentle linear weighting
            weighted_distance = distance_outside / 5.0 + 1.0
        
        # Penalty = structure * distance_weight * legality
        spill_penalty = structure * weighted_distance * legality_field * (sdf > 0).float()
        
        return spill_penalty.sum() / (structure.sum() + 1e-8)


print('ProgressiveSpillLoss: uses sqrt in sculpting phase for harsher gradient')

In [None]:
class GroundOpennessLoss(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.sl = config['street_levels']

    def forward(self, structure: torch.Tensor, corridor_target: torch.Tensor,
                legality_field: torch.Tensor) -> torch.Tensor:
        ground_struct = structure[:, :self.sl, :, :]
        ground_corr = corridor_target[:, :self.sl, :, :]
        ground_legal = legality_field[:, :self.sl, :, :]
        unnecessary = ground_struct * (1 - ground_corr) * ground_legal
        return unnecessary.sum() / (ground_struct.sum() + 1e-8)

In [None]:
class ThicknessLoss(nn.Module):
    def __init__(self, max_thickness: int = 2):
        super().__init__()
        self.max_thickness = max_thickness

    def erode_3d(self, x: torch.Tensor) -> torch.Tensor:
        x_4d = x.unsqueeze(1) if x.dim() == 4 else x.unsqueeze(0).unsqueeze(0)
        eroded = -F.max_pool3d(-x_4d, 3, 1, 1)
        return eroded.squeeze(1) if x.dim() == 4 else eroded.squeeze(0).squeeze(0)

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        soft = torch.sigmoid(10 * (structure - 0.3))
        core = soft
        for _ in range(self.max_thickness):
            core = self.erode_3d(core)
        return core.sum() / (soft.sum() + 1e-8)

In [None]:
class SparsityLoss(nn.Module):
    """Encourages structure to stay within fill ratio bounds."""

    def __init__(self, max_ratio: float = 0.15, min_ratio: float = 0.05):
        super().__init__()
        self.max_ratio = max_ratio
        self.min_ratio = min_ratio

    def forward(self, structure: torch.Tensor, available: torch.Tensor) -> torch.Tensor:
        ratio = structure.sum() / (available.sum() + 1e-8)
        
        # Over-fill penalty
        over = F.relu(ratio - self.max_ratio)
        over_penalty = over * 50
        
        # Under-fill penalty (force minimum growth)
        under = F.relu(self.min_ratio - ratio)
        under_penalty = under * 20.0
        
        return over_penalty + under_penalty

In [None]:
class FacadeContactLoss(nn.Module):
    def __init__(self, config: dict, max_contact: float = 0.15):
        super().__init__()
        self.config = config
        self.max_contact = max_contact

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        existing = state[:, cfg['ch_existing']]
        dilated = F.max_pool3d(existing.unsqueeze(1), 3, 1, 1).squeeze(1)
        facade = torch.clamp(dilated - existing, 0, 1)
        contact = (structure * facade).sum() / (structure.sum() + 1e-8)
        return F.relu(contact - self.max_contact)

In [None]:
class AccessConnectivityLoss(nn.Module):
    def __init__(self, config: dict, iterations: int = 32):
        super().__init__()
        self.config = config
        self.sl = config['street_levels']
        self.iterations = iterations

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure'], :self.sl, :, :]
        existing = state[:, cfg['ch_existing'], :self.sl, :, :]
        access = state[:, cfg['ch_access'], :self.sl, :, :]

        void = (1 - structure) * (1 - existing)
        seed = F.max_pool3d(access.unsqueeze(1), 3, 1, 1).squeeze(1)
        connected = seed * void

        for _ in range(self.iterations):
            dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * void)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected

        access_locs = (access > 0.5).float()
        reachable = (connected * access_locs).sum()
        return 1 - (reachable / (access_locs.sum() + 1e-8))

In [None]:
class LoadPathLoss(nn.Module):
    def __init__(self, config: dict, iterations: int = 32):
        super().__init__()
        self.config = config
        self.sl = config['street_levels']
        self.iterations = iterations

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        existing = state[:, cfg['ch_existing']]
        anchors = state[:, cfg['ch_anchors']]

        support_street = torch.max(existing[:, :self.sl], anchors[:, :self.sl])
        support = torch.cat([support_street, existing[:, self.sl:]], dim=1)

        connected = support.clone()
        soft = torch.sigmoid(10 * (structure - 0.3))

        for _ in range(self.iterations):
            dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * soft)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected

        elevated = structure[:, self.sl:]
        unsupported = elevated * (1 - connected[:, self.sl:])
        return unsupported.sum() / (elevated.sum() + 1e-8)

In [None]:
class CantileverLoss(nn.Module):
    def __init__(self, max_overhang: int = 3):
        super().__init__()
        self.N = max_overhang

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        B, D, H, W = structure.shape
        total = 0.0
        count = 0
        for z in range(self.N, D):
            layer = structure[:, z]
            support = structure[:, max(0, z-self.N):z].max(dim=1)[0]
            support_d = F.max_pool2d(support.unsqueeze(1), 3, 1, 1).squeeze(1)
            has_support = torch.sigmoid(10 * (support_d - 0.3))
            unsupported = layer * (1 - has_support)
            total += unsupported.mean()
            count += 1
        return total / max(1, count)


class DensityPenalty(nn.Module):
    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        return (structure * (1 - structure)).mean()


class TotalVariation3D(nn.Module):
    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        return (structure[:, 1:] - structure[:, :-1]).abs().mean() + \
               (structure[:, :, 1:] - structure[:, :, :-1]).abs().mean() + \
               (structure[:, :, :, 1:] - structure[:, :, :, :-1]).abs().mean()


print('All losses defined')

## 5. Metrics

In [None]:
def compute_legality(state, config):
    leg = LocalLegalityLoss(config)
    field = leg.compute_legality_field(state)
    struct = state[:, config['ch_structure']]
    return (struct * field).sum() / (struct.sum() + 1e-8)

def compute_coverage(struct, corridor):
    return (struct * corridor).sum() / (corridor.sum() + 1e-8)

def compute_spill(struct, corridor, legality):
    outside = struct * (1 - corridor) * legality
    return outside.sum() / (struct.sum() + 1e-8)

def compute_thickness(state, config):
    struct = state[:, config['ch_structure']]
    loss = ThicknessLoss(config.get('max_thickness', 2))
    return 1.0 - loss(struct).item()

def compute_fill(state, config):
    struct = state[:, config['ch_structure']]
    existing = state[:, config['ch_existing']]
    available = 1.0 - existing
    return struct.sum() / (available.sum() + 1e-8)

def compute_loadpath(state, config):
    return 1.0 - LoadPathLoss(config)(state)

def compute_access_reach(state, config):
    return 1.0 - AccessConnectivityLoss(config)(state)

def compute_facade(state, config):
    struct = state[:, config['ch_structure']]
    existing = state[:, config['ch_existing']]
    dilated = F.max_pool3d(existing.unsqueeze(1), 3, 1, 1).squeeze(1)
    facade = torch.clamp(dilated - existing, 0, 1)
    return (struct * facade).sum() / (struct.sum() + 1e-8)

print('Metrics defined')

## 6. Trainer v6.5 (2-Phase Curriculum)

In [None]:
class TrainerV65:
    """
    v6.5 Trainer with 2-PHASE CURRICULUM LEARNING.
    
    Phase 1 (Growth): epochs 0-600
      - Spill weight: 0 -> 10 (gentle, allow exploration)
      - Focus: Learn HOW to grow structure
    
    Phase 2 (Sculpting): epochs 600-1500
      - Spill weight: 10 -> 50 with sqrt distance
      - Focus: Guide growth INTO corridor
    """

    def __init__(self, model: nn.Module, config: dict, device: str):
        self.model = model
        self.config = config
        self.device = device

        self.legality_loss = LocalLegalityLoss(config)
        self.coverage_loss = CorridorCoverageLoss(config)
        self.spill_loss = ProgressiveSpillLoss(config)
        self.ground_loss = GroundOpennessLoss(config)
        self.thickness_loss = ThicknessLoss(config.get('max_thickness', 2))
        self.sparsity_loss = SparsityLoss()
        self.facade_loss = FacadeContactLoss(config)
        self.access_loss = AccessConnectivityLoss(config)
        self.loadpath_loss = LoadPathLoss(config)
        self.cantilever_loss = CantileverLoss()
        self.density_loss = DensityPenalty()
        self.tv_loss = TotalVariation3D()

        # Base weights (spill weight is dynamic)
        self.weights = {
            'legality': 30.0,
            'coverage': 45.0,
            'ground': 10.0,
            'thickness': 15.0,
            'sparsity': 30.0,
            'facade': 8.0,
            'access': 15.0,
            'loadpath': 8.0,
            'cantilever': 5.0,
            'density': 3.0,
            'tv': 1.0,
        }

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['lr_initial'])
        self.scene_gen = UrbanSceneGenerator(config)
        self.history = []

    def get_spill_weight(self, epoch: int) -> float:
        """
        Progressive spill weight based on training phase.
        
        Growth phase (0 -> growth_phase_end):
            weight: spill_weight_min -> spill_weight_growth
        
        Sculpting phase (growth_phase_end -> sculpt_phase_end):
            weight: spill_weight_growth -> spill_weight_max
        """
        cfg = self.config
        g_end = cfg['growth_phase_end']
        s_end = cfg['sculpt_phase_end']
        w_min = cfg['spill_weight_min']
        w_growth = cfg['spill_weight_growth']
        w_max = cfg['spill_weight_max']
        
        if epoch < g_end:
            # Growth phase: 0 -> 10
            progress = epoch / g_end
            return w_min + (w_growth - w_min) * progress
        else:
            # Sculpting phase: 10 -> 50
            progress = min(1.0, (epoch - g_end) / (s_end - g_end))
            return w_growth + (w_max - w_growth) * progress

    def is_sculpting_phase(self, epoch: int) -> bool:
        """Returns True if we're in the sculpting phase."""
        return epoch >= self.config['growth_phase_end']

    def train_epoch(self, epoch: int) -> dict:
        self.model.train()
        cfg = self.config

        seeds = self.scene_gen.batch(cfg['difficulty'], cfg['batch_size'], self.device)

        with torch.no_grad():
            corridor, sdf = compute_corridor_and_sdf(
                seeds, cfg,
                corridor_width=cfg.get('corridor_width', 3),
                vertical_envelope=cfg.get('vertical_envelope', 1)
            )

        steps = random.randint(cfg['steps_min'], cfg['steps_max'])
        final = self.model(seeds, steps=steps)

        struct = final[:, cfg['ch_structure']]
        existing = final[:, cfg['ch_existing']]
        available = 1.0 - existing
        legality_field = self.legality_loss.compute_legality_field(final)

        # Get current spill weight and phase
        spill_weight = self.get_spill_weight(epoch)
        use_sqrt = self.is_sculpting_phase(epoch)

        L = {
            'legality': self.legality_loss(final),
            'coverage': self.coverage_loss(struct, corridor),
            'spill': self.spill_loss(struct, sdf, legality_field, use_sqrt=use_sqrt),
            'ground': self.ground_loss(struct, corridor, legality_field),
            'thickness': self.thickness_loss(struct),
            'sparsity': self.sparsity_loss(struct, available),
            'facade': self.facade_loss(final),
            'access': self.access_loss(final),
            'loadpath': self.loadpath_loss(final),
            'cantilever': self.cantilever_loss(struct),
            'density': self.density_loss(struct),
            'tv': self.tv_loss(struct),
        }

        # Apply weights (spill weight is dynamic)
        total = spill_weight * L['spill']
        for k, v in L.items():
            if k != 'spill':
                total += self.weights[k] * v

        self.optimizer.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg['grad_clip'])
        self.optimizer.step()

        with torch.no_grad():
            corridor_vol = corridor.sum().item() / (cfg['grid_size']**3 * cfg['batch_size'])
            phase = 'SCULPT' if use_sqrt else 'GROWTH'
            metrics = {
                'epoch': epoch,
                'total_loss': total.item(),
                'coverage': compute_coverage(struct, corridor).item(),
                'spill': compute_spill(struct, corridor, legality_field).item(),
                'spill_weight': spill_weight,
                'spill_loss': L['spill'].item(),
                'thickness': compute_thickness(final, cfg),
                'fill_ratio': compute_fill(final, cfg).item(),
                'legality': compute_legality(final, cfg).item(),
                'loadpath': compute_loadpath(final, cfg).item(),
                'corridor_vol': corridor_vol,
                'phase': phase,
            }
        self.history.append(metrics)
        return metrics

    def evaluate(self, n_samples: int = 50) -> dict:
        self.model.eval()
        cfg = self.config
        results = []

        with torch.no_grad():
            for _ in range(n_samples):
                scene, _ = self.scene_gen.generate(cfg['difficulty'], self.device)
                corridor, _ = compute_corridor_and_sdf(
                    scene, cfg,
                    corridor_width=cfg.get('corridor_width', 3),
                    vertical_envelope=cfg.get('vertical_envelope', 1)
                )
                grown = self.model.grow(scene, steps=50)
                struct = grown[:, cfg['ch_structure']]
                leg_field = self.legality_loss.compute_legality_field(grown)

                results.append({
                    'legality': compute_legality(grown, cfg).item(),
                    'coverage': compute_coverage(struct, corridor).item(),
                    'spill': compute_spill(struct, corridor, leg_field).item(),
                    'thickness': compute_thickness(grown, cfg),
                    'facade': compute_facade(grown, cfg).item(),
                    'loadpath': compute_loadpath(grown, cfg).item(),
                    'access': compute_access_reach(grown, cfg).item(),
                    'fill_ratio': compute_fill(grown, cfg).item(),
                })

        return {f'avg_{k}': np.mean([r[k] for r in results]) for k in results[0].keys()}

    def save(self, path: str):
        torch.save({'model': self.model.state_dict(), 'history': self.history}, path)
        print(f'Saved: {path}')


print('TrainerV65 defined (2-Phase Curriculum)')
print('  Growth phase: epochs 0-600 (spill 0->10, linear distance)')
print('  Sculpting phase: epochs 600-1500 (spill 10->50, sqrt distance)')

## 7. Visualization

In [None]:
def visualize(model, scene_gen, config, device, title=''):
    model.eval()
    scene, _ = scene_gen.generate(config['difficulty'], device)

    with torch.no_grad():
        corridor, sdf = compute_corridor_and_sdf(
            scene, config,
            corridor_width=config.get('corridor_width', 3),
            vertical_envelope=config.get('vertical_envelope', 1)
        )
        grown = model.grow(scene, steps=50)

    cfg = config
    s = grown[0].cpu().numpy()
    G = s.shape[1]

    existing = s[cfg['ch_existing']] > 0.5
    access = s[cfg['ch_access']] > 0.5
    structure = s[cfg['ch_structure']] > 0.5
    corr = corridor[0].cpu().numpy() > 0.5
    leg = LocalLegalityLoss(config).compute_legality_field(grown)[0].cpu().numpy()

    in_corr = structure & corr & (leg >= 0.5)
    outside = structure & ~corr & (leg >= 0.5)

    fig = plt.figure(figsize=(16, 4))

    ax1 = fig.add_subplot(141, projection='3d')
    if existing.any(): ax1.voxels(existing.transpose(1,2,0), facecolors='gray', alpha=0.3)
    if access.any(): ax1.voxels(access.transpose(1,2,0), facecolors='green', alpha=0.9)
    if outside.any(): ax1.voxels(outside.transpose(1,2,0), facecolors='orange', alpha=0.6)
    if in_corr.any(): ax1.voxels(in_corr.transpose(1,2,0), facecolors='royalblue', alpha=0.6)
    ax1.set_title(title)

    ax2 = fig.add_subplot(142)
    plan = np.zeros((G,G,3))
    plan[existing[0]] = [0.5,0.5,0.5]
    plan[corr[0] & ~existing[0]] = [0.8,0.9,1.0]
    plan[in_corr.max(axis=0)] = [0.2,0.4,0.8]
    plan[outside.max(axis=0)] = [1.0,0.6,0.2]
    plan[access.max(axis=0)] = [0.2,0.8,0.2]
    ax2.imshow(plan.transpose(1,0,2), origin='lower')
    ax2.set_title('Ground')

    ax3 = fig.add_subplot(143)
    elev = np.zeros((G,G,3))
    elev[existing.max(axis=1)] = [0.5,0.5,0.5]
    elev[in_corr.max(axis=1)] = [0.2,0.4,0.8]
    elev[outside.max(axis=1)] = [1.0,0.6,0.2]
    elev[access.max(axis=1)] = [0.2,0.8,0.2]
    ax3.imshow(elev.transpose(1,0,2), origin='lower')
    ax3.set_title('Elevation')

    ax4 = fig.add_subplot(144)
    ax4.axis('off')
    struct = grown[:, config['ch_structure']]
    leg_f = LocalLegalityLoss(config).compute_legality_field(grown)

    cov = compute_coverage(struct, corridor).item()
    spl = compute_spill(struct, corridor, leg_f).item()
    thk = compute_thickness(grown, config)
    fill = compute_fill(grown, config).item()
    corr_vol = corridor.sum().item() / (G**3)

    txt = f"""METRICS (v6.5)
Coverage: {cov*100:.1f}% {'OK' if cov>0.7 else 'LOW'}
Spill: {spl*100:.1f}% {'OK' if spl<0.2 else 'HIGH'}
Thickness: {thk*100:.1f}% {'OK' if thk>0.9 else 'LOW'}
Fill: {fill*100:.1f}% {'OK' if 0.05<fill<0.15 else 'BAD'}
---
Corridor Vol: {corr_vol*100:.1f}%
"""
    ax4.text(0.1, 0.9, txt, transform=ax4.transAxes, fontsize=10, va='top', family='monospace')

    plt.tight_layout()
    plt.show()
    return grown


def plot_curves(history):
    epochs = [h['epoch'] for h in history]
    fig, axes = plt.subplots(2, 4, figsize=(20, 8))

    # Row 1
    axes[0,0].plot(epochs, [h['total_loss'] for h in history])
    axes[0,0].set_title('Total Loss')
    axes[0,0].set_yscale('log')

    axes[0,1].plot(epochs, [h['coverage'] for h in history], 'b-')
    axes[0,1].axhline(0.70, color='k', ls='--')
    axes[0,1].set_title('Coverage (>70%)')
    axes[0,1].set_ylim(0, 1.1)

    axes[0,2].plot(epochs, [h['spill'] for h in history], 'orange')
    axes[0,2].axhline(0.20, color='k', ls='--')
    axes[0,2].set_title('Spill (<20%)')
    axes[0,2].set_ylim(0, 1.1)

    axes[0,3].plot(epochs, [h['spill_weight'] for h in history], 'r-')
    axes[0,3].axvline(600, color='g', ls='--', label='Phase change')
    axes[0,3].set_title('Spill Weight (progressive)')
    axes[0,3].legend()

    # Row 2
    axes[1,0].plot(epochs, [h['thickness'] for h in history], 'purple')
    axes[1,0].axhline(0.90, color='k', ls='--')
    axes[1,0].set_title('Thickness (>90%)')
    axes[1,0].set_ylim(0, 1.1)

    axes[1,1].plot(epochs, [h['fill_ratio'] for h in history], 'g-')
    axes[1,1].axhline(0.15, color='r', ls='--')
    axes[1,1].axhline(0.05, color='g', ls='--')
    axes[1,1].set_title('Fill (5-15%)')
    axes[1,1].set_ylim(0, 0.3)

    axes[1,2].plot(epochs, [h['spill_loss'] for h in history], 'r-')
    axes[1,2].set_title('Spill Loss (raw)')

    axes[1,3].plot(epochs, [h['legality'] for h in history], 'k-')
    axes[1,3].axhline(0.99, color='g', ls='--')
    axes[1,3].set_title('Legality (>99%)')
    axes[1,3].set_ylim(0.9, 1.01)

    plt.tight_layout()
    plt.show()

print('Visualization defined')

## 8. Training

In [None]:
model = UrbanPavilionNCA(CONFIG).to(device)
trainer = TrainerV65(model, CONFIG, device)

print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')
print()
print('v6.5 2-PHASE CURRICULUM:')
print(f"  Total epochs: {CONFIG['epochs']}")
print(f"  Growth phase: 0-{CONFIG['growth_phase_end']} (spill {CONFIG['spill_weight_min']}->{CONFIG['spill_weight_growth']})")
print(f"  Sculpting phase: {CONFIG['growth_phase_end']}-{CONFIG['sculpt_phase_end']} (spill {CONFIG['spill_weight_growth']}->{CONFIG['spill_weight_max']}, sqrt)")

In [None]:
visualize(model, trainer.scene_gen, CONFIG, device, 'Before Training')

In [None]:
print('='*70)
print('v6.5 TRAINING - 2-PHASE CURRICULUM (1500 epochs)')
print('='*70)
print(f"Phase 1 (GROWTH): epochs 0-{CONFIG['growth_phase_end']}")
print(f"Phase 2 (SCULPT): epochs {CONFIG['growth_phase_end']}-{CONFIG['sculpt_phase_end']}")
print('='*70)

for epoch in tqdm(range(CONFIG['epochs']), desc='Training'):
    m = trainer.train_epoch(epoch)

    if epoch % CONFIG['log_every'] == 0:
        tqdm.write(
            f"E{epoch:4d} [{m['phase']:6s}] | Loss:{m['total_loss']:6.1f} | "
            f"Cov:{m['coverage']*100:4.0f}% | "
            f"Spl:{m['spill']*100:4.0f}% | "
            f"Thk:{m['thickness']*100:4.0f}% | "
            f"Fill:{m['fill_ratio']*100:5.1f}% | "
            f"SpillW:{m['spill_weight']:.1f}"
        )

    if epoch > 0 and epoch % CONFIG['viz_every'] == 0:
        visualize(model, trainer.scene_gen, CONFIG, device, f'Epoch {epoch} ({m["phase"]})')

    if epoch > 0 and epoch % CONFIG['save_every'] == 0:
        trainer.save(f"{PROJECT_ROOT}/step_b/checkpoints/v65_epoch_{epoch}.pth")

print('='*70)
print('Training complete')

## 9. Evaluation

In [None]:
plot_curves(trainer.history)

In [None]:
print('Evaluating...')
eval_results = trainer.evaluate(50)

print('\n' + '='*60)
print('STEP B v6.5 EVALUATION (2-Phase Curriculum)')
print('='*60)

targets = {
    'legality': (0.99, '>'),
    'coverage': (0.70, '>'),
    'spill': (0.20, '<'),
    'thickness': (0.90, '>'),
    'facade': (0.15, '<'),
    'loadpath': (0.95, '>'),
    'access': (0.90, '>'),
    'fill_ratio': (0.15, '<'),
}

for metric, (target, op) in targets.items():
    val = eval_results[f'avg_{metric}']
    if op == '>':
        ok = val > target
        print(f"{metric:12s}: {val*100:5.1f}% {'PASS' if ok else 'FAIL'} (>{target*100:.0f}%)")
    else:
        ok = val < target
        print(f"{metric:12s}: {val*100:5.1f}% {'PASS' if ok else 'FAIL'} (<{target*100:.0f}%)")

print('='*60)

In [None]:
for i in range(3):
    visualize(model, trainer.scene_gen, CONFIG, device, f'Final {i+1}')

In [None]:
trainer.save(f"{PROJECT_ROOT}/step_b/checkpoints/v65_final.pth")
print('Done!')