# NB02: Step C - Connectivity Only (Clean)

Goal: connect two access points with minimal volume and no flooding.

Notes:
- Only connectivity and anti-flood regularizers are used.
- Buildings are 20m+ (voxel-mapped) and there are exactly two access points.


## 1. Setup

In [None]:
# Mount Google Drive (Colab)
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/Constraint-NCA'
print(f'Project root: {PROJECT_ROOT}')


In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from typing import Tuple, List
import json
from datetime import datetime
from tqdm.notebook import tqdm
import os

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')


In [None]:
# Set seeds
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)


## 2. Configuration (Step C)

In [None]:
CONFIG = {
    # Grid
    'grid_size': 32,

    # Channels (6 total: 3 frozen + 3 grown)
    'n_channels': 6,
    'n_frozen': 3,
    'n_grown': 3,

    # Frozen channels
    'ch_ground': 0,
    'ch_existing': 1,
    'ch_access': 2,

    # Grown channels
    'ch_structure': 3,
    'ch_alive': 4,
    'ch_hidden': 5,

    # Network
    'hidden_dim': 96,
    'update_scale': 0.1,
    'fire_rate': 0.5,
    'xavier_gain': 0.5,
    'structure_bias': 0.0,

    # Training
    'lr_initial': 1e-3,
    'batch_size': 4,
    'grad_clip': 1.0,
    'epochs': 400,
    'steps_min': 30,
    'steps_max': 50,
    'difficulty': 'easy',
    'log_every': 20,
    'viz_every': 100,
    'save_every': 100,

    # Corridor
    'corridor_width': 1,
    'vertical_envelope': 1,

    # Anti-flood targets
    'max_fill_ratio': 0.08,
    'max_ground_contact': 0.05,
    'max_facade_contact': 0.05,

    # Connectivity
    'connectivity_iters': 48,
}

print('Step C Configuration:')
print(f"  Grid: {CONFIG['grid_size']}^3")
print(f"  Channels: {CONFIG['n_channels']} ({CONFIG['n_frozen']} frozen + {CONFIG['n_grown']} grown)")
print(f"  Corridor width: {CONFIG['corridor_width']}")


## 3. Perception Module

In [None]:
class Perceive3D(nn.Module):
    # 3D Sobel perception with replicate padding.

    def __init__(self, n_channels: int = 6):
        super().__init__()
        self.n_channels = n_channels

        sobel_x = self._create_sobel_kernel('x')
        sobel_y = self._create_sobel_kernel('y')
        sobel_z = self._create_sobel_kernel('z')
        identity = self._create_identity_kernel()

        kernels = torch.stack([identity, sobel_x, sobel_y, sobel_z], dim=0)
        self.register_buffer('kernels', kernels)

    def _create_sobel_kernel(self, direction: str) -> torch.Tensor:
        derivative = torch.tensor([-1., 0., 1.])
        smoothing = torch.tensor([1., 2., 1.])
        if direction == 'x':
            kernel = torch.einsum('i,j,k->ijk', smoothing, smoothing, derivative)
        elif direction == 'y':
            kernel = torch.einsum('i,j,k->ijk', smoothing, derivative, smoothing)
        else:
            kernel = torch.einsum('i,j,k->ijk', derivative, smoothing, smoothing)
        return kernel / kernel.abs().sum()

    def _create_identity_kernel(self) -> torch.Tensor:
        kernel = torch.zeros(3, 3, 3)
        kernel[1, 1, 1] = 1.0
        return kernel

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = x.shape
        x = x.view(B * C, 1, D, H, W)
        kernels = self.kernels.view(4, 1, 3, 3, 3)
        x = F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate')
        y = F.conv3d(x, kernels)
        y = y.view(B, C * 4, D, H, W)
        return y


## 4. NCA Model

