# Step D - NB02: Training

**Constraint-Based Architectural NCA - Step D**

**Version:** 1.0
**Date:** December 2025
**Purpose:** Train NCA to connect 3 access points with constrained volume

---

## Constraints Summary

| ID | Constraint | Type |
|----|------------|------|
| C1 | No structure inside buildings | Hard |
| C2A | Ground zone (0-5m) mostly empty | Soft |
| C2B | No-go zone (12m center, 0-5m) forbidden | Hard |
| C3 | Height ceiling (tallest + 2m) | Hard |
| C4 | Volume 20-40% of building | Soft |
| C5 | All 3 access points connected | Critical |
| C6 | Structural support (load path) | Soft |

---

## 1. Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
PROJECT_ROOT = '/content/drive/MyDrive/Constraint-NCA-StepD'
os.makedirs(f'{PROJECT_ROOT}/checkpoints', exist_ok=True)
print(f'Project root: {PROJECT_ROOT}')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from typing import Dict, Tuple, List
import json
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [None]:
# Load config
with open(f'{PROJECT_ROOT}/config_step_d.json', 'r') as f:
    CONFIG = json.load(f)

# Training config
CONFIG.update({
    'epochs': 500,
    'steps_min': 40,
    'steps_max': 60,
    'difficulty': 'easy',
    'log_every': 20,
    'viz_every': 100,
    'save_every': 100,
})

print('Config loaded')
print(f"Grid: {CONFIG['grid_size']}")
print(f"Epochs: {CONFIG['epochs']}")

## 2. Copy Foundation Components

*(From NB01)*

In [None]:
# ============================================================
# PERCEPTION MODULE
# ============================================================

class Perceive3D(nn.Module):
    def __init__(self, n_channels: int = 8):
        super().__init__()
        self.n_channels = n_channels
        sobel_x = self._create_sobel_kernel('x')
        sobel_y = self._create_sobel_kernel('y')
        sobel_z = self._create_sobel_kernel('z')
        identity = self._create_identity_kernel()
        kernels = torch.stack([identity, sobel_x, sobel_y, sobel_z], dim=0)
        self.register_buffer('kernels', kernels)

    def _create_sobel_kernel(self, direction: str) -> torch.Tensor:
        derivative = torch.tensor([-1., 0., 1.])
        smoothing = torch.tensor([1., 2., 1.])
        if direction == 'x':
            kernel = torch.einsum('i,j,k->ijk', smoothing, smoothing, derivative)
        elif direction == 'y':
            kernel = torch.einsum('i,j,k->ijk', smoothing, derivative, smoothing)
        elif direction == 'z':
            kernel = torch.einsum('i,j,k->ijk', derivative, smoothing, smoothing)
        return kernel / 16.0

    def _create_identity_kernel(self) -> torch.Tensor:
        kernel = torch.zeros(3, 3, 3)
        kernel[1, 1, 1] = 1.0
        return kernel

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = x.shape
        x_padded = F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate')
        outputs = []
        for k in range(4):
            kernel = self.kernels[k:k+1].unsqueeze(0).expand(C, 1, 3, 3, 3)
            out = F.conv3d(x_padded, kernel, padding=0, groups=C)
            outputs.append(out)
        return torch.cat(outputs, dim=1)

In [None]:
# ============================================================
# NCA MODEL
# ============================================================

