# NB02: All Constraints Training

**Constraint-Based Architectural NCA - Step B**

**Version:** 1.1 (Rebalanced)  
**Date:** December 2025  
**Purpose:** Train with ALL constraints + balanced growth incentives

---

## Rebalanced Weight Strategy

**Problem in v1.0:** Protection losses (void=50, anchor=30) crushed growth incentives (dice=5). Model learned "do nothing" as optimal.

**Solution:** Rebalance weights + add elevated structure bonus with support ratio constraint.

## Loss Weights (v1.1 - Rebalanced)

| Category | Loss | Weight | Change |
|----------|------|--------|--------|
| **Protection** | Street Void | 25.0 | Reduced from 50 |
| | Anchor Budget | 20.0 | Reduced from 30 |
| **Growth** | Dice | 25.0 | Increased from 5 |
| | Elevated Bonus | 5.0 | NEW |
| | Sparsity (under) | 3x coef | Increased from 0.5x |
| **Balance** | Support Ratio | 15.0 | NEW (max 4:1) |
| **Existing** | Connectivity | 10.0 | - |
| | Street Conn | 10.0 | - |
| | Access Reach | 10.0 | - |
| | Cantilever | 5.0 | - |
| **Quality** | Density | 5.0 | - |
| | TV | 1.0 | - |

## New Constraints

- **Elevated Bonus**: Rewards structure above street level (z >= 2)
- **Support Ratio**: Prevents towers on pedestals (max 4:1 elevated:ground)

## Success Criteria

- Street void ratio >70%
- Anchor compliance 100%
- Street connectivity >90%
- Structural connectivity >95%
- Fill ratio 5-15%
- Support ratio <4:1

---

## 1. Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

PROJECT_ROOT = '/content/drive/MyDrive/Constraint-NCA'
print(f'Project root: {PROJECT_ROOT}')

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from typing import Dict, Tuple, Optional, List
import json
from datetime import datetime
from tqdm.notebook import tqdm
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
# Set seeds
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

In [None]:
# Load Step B configuration
with open(f'{PROJECT_ROOT}/config_step_b.json', 'r') as f:
    CONFIG = json.load(f)

# Training configuration
CONFIG.update({
    'epochs': 400,
    'steps_min': 30,
    'steps_max': 50,
    'difficulty': 'easy',  # Start with easy scenes
    'log_every': 20,
    'viz_every': 100,
    'save_every': 100,
})

print('Configuration loaded')
print(f"Difficulty: {CONFIG['difficulty']}")
print(f"Epochs: {CONFIG['epochs']}")

## 2. Copy Foundation Components from NB01

In [None]:
# ============================================================
# PERCEPTION MODULE (from NB01 v2.0)
# ============================================================

class Perceive3D(nn.Module):
    """3D Sobel perception with replicate padding."""

    def __init__(self, n_channels: int = 8):
        super().__init__()
        self.n_channels = n_channels

        sobel_x = self._create_sobel_kernel('x')
        sobel_y = self._create_sobel_kernel('y')
        sobel_z = self._create_sobel_kernel('z')
        identity = self._create_identity_kernel()

        kernels = torch.stack([identity, sobel_x, sobel_y, sobel_z], dim=0)
        self.register_buffer('kernels', kernels)

    def _create_sobel_kernel(self, direction: str) -> torch.Tensor:
        derivative = torch.tensor([-1., 0., 1.])
        smoothing = torch.tensor([1., 2., 1.])

        if direction == 'x':
            kernel = torch.einsum('i,j,k->ijk', smoothing, smoothing, derivative)
        elif direction == 'y':
            kernel = torch.einsum('i,j,k->ijk', smoothing, derivative, smoothing)
        elif direction == 'z':
            kernel = torch.einsum('i,j,k->ijk', derivative, smoothing, smoothing)
        return kernel / 16.0

    def _create_identity_kernel(self) -> torch.Tensor:
        kernel = torch.zeros(3, 3, 3)
        kernel[1, 1, 1] = 1.0
        return kernel

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = x.shape
        x_padded = F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate')
        
        outputs = []
        for k in range(4):
            kernel = self.kernels[k:k+1].unsqueeze(0).expand(C, 1, 3, 3, 3)
            out = F.conv3d(x_padded, kernel, padding=0, groups=C)
            outputs.append(out)
        return torch.cat(outputs, dim=1)

In [None]:
# ============================================================
# NCA MODEL (from NB01 v2.0)
# ============================================================