In [None]:
class UrbanPavilionNCA(nn.Module):
    # Minimal NCA for Step C connectivity.

    def __init__(self, config: dict):
        super().__init__()
        self.config = config

        n_channels = config['n_channels']
        hidden_dim = config['hidden_dim']
        perception_dim = n_channels * 4
        n_grown = config['n_grown']

        self.perceive = Perceive3D(n_channels)
        self.update_net = nn.Sequential(
            nn.Conv3d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, n_grown, 1),
        )
        self._init_weights()

    def _init_weights(self):
        gain = self.config['xavier_gain']
        for m in self.update_net:
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform_(m.weight, gain=gain)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        last_layer = self.update_net[-1]
        with torch.no_grad():
            last_layer.bias[0] = self.config['structure_bias']

    def forward(self, state: torch.Tensor, steps: int = 1) -> torch.Tensor:
        for _ in range(steps):
            state = self._step(state)
        return state

    def _step(self, state: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = state.shape
        cfg = self.config
        perception = self.perceive(state)
        delta = self.update_net(perception)

        if self.training:
            fire_mask = (torch.rand(B, 1, D, H, W, device=state.device) < cfg['fire_rate']).float()
            delta = delta * fire_mask

        grown_start = cfg['n_frozen']
        grown_new = state[:, grown_start:] + cfg['update_scale'] * delta
        grown_new = torch.clamp(grown_new, 0.0, 1.0)

        existing = state[:, cfg['ch_existing']:cfg['ch_existing']+1]
        available_mask = 1.0 - existing
        struct_new = grown_new[:, 0:1] * available_mask
        grown_masked = torch.cat([struct_new, grown_new[:, 1:]], dim=1)
        return torch.cat([state[:, :grown_start], grown_masked], dim=1)

    def grow(self, seed: torch.Tensor, steps: int = 50) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            return self.forward(seed, steps)


## 5. Scene Generator (20m+ buildings, 2 access points)

In [None]:
class UrbanSceneGenerator:
    def __init__(self, config: dict):
        self.config = config
        self.G = config['grid_size']
        self.C = config['n_channels']

    def _get_difficulty_params(self, difficulty: str) -> dict:
        return {
            'n_buildings': 2,
            'height_range': (20, 26),
            'width_range': (8, 12),
            'gap_width': random.randint(12, 16),
            'n_ground_access': 1,
            'n_elevated_access': 1,
        }

    def generate(self, difficulty: str = 'easy', device: str = 'cuda') -> Tuple[torch.Tensor, dict]:
        G = self.G
        cfg = self.config
        state = torch.zeros(1, self.C, G, G, G, device=device)
        state[:, cfg['ch_ground'], 0, :, :] = 1.0

        params = self._get_difficulty_params(difficulty)
        buildings = self._place_buildings(state, params)
        access_points = self._place_access_points(state, params, buildings)

        return state, {'difficulty': difficulty, 'buildings': buildings, 'access_points': access_points}

    def _place_buildings(self, state: torch.Tensor, params: dict) -> list:
        G = self.G
        ch = self.config['ch_existing']
        buildings = []
        gap_width = params['gap_width']
        gap_center = G // 2

        w1 = random.randint(*params['width_range'])
        d1 = random.randint(G // 2, G - 2)
        h1 = random.randint(*params['height_range'])
        x1_end = gap_center - gap_width // 2
        x1_start = max(0, x1_end - w1)
        state[:, ch, :h1, :d1, x1_start:x1_end] = 1.0
        buildings.append({'x': (x1_start, x1_end), 'y': (0, d1), 'z': (0, h1), 'gap_facing_x': x1_end, 'side': 'left'})

        w2 = random.randint(*params['width_range'])
        d2 = random.randint(G // 2, G - 2)
        h2 = random.randint(*params['height_range'])
        x2_start = gap_center + gap_width // 2
        x2_end = min(G, x2_start + w2)
        state[:, ch, :h2, :d2, x2_start:x2_end] = 1.0
        buildings.append({'x': (x2_start, x2_end), 'y': (0, d2), 'z': (0, h2), 'gap_facing_x': x2_start, 'side': 'right'})
        return buildings

    def _place_access_points(self, state: torch.Tensor, params: dict, buildings: list) -> list:
        G = self.G
        ch = self.config['ch_access']
        access_points = []
        left_buildings = [b for b in buildings if b['side'] == 'left']
        right_buildings = [b for b in buildings if b['side'] == 'right']
        gap_x_min = max(b['gap_facing_x'] for b in left_buildings) if left_buildings else 0
        gap_x_max = min(b['gap_facing_x'] for b in right_buildings) if right_buildings else G

        # Ground access in the gap
        x = random.randint(gap_x_min + 1, gap_x_max - 3)
        y = random.randint(0, G - 3)
        state[:, ch, 0:2, y:y+2, x:x+2] = 1.0
        access_points.append({'x': x, 'y': y, 'z': 0, 'type': 'ground'})

        # Elevated access on facade
        building = random.choice(buildings)
        bz_max = building['z'][1]
        is_left = building['side'] == 'left'
        z = random.randint(4, max(5, bz_max - 2))
        y = random.randint(building['y'][0], min(building['y'][1] - 2, building['y'][0] + G // 3))
        x = building['x'][1] if is_left else building['x'][0] - 2
        x = max(0, min(G - 2, x))
        state[:, ch, z:z+2, y:y+2, x:x+2] = 1.0
        access_points.append({'x': x, 'y': y, 'z': z, 'type': 'elevated'})
        return access_points

    def batch(self, difficulty: str, batch_size: int, device: str) -> torch.Tensor:
        return torch.cat([self.generate(difficulty, device)[0] for _ in range(batch_size)], dim=0)


## 6. Corridor Target (minimal path)

In [None]:
def find_access_centroids(access_channel: torch.Tensor) -> List[Tuple[int, int, int]]:
    binary = (access_channel > 0.5).float()
    if binary.sum() == 0:
        return []
    positions = (binary > 0).nonzero(as_tuple=False)
    used = set()
    centroids = []
    for idx in range(len(positions)):
        if idx in used:
            continue
        pos = positions[idx]
        cluster = [pos]
        used.add(idx)
        for idx2 in range(idx + 1, len(positions)):
            if idx2 in used:
                continue
            if (pos - positions[idx2]).abs().sum().item() <= 4:
                cluster.append(positions[idx2])
                used.add(idx2)
        cluster = torch.stack(cluster).float()
        centroid = cluster.mean(dim=0).long()
        centroids.append((centroid[0].item(), centroid[1].item(), centroid[2].item()))
    return centroids


def compute_distance_field_3d(start_points: List[Tuple[int, int, int]],
                              legal_mask: torch.Tensor, max_iters: int = 64) -> torch.Tensor:
    D, H, W = legal_mask.shape
    device = legal_mask.device
    distance = torch.full((D, H, W), float('inf'), device=device)
    for z, y, x in start_points:
        if 0 <= z < D and 0 <= y < H and 0 <= x < W:
            distance[z, y, x] = 0
    for _ in range(max_iters):
        dist_4d = distance.unsqueeze(0).unsqueeze(0)
        expanded = -F.max_pool3d(-dist_4d, 3, 1, 1).squeeze(0).squeeze(0) + 1
        new_distance = torch.where(legal_mask > 0.5, torch.min(distance, expanded), distance)
        if torch.allclose(distance, new_distance, atol=1e-5):
            break
        distance = new_distance
    return distance


def compute_corridor_target(seed_state: torch.Tensor, config: dict,
                            corridor_width: int = 1, vertical_envelope: int = 1) -> torch.Tensor:
    cfg = config
    G = cfg['grid_size']
    device = seed_state.device
    B = seed_state.shape[0]
    corridors = torch.zeros(B, G, G, G, device=device)

    for b in range(B):
        access = seed_state[b, cfg['ch_access']]
        existing = seed_state[b, cfg['ch_existing']]
        legal_mask = 1.0 - existing
        centroids = find_access_centroids(access)

        if len(centroids) < 2:
            dilated = F.max_pool3d(access.unsqueeze(0).unsqueeze(0), 3, 1, 1)
            corridors[b] = dilated.squeeze() * legal_mask
            continue

        start, end = centroids[0], centroids[1]
        dist_from_start = compute_distance_field_3d([start], legal_mask)
        dist_from_end = compute_distance_field_3d([end], legal_mask)
        total_dist = dist_from_start[end[0], end[1], end[2]]
        if total_dist == float('inf'):
            corridors[b] = 0
            continue

        path_cost = dist_from_start + dist_from_end
        on_path = (path_cost <= total_dist + corridor_width).float()

        corridor_4d = on_path.unsqueeze(0).unsqueeze(0)
        dilated = F.max_pool3d(corridor_4d, 2*corridor_width+1, 1, corridor_width)
        corridor_dilated = dilated.squeeze()

        if vertical_envelope > 0:
            for z in range(G):
                z_min = max(0, z - vertical_envelope)
                z_max = min(G, z + vertical_envelope + 1)
                local_max = corridor_dilated[z_min:z_max].max(dim=0)[0]
                if local_max.any():
                    corridor_dilated[z] = torch.max(corridor_dilated[z], local_max)

        corridors[b] = corridor_dilated * legal_mask

    return corridors


## 7. Losses (connectivity + anti-flood)

In [None]:
class AccessConnectivityLoss(nn.Module):
    def __init__(self, iterations: int = 48):
        super().__init__()
        self.iterations = iterations

    def forward(self, structure: torch.Tensor, access: torch.Tensor) -> torch.Tensor:
        struct_soft = torch.sigmoid(10 * (structure - 0.3))

        # Pick a single access voxel as the source seed per batch item.
        access_flat = access.reshape(access.shape[0], -1)
        idx = torch.argmax(access_flat, dim=1)
        source_flat = torch.zeros_like(access_flat)
        source_flat[torch.arange(access_flat.shape[0]), idx] = 1.0
        source = source_flat.view_as(access)

        connected = source.clone()
        for _ in range(self.iterations):
            dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * struct_soft)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected

        reachable = (connected * access).sum()
        total = access.sum() + 1e-8
        return 1.0 - (reachable / total)
class SpillLoss(nn.Module):
    def forward(self, structure: torch.Tensor, corridor: torch.Tensor) -> torch.Tensor:
        outside = structure * (1 - corridor)
        return outside.sum() / (corridor.sum() + 1e-8)


class MassLoss(nn.Module):
    def __init__(self, max_ratio: float = 0.08):
        super().__init__()
        self.max_ratio = max_ratio

    def forward(self, structure: torch.Tensor, available: torch.Tensor) -> torch.Tensor:
        ratio = structure.sum() / (available.sum() + 1e-8)
        return F.relu(ratio - self.max_ratio)


class GroundContactLoss(nn.Module):
    def __init__(self, max_ratio: float = 0.05, ground_levels: int = 2):
        super().__init__()
        self.max_ratio = max_ratio
        self.ground_levels = ground_levels

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        ground_struct = structure[:, :self.ground_levels]
        ratio = ground_struct.sum() / (structure.sum() + 1e-8)
        return F.relu(ratio - self.max_ratio)


class FacadeContactLoss(nn.Module):
    def __init__(self, max_contact: float = 0.05):
        super().__init__()
        self.max_contact = max_contact

    def forward(self, structure: torch.Tensor, existing: torch.Tensor) -> torch.Tensor:
        dilated = F.max_pool3d(existing.unsqueeze(1), 3, 1, 1).squeeze(1)
        facade = torch.clamp(dilated - existing, 0, 1)
        contact = (structure * facade).sum() / (structure.sum() + 1e-8)
        return F.relu(contact - self.max_contact)


## 8. Trainer

In [None]:
class StepCTrainer:
    def __init__(self, model: nn.Module, config: dict, device: str):
        self.model = model
        self.config = config
        self.device = device

        self.conn_loss = AccessConnectivityLoss(iterations=config['connectivity_iters'])
        self.spill_loss = SpillLoss()
        self.mass_loss = MassLoss(max_ratio=config['max_fill_ratio'])
        self.ground_loss = GroundContactLoss(max_ratio=config['max_ground_contact'])
        self.facade_loss = FacadeContactLoss(max_contact=config['max_facade_contact'])

        self.weights = {
            'connectivity': 40.0,
            'spill': 20.0,
            'mass': 10.0,
            'ground': 10.0,
            'facade': 10.0,
        }

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['lr_initial'])
        self.scene_gen = UrbanSceneGenerator(config)
        self.history = []

    def train_epoch(self, epoch: int) -> dict:
        self.model.train()
        cfg = self.config

        seeds = self.scene_gen.batch(cfg['difficulty'], cfg['batch_size'], self.device)
        corridor = compute_corridor_target(seeds, cfg, cfg['corridor_width'], cfg['vertical_envelope'])

        steps = random.randint(cfg['steps_min'], cfg['steps_max'])
        final = self.model(seeds, steps=steps)

        structure = final[:, cfg['ch_structure']]
        existing = final[:, cfg['ch_existing']]
        available = 1.0 - existing
        access = final[:, cfg['ch_access']]

        L_conn = self.conn_loss(structure, access)
        L_spill = self.spill_loss(structure, corridor)
        L_mass = self.mass_loss(structure, available)
        L_ground = self.ground_loss(structure)
        L_facade = self.facade_loss(structure, existing)

        total_loss = (
            self.weights['connectivity'] * L_conn +
            self.weights['spill'] * L_spill +
            self.weights['mass'] * L_mass +
            self.weights['ground'] * L_ground +
            self.weights['facade'] * L_facade
        )

        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg['grad_clip'])
        self.optimizer.step()

        with torch.no_grad():
            fill_ratio = structure.sum() / (available.sum() + 1e-8)
            ground_ratio = structure[:, :2].sum() / (structure.sum() + 1e-8)
            facade_ratio = (structure * (F.max_pool3d(existing.unsqueeze(1), 3, 1, 1).squeeze(1) - existing).clamp(0,1)).sum() / (structure.sum() + 1e-8)
            spill_ratio = (structure * (1 - corridor)).sum() / (structure.sum() + 1e-8)

        metrics = {
            'epoch': epoch,
            'total_loss': total_loss.item(),
            'connectivity': 1.0 - L_conn.item(),
            'spill': spill_ratio.item(),
            'fill_ratio': fill_ratio.item(),
            'ground_ratio': ground_ratio.item(),
            'facade_ratio': facade_ratio.item(),
            'steps': steps,
        }
        self.history.append(metrics)
        return metrics

    def evaluate(self, n_samples: int = 50) -> dict:
        self.model.eval()
        cfg = self.config
        results = []

        with torch.no_grad():
            for _ in range(n_samples):
                scene, _ = self.scene_gen.generate(cfg['difficulty'], self.device)
                corridor = compute_corridor_target(scene, cfg, cfg['corridor_width'], cfg['vertical_envelope'])
                grown = self.model.grow(scene, steps=50)

                structure = grown[:, cfg['ch_structure']]
                existing = grown[:, cfg['ch_existing']]
                available = 1.0 - existing
                access = grown[:, cfg['ch_access']]

                conn = 1.0 - self.conn_loss(structure, access).item()
                spill = (structure * (1 - corridor)).sum().item() / (structure.sum().item() + 1e-8)
                fill = structure.sum().item() / (available.sum().item() + 1e-8)
                ground = structure[:, :2].sum().item() / (structure.sum().item() + 1e-8)
                facade = (structure * (F.max_pool3d(existing.unsqueeze(1), 3, 1, 1).squeeze(1) - existing).clamp(0,1)).sum().item() / (structure.sum().item() + 1e-8)

                results.append({
                    'connectivity': conn,
                    'spill': spill,
                    'fill_ratio': fill,
                    'ground_ratio': ground,
                    'facade_ratio': facade,
                })

        return {f'avg_{k}': float(np.mean([r[k] for r in results])) for k in results[0].keys()}

    def save_checkpoint(self, path: str):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config,
            'history': self.history,
        }, path)
        print(f'Checkpoint saved to {path}')