class StepDNCA(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        n_channels = config['n_channels']
        hidden_dim = config['hidden_dim']
        perception_dim = n_channels * 4
        n_grown = config['n_grown']

        self.perceive = Perceive3D(n_channels)
        self.update_net = nn.Sequential(
            nn.Conv3d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, n_grown, 1),
        )
        self._init_weights()

    def _init_weights(self):
        gain = self.config['xavier_gain']
        for m in self.update_net:
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform_(m.weight, gain=gain)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        last_layer = self.update_net[-1]
        with torch.no_grad():
            last_layer.bias[0] = self.config['structure_bias']

    def forward(self, state: torch.Tensor, steps: int = 1) -> torch.Tensor:
        for _ in range(steps):
            state = self._step(state)
        return state

    def _step(self, state: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = state.shape
        cfg = self.config
        perception = self.perceive(state)
        delta = self.update_net(perception)

        if self.training:
            fire_mask = (torch.rand(B, 1, D, H, W, device=state.device) < cfg['fire_rate']).float()
            delta = delta * fire_mask

        grown_start = cfg['n_frozen']
        grown_new = state[:, grown_start:] + cfg['update_scale'] * delta
        grown_new = torch.clamp(grown_new, 0.0, 1.0)

        existing = state[:, cfg['ch_existing']:cfg['ch_existing']+1]
        available = 1.0 - existing
        struct_new = grown_new[:, 0:1] * available
        grown_masked = torch.cat([struct_new, grown_new[:, 1:]], dim=1)

        return torch.cat([state[:, :grown_start], grown_masked], dim=1)

    def grow(self, seed: torch.Tensor, steps: int = 50) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            return self.forward(seed, steps)

In [None]:
# ============================================================
# SCENE GENERATOR
# ============================================================

class StepDSceneGenerator:
    def __init__(self, config: dict):
        self.config = config
        self.G = config['grid_size']

    def generate(self, difficulty: str = 'easy', device: str = 'cuda') -> Tuple[torch.Tensor, dict]:
        G = self.G
        cfg = self.config
        state = torch.zeros(1, cfg['n_channels'], G, G, G, device=device)
        state[:, cfg['ch_ground'], 0, :, :] = 1.0

        params = self._get_params(difficulty)
        building_info = self._place_buildings(state, params)
        access_info = self._place_access_points(state, params, building_info)
        zone_info = self._create_zone_masks(state, params, building_info)

        return state, {
            'difficulty': difficulty,
            'buildings': building_info,
            'access_points': access_info,
            'zones': zone_info,
            'params': params,
        }

    def _get_params(self, difficulty: str) -> dict:
        if difficulty == 'easy':
            return {
                'height_A': 20, 'height_B': 20, 'gap': 18,
                'building_width': 12, 'building_depth': 20,
                'ap1_height': 10, 'ap2_height': 10,
            }
        elif difficulty == 'medium':
            return {
                'height_A': random.randint(20, 25),
                'height_B': random.randint(20, 25),
                'gap': random.randint(18, 22),
                'building_width': random.randint(10, 14),
                'building_depth': random.randint(18, 24),
                'ap1_height': random.randint(8, 15),
                'ap2_height': random.randint(8, 15),
            }
        else:
            h1, h2 = random.randint(20, 30), random.randint(20, 30)
            return {
                'height_A': h1, 'height_B': h2,
                'gap': random.randint(18, 24),
                'building_width': random.randint(8, 14),
                'building_depth': random.randint(16, 26),
                'ap1_height': random.randint(6, h1 - 4),
                'ap2_height': random.randint(6, h2 - 4),
            }

    def _place_buildings(self, state: torch.Tensor, params: dict) -> dict:
        G = self.G
        cfg = self.config
        ch = cfg['ch_existing']
        gap, w, d = params['gap'], params['building_width'], params['building_depth']
        h_A, h_B = params['height_A'], params['height_B']

        gap_center = G // 2
        y_start = (G - d) // 2

        x_A_end = gap_center - gap // 2
        x_A_start = x_A_end - w
        state[:, ch, :h_A, y_start:y_start+d, x_A_start:x_A_end] = 1.0

        x_B_start = gap_center + gap // 2
        x_B_end = x_B_start + w
        state[:, ch, :h_B, y_start:y_start+d, x_B_start:x_B_end] = 1.0

        return {
            'A': {'x': (x_A_start, x_A_end), 'y': (y_start, y_start+d), 'z': (0, h_A),
                  'height': h_A, 'volume': h_A * w * d, 'facade_x': x_A_end},
            'B': {'x': (x_B_start, x_B_end), 'y': (y_start, y_start+d), 'z': (0, h_B),
                  'height': h_B, 'volume': h_B * w * d, 'facade_x': x_B_start},
            'gap': {'x': (x_A_end, x_B_start), 'width': gap, 'center': gap_center},
            'max_height': max(h_A, h_B),
        }

    def _place_access_points(self, state: torch.Tensor, params: dict, buildings: dict) -> list:
        G = self.G
        cfg = self.config
        ch = cfg['ch_access']
        ped_zone = cfg['pedestrian_zone']
        y_mid = (buildings['A']['y'][0] + buildings['A']['y'][1]) // 2

        access_points = []

        # AP1: Building A facade
        ap1_x, ap1_y, ap1_z = buildings['A']['facade_x'], y_mid, params['ap1_height']
        state[:, ch, ap1_z:ap1_z+2, ap1_y-1:ap1_y+2, ap1_x:ap1_x+2] = 1.0
        access_points.append({'id': 'AP1', 'type': 'building_A', 'x': ap1_x, 'y': ap1_y, 'z': ap1_z})

        # AP2: Building B facade
        ap2_x, ap2_y, ap2_z = buildings['B']['facade_x'] - 1, y_mid, params['ap2_height']
        state[:, ch, ap2_z:ap2_z+2, ap2_y-1:ap2_y+2, ap2_x-1:ap2_x+1] = 1.0
        access_points.append({'id': 'AP2', 'type': 'building_B', 'x': ap2_x, 'y': ap2_y, 'z': ap2_z})

        # AP3: Ground
        if random.random() < 0.5:
            ap3_x = buildings['A']['facade_x'] + ped_zone // 2
        else:
            ap3_x = buildings['B']['facade_x'] - ped_zone // 2 - 1
        ap3_y, ap3_z = y_mid + random.randint(-3, 3), 0
        state[:, ch, ap3_z:ap3_z+2, ap3_y-1:ap3_y+2, ap3_x-1:ap3_x+2] = 1.0
        access_points.append({'id': 'AP3', 'type': 'ground', 'x': ap3_x, 'y': ap3_y, 'z': ap3_z})

        return access_points

    def _create_zone_masks(self, state: torch.Tensor, params: dict, buildings: dict) -> dict:
        G = self.G
        cfg = self.config
        ch = cfg['ch_zones']
        ped_zone, clearance = cfg['pedestrian_zone'], cfg['ground_clearance']
        ceiling_margin = cfg['ceiling_margin']

        gap_info = buildings['gap']
        max_height = buildings['max_height']
        ceiling_height = max_height + ceiling_margin

        zones = state[:, ch]
        existing = state[:, cfg['ch_existing']]
        zones = zones + existing
        zones[:, ceiling_height:, :, :] = 0.75

        gap_x_start, gap_x_end = gap_info['x']
        ped_A_end = gap_x_start + ped_zone
        ped_B_start = gap_x_end - ped_zone

        zones[:, :clearance, :, gap_x_start:ped_A_end] = 0.25
        zones[:, :clearance, :, ped_B_start:gap_x_end] = 0.25
        zones[:, :clearance, :, ped_A_end:ped_B_start] = 0.5

        state[:, ch] = torch.clamp(zones, 0, 1)

        return {
            'pedestrian_A': (gap_x_start, ped_A_end),
            'pedestrian_B': (ped_B_start, gap_x_end),
            'nogo': (ped_A_end, ped_B_start),
            'ground_clearance': clearance,
            'ceiling': ceiling_height,
        }

    def batch(self, difficulty: str, batch_size: int, device: str) -> Tuple[torch.Tensor, List[dict]]:
        scenes, metas = [], []
        for _ in range(batch_size):
            s, m = self.generate(difficulty, device)
            scenes.append(s)
            metas.append(m)
        return torch.cat(scenes, dim=0), metas

print('Scene generator defined')

## 3. Loss Functions

In [None]:
# ============================================================
# C1: BUILDING EXCLUSION (Hard)
# ============================================================

class BuildingExclusionLoss(nn.Module):
    """No structure inside existing buildings."""
    
    def __init__(self, config: dict):
        super().__init__()
        self.config = config

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        existing = state[:, cfg['ch_existing']]
        overlap = (structure * existing).sum()
        return overlap / (structure.sum() + 1e-8)

print('BuildingExclusionLoss defined')

In [None]:
# ============================================================
# C2A: GROUND ZONE CLEARANCE (Soft)
# ============================================================

class GroundClearanceLoss(nn.Module):
    """Keep ground zone (0-5m) empty except near ground access."""
    
    def __init__(self, config: dict, access_radius: int = 4):
        super().__init__()
        self.config = config
        self.access_radius = access_radius

    def forward(self, state: torch.Tensor, access_points: List[dict]) -> torch.Tensor:
        cfg = self.config
        G = cfg['grid_size']
        clearance = cfg['ground_clearance']
        
        structure = state[:, cfg['ch_structure']]
        existing = state[:, cfg['ch_existing']]
        
        # Ground zone structure
        ground_struct = structure[:, :clearance, :, :]
        
        # Create mask for allowed ground structure (near ground access)
        allowed = torch.zeros_like(ground_struct)
        for ap in access_points:
            if ap['type'] == 'ground':
                x, y = ap['x'], ap['y']
                r = self.access_radius
                x_min, x_max = max(0, x-r), min(G, x+r+1)
                y_min, y_max = max(0, y-r), min(G, y+r+1)
                allowed[:, :, y_min:y_max, x_min:x_max] = 1.0
        
        # Penalize ground structure outside allowed zone
        illegal_ground = ground_struct * (1 - allowed) * (1 - existing[:, :clearance])
        
        return illegal_ground.sum() / (ground_struct.sum() + 1e-8)

print('GroundClearanceLoss defined')

In [None]:
# ============================================================
# C2B: NO-GO ZONE (Hard)
# ============================================================

class NoGoZoneLoss(nn.Module):
    """No structure in center 12m of gap below 5m."""
    
    def __init__(self, config: dict):
        super().__init__()
        self.config = config

    def forward(self, state: torch.Tensor, zone_info: dict) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        clearance = zone_info['ground_clearance']
        nogo_x = zone_info['nogo']
        
        # Structure in no-go zone
        nogo_struct = structure[:, :clearance, :, nogo_x[0]:nogo_x[1]]
        
        return nogo_struct.sum() / (structure.sum() + 1e-8)

print('NoGoZoneLoss defined')

In [None]:
# ============================================================
# C3: HEIGHT CEILING (Hard)
# ============================================================

class CeilingLoss(nn.Module):
    """No structure above ceiling height."""
    
    def __init__(self, config: dict):
        super().__init__()
        self.config = config

    def forward(self, state: torch.Tensor, zone_info: dict) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        ceiling = zone_info['ceiling']
        
        above_ceiling = structure[:, ceiling:, :, :]
        
        return above_ceiling.sum() / (structure.sum() + 1e-8)

print('CeilingLoss defined')

In [None]:
# ============================================================
# C4: VOLUME BUDGET (Soft)
# ============================================================

class VolumeBudgetLoss(nn.Module):
    """Structure volume must be 20-40% of reference building."""
    
    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.min_ratio = config['min_volume_ratio']
        self.max_ratio = config['max_volume_ratio']

    def forward(self, state: torch.Tensor, building_volume: float) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        
        struct_volume = structure.sum()
        ratio = struct_volume / building_volume
        
        # Under penalty
        under = F.relu(self.min_ratio - ratio)
        
        # Over penalty (squared for sharper cutoff)
        over = F.relu(ratio - self.max_ratio)
        over_penalty = over ** 2 * 10
        
        return under + over_penalty

print('VolumeBudgetLoss defined')

In [None]:
# ============================================================
# C5: ACCESS CONNECTIVITY (Critical)
# ============================================================

class AccessConnectivityLoss(nn.Module):
    """All 3 access points must be connected via structure."""
    
    def __init__(self, config: dict, iterations: int = 64):
        super().__init__()
        self.config = config
        self.iterations = iterations

    def forward(self, state: torch.Tensor, access_points: List[dict]) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        access = state[:, cfg['ch_access']]
        existing = state[:, cfg['ch_existing']]
        
        # Soft structure for flood fill
        struct_soft = torch.sigmoid(10 * (structure - 0.3))
        
        # Combined traversable space: structure OR existing buildings
        traversable = torch.max(struct_soft, existing)
        
        # Seed from first access point and flood fill
        connected = access.clone()
        
        for _ in range(self.iterations):
            dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * traversable)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected
        
        # Check how many access points are connected
        access_reached = (connected * access).sum() / (access.sum() + 1e-8)
        
        return 1 - access_reached