class UrbanPavilionNCA(nn.Module):
    """Neural Cellular Automaton for urban pavilion generation."""

    def __init__(self, config: dict):
        super().__init__()
        self.config = config

        n_channels = config['n_channels']
        hidden_dim = config['hidden_dim']
        perception_dim = n_channels * 4
        n_grown = config['n_grown']

        self.perceive = Perceive3D(n_channels)

        self.update_net = nn.Sequential(
            nn.Conv3d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, n_grown, 1),
        )

        self._init_weights()

    def _init_weights(self):
        gain = self.config['xavier_gain']
        for m in self.update_net:
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform_(m.weight, gain=gain)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        last_layer = self.update_net[-1]
        with torch.no_grad():
            last_layer.bias[0] = self.config['structure_bias']
            last_layer.bias[1] = self.config['surface_bias']

    def forward(self, state: torch.Tensor, steps: int = 1) -> torch.Tensor:
        for _ in range(steps):
            state = self._step(state)
        return state

    def _step(self, state: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = state.shape
        cfg = self.config

        perception = self.perceive(state)
        delta = self.update_net(perception)

        if self.training:
            fire_mask = (torch.rand(B, 1, D, H, W, device=state.device) < cfg['fire_rate']).float()
            delta = delta * fire_mask

        grown_start = cfg['n_frozen']
        grown_new = state[:, grown_start:] + cfg['update_scale'] * delta
        grown_new = torch.clamp(grown_new, 0.0, 1.0)

        # Hard mask: no structure inside buildings
        existing = state[:, cfg['ch_existing']:cfg['ch_existing']+1]
        available_mask = 1.0 - existing
        struct_new = grown_new[:, 0:1] * available_mask
        
        grown_masked = torch.cat([struct_new, grown_new[:, 1:]], dim=1)
        new_state = torch.cat([state[:, :grown_start], grown_masked], dim=1)

        return new_state

    def grow(self, seed: torch.Tensor, steps: int = 50) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            return self.forward(seed, steps)

In [None]:
# ============================================================
# SCENE GENERATOR (from NB01 v2.0)
# ============================================================

class UrbanSceneGenerator:
    """Generate urban scenes with buildings, access points, and anchor zones."""

    def __init__(self, config: dict):
        self.config = config
        self.G = config['grid_size']
        self.C = config['n_channels']

    def generate(self, difficulty: str = 'easy',
                 device: str = 'cuda') -> Tuple[torch.Tensor, dict]:
        G = self.G
        cfg = self.config
        
        state = torch.zeros(1, self.C, G, G, G, device=device)
        state[:, cfg['ch_ground'], 0, :, :] = 1.0

        params = self._get_difficulty_params(difficulty)
        building_info = self._place_buildings(state, params)
        access_info = self._place_access_points(state, params, building_info)
        anchor_info = self._generate_anchor_zones(state, params, building_info, access_info)

        metadata = {
            'difficulty': difficulty,
            'buildings': building_info,
            'access_points': access_info,
            'anchor_zones': anchor_info,
            'gap_width': params['gap_width'],
        }
        
        return state, metadata

    def _get_difficulty_params(self, difficulty: str) -> dict:
        G = self.G
        if difficulty == 'easy':
            return {
                'n_buildings': 2, 'height_range': (12, 16), 'height_variance': False,
                'width_range': (8, 12), 'gap_width': random.randint(14, 18),
                'n_ground_access': 1, 'n_elevated_access': 1, 'anchor_budget': 0.10,
            }
        elif difficulty == 'medium':
            return {
                'n_buildings': 2, 'height_range': (10, 20), 'height_variance': True,
                'width_range': (6, 10), 'gap_width': random.randint(10, 14),
                'n_ground_access': random.randint(1, 2), 'n_elevated_access': random.randint(1, 2),
                'anchor_budget': 0.07,
            }
        else:  # hard
            return {
                'n_buildings': random.randint(2, 4), 'height_range': (8, 24), 'height_variance': True,
                'width_range': (5, 8), 'gap_width': random.randint(6, 10),
                'n_ground_access': random.randint(2, 3), 'n_elevated_access': random.randint(2, 3),
                'anchor_budget': 0.05,
            }

    def _place_buildings(self, state: torch.Tensor, params: dict) -> list:
        G = self.G
        ch = self.config['ch_existing']
        buildings = []
        gap_width = params['gap_width']
        gap_center = G // 2

        w1 = random.randint(*params['width_range'])
        d1 = random.randint(G//2, G-2)
        h1 = random.randint(*params['height_range'])
        x1_end = gap_center - gap_width // 2
        x1_start = max(0, x1_end - w1)
        state[:, ch, :h1, :d1, x1_start:x1_end] = 1.0
        buildings.append({'x': (x1_start, x1_end), 'y': (0, d1), 'z': (0, h1),
                         'gap_facing_x': x1_end, 'side': 'left'})

        w2 = random.randint(*params['width_range'])
        d2 = random.randint(G//2, G-2)
        h2 = h1 if not params['height_variance'] else random.randint(*params['height_range'])
        x2_start = gap_center + gap_width // 2
        x2_end = min(G, x2_start + w2)
        state[:, ch, :h2, :d2, x2_start:x2_end] = 1.0
        buildings.append({'x': (x2_start, x2_end), 'y': (0, d2), 'z': (0, h2),
                         'gap_facing_x': x2_start, 'side': 'right'})

        return buildings

    def _place_access_points(self, state: torch.Tensor, params: dict, buildings: list) -> list:
        G = self.G
        ch = self.config['ch_access']
        access_points = []
        
        left_buildings = [b for b in buildings if b['side'] == 'left']
        right_buildings = [b for b in buildings if b['side'] == 'right']
        gap_x_min = max(b['gap_facing_x'] for b in left_buildings) if left_buildings else 0
        gap_x_max = min(b['gap_facing_x'] for b in right_buildings) if right_buildings else G
        
        for i in range(params.get('n_ground_access', 1)):
            x = random.randint(gap_x_min + 1, gap_x_max - 3)
            y = random.randint(0, G - 3)
            state[:, ch, 0:2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': 0, 'type': 'ground'})
        
        for i in range(params.get('n_elevated_access', 1)):
            building = random.choice(buildings)
            bz_max = building['z'][1]
            is_left = building['side'] == 'left'
            
            z = random.randint(3, max(4, bz_max - 2))
            y = random.randint(building['y'][0], min(building['y'][1] - 2, building['y'][0] + G//3))
            x = building['x'][1] if is_left else building['x'][0] - 2
            x = max(0, min(G - 2, x))
            
            state[:, ch, z:z+2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': z, 'type': 'elevated'})

        return access_points

    def _generate_anchor_zones(self, state: torch.Tensor, params: dict,
                               buildings: list, access_points: list) -> dict:
        G = self.G
        ch = self.config['ch_anchors']
        street_levels = self.config['street_levels']
        
        existing = state[:, self.config['ch_existing'], 0, :, :]
        street_mask = 1.0 - existing
        total_street_area = street_mask.sum().item()
        max_anchor_area = int(total_street_area * params['anchor_budget'])
        
        anchors = torch.zeros(1, 1, G, G, G, device=state.device)
        current_anchor_area = 0
        
        for ap in access_points:
            if ap['type'] == 'ground':
                x, y = ap['x'], ap['y']
                for z in range(street_levels):
                    anchors[:, 0, z, max(0,y-1):min(G,y+3), max(0,x-1):min(G,x+3)] = 1.0
        
        for building in buildings:
            by_start, by_end = building['y']
            gap_x = building['gap_facing_x']
            is_left = building['side'] == 'left'
            x_start = gap_x if is_left else gap_x - 1
            x_end = gap_x + 1 if is_left else gap_x
            for z in range(street_levels):
                anchors[:, 0, z, by_start:by_end, max(0,x_start):min(G,x_end)] = 1.0
        
        for z in range(street_levels):
            anchors[:, 0, z, :, :] *= street_mask
        
        state[:, ch:ch+1, :, :, :] = anchors
        final_anchor_area = (anchors > 0.5).sum().item()
        
        return {'total_area': final_anchor_area, 'budget': max_anchor_area,
                'budget_ratio': params['anchor_budget'], 'street_area': total_street_area}

    def batch(self, difficulty: str, batch_size: int, device: str) -> torch.Tensor:
        scenes = [self.generate(difficulty, device)[0] for _ in range(batch_size)]
        return torch.cat(scenes, dim=0)


print('Foundation components loaded')

## 3. Street Void Utilities (from NB01 v2.0)

In [None]:
def compute_street_zone(state: torch.Tensor, config: dict) -> torch.Tensor:
    """Compute street zone mask (ground level, not in buildings)."""
    street_levels = config['street_levels']
    existing = state[:, config['ch_existing'], :street_levels, :, :]
    return 1.0 - existing


def compute_street_void_ratio(state: torch.Tensor, config: dict) -> torch.Tensor:
    """Ratio of street zone that is empty (target >70%)."""
    street_levels = config['street_levels']
    street_zone = compute_street_zone(state, config)
    structure = state[:, config['ch_structure'], :street_levels, :, :]
    void_mask = street_zone * (1.0 - structure)
    return void_mask.sum(dim=(1,2,3)) / (street_zone.sum(dim=(1,2,3)) + 1e-8)


def compute_anchor_compliance(state: torch.Tensor, config: dict) -> torch.Tensor:
    """Check if ground structure is only in anchors (target 100%)."""
    street_levels = config['street_levels']
    structure = state[:, config['ch_structure'], :street_levels, :, :]
    anchors = state[:, config['ch_anchors'], :street_levels, :, :]
    street_zone = compute_street_zone(state, config)
    
    struct_in_street = structure * street_zone
    violation = struct_in_street * (1.0 - anchors)
    
    total_struct = struct_in_street.sum(dim=(1,2,3)) + 1e-8
    return 1.0 - (violation.sum(dim=(1,2,3)) / total_struct)


def compute_street_connectivity(state: torch.Tensor, config: dict,
                                iterations: int = 32) -> torch.Tensor:
    """Check if void is connected across scene (target >90%)."""
    street_levels = config['street_levels']
    G = config['grid_size']
    
    structure = state[:, config['ch_structure'], :street_levels, :, :]
    existing = state[:, config['ch_existing'], :street_levels, :, :]
    void_mask = (1.0 - structure) * (1.0 - existing)
    
    connected = torch.zeros_like(void_mask)
    connected[:, :, :, 0] = 1.0
    connected[:, :, :, -1] = 1.0
    connected[:, :, 0, :] = 1.0
    connected[:, :, -1, :] = 1.0
    connected = connected * void_mask
    
    for _ in range(iterations):
        conn_3d = connected.unsqueeze(1)
        dilated = F.max_pool3d(conn_3d, 3, 1, 1).squeeze(1)
        new_connected = torch.max(connected, dilated * void_mask)
        if torch.allclose(connected, new_connected, atol=1e-5):
            break
        connected = new_connected
    
    return connected.sum(dim=(1,2,3)) / (void_mask.sum(dim=(1,2,3)) + 1e-8)

## 4. Loss Functions (ALL CONSTRAINTS)

In [None]:
# ============================================================
# CRITICAL: STREET VOID LOSS (C2A)
# ============================================================

class StreetVoidLoss(nn.Module):
    """CRITICAL: Protect ground-level circulation space.
    
    Street zone must remain mostly empty (>70% void).
    This is the most important loss in Step B.
    """

    def __init__(self, config: dict, target_void: float = 0.70):
        super().__init__()
        self.config = config
        self.target_void = target_void
        self.street_levels = config['street_levels']

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        
        # Get street zone (ground level, not in buildings)
        existing = state[:, cfg['ch_existing'], :self.street_levels, :, :]
        street_zone = 1.0 - existing
        
        # Get structure at street level
        structure = state[:, cfg['ch_structure'], :self.street_levels, :, :]
        
        # Occupation ratio (what we want to minimize)
        occupation = (structure * street_zone).sum() / (street_zone.sum() + 1e-8)
        
        # Target: occupation < (1 - target_void)
        # Loss increases as occupation exceeds threshold
        max_occupation = 1.0 - self.target_void
        loss = F.relu(occupation - max_occupation)
        
        # Also add soft penalty for any occupation
        loss = loss + 0.1 * occupation
        
        return loss

In [None]:
# ============================================================
# HIGH: ANCHOR BUDGET LOSS (C3B)
# ============================================================

class AnchorBudgetLoss(nn.Module):
    """Ensure ground structure only exists in anchor zones.
    
    Structure at street level must be within designated anchors.
    Any structure outside anchors is a violation.
    """

    def __init__(self, config: dict):
        super().__init__()
        self.config = config
        self.street_levels = config['street_levels']

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        
        # Get structure at street level
        structure = state[:, cfg['ch_structure'], :self.street_levels, :, :]
        
        # Get anchor zones
        anchors = state[:, cfg['ch_anchors'], :self.street_levels, :, :]
        
        # Get street zone (exclude buildings)
        existing = state[:, cfg['ch_existing'], :self.street_levels, :, :]
        street_zone = 1.0 - existing
        
        # Structure in street zone but NOT in anchors = violation
        struct_in_street = structure * street_zone
        violation = struct_in_street * (1.0 - anchors)
        
        # Loss = violation sum normalized by total street structure
        loss = violation.sum() / (struct_in_street.sum() + 1e-8)
        
        return loss

In [None]:
# ============================================================
# HIGH: CONNECTIVITY LOSS (C1A) - Revised
# ============================================================

class ConnectivityLoss(nn.Module):
    """Ensure structure is connected to support.
    
    Support = existing buildings OR anchor zones (not free ground).
    This prevents structure from floating.
    """

    def __init__(self, config: dict, threshold: float = 0.3, iterations: int = 32):
        super().__init__()
        self.config = config
        self.threshold = threshold
        self.iterations = iterations

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        
        # Get structure
        structure = state[:, cfg['ch_structure']]
        struct_soft = torch.sigmoid(10 * (structure - self.threshold))
        
        # Support = existing buildings + anchor zones
        existing = state[:, cfg['ch_existing']]
        anchors = state[:, cfg['ch_anchors']]
        support = torch.clamp(existing + anchors, 0, 1)
        
        # Flood fill from support
        connected = support.clone()
        for _ in range(self.iterations):
            dilated = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * struct_soft)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected
        
        # Disconnected = structure not reached by flood fill
        disconnected = structure * (1.0 - connected)
        loss = disconnected.sum() / (structure.sum() + 1e-8)
        
        return loss

In [None]:
# ============================================================
# HIGH: SPARSITY LOSS (C3A) - Updated with stronger under-penalty
# ============================================================

class SparsityLoss(nn.Module):
    """Limit total volume (5-15% fill ratio).
    
    Updated: Stronger under-penalty to encourage growth.
    """

    def __init__(self, max_ratio: float = 0.15, min_ratio: float = 0.05, 
                 under_coef: float = 3.0):
        super().__init__()
        self.max_ratio = max_ratio
        self.min_ratio = min_ratio
        self.under_coef = under_coef  # Increased from 0.5 to encourage growth

    def forward(self, structure: torch.Tensor, available: torch.Tensor) -> torch.Tensor:
        ratio = structure.sum() / (available.sum() + 1e-8)
        over_penalty = F.relu(ratio - self.max_ratio)
        under_penalty = self.under_coef * F.relu(self.min_ratio - ratio)
        return over_penalty + under_penalty

In [None]:
# ============================================================
# MEDIUM: CANTILEVER LOSS (C1B)
# ============================================================

class CantileverLoss(nn.Module):
    """Limit horizontal overhangs."""

    def __init__(self, max_overhang: int = 3, threshold: float = 0.3):
        super().__init__()
        self.max_overhang = max_overhang
        self.threshold = threshold

    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        B, D, H, W = structure.shape
        N = self.max_overhang
        total_loss = 0.0
        count = 0

        for z in range(N, D):
            layer = structure[:, z]
            support_volume = structure[:, max(0, z-N):z]
            support_max = support_volume.max(dim=1)[0]
            support_dilated = F.max_pool2d(support_max.unsqueeze(1), 3, 1, 1).squeeze(1)
            has_support = torch.sigmoid(10 * (support_dilated - self.threshold))
            unsupported = layer * (1.0 - has_support)
            total_loss += unsupported.mean()
            count += 1

        return total_loss / max(1, count)

In [None]:
# ============================================================
# MEDIUM: ACCESS REACH LOSS (C4A)
# ============================================================

class AccessReachLoss(nn.Module):
    """Structure must reach access points."""

    def __init__(self, config: dict, dilation_radius: int = 3):
        super().__init__()
        self.config = config
        self.dilation_radius = dilation_radius

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        
        structure = state[:, cfg['ch_structure']]
        access = state[:, cfg['ch_access']]
        
        # Dilate access points
        kernel_size = 2 * self.dilation_radius + 1
        access_dilated = F.max_pool3d(
            access.unsqueeze(1), kernel_size, 1, self.dilation_radius
        ).squeeze(1)
        
        # Structure should be in dilated access zone
        reach = (structure * access_dilated).sum() / (access_dilated.sum() + 1e-8)
        
        return 1.0 - reach

In [None]:
# ============================================================
# HIGH: STREET CONNECTIVITY LOSS (C2B) - NEW
# ============================================================

class StreetConnectivityLoss(nn.Module):
    """Ensure street void forms connected network for circulation.
    
    Void at street level must be connected across the scene
    (people need continuous paths to walk through).
    """

    def __init__(self, config: dict, iterations: int = 32):
        super().__init__()
        self.config = config
        self.street_levels = config['street_levels']
        self.iterations = iterations

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        cfg = self.config
        G = cfg['grid_size']
        
        # Get void at street level (not structure, not existing)
        structure = state[:, cfg['ch_structure'], :self.street_levels, :, :]
        existing = state[:, cfg['ch_existing'], :self.street_levels, :, :]
        void_mask = (1.0 - structure) * (1.0 - existing)
        
        # Initialize flood-fill from scene edges (where people enter)
        connected = torch.zeros_like(void_mask)
        connected[:, :, :, 0] = 1.0    # y=0 edge
        connected[:, :, :, -1] = 1.0   # y=max edge
        connected[:, :, 0, :] = 1.0    # x=0 edge
        connected[:, :, -1, :] = 1.0   # x=max edge
        connected = connected * void_mask
        
        # Flood fill
        for _ in range(self.iterations):
            conn_3d = connected.unsqueeze(1)
            dilated = F.max_pool3d(conn_3d, 3, 1, 1).squeeze(1)
            new_connected = torch.max(connected, dilated * void_mask)
            if torch.allclose(connected, new_connected, atol=1e-5):
                break
            connected = new_connected
        
        # Loss = fraction of void that is NOT connected
        total_void = void_mask.sum() + 1e-8
        disconnected_void = void_mask.sum() - connected.sum()
        
        return disconnected_void / total_void


# ============================================================
# MEDIUM: DICE LOSS (Growth Incentive)
# ============================================================

class DiceLoss(nn.Module):
    """Growth incentive via Dice loss."""

    def __init__(self, smooth: float = 1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, structure: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        intersection = (structure * target).sum()
        dice = (2 * intersection + self.smooth) / (
            structure.sum() + target.sum() + self.smooth
        )
        return 1.0 - dice

In [None]:
# ============================================================
# LOW: QUALITY LOSSES (Reduced weights in Step B)
# ============================================================

class DensityPenalty(nn.Module):
    """SIMP density penalty for binary outputs."""
    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        return (structure * (1.0 - structure)).mean()


class TotalVariation3D(nn.Module):
    """Total variation for smooth surfaces."""
    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        tv_d = (structure[:, 1:, :, :] - structure[:, :-1, :, :]).abs().mean()
        tv_h = (structure[:, :, 1:, :] - structure[:, :, :-1, :]).abs().mean()
        tv_w = (structure[:, :, :, 1:] - structure[:, :, :, :-1]).abs().mean()
        return tv_d + tv_h + tv_w


# ============================================================
# NEW: ELEVATED STRUCTURE BONUS
# ============================================================

class ElevatedBonus(nn.Module):
    """Reward structure above street level.
    
    Encourages upward growth to preserve street void.
    Returns negative loss (reward) for elevated structure.
    """
    
    def __init__(self, config: dict):
        super().__init__()
        self.street_levels = config['street_levels']
    
    def forward(self, structure: torch.Tensor, available: torch.Tensor) -> torch.Tensor:
        # Structure above street level (z >= street_levels)
        elevated = structure[:, self.street_levels:, :, :]
        
        # Normalize by available space
        elevated_ratio = elevated.sum() / (available.sum() + 1e-8)
        
        # Return negative (this is a REWARD, will be subtracted from loss)
        return -elevated_ratio


# ============================================================
# NEW: SUPPORT RATIO CONSTRAINT
# ============================================================

class SupportRatioLoss(nn.Module):
    """Prevent towers on tiny pedestals.
    
    Penalizes when elevated mass is disproportionate to ground support.
    Ensures structurally realistic forms.
    """
    
    def __init__(self, config: dict, max_ratio: float = 4.0):
        super().__init__()
        self.street_levels = config['street_levels']
        self.max_ratio = max_ratio  # Max elevated:ground ratio
    
    def forward(self, structure: torch.Tensor) -> torch.Tensor:
        # Ground contact (structure at z < street_levels)
        ground_contact = structure[:, :self.street_levels, :, :].sum()
        
        # Elevated mass (structure at z >= street_levels)
        elevated_mass = structure[:, self.street_levels:, :, :].sum()
        
        # Support ratio
        ratio = elevated_mass / (ground_contact + 1e-8)
        
        # Penalize if ratio exceeds max (towers on pedestals)
        return F.relu(ratio - self.max_ratio)


print('All loss functions defined (including new growth incentives)')

In [None]:
# Test all losses
print('Testing loss functions...')

scene_gen = UrbanSceneGenerator(CONFIG)
scene, meta = scene_gen.generate('easy', device)

# Add fake structure (elevated, with some ground contact)
test_state = scene.clone()
test_state[:, CONFIG['ch_structure'], 2:10, 10:20, 12:20] = 0.6  # Elevated
test_state[:, CONFIG['ch_structure'], 0:2, 12:16, 14:18] = 0.6   # Ground contact

structure = test_state[:, CONFIG['ch_structure']]
existing = test_state[:, CONFIG['ch_existing']]
available = 1.0 - existing

# Test each loss
print('\nProtection losses:')
street_void = StreetVoidLoss(CONFIG)(test_state)
print(f'  Street void (C2A): {street_void.item():.4f}')

anchor_budget = AnchorBudgetLoss(CONFIG)(test_state)
print(f'  Anchor budget (C3B): {anchor_budget.item():.4f}')

street_conn = StreetConnectivityLoss(CONFIG)(test_state)
print(f'  Street connectivity (C2B): {street_conn.item():.4f}')

print('\nGrowth incentives:')
access = test_state[:, CONFIG['ch_access']]
target = F.max_pool3d(access.unsqueeze(1), 7, 1, 3).squeeze(1) * available
dice = DiceLoss()(structure, target)
print(f'  Dice: {dice.item():.4f}')

elevated = ElevatedBonus(CONFIG)(structure, available)
print(f'  Elevated bonus: {elevated.item():.4f} (negative = reward)')

print('\nBalance constraint:')
support = SupportRatioLoss(CONFIG)(structure)
print(f'  Support ratio: {support.item():.4f}')

print('\nExisting constraints:')
connectivity = ConnectivityLoss(CONFIG)(test_state)
print(f'  Connectivity (C1A): {connectivity.item():.4f}')

sparsity = SparsityLoss(under_coef=3.0)(structure, available)
print(f'  Sparsity (C3A): {sparsity.item():.4f}')

cantilever = CantileverLoss()(structure)
print(f'  Cantilever (C1B): {cantilever.item():.4f}')

access_reach = AccessReachLoss(CONFIG)(test_state)
print(f'  Access reach (C4A): {access_reach.item():.4f}')

print('\nQuality losses:')
density = DensityPenalty()(structure)
print(f'  Density penalty: {density.item():.4f}')

tv = TotalVariation3D()(structure)
print(f'  Total variation: {tv.item():.4f}')

print('\nâœ“ All 12 loss functions working')

## 5. Trainer (All Constraints Active)

In [None]:
class AllConstraintsTrainer:
    """Step B Trainer: All constraints active from epoch 0.
    
    Updated with rebalanced weights and growth incentives.
    """

    def __init__(self, model: nn.Module, config: dict, device: str):
        self.model = model
        self.config = config
        self.device = device

        # Initialize all loss functions
        self.street_void_loss = StreetVoidLoss(config)
        self.anchor_budget_loss = AnchorBudgetLoss(config)
        self.street_conn_loss = StreetConnectivityLoss(config)
        self.connectivity_loss = ConnectivityLoss(config)
        self.sparsity_loss = SparsityLoss(max_ratio=0.15, min_ratio=0.05, under_coef=3.0)
        self.cantilever_loss = CantileverLoss()
        self.access_reach_loss = AccessReachLoss(config)
        self.dice_loss = DiceLoss()
        self.density_loss = DensityPenalty()
        self.tv_loss = TotalVariation3D()
        
        # NEW: Growth incentive losses
        self.elevated_bonus = ElevatedBonus(config)
        self.support_ratio_loss = SupportRatioLoss(config, max_ratio=4.0)

        # REBALANCED weights - encourage growth while protecting street
        self.weights = {
            # Protection (reduced)
            'street_void': 25.0,      # Reduced from 50 (still important but not crushing)
            'anchor_budget': 20.0,    # Reduced from 30
            
            # Growth incentives (increased)
            'dice': 25.0,             # Increased from 5 (main growth driver)
            'elevated_bonus': 5.0,    # NEW: reward upward growth
            
            # Balance constraint
            'support_ratio': 15.0,    # NEW: prevent towers on pedestals
            
            # Existing constraints
            'sparsity': 20.0,         # Now with 3x under-penalty
            'connectivity': 10.0,
            'street_conn': 10.0,
            'access_reach': 10.0,
            'cantilever': 5.0,
            
            # Quality (unchanged)
            'density': 5.0,
            'tv': 1.0,
        }

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['lr_initial'])
        self.scene_gen = UrbanSceneGenerator(config)
        self.history = []

    def train_epoch(self, epoch: int) -> dict:
        self.model.train()
        cfg = self.config

        # Generate batch
        seeds = self.scene_gen.batch(cfg['difficulty'], cfg['batch_size'], self.device)
        steps = random.randint(cfg['steps_min'], cfg['steps_max'])

        # Forward pass
        final = self.model(seeds, steps=steps)

        # Extract channels
        structure = final[:, cfg['ch_structure']]
        existing = final[:, cfg['ch_existing']]
        access = final[:, cfg['ch_access']]
        available = 1.0 - existing

        # Compute ALL losses
        L_street = self.street_void_loss(final)
        L_anchor = self.anchor_budget_loss(final)
        L_street_conn = self.street_conn_loss(final)
        L_conn = self.connectivity_loss(final)
        L_sparse = self.sparsity_loss(structure, available)
        L_cant = self.cantilever_loss(structure)
        L_reach = self.access_reach_loss(final)
        
        # Dice target
        target = F.max_pool3d(access.unsqueeze(1), 7, 1, 3).squeeze(1) * available
        L_dice = self.dice_loss(structure, target)
        
        L_density = self.density_loss(structure)
        L_tv = self.tv_loss(structure)
        
        # NEW: Growth incentive losses
        L_elevated = self.elevated_bonus(structure, available)  # Negative = reward
        L_support = self.support_ratio_loss(structure)

        # Total loss with rebalanced weights
        total_loss = (
            # Protection
            self.weights['street_void'] * L_street +
            self.weights['anchor_budget'] * L_anchor +
            self.weights['street_conn'] * L_street_conn +
            
            # Growth incentives
            self.weights['dice'] * L_dice +
            self.weights['elevated_bonus'] * L_elevated +  # Negative, so subtracts
            
            # Balance
            self.weights['support_ratio'] * L_support +
            
            # Existing
            self.weights['connectivity'] * L_conn +
            self.weights['sparsity'] * L_sparse +
            self.weights['cantilever'] * L_cant +
            self.weights['access_reach'] * L_reach +
            
            # Quality
            self.weights['density'] * L_density +
            self.weights['tv'] * L_tv
        )

        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg['grad_clip'])
        self.optimizer.step()

        # Compute metrics
        with torch.no_grad():
            void_ratio = compute_street_void_ratio(final, cfg).mean().item()
            anchor_comp = compute_anchor_compliance(final, cfg).mean().item()
            street_conn = compute_street_connectivity(final, cfg).mean().item()
            fill_ratio = structure.sum().item() / (available.sum().item() + 1e-8)
            
            # NEW: Compute support ratio for logging
            ground = structure[:, :cfg['street_levels'], :, :].sum().item()
            elevated = structure[:, cfg['street_levels']:, :, :].sum().item()
            support_ratio = elevated / (ground + 1e-8)

        metrics = {
            'epoch': epoch,
            'total_loss': total_loss.item(),
            'street_void': L_street.item(),
            'anchor_budget': L_anchor.item(),
            'street_conn_loss': L_street_conn.item(),
            'connectivity': L_conn.item(),
            'sparsity': L_sparse.item(),
            'cantilever': L_cant.item(),
            'access_reach': L_reach.item(),
            'dice': L_dice.item(),
            'elevated_bonus': L_elevated.item(),
            'support_ratio_loss': L_support.item(),
            'density': L_density.item(),
            'tv': L_tv.item(),
            'void_ratio': void_ratio,
            'anchor_compliance': anchor_comp,
            'street_connectivity': street_conn,
            'fill_ratio': fill_ratio,
            'support_ratio': support_ratio,
            'steps': steps,
        }

        self.history.append(metrics)
        return metrics

    def evaluate(self, n_samples: int = 20) -> dict:
        self.model.eval()
        cfg = self.config
        results = []

        with torch.no_grad():
            for _ in range(n_samples):
                scene, meta = self.scene_gen.generate(cfg['difficulty'], self.device)
                grown = self.model.grow(scene, steps=50)
                
                void_ratio = compute_street_void_ratio(grown, cfg).item()
                anchor_comp = compute_anchor_compliance(grown, cfg).item()
                street_conn = compute_street_connectivity(grown, cfg).item()
                
                structure = grown[:, cfg['ch_structure']]
                existing = grown[:, cfg['ch_existing']]
                available = 1.0 - existing
                fill_ratio = structure.sum().item() / (available.sum().item() + 1e-8)
                
                # Support ratio
                ground = structure[:, :cfg['street_levels'], :, :].sum().item()
                elevated = structure[:, cfg['street_levels']:, :, :].sum().item()
                support_ratio = elevated / (ground + 1e-8)
                
                # Structural connectivity
                conn_loss = self.connectivity_loss(grown).item()
                conn_rate = 1.0 - conn_loss
                
                results.append({
                    'void_ratio': void_ratio,
                    'anchor_compliance': anchor_comp,
                    'street_connectivity': street_conn,
                    'fill_ratio': fill_ratio,
                    'support_ratio': support_ratio,
                    'connectivity_rate': conn_rate,
                })

        return {
            'avg_void_ratio': np.mean([r['void_ratio'] for r in results]),
            'avg_anchor_compliance': np.mean([r['anchor_compliance'] for r in results]),
            'avg_street_connectivity': np.mean([r['street_connectivity'] for r in results]),
            'avg_fill_ratio': np.mean([r['fill_ratio'] for r in results]),
            'avg_support_ratio': np.mean([r['support_ratio'] for r in results]),
            'avg_connectivity_rate': np.mean([r['connectivity_rate'] for r in results]),
            'n_samples': n_samples,
        }

    def save_checkpoint(self, path: str):
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config,
            'history': self.history,
        }, path)
        print(f'Checkpoint saved: {path}')

## 6. Visualization

In [None]:
def visualize_result(model, scene_gen, config, device, title='Result'):
    """Visualize grown structure with void metrics."""
    model.eval()
    scene, meta = scene_gen.generate(config['difficulty'], device)
    
    with torch.no_grad():
        grown = model.grow(scene, steps=50)
    
    cfg = config
    s = grown[0].cpu().numpy()
    G = s.shape[1]
    street_levels = cfg['street_levels']
    
    existing = s[cfg['ch_existing']] > 0.5
    access = s[cfg['ch_access']] > 0.5
    anchors = s[cfg['ch_anchors']] > 0.5
    structure = s[cfg['ch_structure']] > 0.5
    
    # Violations
    street_struct = structure.copy()
    street_struct[street_levels:, :, :] = False
    street_struct = street_struct & ~existing
    violations = street_struct & ~anchors
    valid_struct = structure & ~violations
    
    fig = plt.figure(figsize=(20, 5))
    
    # 3D View
    ax1 = fig.add_subplot(141, projection='3d')
    if existing.any():
        ax1.voxels(existing.transpose(1,2,0), facecolors='gray', alpha=0.3)
    if anchors.any():
        anchor_disp = anchors.copy()
        anchor_disp[street_levels:,:,:] = False
        if anchor_disp.any():
            ax1.voxels(anchor_disp.transpose(1,2,0), facecolors='yellow', alpha=0.3)
    if access.any():
        ax1.voxels(access.transpose(1,2,0), facecolors='green', alpha=0.9)
    if violations.any():
        ax1.voxels(violations.transpose(1,2,0), facecolors='red', alpha=0.9)
    if valid_struct.any():
        ax1.voxels(valid_struct.transpose(1,2,0), facecolors='royalblue', alpha=0.6)
    ax1.set_xlabel('Y'); ax1.set_ylabel('X'); ax1.set_zlabel('Z')
    ax1.set_title(f'{title} (3D)')
    
    # Ground level
    ax2 = fig.add_subplot(142)
    plan = np.zeros((G,G,3))
    plan[existing[0,:,:]] = [0.5,0.5,0.5]
    plan[anchors[0,:,:] & ~existing[0,:,:]] = [1,1,0.5]
    plan[valid_struct.max(axis=0)] = [0.2,0.4,0.8]
    plan[violations[0,:,:]] = [1,0,0]
    plan[access.max(axis=0)] = [0.2,0.8,0.2]
    ax2.imshow(plan.transpose(1,0,2), origin='lower')
    ax2.set_title('Ground Level')
    
    # Elevation
    ax3 = fig.add_subplot(143)
    elev = np.zeros((G,G,3))
    elev[existing.max(axis=1)] = [0.5,0.5,0.5]
    elev[valid_struct.max(axis=1)] = [0.2,0.4,0.8]
    elev[violations.max(axis=1)] = [1,0,0]
    elev[access.max(axis=1)] = [0.2,0.8,0.2]
    ax3.imshow(elev.transpose(1,0,2), origin='lower')
    ax3.set_title('Elevation')
    
    # Metrics
    ax4 = fig.add_subplot(144)
    ax4.axis('off')
    
    void_ratio = compute_street_void_ratio(grown, config).item()
    anchor_comp = compute_anchor_compliance(grown, config).item()
    street_conn = compute_street_connectivity(grown, config).item()
    
    status_void = 'PASS' if void_ratio > 0.70 else 'FAIL'
    status_anchor = 'PASS' if anchor_comp > 0.99 else 'FAIL'
    status_conn = 'PASS' if street_conn > 0.90 else 'FAIL'
    
    text = f"""STEP B METRICS

Street Void: {void_ratio*100:.1f}% [{status_void}]
(Target: >70%)

Anchor Compliance: {anchor_comp*100:.1f}% [{status_anchor}]
(Target: 100%)

Street Connectivity: {street_conn*100:.1f}% [{status_conn}]
(Target: >90%)

---
Structure: {structure.sum():.0f} voxels
Violations: {violations.sum():.0f}
"""
    ax4.text(0.1, 0.9, text, transform=ax4.transAxes, fontsize=10,
             verticalalignment='top', fontfamily='monospace')
    
    plt.tight_layout()
    plt.show()
    
    return grown, meta


def plot_training_curves(history):
    """Plot training curves."""
    epochs = [h['epoch'] for h in history]
    
    fig, axes = plt.subplots(2, 5, figsize=(25, 8))
    
    # Total loss
    axes[0,0].plot(epochs, [h['total_loss'] for h in history])
    axes[0,0].set_title('Total Loss')
    axes[0,0].set_yscale('log')
    
    # Street void loss (CRITICAL)
    axes[0,1].plot(epochs, [h['street_void'] for h in history], 'r-')
    axes[0,1].set_title('Street Void Loss (CRITICAL)')
    
    # Anchor budget loss
    axes[0,2].plot(epochs, [h['anchor_budget'] for h in history], 'orange')
    axes[0,2].set_title('Anchor Budget Loss')
    
    # Connectivity
    axes[0,3].plot(epochs, [h['connectivity'] for h in history])
    axes[0,3].set_title('Connectivity Loss')
    
    # Sparsity
    axes[0,4].plot(epochs, [h['sparsity'] for h in history], 'purple')
    axes[0,4].set_title('Sparsity Loss')
    
    # Void ratio metric
    axes[1,0].plot(epochs, [h['void_ratio'] for h in history], 'g-')
    axes[1,0].axhline(0.70, color='r', linestyle='--', label='Target')
    axes[1,0].set_title('Street Void Ratio')
    axes[1,0].set_ylim(0, 1.1)
    axes[1,0].legend()
    
    # Anchor compliance
    axes[1,1].plot(epochs, [h['anchor_compliance'] for h in history], 'orange')
    axes[1,1].axhline(1.0, color='r', linestyle='--', label='Target')
    axes[1,1].set_title('Anchor Compliance')
    axes[1,1].set_ylim(0, 1.1)
    axes[1,1].legend()
    
    # Street connectivity
    axes[1,2].plot(epochs, [h['street_connectivity'] for h in history], 'cyan')
    axes[1,2].axhline(0.90, color='r', linestyle='--', label='Target')
    axes[1,2].set_title('Street Connectivity')
    axes[1,2].set_ylim(0, 1.1)
    axes[1,2].legend()
    
    # Fill ratio
    axes[1,3].plot(epochs, [h['fill_ratio'] for h in history], 'magenta')
    axes[1,3].axhline(0.15, color='r', linestyle='--', label='Max')
    axes[1,3].axhline(0.05, color='g', linestyle='--', label='Min')
    axes[1,3].set_title('Fill Ratio')
    axes[1,3].set_ylim(0, 0.3)
    axes[1,3].legend()
    
    # Access reach
    axes[1,4].plot(epochs, [h['access_reach'] for h in history], 'brown')
    axes[1,4].set_title('Access Reach Loss')
    
    plt.tight_layout()
    plt.show()

## 7. Training

In [None]:
# Initialize
model = UrbanPavilionNCA(CONFIG).to(device)
trainer = AllConstraintsTrainer(model, CONFIG, device)

print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f"Difficulty: {CONFIG['difficulty']}")
print(f"Epochs: {CONFIG['epochs']}")

In [None]:
# Before training
print('Before training:')
visualize_result(model, trainer.scene_gen, CONFIG, device, 'Before Training')

In [None]:
# Training loop
print('\n' + '='*70)
print('STEP B TRAINING v1.1: REBALANCED WEIGHTS')
print('='*70)
print('NEW: Elevated bonus + Support ratio constraint')
print('CHANGED: Reduced void/anchor weights, increased dice/sparsity-under')
print('='*70)

for epoch in tqdm(range(CONFIG['epochs']), desc='Training'):
    metrics = trainer.train_epoch(epoch)
    
    if epoch % CONFIG['log_every'] == 0:
        tqdm.write(
            f"Epoch {epoch:4d} | Loss: {metrics['total_loss']:.2f} | "
            f"Void: {metrics['void_ratio']*100:.0f}% | "
            f"Fill: {metrics['fill_ratio']*100:.1f}% | "
            f"Sup.Ratio: {metrics['support_ratio']:.1f}"
        )
    
    if epoch > 0 and epoch % CONFIG['viz_every'] == 0:
        visualize_result(model, trainer.scene_gen, CONFIG, device, f'Epoch {epoch}')
    
    if epoch > 0 and epoch % CONFIG['save_every'] == 0:
        trainer.save_checkpoint(f"{PROJECT_ROOT}/step_b/checkpoints/epoch_{epoch}.pth")

print('='*70)
print('Training complete!')

## 8. Evaluation

In [None]:
# Plot curves
plot_training_curves(trainer.history)

In [None]:
# Evaluate
print('Evaluating...')
eval_results = trainer.evaluate(n_samples=50)

print('\n' + '='*70)
print('STEP B v1.1 EVALUATION RESULTS')
print('='*70)

void_pass = eval_results['avg_void_ratio'] > 0.70
anchor_pass = eval_results['avg_anchor_compliance'] > 0.99
conn_pass = eval_results['avg_street_connectivity'] > 0.90
fill_pass = 0.05 < eval_results['avg_fill_ratio'] < 0.15
struct_pass = eval_results['avg_connectivity_rate'] > 0.95
support_pass = eval_results['avg_support_ratio'] < 4.0

print(f"Street Void Ratio:    {eval_results['avg_void_ratio']*100:.1f}% {'PASS' if void_pass else 'FAIL'} (target >70%)")
print(f"Anchor Compliance:    {eval_results['avg_anchor_compliance']*100:.1f}% {'PASS' if anchor_pass else 'FAIL'} (target 100%)")
print(f"Street Connectivity:  {eval_results['avg_street_connectivity']*100:.1f}% {'PASS' if conn_pass else 'FAIL'} (target >90%)")
print(f"Fill Ratio:           {eval_results['avg_fill_ratio']*100:.1f}% {'PASS' if fill_pass else 'FAIL'} (target 5-15%)")
print(f"Struct Connectivity:  {eval_results['avg_connectivity_rate']*100:.1f}% {'PASS' if struct_pass else 'FAIL'} (target >95%)")
print(f"Support Ratio:        {eval_results['avg_support_ratio']:.1f}:1 {'PASS' if support_pass else 'FAIL'} (target <4:1)")
print('='*70)

all_pass = void_pass and anchor_pass and conn_pass and fill_pass and struct_pass and support_pass
if all_pass:
    print('\nALL CRITERIA PASSED - Ready for medium difficulty')
else:
    print('\nSome criteria not met - Analyze results and adjust if needed')

In [None]:
# Final visualizations
print('\nFinal Results:')
for i in range(3):
    visualize_result(model, trainer.scene_gen, CONFIG, device, f'Final {i+1}')

## 9. Save Final Checkpoint

In [None]:
# Save final
trainer.save_checkpoint(f"{PROJECT_ROOT}/step_b/checkpoints/all_constraints_easy.pth")

# Save history
with open(f"{PROJECT_ROOT}/step_b/logs/training_history.json", 'w') as f:
    json.dump(trainer.history, f)

# Save eval
with open(f"{PROJECT_ROOT}/step_b/logs/evaluation.json", 'w') as f:
    json.dump(eval_results, f, indent=2)

print('\nAll outputs saved')
print(f"Checkpoint: {PROJECT_ROOT}/step_b/checkpoints/all_constraints_easy.pth")

## Summary

### Step B Training v1.1 - Rebalanced Weights

**Problem in v1.0:** Model learned "do nothing" because protection losses dominated.

**Solution:** Rebalance weights + add growth incentives with structural balance.

### Weight Changes

| Loss | v1.0 | v1.1 | Reason |
|------|------|------|--------|
| Street Void | 50 | **25** | Reduced to allow growth |
| Anchor Budget | 30 | **20** | Reduced to allow growth |
| Dice | 5 | **25** | Main growth driver |
| Sparsity under-coef | 0.5 | **3.0** | Stronger growth pressure |
| Elevated Bonus | - | **5** | NEW: Reward upward growth |
| Support Ratio | - | **15** | NEW: Prevent towers on pedestals |

### New Constraints

- **Elevated Bonus**: Rewards structure at z >= 2 (preserves street void naturally)
- **Support Ratio**: Max 4:1 elevated:ground ratio (structural realism)

### Success Criteria

| Metric | Target |
|--------|--------|
| Street void | >70% |
| Anchor compliance | 100% |
| Street connectivity | >90% |
| Structural connectivity | >95% |
| Fill ratio | 5-15% |
| Support ratio | <4:1 |

---

*NB02_AllConstraints_v1.1 - Step B Rebalanced - December 2025*