## 9. Visualization

In [None]:
def visualize_result(model: nn.Module, scene_gen: UrbanSceneGenerator,
                     config: dict, device: str, title: str = 'Result'):
    model.eval()
    scene, _ = scene_gen.generate(config['difficulty'], device)

    with torch.no_grad():
        corridor = compute_corridor_target(scene, config, config['corridor_width'], config['vertical_envelope'])
        grown = model.grow(scene, steps=50)

    cfg = config
    s = grown[0].cpu().numpy()
    G = s.shape[1]

    existing = s[cfg['ch_existing']] > 0.5
    access = s[cfg['ch_access']] > 0.5
    structure = s[cfg['ch_structure']] > 0.5
    corr = corridor[0].cpu().numpy() > 0.5

    in_corr = structure & corr
    outside = structure & ~corr

    fig = plt.figure(figsize=(16, 4))

    ax1 = fig.add_subplot(141, projection='3d')
    if existing.any(): ax1.voxels(existing.transpose(1,2,0), facecolors='gray', alpha=0.3)
    if access.any(): ax1.voxels(access.transpose(1,2,0), facecolors='green', alpha=0.9)
    if outside.any(): ax1.voxels(outside.transpose(1,2,0), facecolors='orange', alpha=0.6)
    if in_corr.any(): ax1.voxels(in_corr.transpose(1,2,0), facecolors='royalblue', alpha=0.6)
    ax1.set_title(title)

    ax2 = fig.add_subplot(142)
    plan = np.zeros((G, G, 3))
    plan[existing[0]] = [0.5, 0.5, 0.5]
    plan[corr[0] & ~existing[0]] = [0.8, 0.9, 1.0]
    plan[in_corr.max(axis=0)] = [0.2, 0.4, 0.8]
    plan[outside.max(axis=0)] = [1.0, 0.6, 0.2]
    plan[access.max(axis=0)] = [0.2, 0.8, 0.2]
    ax2.imshow(plan.transpose(1, 0, 2), origin='lower')
    ax2.set_title('Ground')

    ax3 = fig.add_subplot(143)
    elev = np.zeros((G, G, 3))
    elev[existing.max(axis=1)] = [0.5, 0.5, 0.5]
    elev[in_corr.max(axis=1)] = [0.2, 0.4, 0.8]
    elev[outside.max(axis=1)] = [1.0, 0.6, 0.2]
    elev[access.max(axis=1)] = [0.2, 0.8, 0.2]
    ax3.imshow(elev.transpose(1, 0, 2), origin='lower')
    ax3.set_title('Elevation')

    ax4 = fig.add_subplot(144)
    ax4.axis('off')
    struct = grown[:, config['ch_structure']]
    existing_t = grown[:, config['ch_existing']]

    fill = struct.sum() / ((1 - existing_t).sum() + 1e-8)
    spill = (struct * (1 - corridor)).sum() / (struct.sum() + 1e-8)
    ground = struct[:, :2].sum() / (struct.sum() + 1e-8)
    facade = (struct * (F.max_pool3d(existing_t.unsqueeze(1), 3, 1, 1).squeeze(1) - existing_t).clamp(0,1)).sum() / (struct.sum() + 1e-8)

    txt = (
        'METRICS (Step C)'
        f'Spill: {spill.item()*100:.1f}%'
        f'Fill: {fill.item()*100:.1f}%'
        f'Ground: {ground.item()*100:.1f}%'
        f'Facade: {facade.item()*100:.1f}%'
    )
    ax4.text(0.1, 0.9, txt, transform=ax4.transAxes, fontsize=10, va='top', family='monospace')

    plt.tight_layout()
    plt.show()