print('AccessConnectivityLoss defined')

In [None]:
# ============================================================
# C6: STRUCTURAL SUPPORT (Soft)
# ============================================================

class LoadPathLoss(nn.Module):
    """Structure must have load path to support."""
    
    def __init__(self, config: dict, iterations: int = 64):
        super().__init__()
        self.config = config
        self.iterations = iterations

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        existing = state[:, cfg['ch_existing']]
        ground = state[:, cfg['ch_ground']]
        
        # Support = ground + existing buildings
        support = torch.max(ground, existing)
        
        # Soft structure
        struct_soft = torch.sigmoid(10 * (structure - 0.3))
        
        # Flood fill from support through structure
        connected = support.clone()
        
        for _ in range(self.iterations):
            dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * struct_soft)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected
        
        # Unsupported structure
        unsupported = structure * (1 - connected)
        
        return unsupported.sum() / (structure.sum() + 1e-8)

print('LoadPathLoss defined')

In [None]:
# ============================================================
# GROWTH INCENTIVE
# ============================================================

class AccessReachLoss(nn.Module):
    """Encourage structure to reach access points."""
    
    def __init__(self, config: dict, dilation: int = 5):
        super().__init__()
        self.config = config
        self.dilation = dilation

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        structure = state[:, cfg['ch_structure']]
        access = state[:, cfg['ch_access']]
        existing = state[:, cfg['ch_existing']]
        
        # Dilate access points to create target zone
        d = self.dilation
        target = F.max_pool3d(access.unsqueeze(1), 2*d+1, 1, d).squeeze(1)
        target = target * (1 - existing)  # Exclude buildings
        
        # Structure coverage of target
        coverage = (structure * target).sum() / (target.sum() + 1e-8)
        
        return 1 - coverage

