# NB02: Phase 1 - Core Constraints

**Constraint-Based Architectural NCA**

**Version:** 1.6
**Date:** December 2025
**Purpose:** Train Phase 1 - Core Constraints (C1A, C1B, C3A, C4B, C4C + Quality Losses)

---

## Aims

1. Implement ConnectivityLoss (C1A) - No floating voxels
2. Implement CantileverLoss (C1B) - Limit horizontal overhangs
3. Implement SparsityLoss (C3A) - Prevent fill-everything solution
4. Implement AccessReachLoss (C4B) - Structure/walkable must reach access points
5. Implement WalkableCoverageLoss (C4C) - Walkable surfaces must develop
6. Implement ExclusionLoss - No structure inside existing buildings
7. Implement DensityPenalty (SIMP) - Encourage binary outputs
8. Implement TotalVariation - Smooth surfaces
9. Implement DiceLoss for structure growth incentive
10. Implement WalkableDiceLoss for walkable growth incentive
11. Train on easy scenes with all core constraints
12. Save Phase 1 checkpoint

## Success Criteria

- Connectivity rate >95% on easy scenes
- Cantilever compliance >90%
- Fill ratio between 3-25%
- Access reach >80%
- Walkable coverage 20-100%
- Overlap ratio <1% (no structure inside buildings)
- Structure is non-blobby (distributed, not solid mass)
- Model produces vertical growth reaching elevated access points

## Dependencies

- NB01_Foundation (core components)

## Key Fixes

**v1.6:**
- **Perception padding**: REPLICATE padding instead of zero-padding to eliminate boundary gradient artifacts (systematic x=0 bias)
- **Access point randomization**: Full randomization across valid locations (ground in gap, facades facing gap - not rooftops, not back facades)
- **Axis labels**: Correct tensor dimension mapping in visualizations

**v1.5:**
- DensityPenalty (SIMP): `(s * (1-s)).mean()` - penalizes intermediate values
- TotalVariation: Smooth surfaces
- Quality ramp-up: Quality losses start after epoch 100

**v1.4:**
- ExclusionLoss: No structure inside existing buildings
- Hard mask in NCA

---

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/Constraint-NCA'
print(f'Project root: {PROJECT_ROOT}')

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from typing import Dict, Tuple, Optional, List
import json
from datetime import datetime
from tqdm.notebook import tqdm
import os

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
# Set seeds
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [None]:
# Load configuration from NB01
with open(f'{PROJECT_ROOT}/config.json', 'r') as f:
    CONFIG = json.load(f)

# Add Phase 1 specific config
CONFIG.update({
    'phase': 1,
    'epochs': 400,
    'steps_min': 30,
    'steps_max': 50,
    'difficulty': 'easy',
    'access_type': 'mixed',  # Must include elevated access points
    'log_every': 20,
    'viz_every': 100,
    'save_every': 100,
})

print('Configuration loaded')
print(f"Access type: {CONFIG['access_type']} (includes elevated access points)")
print(f"Walkable bias: {CONFIG['walkable_bias']}")

## 2. Copy Foundation Components from NB01

*(In practice, these would be imported from a shared module. For Colab, we include them here.)*

In [None]:
# ============================================================
# PERCEPTION MODULE (from NB01 - with REPLICATE padding fix)
# ============================================================

class Perceive3D(nn.Module):
    """3D Sobel perception for NCA.
    
    Uses REPLICATE padding to avoid boundary artifacts that cause
    systematic growth bias toward grid edges.
    """

    def __init__(self, n_channels: int = 8):
        super().__init__()
        self.n_channels = n_channels

        sobel_x = self._create_sobel_kernel('x')
        sobel_y = self._create_sobel_kernel('y')
        sobel_z = self._create_sobel_kernel('z')
        identity = self._create_identity_kernel()

        kernels = torch.stack([identity, sobel_x, sobel_y, sobel_z], dim=0)
        self.register_buffer('kernels', kernels)

    def _create_sobel_kernel(self, direction: str) -> torch.Tensor:
        derivative = torch.tensor([-1., 0., 1.])
        smoothing = torch.tensor([1., 2., 1.])

        if direction == 'x':
            kernel = torch.einsum('i,j,k->ijk', smoothing, smoothing, derivative)
        elif direction == 'y':
            kernel = torch.einsum('i,j,k->ijk', smoothing, derivative, smoothing)
        elif direction == 'z':
            kernel = torch.einsum('i,j,k->ijk', derivative, smoothing, smoothing)
        return kernel / 16.0

    def _create_identity_kernel(self) -> torch.Tensor:
        kernel = torch.zeros(3, 3, 3)
        kernel[1, 1, 1] = 1.0
        return kernel

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = x.shape
        
        # Use REPLICATE padding to avoid boundary artifacts
        x_padded = F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate')
        
        outputs = []
        for k in range(4):
            kernel = self.kernels[k:k+1].unsqueeze(0).expand(C, 1, 3, 3, 3)
            out = F.conv3d(x_padded, kernel, padding=0, groups=C)
            outputs.append(out)
        return torch.cat(outputs, dim=1)

In [None]:
# ============================================================
# NCA MODEL (v1.4 - with exclusion mask for existing buildings)
# ============================================================