## 10. Training

In [None]:
model = UrbanPavilionNCA(CONFIG).to(device)
trainer = StepCTrainer(model, CONFIG, device)

print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')
print('Starting Step C training...')

for epoch in tqdm(range(CONFIG['epochs']), desc='Training'):
    m = trainer.train_epoch(epoch)

    if epoch % CONFIG['log_every'] == 0:
        tqdm.write(
            f"Epoch {epoch:4d} | Loss: {m['total_loss']:.3f} | "
            f"Conn: {m['connectivity']*100:5.1f}% | "
            f"Spill: {m['spill']*100:5.1f}% | "
            f"Fill: {m['fill_ratio']*100:5.1f}% | "
            f"Ground: {m['ground_ratio']*100:5.1f}% | "
            f"Facade: {m['facade_ratio']*100:5.1f}%"
        )

    if epoch > 0 and epoch % CONFIG['viz_every'] == 0:
        visualize_result(model, trainer.scene_gen, CONFIG, device, f'Epoch {epoch}')

    if epoch > 0 and epoch % CONFIG['save_every'] == 0:
        trainer.save_checkpoint(f"{PROJECT_ROOT}/checkpoints/stepc_epoch_{epoch}.pth")

print('Training complete')


## 11. Evaluation

In [None]:
print('Evaluating...')
res = trainer.evaluate(n_samples=50)

print('STEP C EVALUATION')
print(f"Connectivity: {res['avg_connectivity']*100:.1f}%")
print(f"Spill: {res['avg_spill']*100:.1f}%")
print(f"Fill: {res['avg_fill_ratio']*100:.1f}%")
print(f"Ground: {res['avg_ground_ratio']*100:.1f}%")
print(f"Facade: {res['avg_facade_ratio']*100:.1f}%")