print('AccessReachLoss defined')

In [None]:
# ============================================================
# QUALITY LOSSES
# ============================================================

class DensityPenalty(nn.Module):
    """Encourage binary outputs."""
    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        return (structure * (1 - structure)).mean()


class ThicknessLoss(nn.Module):
    """Prevent blobs - penalize thick cores."""
    
    def __init__(self, max_thickness: int = 3):
        super().__init__()
        self.max_thickness = max_thickness

    def erode_3d(self, x: torch.Tensor) -> torch.Tensor:
        x_4d = x.unsqueeze(1) if x.dim() == 4 else x.unsqueeze(0).unsqueeze(0)
        eroded = -F.max_pool3d(-x_4d, 3, 1, 1)
        return eroded.squeeze(1) if x.dim() == 4 else eroded.squeeze(0).squeeze(0)

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        soft = torch.sigmoid(10 * (structure - 0.3))
        core = soft
        for _ in range(self.max_thickness):
            core = self.erode_3d(core)
        return core.sum() / (soft.sum() + 1e-8)

print('Quality losses defined')

## 4. Trainer

In [None]:
class StepDTrainer:
    """Step D Trainer."""

    def __init__(self, model: nn.Module, config: dict, device: str):
        self.model = model
        self.config = config
        self.device = device

        # Loss functions
        self.exclusion_loss = BuildingExclusionLoss(config)
        self.ground_loss = GroundClearanceLoss(config)
        self.nogo_loss = NoGoZoneLoss(config)
        self.ceiling_loss = CeilingLoss(config)
        self.volume_loss = VolumeBudgetLoss(config)
        self.connectivity_loss = AccessConnectivityLoss(config)
        self.loadpath_loss = LoadPathLoss(config)
        self.reach_loss = AccessReachLoss(config)
        self.density_loss = DensityPenalty()
        self.thickness_loss = ThicknessLoss()

        # Weights
        self.weights = {
            'exclusion': 100.0,    # Hard - no building penetration
            'nogo': 50.0,          # Hard - no-go zone
            'ceiling': 50.0,       # Hard - height limit
            'connectivity': 40.0,  # Critical - connect access points
            'ground': 30.0,        # Soft - ground clearance
            'volume': 25.0,        # Soft - volume budget
            'loadpath': 20.0,      # Soft - structural support
            'reach': 15.0,         # Growth incentive
            'thickness': 10.0,     # Anti-blob
            'density': 5.0,        # Binary outputs
        }

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['lr_initial'])
        self.scene_gen = StepDSceneGenerator(config)
        self.history = []

    def train_epoch(self, epoch: int) -> dict:
        self.model.train()
        cfg = self.config

        # Generate batch
        seeds, metas = self.scene_gen.batch(cfg['difficulty'], cfg['batch_size'], self.device)
        
        # Use first scene's metadata for zone info
        meta = metas[0]
        building_vol = meta['buildings']['A']['volume']
        zone_info = meta['zones']
        access_points = meta['access_points']

        # Forward
        steps = random.randint(cfg['steps_min'], cfg['steps_max'])
        final = self.model(seeds, steps=steps)

        structure = final[:, cfg['ch_structure']]

        # Compute losses
        L = {
            'exclusion': self.exclusion_loss(final),
            'ground': self.ground_loss(final, access_points),
            'nogo': self.nogo_loss(final, zone_info),
            'ceiling': self.ceiling_loss(final, zone_info),
            'volume': self.volume_loss(final, building_vol),
            'connectivity': self.connectivity_loss(final, access_points),
            'loadpath': self.loadpath_loss(final),
            'reach': self.reach_loss(final),
            'thickness': self.thickness_loss(structure),
            'density': self.density_loss(structure),
        }

        total = sum(self.weights[k] * v for k, v in L.items())

        # Backward
        self.optimizer.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg['grad_clip'])
        self.optimizer.step()

        # Metrics
        with torch.no_grad():
            struct_vol = structure.sum().item()
            vol_ratio = struct_vol / building_vol
            
            metrics = {
                'epoch': epoch,
                'total_loss': total.item(),
                'exclusion': L['exclusion'].item(),
                'connectivity': L['connectivity'].item(),
                'volume_ratio': vol_ratio,
                'ground': L['ground'].item(),
                'nogo': L['nogo'].item(),
                'ceiling': L['ceiling'].item(),
                'loadpath': L['loadpath'].item(),
            }

        self.history.append(metrics)
        return metrics

    def evaluate(self, n_samples: int = 20) -> dict:
        self.model.eval()
        cfg = self.config
        results = []

        with torch.no_grad():
            for _ in range(n_samples):
                scene, meta = self.scene_gen.generate(cfg['difficulty'], self.device)
                grown = self.model.grow(scene, steps=50)

                structure = grown[:, cfg['ch_structure']]
                existing = grown[:, cfg['ch_existing']]
                
                building_vol = meta['buildings']['A']['volume']
                zone_info = meta['zones']

                # Metrics
                struct_vol = (structure > 0.5).float().sum().item()
                vol_ratio = struct_vol / building_vol
                
                exclusion = ((structure > 0.5).float() * existing).sum().item() == 0
                
                nogo_struct = structure[:, :zone_info['ground_clearance'], :,
                                        zone_info['nogo'][0]:zone_info['nogo'][1]]
                nogo_clean = (nogo_struct > 0.5).float().sum().item() == 0
                
                ceiling_struct = structure[:, zone_info['ceiling']:, :, :]
                ceiling_clean = (ceiling_struct > 0.5).float().sum().item() == 0
                
                vol_ok = cfg['min_volume_ratio'] <= vol_ratio <= cfg['max_volume_ratio']

                results.append({
                    'volume_ratio': vol_ratio,
                    'exclusion_ok': exclusion,
                    'nogo_ok': nogo_clean,
                    'ceiling_ok': ceiling_clean,
                    'volume_ok': vol_ok,
                })

        return {
            'avg_volume_ratio': np.mean([r['volume_ratio'] for r in results]),
            'exclusion_rate': np.mean([r['exclusion_ok'] for r in results]),
            'nogo_rate': np.mean([r['nogo_ok'] for r in results]),
            'ceiling_rate': np.mean([r['ceiling_ok'] for r in results]),
            'volume_rate': np.mean([r['volume_ok'] for r in results]),
        }

    def save(self, path: str):
        torch.save({
            'model': self.model.state_dict(),
            'history': self.history,
            'config': self.config,
        }, path)
        print(f'Saved: {path}')