class UrbanPavilionNCA(nn.Module):
    """Neural Cellular Automaton for urban pavilion generation.
    
    v1.4 FIX: Added hard mask to prevent structure/walkable inside existing buildings.
    v1.3 FIX: Added hard mask to ensure walkable ONLY exists where structure exists.
    """

    def __init__(self, config: dict):
        super().__init__()
        self.config = config

        n_channels = config['n_channels']
        hidden_dim = config['hidden_dim']
        perception_dim = n_channels * 4
        n_grown = config['n_grown']

        self.perceive = Perceive3D(n_channels)

        self.update_net = nn.Sequential(
            nn.Conv3d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, n_grown, 1),
        )

        self._init_weights()

    def _init_weights(self):
        gain = self.config['xavier_gain']
        for m in self.update_net:
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform_(m.weight, gain=gain)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        last_layer = self.update_net[-1]
        with torch.no_grad():
            last_layer.bias[0] = self.config['structure_bias']
            last_layer.bias[1] = self.config['walkable_bias']

    def forward(self, state: torch.Tensor, steps: int = 1) -> torch.Tensor:
        for _ in range(steps):
            state = self._step(state)
        return state

    def _step(self, state: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = state.shape
        cfg = self.config

        perception = self.perceive(state)
        delta = self.update_net(perception)

        if self.training:
            fire_mask = (torch.rand(B, 1, D, H, W, device=state.device) < cfg['fire_rate']).float()
            delta = delta * fire_mask

        # Soft gate for walkable updates (gradient-friendly)
        structure = state[:, cfg['ch_structure']:cfg['ch_structure']+1]
        gate = torch.sigmoid(10.0 * (structure - 0.3))

        delta_gated = delta.clone()
        delta_gated = torch.cat([
            delta[:, 0:1],  # structure delta (unchanged)
            delta[:, 1:2] * gate,  # walkable delta (gated)
            delta[:, 2:],  # other channels (unchanged)
        ], dim=1)

        # Apply updates to grown channels
        grown_start = cfg['n_frozen']
        grown_new = state[:, grown_start:] + cfg['update_scale'] * delta_gated
        grown_new = torch.clamp(grown_new, 0.0, 1.0)

        # ====== HARD MASKS (v1.4) - avoid in-place operations ======
        
        # Get existing buildings mask (frozen, doesn't change)
        existing = state[:, cfg['ch_existing']:cfg['ch_existing']+1]
        available_mask = 1.0 - existing  # Where structure CAN exist
        
        # Extract structure and walkable from grown channels
        # grown channels order: structure (0), walkable (1), alive (2), hidden (3)
        struct_new = grown_new[:, 0:1] * available_mask  # Mask out existing buildings
        
        # Walkable: only where structure exists AND not in existing buildings
        walkable_mask = (struct_new > 0.3).float()  # Must be on structure
        walk_new = grown_new[:, 1:2] * walkable_mask
        
        # Reconstruct grown channels (no in-place modification)
        grown_masked = torch.cat([
            struct_new,
            walk_new,
            grown_new[:, 2:],  # alive and hidden unchanged
        ], dim=1)

        # Reconstruct full state (frozen channels unchanged)
        new_state = torch.cat([
            state[:, :grown_start],  # frozen channels
            grown_masked,  # masked grown channels
        ], dim=1)

        return new_state

    def grow(self, seed: torch.Tensor, steps: int = 50) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            return self.forward(seed, steps)

In [None]:
# ============================================================
# SCENE GENERATOR (v1.6 - Randomized access point placement)
# ============================================================

class UrbanSceneGenerator:
    """Generate urban scenes with buildings and access points.
    
    v1.6 Updates:
    - Full randomization of access point placement
    - Access points can be anywhere on facades facing the gap, or on ground
    - NOT inside buildings, NOT on rooftops, NOT on back facades
    - Tracks gap_facing_x for each building to identify valid facade positions
    """

    def __init__(self, config: dict):
        self.config = config
        self.G = config['grid_size']
        self.C = config['n_channels']

    def generate(self, difficulty: str = 'easy', access_type: str = 'mixed',
                 device: str = 'cuda') -> Tuple[torch.Tensor, dict]:
        G = self.G
        state = torch.zeros(1, self.C, G, G, G, device=device)
        state[:, self.config['ch_ground'], 0, :, :] = 1.0

        params = self._get_difficulty_params(difficulty)
        building_info = self._place_buildings(state, params)
        access_info = self._place_access_points(state, access_type, params, building_info)

        metadata = {
            'difficulty': difficulty,
            'access_type': access_type,
            'buildings': building_info,
            'access_points': access_info,
            'gap_width': params['gap_width'],
        }
        return state, metadata

    def _get_difficulty_params(self, difficulty: str) -> dict:
        G = self.G
        if difficulty == 'easy':
            return {
                'n_buildings': 2, 'height_range': (12, 16), 'height_variance': False,
                'width_range': (8, 12), 'gap_width': random.randint(12, 16),
                'n_ground_access': 1, 'n_elevated_access': 1,
            }
        elif difficulty == 'medium':
            return {
                'n_buildings': 2, 'height_range': (10, 20), 'height_variance': True,
                'width_range': (6, 10), 'gap_width': random.randint(8, 12),
                'n_ground_access': random.randint(1, 2), 'n_elevated_access': random.randint(1, 2),
            }
        elif difficulty == 'hard':
            return {
                'n_buildings': random.randint(2, 4), 'height_range': (8, 24), 'height_variance': True,
                'width_range': (5, 8), 'gap_width': random.randint(5, 8),
                'n_ground_access': random.randint(2, 3), 'n_elevated_access': random.randint(2, 3),
            }
        else:
            return {
                'n_buildings': random.randint(2, 4), 'height_range': (8, 24), 'height_variance': True,
                'width_range': (5, 12), 'gap_width': random.randint(5, 16),
                'n_ground_access': random.randint(1, 3), 'n_elevated_access': random.randint(1, 3),
            }

    def _place_buildings(self, state: torch.Tensor, params: dict) -> list:
        G = self.G
        ch = self.config['ch_existing']
        buildings = []
        gap_width = params['gap_width']
        gap_center = G // 2

        # Building 1 (left of gap)
        w1 = random.randint(*params['width_range'])
        d1 = random.randint(G//2, G-2)
        h1 = random.randint(*params['height_range'])
        x1_end = gap_center - gap_width // 2
        x1_start = max(0, x1_end - w1)
        state[:, ch, :h1, :d1, x1_start:x1_end] = 1.0
        buildings.append({
            'x': (x1_start, x1_end), 
            'y': (0, d1), 
            'z': (0, h1),
            'gap_facing_x': x1_end  # Facade facing the gap
        })

        # Building 2 (right of gap)
        w2 = random.randint(*params['width_range'])
        d2 = random.randint(G//2, G-2)
        h2 = h1 if not params['height_variance'] else random.randint(*params['height_range'])
        x2_start = gap_center + gap_width // 2
        x2_end = min(G, x2_start + w2)
        state[:, ch, :h2, :d2, x2_start:x2_end] = 1.0
        buildings.append({
            'x': (x2_start, x2_end), 
            'y': (0, d2), 
            'z': (0, h2),
            'gap_facing_x': x2_start  # Facade facing the gap
        })

        return buildings

    def _place_access_points(self, state: torch.Tensor, access_type: str,
                             params: dict, buildings: list) -> list:
        G = self.G
        ch = self.config['ch_access']
        access_points = []
        
        n_ground = params.get('n_ground_access', 1)
        n_elevated = params.get('n_elevated_access', 1)
        
        # Compute gap boundaries from buildings
        gap_x_min = min(b['gap_facing_x'] for b in buildings if b['x'][1] <= G // 2)
        gap_x_max = max(b['gap_facing_x'] for b in buildings if b['x'][0] >= G // 2)
        
        # Ground access points: anywhere in the gap on the ground
        for i in range(n_ground):
            # Full randomization across the gap width
            x = random.randint(gap_x_min, gap_x_max - 2)
            # Full randomization across Y (depth)
            y = random.randint(0, G - 3)
            z = 0
            
            state[:, ch, z:z+2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': z, 'type': 'ground'})
        
        # Elevated access points: on building facades FACING THE GAP
        for i in range(n_elevated):
            # Choose a random building
            building = random.choice(buildings)
            bx_start, bx_end = building['x']
            by_start, by_end = building['y']
            bz_max = building['z'][1]
            gap_facing_x = building['gap_facing_x']
            
            # Determine if this is left or right building
            is_left_building = bx_end <= G // 2
            
            # Random height on the facade (not ground level, not rooftop)
            z = random.randint(2, bz_max - 2)
            
            # Random depth along the facade
            y = random.randint(by_start, by_end - 2)
            
            # X position: just outside the facade facing the gap
            if is_left_building:
                x = gap_facing_x  # Just outside right edge (facing gap)
            else:
                x = gap_facing_x - 1  # Just outside left edge (facing gap)
            
            # Clamp to valid range
            x = max(0, min(G - 2, x))
            y = max(0, min(G - 2, y))
            z = max(1, min(G - 2, z))
            
            state[:, ch, z:z+2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': z, 'type': 'elevated'})

        return access_points

    def batch(self, difficulty: str, access_type: str, batch_size: int, device: str) -> torch.Tensor:
        scenes = [self.generate(difficulty, access_type, device)[0] for _ in range(batch_size)]
        return torch.cat(scenes, dim=0)


# Test scene generator
print('Testing scene generator with randomized access points (v1.6)...')
test_gen = UrbanSceneGenerator(CONFIG)
test_scene, test_meta = test_gen.generate('easy', 'mixed', device)
print(f"  Access points: {len(test_meta['access_points'])}")
for ap in test_meta['access_points']:
    print(f"    - {ap['type']}: z={ap['z']}, y={ap['y']}, x={ap['x']}")
print('  ✓ Scene generator working')

## 3. Loss Functions

In [None]:
class ConnectivityLoss(nn.Module):
    """
    Constraint 1A: No floating voxels.

    All structure voxels must be connected to support (ground + buildings)
    via a continuous path of solid material.

    Uses differentiable flood-fill via iterative max-pooling.
    """

    def __init__(self, threshold: float = 0.3, iterations: int = 32):
        super().__init__()
        self.threshold = threshold
        self.iterations = iterations

    def forward(self, structure: torch.Tensor, support: torch.Tensor) -> torch.Tensor:
        struct_soft = torch.sigmoid(10 * (structure - self.threshold))
        connected = support.clone()

        for _ in range(self.iterations):
            dilated = F.max_pool3d(
                connected.unsqueeze(1), kernel_size=3, stride=1, padding=1
            ).squeeze(1)
            new_connected = torch.max(connected, dilated * struct_soft)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected

        disconnected = structure * (1.0 - connected)
        loss = disconnected.sum() / (structure.sum() + 1e-8)
        return loss


class CantileverLoss(nn.Module):
    """Constraint 1B: Limit horizontal overhangs."""

    def __init__(self, max_overhang: int = 3, threshold: float = 0.3):
        super().__init__()
        self.max_overhang = max_overhang
        self.threshold = threshold

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        B, D, H, W = structure.shape
        N = self.max_overhang
        total_loss = 0.0
        count = 0

        for z in range(N, D):
            layer = structure[:, z]
            support_volume = structure[:, max(0, z-N):z]
            support_max = support_volume.max(dim=1)[0]
            support_dilated = F.max_pool2d(
                support_max.unsqueeze(1), kernel_size=3, stride=1, padding=1
            ).squeeze(1)
            has_support = torch.sigmoid(10 * (support_dilated - self.threshold))
            unsupported = layer * (1.0 - has_support)
            total_loss += unsupported.mean()
            count += 1

        return total_loss / max(1, count)


class DiceLoss(nn.Module):
    """Growth incentive via Dice loss."""

    def __init__(self, smooth: float = 1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, structure: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        intersection = (structure * target).sum()
        dice = (2 * intersection + self.smooth) / (
            structure.sum() + target.sum() + self.smooth
        )
        return 1.0 - dice


class SparsityLoss(nn.Module):
    """
    Constraint 3A (Baseline): Limit total volume.
    Prevents trivial "fill everything" solution.
    """

    def __init__(self, max_ratio: float = 0.25, min_ratio: float = 0.03):
        super().__init__()
        self.max_ratio = max_ratio
        self.min_ratio = min_ratio

    def forward(self, structure: torch.Tensor, available: torch.Tensor) -> torch.Tensor:
        ratio = structure.sum() / (available.sum() + 1e-8)
        over_penalty = F.relu(ratio - self.max_ratio)
        under_penalty = 0.3 * F.relu(self.min_ratio - ratio)
        return over_penalty + under_penalty


class ExclusionLoss(nn.Module):
    """
    Prevent structure from growing inside existing buildings.
    
    Structure should NEVER overlap with existing buildings. This is a hard
    constraint that should always be satisfied.
    """

    def __init__(self):
        super().__init__()

    def forward(self, structure: torch.Tensor, existing: torch.Tensor) -> torch.Tensor:
        overlap = (structure * existing).sum()
        loss = overlap / (structure.sum() + 1e-8)
        return loss


# ============================================================
# QUALITY LOSSES (NEW v1.5 - from reference implementation)
# ============================================================

class DensityPenalty(nn.Module):
    """
    SIMP-inspired density penalty for binary outputs.
    
    Penalizes intermediate values (between 0 and 1), encouraging the NCA
    to commit to either solid (1) or void (0). This prevents blobby,
    fuzzy structures.
    
    Formula: mean(s * (1 - s))
    - Maximum at s=0.5 (penalty = 0.25)
    - Zero at s=0 or s=1
    """

    def __init__(self):
        super().__init__()

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        # s * (1-s) is maximum at 0.5, zero at 0 and 1
        penalty = (structure * (1.0 - structure)).mean()
        return penalty


class TotalVariation3D(nn.Module):
    """
    Total Variation loss for smooth 3D surfaces.
    
    Penalizes differences between neighboring voxels, encouraging
    smoother, more coherent surfaces instead of noisy patterns.
    
    Computes sum of absolute differences along D, H, W axes.
    """

    def __init__(self):
        super().__init__()

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        # Structure shape: (B, D, H, W)
        # Compute differences along each axis
        tv_d = (structure[:, 1:, :, :] - structure[:, :-1, :, :]).abs().mean()
        tv_h = (structure[:, :, 1:, :] - structure[:, :, :-1, :]).abs().mean()
        tv_w = (structure[:, :, :, 1:] - structure[:, :, :, :-1]).abs().mean()
        return tv_d + tv_h + tv_w


class AccessReachLoss(nn.Module):
    """
    Constraint 4B (Baseline): Structure and walkable must reach access points.
    
    Critical for forcing vertical growth when elevated access points exist.
    """

    def __init__(self, dilation_radius: int = 3):
        super().__init__()
        self.dilation_radius = dilation_radius

    def forward(self, structure: torch.Tensor, walkable: torch.Tensor, 
                access: torch.Tensor) -> torch.Tensor:
        # Dilate access points to create reach zone
        kernel_size = 2 * self.dilation_radius + 1
        access_dilated = F.max_pool3d(
            access.unsqueeze(1), kernel_size=kernel_size, stride=1, 
            padding=self.dilation_radius
        ).squeeze(1)
        
        # Structure should reach access zones
        struct_reach = (structure * access_dilated).sum() / (access_dilated.sum() + 1e-8)
        
        # Walkable should also reach access zones  
        walk_reach = (walkable * access_dilated).sum() / (access_dilated.sum() + 1e-8)
        
        # Combined loss (both need to reach)
        loss = (1.0 - struct_reach) + 0.5 * (1.0 - walk_reach)
        return loss


class WalkableCoverageLoss(nn.Module):
    """
    Constraint 4C (Baseline): Walkable surfaces must develop on structure.
    
    Ensures the model creates traversable surfaces, not just structural blobs.
    """

    def __init__(self, min_coverage: float = 0.2):
        super().__init__()
        self.min_coverage = min_coverage

    def forward(self, structure: torch.Tensor, walkable: torch.Tensor) -> torch.Tensor:
        # Coverage = walkable / structure
        coverage = walkable.sum() / (structure.sum() + 1e-8)
        
        # Penalize if coverage is below minimum
        loss = F.relu(self.min_coverage - coverage)
        return loss


class WalkableDiceLoss(nn.Module):
    """
    Direct growth incentive for walkable channel.
    
    Encourages walkable to grow where structure exists, especially near access points.
    This provides positive gradient for walkable growth, not just penalty for absence.
    """

    def __init__(self, smooth: float = 1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, walkable: torch.Tensor, structure: torch.Tensor,
                target: torch.Tensor) -> torch.Tensor:
        # Walkable should grow where structure exists AND target is active
        # target is the dilated access zone
        walkable_target = structure * target  # Only where structure exists near access
        
        intersection = (walkable * walkable_target).sum()
        dice = (2 * intersection + self.smooth) / (
            walkable.sum() + walkable_target.sum() + self.smooth
        )
        return 1.0 - dice


# Test losses
print('Testing loss functions...')

scene_gen = UrbanSceneGenerator(CONFIG)
scene, meta = scene_gen.generate('easy', 'mixed', device)

existing = scene[:, CONFIG['ch_existing']]
ground = scene[:, CONFIG['ch_ground']]
access = scene[:, CONFIG['ch_access']]
support = torch.clamp(ground + existing, 0, 1)
available = 1.0 - existing

fake_structure = torch.rand_like(ground) * 0.5
fake_walkable = torch.rand_like(ground) * 0.3

conn_loss = ConnectivityLoss()
L_conn = conn_loss(fake_structure, support)
print(f'  Connectivity loss: {L_conn.item():.4f}')

cant_loss = CantileverLoss()
L_cant = cant_loss(fake_structure)
print(f'  Cantilever loss: {L_cant.item():.4f}')

dice_loss = DiceLoss()
target = F.max_pool3d(access.unsqueeze(1), 7, 1, 3).squeeze(1)
L_dice = dice_loss(fake_structure, target)
print(f'  Dice loss: {L_dice.item():.4f}')

sparse_loss = SparsityLoss()
L_sparse = sparse_loss(fake_structure, available)
print(f'  Sparsity loss: {L_sparse.item():.4f}')

excl_loss = ExclusionLoss()
L_excl = excl_loss(fake_structure, existing)
print(f'  Exclusion loss: {L_excl.item():.4f}')

# NEW v1.5: Quality losses
density_loss = DensityPenalty()
L_density = density_loss(fake_structure)
print(f'  Density penalty: {L_density.item():.4f}')

tv_loss = TotalVariation3D()
L_tv = tv_loss(fake_structure)
print(f'  Total variation: {L_tv.item():.4f}')

reach_loss = AccessReachLoss()
L_reach = reach_loss(fake_structure, fake_walkable, access)
print(f'  Access reach loss: {L_reach.item():.4f}')

walk_cov_loss = WalkableCoverageLoss()
L_walk_cov = walk_cov_loss(fake_structure, fake_walkable)
print(f'  Walkable coverage loss: {L_walk_cov.item():.4f}')

walk_dice_loss = WalkableDiceLoss()
L_walk_dice = walk_dice_loss(fake_walkable, fake_structure, target)
print(f'  Walkable dice loss: {L_walk_dice.item():.4f}')

print('  ✓ All losses working')
print(f'\nAccess points in test scene:')
for ap in meta['access_points']:
    print(f"  - {ap['type']}: z={ap['z']}")

## 4. Phase 1 Training Loop

In [None]:
class Phase1Trainer:
    """
    Phase 1 Trainer: Core Constraints (v1.5)

    Constraints:
        - C1A: Connectivity (no floating voxels)
        - C1B: Cantilever (limit overhangs)
        - C3A: Sparsity (prevent fill-everything)
        - C4B: Access Reach (structure/walkable must reach access points)
        - C4C: Walkable Coverage (walkable must develop on structure)
        - Exclusion: No structure inside existing buildings
        - Dice: Growth incentive for structure
        - Walkable Dice: Growth incentive for walkable
        
    Quality Losses (NEW v1.5):
        - DensityPenalty: Encourage binary (0 or 1) outputs
        - TotalVariation: Smooth surfaces
        - Ramp-up: Quality losses start after epoch 100, increase gradually
    """

    def __init__(self, model: nn.Module, config: dict, device: str):
        self.model = model
        self.config = config
        self.device = device

        # Loss functions
        self.dice_loss = DiceLoss()
        self.walk_dice_loss = WalkableDiceLoss()
        self.conn_loss = ConnectivityLoss()
        self.cant_loss = CantileverLoss()
        self.sparse_loss = SparsityLoss(max_ratio=0.25, min_ratio=0.03)
        self.excl_loss = ExclusionLoss()
        self.reach_loss = AccessReachLoss(dilation_radius=3)
        self.walk_cov_loss = WalkableCoverageLoss(min_coverage=0.2)
        
        # NEW v1.5: Quality losses
        self.density_loss = DensityPenalty()
        self.tv_loss = TotalVariation3D()

        # Loss weights (base weights)
        self.weights = {
            'dice': 10.0,
            'walk_dice': 8.0,
            'connectivity': 10.0,
            'cantilever': 5.0,
            'sparsity': 20.0,
            'exclusion': 50.0,
            'access_reach': 10.0,
            'walkable_cov': 5.0,
            # NEW v1.5: Quality weights (will be ramped up)
            'density': 15.0,  # Encourage binary outputs
            'tv': 2.0,        # Smooth surfaces (lower weight - don't over-smooth)
        }
        
        # NEW v1.5: Quality loss ramp-up schedule
        self.quality_start_epoch = 100   # Start quality losses after this epoch
        self.quality_ramp_epochs = 100   # Ramp up over this many epochs

        # Optimizer
        self.optimizer = torch.optim.Adam(
            model.parameters(),
            lr=config['lr_initial']
        )

        # Scene generator
        self.scene_gen = UrbanSceneGenerator(config)

        # History
        self.history = []

    def _get_quality_weight(self, epoch: int) -> float:
        """
        Compute quality loss weight multiplier based on epoch.
        
        Ramp-up schedule:
        - Before quality_start_epoch: 0.0
        - During ramp: linear from 0.0 to 1.0
        - After ramp: 1.0
        """
        if epoch < self.quality_start_epoch:
            return 0.0
        elif epoch < self.quality_start_epoch + self.quality_ramp_epochs:
            progress = (epoch - self.quality_start_epoch) / self.quality_ramp_epochs
            return progress
        else:
            return 1.0

    def train_epoch(self, epoch: int) -> dict:
        """Train for one epoch."""
        self.model.train()
        cfg = self.config

        # Generate batch of scenes (with elevated access points)
        seeds = self.scene_gen.batch(
            cfg['difficulty'], cfg['access_type'],
            cfg['batch_size'], self.device
        )

        # Random number of steps
        steps = random.randint(cfg['steps_min'], cfg['steps_max'])

        # Forward pass
        final = self.model(seeds, steps=steps)

        # Extract channels
        existing = final[:, cfg['ch_existing']]
        ground = final[:, cfg['ch_ground']]
        access = final[:, cfg['ch_access']]
        structure = final[:, cfg['ch_structure']]
        walkable = final[:, cfg['ch_walkable']]

        support = torch.clamp(ground + existing, 0, 1)
        available = 1.0 - existing
        
        # Target for Dice: dilated access, but EXCLUDE existing buildings
        target_raw = F.max_pool3d(access.unsqueeze(1), 7, 1, 3).squeeze(1)
        target = target_raw * available  # Don't incentivize growth inside buildings

        # Compute core losses
        L_dice = self.dice_loss(structure, target)
        L_walk_dice = self.walk_dice_loss(walkable, structure, target)
        L_conn = self.conn_loss(structure, support)
        L_cant = self.cant_loss(structure)
        L_sparse = self.sparse_loss(structure, available)
        L_excl = self.excl_loss(structure, existing)
        L_reach = self.reach_loss(structure, walkable, access)
        L_walk_cov = self.walk_cov_loss(structure, walkable)
        
        # NEW v1.5: Compute quality losses
        L_density = self.density_loss(structure)
        L_tv = self.tv_loss(structure)
        
        # Get quality weight multiplier for this epoch
        quality_mult = self._get_quality_weight(epoch)

        # Total loss (with ramped quality losses)
        total_loss = (
            self.weights['dice'] * L_dice +
            self.weights['walk_dice'] * L_walk_dice +
            self.weights['connectivity'] * L_conn +
            self.weights['cantilever'] * L_cant +
            self.weights['sparsity'] * L_sparse +
            self.weights['exclusion'] * L_excl +
            self.weights['access_reach'] * L_reach +
            self.weights['walkable_cov'] * L_walk_cov +
            # Quality losses with ramp-up
            quality_mult * self.weights['density'] * L_density +
            quality_mult * self.weights['tv'] * L_tv
        )

        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg['grad_clip'])
        self.optimizer.step()

        # Compute metrics for monitoring
        fill_ratio = structure.sum() / (available.sum() + 1e-8)
        walk_coverage = walkable.sum() / (structure.sum() + 1e-8)
        
        # Track overlap with existing buildings
        overlap_ratio = (structure * existing).sum() / (structure.sum() + 1e-8)

        # Record metrics
        metrics = {
            'epoch': epoch,
            'total_loss': total_loss.item(),
            'dice': L_dice.item(),
            'walk_dice': L_walk_dice.item(),
            'connectivity': L_conn.item(),
            'cantilever': L_cant.item(),
            'sparsity': L_sparse.item(),
            'exclusion': L_excl.item(),
            'access_reach': L_reach.item(),
            'walkable_cov': L_walk_cov.item(),
            # NEW v1.5: Quality metrics
            'density': L_density.item(),
            'tv': L_tv.item(),
            'quality_mult': quality_mult,
            # Other metrics
            'fill_ratio': fill_ratio.item(),
            'walk_coverage': walk_coverage.item(),
            'overlap_ratio': overlap_ratio.item(),
            'steps': steps,
            'structure_mean': structure.mean().item(),
            'walkable_mean': walkable.mean().item(),
        }

        self.history.append(metrics)
        return metrics

    def evaluate(self, n_samples: int = 20) -> dict:
        """Evaluate model on test scenes."""
        self.model.eval()
        cfg = self.config

        results = []
        with torch.no_grad():
            for _ in range(n_samples):
                scene, meta = self.scene_gen.generate(
                    cfg['difficulty'], cfg['access_type'], self.device
                )

                # Grow structure
                grown = self.model.grow(scene, steps=50)

                # Extract channels
                existing = grown[:, cfg['ch_existing']]
                ground = grown[:, cfg['ch_ground']]
                access = grown[:, cfg['ch_access']]
                structure = grown[:, cfg['ch_structure']]
                walkable = grown[:, cfg['ch_walkable']]
                support = torch.clamp(ground + existing, 0, 1)
                available = 1.0 - existing

                # Compute connectivity rate
                struct_binary = (structure > 0.5).float()
                connected = support.clone()
                for _ in range(32):
                    dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
                    connected = torch.max(connected, dilated * struct_binary)

                total_struct = struct_binary.sum().item()
                connected_struct = (connected * struct_binary).sum().item()
                conn_rate = connected_struct / (total_struct + 1e-8)

                # Check cantilever compliance
                cant_ok = self._check_cantilever(struct_binary)

                # Compute fill ratio
                fill_ratio = struct_binary.sum().item() / (available.sum().item() + 1e-8)

                # Compute access reach
                access_dilated = F.max_pool3d(access.unsqueeze(1), 7, 1, 3).squeeze(1)
                struct_reach = ((structure > 0.5).float() * access_dilated).sum().item()
                access_reach = struct_reach / (access_dilated.sum().item() + 1e-8)

                # Compute walkable coverage
                walk_coverage = walkable.sum().item() / (total_struct + 1e-8)

                # Compute overlap with existing buildings
                overlap = (struct_binary * existing).sum().item()
                overlap_ratio = overlap / (total_struct + 1e-8)
                
                # NEW v1.5: Compute density penalty (for evaluation)
                density_penalty = (structure * (1.0 - structure)).mean().item()

                # Check if elevated access was reached
                elevated_reached = False
                for ap in meta['access_points']:
                    if ap['type'] == 'elevated':
                        z, y, x = ap['z'], ap['y'], ap['x']
                        region = structure[0, max(0,z-2):z+3, max(0,y-2):y+3, max(0,x-2):x+3]
                        if region.max() > 0.5:
                            elevated_reached = True
                            break

                results.append({
                    'connectivity_rate': conn_rate,
                    'cantilever_ok': cant_ok,
                    'voxel_count': total_struct,
                    'fill_ratio': fill_ratio,
                    'access_reach': access_reach,
                    'walkable_coverage': walk_coverage,
                    'overlap_ratio': overlap_ratio,
                    'density_penalty': density_penalty,
                    'elevated_reached': elevated_reached,
                })

        # Aggregate results
        return {
            'connectivity_rate': np.mean([r['connectivity_rate'] for r in results]),
            'cantilever_compliance': np.mean([r['cantilever_ok'] for r in results]),
            'avg_voxels': np.mean([r['voxel_count'] for r in results]),
            'avg_fill_ratio': np.mean([r['fill_ratio'] for r in results]),
            'avg_access_reach': np.mean([r['access_reach'] for r in results]),
            'avg_walkable_coverage': np.mean([r['walkable_coverage'] for r in results]),
            'avg_overlap_ratio': np.mean([r['overlap_ratio'] for r in results]),
            'avg_density_penalty': np.mean([r['density_penalty'] for r in results]),
            'elevated_reach_rate': np.mean([r['elevated_reached'] for r in results]),
            'n_samples': n_samples,
        }

    def _check_cantilever(self, structure: torch.Tensor, max_overhang: int = 3) -> float:
        """Check cantilever compliance (hard threshold)."""
        B, D, H, W = structure.shape
        violations = 0
        total = 0

        for z in range(max_overhang, D):
            layer = structure[:, z]
            support_max = structure[:, max(0, z-max_overhang):z].max(dim=1)[0]
            support_dilated = F.max_pool2d(support_max.unsqueeze(1), 3, 1, 1).squeeze(1)

            unsupported = layer * (1 - (support_dilated > 0.5).float())
            violations += unsupported.sum().item()
            total += layer.sum().item()

        compliance = 1.0 - (violations / (total + 1e-8))
        return compliance

    def save_checkpoint(self, path: str):
        """Save model checkpoint."""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config,
            'history': self.history,
        }, path)
        print(f'Checkpoint saved to {path}')

## 5. Visualization Utilities

In [None]:
def visualize_result(model: nn.Module, scene_gen: UrbanSceneGenerator,
                     config: dict, device: str, title: str = 'Result'):
    """Visualize a grown structure with walkable surfaces.
    
    Axis labels reflect tensor dimensions after transpose(1,2,0):
    - Tensor [Z, Y, X] -> Plot axes [Y, X, Z]
    - Y = depth (into scene), X = left-right, Z = height
    """
    model.eval()

    scene, meta = scene_gen.generate(config['difficulty'], config['access_type'], device)

    with torch.no_grad():
        grown = model.grow(scene, steps=50)

    # Extract for visualization
    state = grown[0].cpu().numpy()
    existing = state[config['ch_existing']] > 0.5
    structure = state[config['ch_structure']] > 0.5
    walkable = state[config['ch_walkable']] > 0.3
    access = state[config['ch_access']] > 0.5
    
    # Check for overlap (structure inside existing buildings)
    overlap = structure & existing

    fig = plt.figure(figsize=(20, 5))

    # 3D view - transpose(1,2,0) maps [Z,Y,X] -> [Y,X,Z]
    ax1 = fig.add_subplot(141, projection='3d')
    if existing.any():
        ax1.voxels(existing.transpose(1, 2, 0), facecolors='gray', alpha=0.3)
    if access.any():
        ax1.voxels(access.transpose(1, 2, 0), facecolors='green', alpha=0.9)
    if overlap.any():
        # Show overlap in RED (this is a problem!)
        ax1.voxels(overlap.transpose(1, 2, 0), facecolors='red', alpha=0.9)
    if structure.any():
        # Show structure that's NOT overlapping
        struct_clean = structure & ~existing
        if struct_clean.any():
            ax1.voxels(struct_clean.transpose(1, 2, 0), facecolors='royalblue', alpha=0.6)
    if walkable.any():
        walk_only = walkable & ~existing & ~structure
        if walk_only.any():
            ax1.voxels(walk_only.transpose(1, 2, 0), facecolors='yellow', alpha=0.4)
    ax1.set_title(f'{title} (3D)')
    # Correct axis labels: tensor [Z,Y,X].transpose(1,2,0) -> plot [Y,X,Z]
    ax1.set_xlabel('Y (depth)')
    ax1.set_ylabel('X (left-right)')
    ax1.set_zlabel('Z (height)')

    # Plan view (top-down) - max over Z axis
    ax2 = fig.add_subplot(142)
    G = state.shape[1]
    plan = np.zeros((G, G, 3))
    plan[existing.max(axis=0)] = [0.5, 0.5, 0.5]  # Gray for existing
    struct_clean = structure & ~existing
    plan[struct_clean.max(axis=0)] = [0.2, 0.4, 0.8]  # Blue for clean structure
    plan[overlap.max(axis=0)] = [1.0, 0.0, 0.0]  # RED for overlap (bad!)
    plan[walkable.max(axis=0) & ~existing.max(axis=0)] = [0.8, 0.8, 0.2]  # Yellow for walkable
    plan[access.max(axis=0)] = [0.2, 0.8, 0.2]  # Green for access
    ax2.imshow(plan.transpose(1, 0, 2), origin='lower')
    ax2.set_title(f'{title} (Plan View)')
    ax2.set_xlabel('Y (depth)')
    ax2.set_ylabel('X (left-right)')

    # Side view (elevation) - max over Y axis
    ax3 = fig.add_subplot(143)
    elev = np.zeros((G, G, 3))
    elev[existing.max(axis=1)] = [0.5, 0.5, 0.5]
    elev[struct_clean.max(axis=1)] = [0.2, 0.4, 0.8]
    elev[overlap.max(axis=1)] = [1.0, 0.0, 0.0]  # RED for overlap
    elev[walkable.max(axis=1) & ~existing.max(axis=1)] = [0.8, 0.8, 0.2]
    elev[access.max(axis=1)] = [0.2, 0.8, 0.2]
    ax3.imshow(elev.transpose(1, 0, 2), origin='lower')
    ax3.set_title(f'{title} (Elevation)')
    ax3.set_xlabel('X (left-right)')
    ax3.set_ylabel('Z (height)')

    # Cross-section at mid-Y
    ax4 = fig.add_subplot(144)
    mid_y = G // 2
    cross = np.zeros((G, G, 3))
    cross[existing[:, mid_y, :]] = [0.5, 0.5, 0.5]
    cross[struct_clean[:, mid_y, :]] = [0.2, 0.4, 0.8]
    cross[overlap[:, mid_y, :]] = [1.0, 0.0, 0.0]
    cross[walkable[:, mid_y, :] & ~existing[:, mid_y, :]] = [0.8, 0.8, 0.2]
    cross[access[:, mid_y, :]] = [0.2, 0.8, 0.2]
    ax4.imshow(cross.transpose(1, 0, 2), origin='lower')
    ax4.set_title(f'{title} (Cross-section Y={mid_y})')
    ax4.set_xlabel('X (left-right)')
    ax4.set_ylabel('Z (height)')

    plt.tight_layout()
    plt.show()

    # Print stats
    struct_count = structure.sum()
    walk_count = walkable.sum()
    overlap_count = overlap.sum()
    walk_cov = walk_count / (struct_count + 1e-8)
    overlap_pct = overlap_count / (struct_count + 1e-8) * 100
    
    # Compute density penalty for this result
    structure_soft = state[config['ch_structure']]
    density_penalty = (structure_soft * (1.0 - structure_soft)).mean()
    
    print(f"  Structure: {struct_count:.0f} voxels, Walkable: {walk_count:.0f}, Coverage: {walk_cov:.1%}")
    print(f"  Density penalty: {density_penalty:.4f} (lower = more binary)")
    if overlap_count > 0:
        print(f"  WARNING: OVERLAP: {overlap_count:.0f} voxels ({overlap_pct:.1f}%) inside buildings!")
    else:
        print(f"  No overlap with existing buildings")
    print(f"  Access points: {len(meta['access_points'])}")
    for ap in meta['access_points']:
        print(f"    - {ap['type']}: z={ap['z']}, y={ap['y']}, x={ap['x']}")

    return grown, meta


def plot_training_curves(history: list):
    """Plot training loss curves including quality metrics."""
    epochs = [h['epoch'] for h in history]

    fig, axes = plt.subplots(2, 6, figsize=(26, 8))

    # Total loss
    axes[0, 0].plot(epochs, [h['total_loss'] for h in history])
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_yscale('log')

    # Structural losses
    axes[0, 1].plot(epochs, [h['dice'] for h in history], label='Dice (struct)')
    axes[0, 1].plot(epochs, [h.get('walk_dice', 0) for h in history], label='Dice (walk)')
    axes[0, 1].plot(epochs, [h['connectivity'] for h in history], label='Connectivity')
    axes[0, 1].plot(epochs, [h['cantilever'] for h in history], label='Cantilever')
    axes[0, 1].set_title('Core Losses')
    axes[0, 1].legend()
    axes[0, 1].set_xlabel('Epoch')

    # Sparsity loss
    axes[0, 2].plot(epochs, [h['sparsity'] for h in history], color='purple')
    axes[0, 2].set_title('Sparsity Loss')
    axes[0, 2].set_xlabel('Epoch')

    # Exclusion loss
    axes[0, 3].plot(epochs, [h.get('exclusion', 0) for h in history], color='red')
    axes[0, 3].set_title('Exclusion Loss')
    axes[0, 3].set_xlabel('Epoch')

    # Density penalty
    axes[0, 4].plot(epochs, [h.get('density', 0) for h in history], color='magenta')
    axes[0, 4].axhline(y=0.25, color='r', linestyle='--', alpha=0.5, label='Max (all 0.5)')
    axes[0, 4].set_title('Density Penalty (lower=binary)')
    axes[0, 4].set_xlabel('Epoch')
    axes[0, 4].legend()

    # Total Variation
    axes[0, 5].plot(epochs, [h.get('tv', 0) for h in history], color='cyan')
    axes[0, 5].set_title('Total Variation')
    axes[0, 5].set_xlabel('Epoch')

    # Fill ratio
    axes[1, 0].plot(epochs, [h['fill_ratio'] for h in history], color='orange')
    axes[1, 0].axhline(y=0.25, color='r', linestyle='--', label='Max (25%)')
    axes[1, 0].axhline(y=0.03, color='g', linestyle='--', label='Min (3%)')
    axes[1, 0].set_title('Fill Ratio')
    axes[1, 0].legend()
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylim(0, 0.5)

    # Walkable coverage
    axes[1, 1].plot(epochs, [h['walk_coverage'] for h in history], color='gold')
    axes[1, 1].axhline(y=0.2, color='g', linestyle='--', label='Min (20%)')
    axes[1, 1].set_title('Walkable Coverage')
    axes[1, 1].legend()
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylim(0, 1.5)

    # Overlap ratio
    axes[1, 2].plot(epochs, [h.get('overlap_ratio', 0) for h in history], color='red')
    axes[1, 2].axhline(y=0.0, color='g', linestyle='--', label='Target (0%)')
    axes[1, 2].set_title('Overlap Ratio (should be 0!)')
    axes[1, 2].legend()
    axes[1, 2].set_xlabel('Epoch')

    # Structure and walkable means
    axes[1, 3].plot(epochs, [h['structure_mean'] for h in history], label='Structure')
    axes[1, 3].plot(epochs, [h['walkable_mean'] for h in history], label='Walkable')
    axes[1, 3].set_title('Channel Means')
    axes[1, 3].legend()
    axes[1, 3].set_xlabel('Epoch')

    # Access reach loss
    axes[1, 4].plot(epochs, [h['access_reach'] for h in history], color='orange')
    axes[1, 4].set_title('Access Reach Loss')
    axes[1, 4].set_xlabel('Epoch')

    # Quality weight multiplier
    axes[1, 5].plot(epochs, [h.get('quality_mult', 0) for h in history], color='green')
    axes[1, 5].set_title('Quality Weight Ramp')
    axes[1, 5].set_xlabel('Epoch')
    axes[1, 5].set_ylim(-0.1, 1.1)

    plt.tight_layout()
    plt.show()

## 6. Training

In [None]:
# Initialize model and trainer
model = UrbanPavilionNCA(CONFIG).to(device)
trainer = Phase1Trainer(model, CONFIG, device)

print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f'Training for {CONFIG["epochs"]} epochs')
print(f'Difficulty: {CONFIG["difficulty"]}')

In [None]:
# Visualize before training
print('Before training:')
visualize_result(model, trainer.scene_gen, CONFIG, device, 'Before Training')

In [None]:
# Training loop
print('\nStarting Phase 1 Training (Core Constraints v1.5)...')
print('='*70)
print('Constraints: C1A (Connectivity), C1B (Cantilever), C3A (Sparsity),')
print('             C4B (Access Reach), C4C (Walkable Coverage)')
print('             Exclusion (no structure inside buildings)')
print('NEW v1.5:    DensityPenalty (binary outputs), TotalVariation (smooth)')
print('='*70)
print(f"Walkable bias: {CONFIG['walkable_bias']}")
print(f"Quality losses: start epoch {trainer.quality_start_epoch}, ramp over {trainer.quality_ramp_epochs} epochs")
print('='*70)

for epoch in tqdm(range(CONFIG['epochs']), desc='Training'):
    metrics = trainer.train_epoch(epoch)

    # Log progress
    if epoch % CONFIG['log_every'] == 0:
        overlap_pct = metrics['overlap_ratio'] * 100
        overlap_str = f"Ovlp: {overlap_pct:.1f}%" if overlap_pct > 0.1 else "Ovlp: OK"
        quality_str = f"Dens: {metrics['density']:.3f}" if metrics['quality_mult'] > 0 else "Qual: OFF"
        tqdm.write(
            f"Epoch {epoch:4d} | Loss: {metrics['total_loss']:.3f} | "
            f"Fill: {metrics['fill_ratio']:.1%} | Walk: {metrics['walk_coverage']:.1%} | "
            f"{overlap_str} | {quality_str}"
        )

    # Visualize progress
    if epoch > 0 and epoch % CONFIG['viz_every'] == 0:
        visualize_result(model, trainer.scene_gen, CONFIG, device, f'Epoch {epoch}')

    # Save checkpoint
    if epoch > 0 and epoch % CONFIG['save_every'] == 0:
        trainer.save_checkpoint(f'{PROJECT_ROOT}/checkpoints/phase1_epoch{epoch}.pth')

print('='*70)
print('Training complete!')

## 7. Evaluation

In [None]:
# Plot training curves
plot_training_curves(trainer.history)

In [None]:
# Evaluate on test scenes
print('Evaluating on test scenes...')
eval_results = trainer.evaluate(n_samples=50)

print('\n' + '='*70)
print('PHASE 1 EVALUATION RESULTS (v1.5)')
print('='*70)
print(f"Connectivity Rate:     {eval_results['connectivity_rate']*100:.1f}% (target: >95%)")
print(f"Cantilever Compliance: {eval_results['cantilever_compliance']*100:.1f}% (target: >90%)")
print(f"Fill Ratio:            {eval_results['avg_fill_ratio']*100:.1f}% (target: 3-25%)")
print(f"Access Reach:          {eval_results['avg_access_reach']*100:.1f}% (target: >80%)")
print(f"Walkable Coverage:     {eval_results['avg_walkable_coverage']*100:.1f}% (target: >20%)")
print(f"Overlap Ratio:         {eval_results['avg_overlap_ratio']*100:.1f}% (target: 0%)")
print(f"Density Penalty:       {eval_results['avg_density_penalty']:.4f} (target: <0.1 = binary)")
print(f"Elevated Reach Rate:   {eval_results['elevated_reach_rate']*100:.1f}% (target: >70%)")
print(f"Average Voxel Count:   {eval_results['avg_voxels']:.0f}")
print('='*70)

# Check success criteria
conn_pass = eval_results['connectivity_rate'] > 0.95
cant_pass = eval_results['cantilever_compliance'] > 0.90
fill_pass = 0.03 < eval_results['avg_fill_ratio'] < 0.25
reach_pass = eval_results['avg_access_reach'] > 0.80
walk_pass = eval_results['avg_walkable_coverage'] > 0.20
overlap_pass = eval_results['avg_overlap_ratio'] < 0.01
density_pass = eval_results['avg_density_penalty'] < 0.1  # Binary-ish outputs
elev_pass = eval_results['elevated_reach_rate'] > 0.70

print(f"\nConnectivity:    {'PASS' if conn_pass else 'FAIL'}")
print(f"Cantilever:      {'PASS' if cant_pass else 'FAIL'}")
print(f"Fill Ratio:      {'PASS' if fill_pass else 'FAIL'}")
print(f"Access Reach:    {'PASS' if reach_pass else 'FAIL'}")
print(f"Walkable Cov:    {'PASS' if walk_pass else 'FAIL'}")
print(f"No Overlap:      {'PASS' if overlap_pass else 'FAIL'}")
print(f"Binary Output:   {'PASS' if density_pass else 'FAIL'}")
print(f"Elevated Reach:  {'PASS' if elev_pass else 'FAIL'}")

all_pass = conn_pass and cant_pass and fill_pass and reach_pass and walk_pass and overlap_pass and elev_pass
if all_pass:
    print('\nPhase 1 SUCCESS - Ready for Phase 2')
else:
    print('\nPhase 1 needs more training')

In [None]:
# Visualize final results
print('\nFinal Results:')
for i in range(3):
    visualize_result(model, trainer.scene_gen, CONFIG, device, f'Final Result {i+1}')

## 8. Save Outputs

In [None]:
# Save final checkpoint
trainer.save_checkpoint(f'{PROJECT_ROOT}/checkpoints/phase1_structural.pth')

# Save training history
history_path = f'{PROJECT_ROOT}/logs/phase1_history.json'
with open(history_path, 'w') as f:
    json.dump(trainer.history, f)
print(f'History saved to {history_path}')

# Save evaluation results
eval_path = f'{PROJECT_ROOT}/logs/phase1_eval.json'
with open(eval_path, 'w') as f:
    json.dump(eval_results, f, indent=2)
print(f'Evaluation saved to {eval_path}')

print('\n' + '='*60)
print('PHASE 1 COMPLETE')
print('='*60)
print(f"\nCheckpoint: {PROJECT_ROOT}/checkpoints/phase1_structural.pth")
print('\nNext: NB03_Phase2_Openness')

---

## Summary

### What was implemented:

1. **ConnectivityLoss (C1A)**: Differentiable flood-fill to ensure all voxels connect to support
2. **CantileverLoss (C1B)**: Penalizes unsupported horizontal overhangs
3. **SparsityLoss (C3A)**: Prevents "fill everything" solution by limiting fill ratio to 3-25%
4. **ExclusionLoss**: Prevents structure from growing inside existing buildings
5. **AccessReachLoss (C4B)**: Ensures structure and walkable reach all access points
6. **WalkableCoverageLoss (C4C)**: Ensures walkable surfaces develop on structure
7. **DiceLoss**: Growth incentive for structure toward access points
8. **WalkableDiceLoss**: Growth incentive for walkable toward access zones
9. **UrbanSceneGenerator**: Updated to include elevated access points (roof/facade)
10. **Phase1Trainer**: Complete training loop with all core constraints

### Key Additions (v1.5):

**Problem:** Structure was a solid blob despite passing fill ratio checks.

**Root Cause:** No incentive for binary (0 or 1) outputs. Voxels with values like 0.4-0.6 
create fuzzy, blobby structures that technically satisfy constraints but look terrible.

**Fixes:**
1. **DensityPenalty (SIMP)**: `s * (1 - s)` penalizes values near 0.5, encourages 0 or 1
2. **TotalVariation3D**: Penalizes differences between neighbors for smooth surfaces
3. **Quality ramp-up**: Quality losses start at epoch 100, ramp up over 100 epochs
   - First 100 epochs: Learn basic growth and constraints
   - Epochs 100-200: Gradually introduce quality requirements
   - After 200: Full quality enforcement

### Quality Loss Weights:
- **Density**: 15.0 (strong push toward binary)
- **TV**: 2.0 (mild smoothness - don't over-smooth)

### Expected Results:

- Density penalty should drop below 0.1 (binary-ish outputs)
- Structure should have cleaner edges, less fuzzy blob appearance
- Fill ratio still 3-25%, but with more architectural character

### Visualization Colors:
- **Gray**: Existing buildings (frozen)
- **Blue**: Generated structure (clean)
- **RED**: Overlap - structure inside buildings (BAD!)
- **Yellow**: Walkable surfaces
- **Green**: Access points

### Next steps (NB03):

1. Load Phase 1 checkpoint
2. Implement GroundVoidLoss (C2A)
3. Add curriculum annealing for smooth constraint activation
4. Train Phase 2: Street-Level Openness

---

*NB02_Phase1_Structural_v1.5 - December 2025*