print('StepDTrainer defined')

## 5. Visualization

In [None]:
def visualize_result(model, scene_gen, config, device, title='Result'):
    """Visualize training result - ZOOMED to region of interest."""
    model.eval()
    scene, meta = scene_gen.generate(config['difficulty'], device)

    with torch.no_grad():
        grown = model.grow(scene, steps=50)

    cfg = config
    s = grown[0].cpu().numpy()
    G = s.shape[1]

    existing = s[cfg['ch_existing']] > 0.5
    access = s[cfg['ch_access']] > 0.5
    structure = s[cfg['ch_structure']] > 0.5

    # Check violations
    overlap = structure & existing
    zone_info = meta['zones']
    nogo_violation = structure[:zone_info['ground_clearance'], :, zone_info['nogo'][0]:zone_info['nogo'][1]]

    # === CALCULATE REGION OF INTEREST ===
    buildings = meta['buildings']
    margin = 4
    
    x_min = max(0, buildings['A']['x'][0] - margin)
    x_max = min(G, buildings['B']['x'][1] + margin)
    y_min = max(0, buildings['A']['y'][0] - margin)
    y_max = min(G, buildings['A']['y'][1] + margin)
    z_min = 0
    z_max = min(G, buildings['max_height'] + cfg['ceiling_margin'] + margin)
    
    # Crop arrays
    existing_crop = existing[z_min:z_max, y_min:y_max, x_min:x_max]
    access_crop = access[z_min:z_max, y_min:y_max, x_min:x_max]
    structure_crop = structure[z_min:z_max, y_min:y_max, x_min:x_max]
    overlap_crop = overlap[z_min:z_max, y_min:y_max, x_min:x_max]
    struct_clean_crop = structure_crop & ~existing_crop

    fig = plt.figure(figsize=(20, 5))

    # 3D view - cropped
    ax1 = fig.add_subplot(141, projection='3d')
    if existing_crop.any():
        ax1.voxels(existing_crop.transpose(1, 2, 0), facecolors='gray', alpha=0.3)
    if access_crop.any():
        ax1.voxels(access_crop.transpose(1, 2, 0), facecolors='green', alpha=0.9)
    if overlap_crop.any():
        ax1.voxels(overlap_crop.transpose(1, 2, 0), facecolors='red', alpha=0.9)
    if struct_clean_crop.any():
        ax1.voxels(struct_clean_crop.transpose(1, 2, 0), facecolors='royalblue', alpha=0.6)
    ax1.set_xlabel('Y (m)')
    ax1.set_ylabel('X (m)')
    ax1.set_zlabel('Z (m)')
    ax1.set_title(title)
    ax1.view_init(elev=25, azim=45)

    # Plan view (top-down) - cropped
    ax2 = fig.add_subplot(142)
    h_crop = z_max - z_min
    w_crop = x_max - x_min
    d_crop = y_max - y_min
    plan = np.zeros((d_crop, w_crop, 3))
    plan[existing_crop.max(axis=0)] = [0.5, 0.5, 0.5]
    plan[struct_clean_crop.max(axis=0)] = [0.2, 0.4, 0.8]
    plan[access_crop.max(axis=0)] = [0.2, 0.8, 0.2]
    # Plan: Y is rows (vertical), X is columns (horizontal)
    ax2.imshow(plan, origin='lower', 
               extent=[x_min, x_max, y_min, y_max])
    ax2.set_xlabel('X (m)')
    ax2.set_ylabel('Y (m)')
    ax2.set_title('Plan')

    # Elevation (front view) - cropped
    ax3 = fig.add_subplot(143)
    elev = np.zeros((h_crop, w_crop, 3))
    elev[existing_crop.max(axis=1)] = [0.5, 0.5, 0.5]
    elev[struct_clean_crop.max(axis=1)] = [0.2, 0.4, 0.8]
    elev[access_crop.max(axis=1)] = [0.2, 0.8, 0.2]
    # Elevation: Z is rows (vertical), X is columns (horizontal) - NO transpose
    ax3.imshow(elev, origin='lower',
               extent=[x_min, x_max, z_min, z_max])
    # Zone markers - horizontal lines at constant Z values
    clearance = zone_info['ground_clearance']
    ceiling = zone_info['ceiling']
    ax3.axhline(y=clearance, color='orange', linestyle='--', alpha=0.7, label='Ground clearance')
    ax3.axhline(y=ceiling, color='red', linestyle='--', alpha=0.7, label='Ceiling')
    ax3.set_xlabel('X (m)')
    ax3.set_ylabel('Z (m)')
    ax3.set_title('Elevation (red=no-go zone)')
    ax3.legend(loc='upper right', fontsize=7)

    # Metrics
    ax4 = fig.add_subplot(144)
    ax4.axis('off')
    
    building_vol = meta['buildings']['A']['volume']
    struct_vol = structure.sum()
    vol_ratio = struct_vol / building_vol
    
    txt = f"""METRICS

Volume: {struct_vol:.0f} voxels
Ratio: {vol_ratio*100:.1f}% (target 20-40%)

Building exclusion: {'OK' if not overlap.any() else 'VIOLATION'}
No-go zone: {'OK' if not nogo_violation.any() else 'VIOLATION'}
Ceiling: {'OK' if not structure[ceiling:].any() else 'VIOLATION'}

Access points:
"""
    for ap in meta['access_points']:
        txt += f"  {ap['id']}: z={ap['z']}m\n"
    
    ax4.text(0.1, 0.9, txt, transform=ax4.transAxes, fontsize=9, va='top', family='monospace')

    plt.tight_layout()
    plt.show()
    return grown, meta


def plot_curves(history):
    epochs = [h['epoch'] for h in history]
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))

    axes[0, 0].plot(epochs, [h['total_loss'] for h in history])
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].set_yscale('log')

    axes[0, 1].plot(epochs, [h['connectivity'] for h in history])
    axes[0, 1].set_title('Connectivity Loss')

    axes[0, 2].plot(epochs, [h['volume_ratio'] for h in history])
    axes[0, 2].axhline(0.20, color='g', linestyle='--')
    axes[0, 2].axhline(0.40, color='r', linestyle='--')
    axes[0, 2].set_title('Volume Ratio (20-40%)')

    axes[1, 0].plot(epochs, [h['exclusion'] for h in history])
    axes[1, 0].set_title('Exclusion Loss')

    axes[1, 1].plot(epochs, [h['nogo'] for h in history])
    axes[1, 1].set_title('No-Go Zone Loss')

    axes[1, 2].plot(epochs, [h['loadpath'] for h in history])
    axes[1, 2].set_title('Load Path Loss')

    plt.tight_layout()
    plt.show()

print('Visualization defined')

## 6. Training

In [None]:
model = StepDNCA(CONFIG).to(device)
trainer = StepDTrainer(model, CONFIG, device)

print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f'Epochs: {CONFIG["epochs"]}')
print(f'Difficulty: {CONFIG["difficulty"]}')

In [None]:
# Before training
print('Before training:')
visualize_result(model, trainer.scene_gen, CONFIG, device, 'Before Training')

In [None]:
print('='*60)
print('STEP D TRAINING')
print('='*60)

for epoch in tqdm(range(CONFIG['epochs']), desc='Training'):
    m = trainer.train_epoch(epoch)

    if epoch % CONFIG['log_every'] == 0:
        tqdm.write(
            f"E{epoch:4d} | Loss:{m['total_loss']:7.1f} | "
            f"Conn:{m['connectivity']:.3f} | "
            f"Vol:{m['volume_ratio']*100:5.1f}% | "
            f"Excl:{m['exclusion']:.3f} | "
            f"NoGo:{m['nogo']:.3f}"
        )

    if epoch > 0 and epoch % CONFIG['viz_every'] == 0:
        visualize_result(model, trainer.scene_gen, CONFIG, device, f'Epoch {epoch}')

    if epoch > 0 and epoch % CONFIG['save_every'] == 0:
        trainer.save(f"{PROJECT_ROOT}/checkpoints/step_d_epoch_{epoch}.pth")

print('='*60)
print('Training complete')

## 7. Evaluation

In [None]:
plot_curves(trainer.history)

In [None]:
print('Evaluating...')
eval_results = trainer.evaluate(30)

print('\n' + '='*60)
print('STEP D EVALUATION')
print('='*60)
print(f"Volume ratio:      {eval_results['avg_volume_ratio']*100:.1f}% (target: 20-40%)")
print(f"Exclusion rate:    {eval_results['exclusion_rate']*100:.0f}% (target: 100%)")
print(f"No-go compliance:  {eval_results['nogo_rate']*100:.0f}% (target: 100%)")
print(f"Ceiling compliance:{eval_results['ceiling_rate']*100:.0f}% (target: 100%)")
print(f"Volume compliance: {eval_results['volume_rate']*100:.0f}% (target: high)")
print('='*60)

In [None]:
# Final visualizations
print('Final results:')
for i in range(3):
    visualize_result(model, trainer.scene_gen, CONFIG, device, f'Final {i+1}')

In [None]:
trainer.save(f"{PROJECT_ROOT}/checkpoints/step_d_final.pth")
print('Done!')