# NB02: All Constraints Training v6.5

**Version:** 6.5 (Extended 2-Phase Curriculum)

---

## Reviewer Diagnosis of v6.3/v6.4

**Phase Transition Instability**:
- Epochs 0-440: Stuck in "no growth" minimum
- Epoch 460: Sudden breakout
- Epochs 460-500: Chaotic blob growth (no spatial discipline)

**Root cause**: Model only started growing at epoch 460 - not enough time to learn WHERE to grow.

## v6.5 Solution: Extended 2-Phase Curriculum

| Phase | Epochs | Spill Weight | Goal |
|-------|--------|--------------|------|
| **Growth** | 0-600 | 0 → 10 | Learn TO grow |
| **Sculpting** | 600-1500 | 10 → 50 | Learn WHERE to grow |

**Key changes:**
1. **1500 epochs** (3x longer)
2. **2-phase curriculum** - Growth then Sculpting
3. **Progressive spill hardening** - 0 → 50
4. **sqrt distance** in sculpting phase

---

In [None]:
from google.colab import drive
drive.mount('/content/drive')
PROJECT_ROOT = '/content/drive/MyDrive/Constraint-NCA'

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from typing import Dict, Tuple, List
import json
from tqdm.notebook import tqdm
from itertools import combinations

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(42)

In [None]:
with open(f'{PROJECT_ROOT}/config_step_b.json', 'r') as f:
    CONFIG = json.load(f)

# v6.5 EXTENDED CURRICULUM
CONFIG.update({
    'epochs': 1500,
    'steps_min': 30,
    'steps_max': 50,
    'difficulty': 'easy',
    'log_every': 50,
    'viz_every': 300,
    'save_every': 300,
    'corridor_width': 4,
    'max_thickness': 2,
    'vertical_envelope': 2,
    'lr_initial': 1e-3,
    # 2-PHASE CURRICULUM
    'growth_phase_end': 600,
    'sculpt_phase_end': 1500,
    'spill_weight_min': 0.0,
    'spill_weight_growth': 10.0,
    'spill_weight_max': 50.0,
})

print('='*60)
print('v6.5 EXTENDED 2-PHASE CURRICULUM')
print('='*60)
print(f"Total epochs: {CONFIG['epochs']}")
print(f"Phase 1 (Growth): 0-{CONFIG['growth_phase_end']}")
print(f"Phase 2 (Sculpt): {CONFIG['growth_phase_end']}-{CONFIG['sculpt_phase_end']}")
print(f"Spill weight: 0 -> 10 -> 50")
print('='*60)

In [None]:
class Perceive3D(nn.Module):
    def __init__(self, n_channels=8):
        super().__init__()
        sobel_x = self._sobel('x')
        sobel_y = self._sobel('y')
        sobel_z = self._sobel('z')
        identity = torch.zeros(3,3,3); identity[1,1,1] = 1.0
        self.register_buffer('kernels', torch.stack([identity, sobel_x, sobel_y, sobel_z]))
        self.n_channels = n_channels

    def _sobel(self, d):
        deriv = torch.tensor([-1., 0., 1.])
        smooth = torch.tensor([1., 2., 1.])
        if d == 'x': return torch.einsum('i,j,k->ijk', smooth, smooth, deriv) / 16
        if d == 'y': return torch.einsum('i,j,k->ijk', smooth, deriv, smooth) / 16
        return torch.einsum('i,j,k->ijk', deriv, smooth, smooth) / 16

    def forward(self, x):
        B, C, D, H, W = x.shape
        x_pad = F.pad(x, (1,1,1,1,1,1), mode='replicate')
        outs = []
        for k in range(4):
            kern = self.kernels[k:k+1].unsqueeze(0).expand(C, 1, 3, 3, 3)
            outs.append(F.conv3d(x_pad, kern, padding=0, groups=C))
        return torch.cat(outs, dim=1)

In [None]:
class UrbanPavilionNCA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        n_ch = config['n_channels']
        hid = config['hidden_dim']
        n_grown = config['n_grown']
        self.perceive = Perceive3D(n_ch)
        self.update_net = nn.Sequential(
            nn.Conv3d(n_ch * 4, hid, 1), nn.ReLU(),
            nn.Conv3d(hid, hid, 1), nn.ReLU(),
            nn.Conv3d(hid, n_grown, 1),
        )
        for m in self.update_net:
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform_(m.weight, gain=config['xavier_gain'])
                if m.bias is not None: nn.init.zeros_(m.bias)
        with torch.no_grad():
            self.update_net[-1].bias[0] = config.get('structure_bias', -0.5)

    def forward(self, state, steps=1):
        for _ in range(steps): state = self._step(state)
        return state

    def _step(self, state):
        B, C, D, H, W = state.shape
        cfg = self.config
        delta = self.update_net(self.perceive(state))
        if self.training:
            delta = delta * (torch.rand(B,1,D,H,W,device=state.device) < cfg['fire_rate']).float()
        grown_start = cfg['n_frozen']
        grown_new = torch.clamp(state[:, grown_start:] + cfg['update_scale'] * delta, 0, 1)
        existing = state[:, cfg['ch_existing']:cfg['ch_existing']+1]
        struct_new = grown_new[:, 0:1] * (1 - existing)
        grown_masked = torch.cat([struct_new, grown_new[:, 1:]], dim=1)
        return torch.cat([state[:, :grown_start], grown_masked], dim=1)

    def grow(self, seed, steps=50):
        self.eval()
        with torch.no_grad(): return self.forward(seed, steps)

In [None]:
class UrbanSceneGenerator:
    def __init__(self, config):
        self.config = config
        self.G = config['grid_size']
        self.C = config['n_channels']

    def generate(self, difficulty='easy', device='cuda'):
        G, cfg = self.G, self.config
        state = torch.zeros(1, self.C, G, G, G, device=device)
        state[:, cfg['ch_ground'], 0, :, :] = 1.0
        params = {'n_buildings': 2, 'height_range': (12, 16), 'width_range': (8, 12),
                  'gap_width': random.randint(14, 18), 'n_ground_access': 1, 'n_elevated_access': 1}
        buildings = self._place_buildings(state, params)
        access = self._place_access(state, params, buildings)
        self._gen_anchors(state, params, buildings, access)
        return state, {'buildings': buildings, 'access': access}

    def _place_buildings(self, state, params):
        G, ch = self.G, self.config['ch_existing']
        buildings = []
        gap_center = G // 2
        gap_width = params['gap_width']
        for side in ['left', 'right']:
            w = random.randint(*params['width_range'])
            d = random.randint(G//2, G-2)
            h = random.randint(*params['height_range'])
            if side == 'left':
                x_end = gap_center - gap_width // 2
                x_start = max(0, x_end - w)
                gap_x = x_end
            else:
                x_start = gap_center + gap_width // 2
                x_end = min(G, x_start + w)
                gap_x = x_start
            state[:, ch, :h, :d, x_start:x_end] = 1.0
            buildings.append({'x': (x_start, x_end), 'y': (0, d), 'z': (0, h), 'gap_x': gap_x, 'side': side})
        return buildings

    def _place_access(self, state, params, buildings):
        G, ch = self.G, self.config['ch_access']
        access = []
        left = [b for b in buildings if b['side'] == 'left']
        right = [b for b in buildings if b['side'] == 'right']
        gap_min = max(b['gap_x'] for b in left) if left else 0
        gap_max = min(b['gap_x'] for b in right) if right else G
        for _ in range(params['n_ground_access']):
            x, y = random.randint(gap_min+1, gap_max-3), random.randint(0, G-3)
            state[:, ch, 0:2, y:y+2, x:x+2] = 1.0
            access.append({'x': x, 'y': y, 'z': 0, 'type': 'ground'})
        for _ in range(params['n_elevated_access']):
            b = random.choice(buildings)
            z = random.randint(3, max(4, b['z'][1]-2))
            y = random.randint(b['y'][0], min(b['y'][1]-2, b['y'][0]+G//3))
            x = b['x'][1] if b['side'] == 'left' else b['x'][0]-2
            x = max(0, min(G-2, x))
            state[:, ch, z:z+2, y:y+2, x:x+2] = 1.0
            access.append({'x': x, 'y': y, 'z': z, 'type': 'elevated'})
        return access

    def _gen_anchors(self, state, params, buildings, access):
        G, ch, sl = self.G, self.config['ch_anchors'], self.config['street_levels']
        existing = state[:, self.config['ch_existing'], 0, :, :]
        anchors = torch.zeros(1, 1, G, G, G, device=state.device)
        for ap in access:
            if ap['type'] == 'ground':
                for z in range(sl):
                    anchors[:, 0, z, max(0,ap['y']-2):min(G,ap['y']+4), max(0,ap['x']-2):min(G,ap['x']+4)] = 1.0
        for b in buildings:
            for z in range(sl):
                x_s = b['gap_x'] if b['side'] == 'left' else b['gap_x']-1
                x_e = b['gap_x']+1 if b['side'] == 'left' else b['gap_x']
                anchors[:, 0, z, b['y'][0]:min(b['y'][0]+4,b['y'][1]), max(0,x_s):min(G,x_e)] = 1.0
        for z in range(sl): anchors[:, 0, z] *= (1 - existing)
        state[:, ch:ch+1] = anchors

    def batch(self, difficulty, batch_size, device):
        return torch.cat([self.generate(difficulty, device)[0] for _ in range(batch_size)], dim=0)

In [None]:
def find_access_centroids(access_channel):
    binary = (access_channel > 0.5).float()
    if binary.sum() == 0: return []
    positions = (binary > 0).nonzero(as_tuple=False)
    if len(positions) == 0: return []
    centroids, used = [], set()
    for idx in range(len(positions)):
        if idx in used: continue
        cluster = [positions[idx]]
        used.add(idx)
        for idx2 in range(idx+1, len(positions)):
            if idx2 not in used and (positions[idx] - positions[idx2]).abs().sum().item() <= 4:
                cluster.append(positions[idx2])
                used.add(idx2)
        c = torch.stack(cluster).float().mean(dim=0).long()
        centroids.append((c[0].item(), c[1].item(), c[2].item()))
    return centroids

def compute_distance_field_3d(starts, legal_mask, max_iters=64):
    D, H, W = legal_mask.shape
    dist = torch.full((D,H,W), float('inf'), device=legal_mask.device)
    for z, y, x in starts:
        if 0 <= z < D and 0 <= y < H and 0 <= x < W: dist[z, y, x] = 0
    for _ in range(max_iters):
        exp = -F.max_pool3d(-dist.unsqueeze(0).unsqueeze(0), 3, 1, 1).squeeze() + 1
        new_dist = torch.where(legal_mask > 0.5, torch.min(dist, exp), dist)
        if torch.allclose(dist, new_dist, atol=1e-5): break
        dist = new_dist
    return dist

def compute_corridor_and_access_distance(seed, config, corridor_width=4, vertical_envelope=2):
    cfg, G = config, config['grid_size']
    B = seed.shape[0]
    corridors = torch.zeros(B, G, G, G, device=seed.device)
    access_dists = torch.zeros(B, G, G, G, device=seed.device)
    for b in range(B):
        access = seed[b, cfg['ch_access']]
        existing = seed[b, cfg['ch_existing']]
        legal = 1.0 - existing
        centroids = find_access_centroids(access)
        if centroids:
            access_dists[b] = torch.clamp(compute_distance_field_3d(centroids, legal), 0, G)
        if len(centroids) < 2:
            corridors[b] = F.max_pool3d(access[None,None], 2*corridor_width+1, 1, corridor_width).squeeze() * legal
            continue
        corr_mask = torch.zeros(G, G, G, device=seed.device)
        for i, j in combinations(range(len(centroids)), 2):
            d1 = compute_distance_field_3d([centroids[i]], legal)
            d2 = compute_distance_field_3d([centroids[j]], legal)
            total = d1[centroids[j]]
            if total == float('inf'): continue
            corr_mask = torch.max(corr_mask, (d1 + d2 <= total + corridor_width).float())
        if corr_mask.sum() > 0:
            dilated = F.max_pool3d(corr_mask[None,None], 2*corridor_width+1, 1, corridor_width).squeeze()
            if vertical_envelope > 0:
                for z in range(G):
                    z_min, z_max = max(0, z-vertical_envelope), min(G, z+vertical_envelope+1)
                    local_max = dilated[z_min:z_max].max(dim=0)[0]
                    if local_max.any(): dilated[z] = torch.max(dilated[z], local_max * 0.7)
            corridors[b] = dilated * legal
        else:
            corridors[b] = F.max_pool3d(access[None,None], 2*corridor_width+1, 1, corridor_width).squeeze() * legal
    return corridors, access_dists

In [None]:
class LocalLegalityLoss(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.sl = config['street_levels']

    def compute_legality_field(self, state):
        cfg, G = self.config, self.config['grid_size']
        existing = state[:, cfg['ch_existing']]
        anchors = state[:, cfg['ch_anchors']]
        B = state.shape[0]
        z_idx = torch.arange(G, device=state.device).view(1, G, 1, 1).expand(B, G, G, G)
        above = (z_idx >= self.sl).float()
        at = (z_idx < self.sl).float()
        return torch.clamp((1 - existing) * (above + at * anchors), 0, 1)

    def forward(self, state):
        struct = state[:, self.config['ch_structure']]
        legal = self.compute_legality_field(state)
        return (struct * (1 - legal)).sum() / (struct.sum() + 1e-8)

class AccessSeededCoverageLoss(nn.Module):
    def __init__(self, config): super().__init__(); self.config = config
    def forward(self, struct, corridor, access_dist):
        weight = 1.0 + 2.0 * torch.exp(-access_dist / 5.0)
        unfilled = corridor * (1 - struct) * weight
        return unfilled.sum() / ((corridor * weight).sum() + 1e-8)

class ProgressiveSpillLoss(nn.Module):
    def __init__(self, config): super().__init__(); self.config = config
    def forward(self, struct, corridor, legality, access_dist, use_sqrt=False):
        outside = struct * (1 - corridor) * legality
        if use_sqrt:
            dist_weight = torch.sqrt(access_dist.clamp(min=1.0) / 5.0)
            outside = outside * dist_weight
        return outside.sum() / (struct.sum() + 1e-8)

class ThicknessLoss(nn.Module):
    def __init__(self, max_thickness=2): super().__init__(); self.mt = max_thickness
    def forward(self, struct):
        soft = torch.sigmoid(10 * (struct - 0.3))
        core = soft
        for _ in range(self.mt):
            core = -F.max_pool3d(-core.unsqueeze(1), 3, 1, 1).squeeze(1)
        return core.sum() / (soft.sum() + 1e-8)

class SparsityLoss(nn.Module):
    def __init__(self, max_r=0.15, min_r=0.05): super().__init__(); self.max_r, self.min_r = max_r, min_r
    def forward(self, struct, available):
        r = struct.sum() / (available.sum() + 1e-8)
        return F.relu(r - self.max_r) * 500 + F.relu(self.min_r - r) * 30

class LoadPathLoss(nn.Module):
    def __init__(self, config, iters=32):
        super().__init__()
        self.config, self.sl, self.iters = config, config['street_levels'], iters
    def forward(self, state):
        cfg = self.config
        struct = state[:, cfg['ch_structure']]
        existing = state[:, cfg['ch_existing']]
        anchors = state[:, cfg['ch_anchors']]
        support = torch.cat([torch.max(existing[:, :self.sl], anchors[:, :self.sl]), existing[:, self.sl:]], dim=1)
        connected = support.clone()
        soft = torch.sigmoid(10 * (struct - 0.3))
        for _ in range(self.iters):
            dil = F.max_pool3d(connected.unsqueeze(1), 3, 1, 1).squeeze(1)
            new_conn = torch.max(connected, dil * soft)
            if torch.allclose(connected, new_conn, atol=1e-5): break
            connected = new_conn
        elevated = struct[:, self.sl:]
        unsup = elevated * (1 - connected[:, self.sl:])
        return unsup.sum() / (elevated.sum() + 1e-8)

In [None]:
def compute_metrics(state, corridor, config):
    struct = state[:, config['ch_structure']]
    existing = state[:, config['ch_existing']]
    leg = LocalLegalityLoss(config).compute_legality_field(state)
    return {
        'coverage': (struct * corridor).sum() / (corridor.sum() + 1e-8),
        'spill': (struct * (1-corridor) * leg).sum() / (struct.sum() + 1e-8),
        'thickness': 1.0 - ThicknessLoss(config.get('max_thickness', 2))(struct),
        'fill': struct.sum() / ((1-existing).sum() + 1e-8),
        'legality': (struct * leg).sum() / (struct.sum() + 1e-8),
        'loadpath': 1.0 - LoadPathLoss(config)(state),
    }

In [None]:
class TrainerV65:
    def __init__(self, model, config, device):
        self.model, self.config, self.device = model, config, device
        self.legality = LocalLegalityLoss(config)
        self.coverage = AccessSeededCoverageLoss(config)
        self.spill = ProgressiveSpillLoss(config)
        self.thickness = ThicknessLoss(config.get('max_thickness', 2))
        self.sparsity = SparsityLoss()
        self.loadpath = LoadPathLoss(config)
        self.base_weights = {'legality': 30, 'coverage': 35, 'spill': 50, 'thickness': 15, 'sparsity': 20, 'loadpath': 8}
        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['lr_initial'])
        self.scene_gen = UrbanSceneGenerator(config)
        self.history = []

    def get_spill_weight(self, epoch):
        cfg = self.config
        g_end, s_end = cfg['growth_phase_end'], cfg['sculpt_phase_end']
        w_min, w_g, w_max = cfg['spill_weight_min'], cfg['spill_weight_growth'], cfg['spill_weight_max']
        if epoch < g_end:
            return w_min + (w_g - w_min) * (epoch / g_end)
        else:
            prog = min(1.0, (epoch - g_end) / (s_end - g_end))
            return w_g + (w_max - w_g) * prog

    def train_epoch(self, epoch):
        self.model.train()
        cfg = self.config
        seeds = self.scene_gen.batch(cfg['difficulty'], cfg['batch_size'], self.device)
        with torch.no_grad():
            corridor, access_dist = compute_corridor_and_access_distance(
                seeds, cfg, cfg.get('corridor_width', 4), cfg.get('vertical_envelope', 2))
        steps = random.randint(cfg['steps_min'], cfg['steps_max'])
        final = self.model(seeds, steps=steps)
        struct = final[:, cfg['ch_structure']]
        available = 1.0 - final[:, cfg['ch_existing']]
        leg_field = self.legality.compute_legality_field(final)
        
        spill_w = self.get_spill_weight(epoch)
        use_sqrt = epoch >= cfg['growth_phase_end']
        
        L = {
            'legality': self.legality(final),
            'coverage': self.coverage(struct, corridor, access_dist),
            'spill': self.spill(struct, corridor, leg_field, access_dist, use_sqrt),
            'thickness': self.thickness(struct),
            'sparsity': self.sparsity(struct, available),
            'loadpath': self.loadpath(final),
        }
        weights = dict(self.base_weights)
        weights['spill'] = spill_w
        total = sum(weights[k] * v for k, v in L.items())
        
        self.optimizer.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg['grad_clip'])
        self.optimizer.step()
        
        with torch.no_grad():
            m = compute_metrics(final, corridor, cfg)
            phase = 'GROWTH' if epoch < cfg['growth_phase_end'] else 'SCULPT'
            metrics = {
                'epoch': epoch, 'phase': phase, 'total_loss': total.item(),
                'coverage': m['coverage'].item(), 'spill': m['spill'].item(),
                'spill_weight': spill_w, 'thickness': m['thickness'].item(),
                'fill_ratio': m['fill'].item(),
            }
        self.history.append(metrics)
        return metrics

    def save(self, path):
        torch.save({'model': self.model.state_dict(), 'history': self.history}, path)

In [None]:
def visualize(model, scene_gen, config, device, title=''):
    model.eval()
    scene, _ = scene_gen.generate(config['difficulty'], device)
    with torch.no_grad():
        corridor, _ = compute_corridor_and_access_distance(scene, config, config.get('corridor_width', 4), config.get('vertical_envelope', 2))
        grown = model.grow(scene, steps=50)
    cfg, G = config, config['grid_size']
    s = grown[0].cpu().numpy()
    existing = s[cfg['ch_existing']] > 0.5
    access = s[cfg['ch_access']] > 0.5
    structure = s[cfg['ch_structure']] > 0.5
    corr = corridor[0].cpu().numpy() > 0.5
    leg = LocalLegalityLoss(config).compute_legality_field(grown)[0].cpu().numpy()
    in_corr = structure & corr & (leg >= 0.5)
    outside = structure & ~corr & (leg >= 0.5)
    
    fig = plt.figure(figsize=(16, 4))
    ax1 = fig.add_subplot(141, projection='3d')
    if existing.any(): ax1.voxels(existing.transpose(1,2,0), facecolors='gray', alpha=0.3)
    if access.any(): ax1.voxels(access.transpose(1,2,0), facecolors='green', alpha=0.9)
    if outside.any(): ax1.voxels(outside.transpose(1,2,0), facecolors='orange', alpha=0.6)
    if in_corr.any(): ax1.voxels(in_corr.transpose(1,2,0), facecolors='royalblue', alpha=0.6)
    ax1.set_title(title)
    
    ax2, ax3 = fig.add_subplot(142), fig.add_subplot(143)
    plan, elev = np.zeros((G,G,3)), np.zeros((G,G,3))
    plan[existing[0]] = [0.5,0.5,0.5]; plan[corr[0] & ~existing[0]] = [0.8,0.9,1.0]
    plan[in_corr.max(axis=0)] = [0.2,0.4,0.8]; plan[outside.max(axis=0)] = [1.0,0.6,0.2]
    elev[existing.max(axis=1)] = [0.5,0.5,0.5]; elev[in_corr.max(axis=1)] = [0.2,0.4,0.8]
    elev[outside.max(axis=1)] = [1.0,0.6,0.2]
    ax2.imshow(plan.transpose(1,0,2), origin='lower'); ax2.set_title('Ground')
    ax3.imshow(elev.transpose(1,0,2), origin='lower'); ax3.set_title('Elevation')
    
    ax4 = fig.add_subplot(144); ax4.axis('off')
    m = compute_metrics(grown, corridor, config)
    txt = f"""METRICS (v6.5)
Coverage: {m['coverage'].item()*100:.1f}%
Spill: {m['spill'].item()*100:.1f}%
Thickness: {m['thickness'].item()*100:.1f}%
Fill: {m['fill'].item()*100:.1f}%"""
    ax4.text(0.1, 0.9, txt, transform=ax4.transAxes, fontsize=10, va='top', family='monospace')
    plt.tight_layout(); plt.show()

In [None]:
model = UrbanPavilionNCA(CONFIG).to(device)
trainer = TrainerV65(model, CONFIG, device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')
print(f"\nPhase 1 (GROWTH): 0-{CONFIG['growth_phase_end']}, spill: 0->10")
print(f"Phase 2 (SCULPT): {CONFIG['growth_phase_end']}-{CONFIG['sculpt_phase_end']}, spill: 10->50 (sqrt)")

In [None]:
visualize(model, trainer.scene_gen, CONFIG, device, 'Before Training')

In [None]:
print('='*70)
print('v6.5 TRAINING - EXTENDED 2-PHASE CURRICULUM (1500 epochs)')
print('='*70)

for epoch in tqdm(range(CONFIG['epochs']), desc='Training'):
    m = trainer.train_epoch(epoch)
    if epoch % CONFIG['log_every'] == 0:
        tqdm.write(
            f"E{epoch:4d} [{m['phase']:6s}] | Loss:{m['total_loss']:6.1f} | "
            f"Cov:{m['coverage']*100:4.0f}% | Spl:{m['spill']*100:4.0f}% (w={m['spill_weight']:4.1f}) | "
            f"Thk:{m['thickness']*100:4.0f}% | Fill:{m['fill_ratio']*100:5.1f}%")
    if epoch > 0 and epoch % CONFIG['viz_every'] == 0:
        visualize(model, trainer.scene_gen, CONFIG, device, f'Epoch {epoch} ({m["phase"]})')
    if epoch > 0 and epoch % CONFIG['save_every'] == 0:
        trainer.save(f"{PROJECT_ROOT}/step_b/checkpoints/v65_epoch_{epoch}.pth")

print('Training complete')

In [None]:
# Plot training curves
epochs = [h['epoch'] for h in trainer.history]
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
g_end = CONFIG['growth_phase_end']

axes[0,0].plot(epochs, [h['total_loss'] for h in trainer.history])
axes[0,0].axvline(g_end, color='r', ls='--', alpha=0.5); axes[0,0].set_title('Total Loss'); axes[0,0].set_yscale('log')

axes[0,1].plot(epochs, [h['coverage'] for h in trainer.history], 'b-')
axes[0,1].axhline(0.70, color='k', ls='--'); axes[0,1].axvline(g_end, color='r', ls='--', alpha=0.5)
axes[0,1].set_title('Coverage (>70%)'); axes[0,1].set_ylim(0, 1.1)

axes[0,2].plot(epochs, [h['spill'] for h in trainer.history], 'orange')
axes[0,2].plot(epochs, [h['spill_weight']/50 for h in trainer.history], 'r--', alpha=0.5)
axes[0,2].axhline(0.20, color='k', ls='--'); axes[0,2].axvline(g_end, color='r', ls='--', alpha=0.5)
axes[0,2].set_title('Spill (<20%)'); axes[0,2].set_ylim(0, 1.1)

axes[1,0].plot(epochs, [h['thickness'] for h in trainer.history], 'purple')
axes[1,0].axhline(0.90, color='k', ls='--'); axes[1,0].axvline(g_end, color='r', ls='--', alpha=0.5)
axes[1,0].set_title('Thickness (>90%)'); axes[1,0].set_ylim(0, 1.1)

axes[1,1].plot(epochs, [h['fill_ratio'] for h in trainer.history], 'g-')
axes[1,1].axhline(0.15, color='r', ls='--'); axes[1,1].axhline(0.05, color='g', ls='--')
axes[1,1].axvline(g_end, color='r', ls='--', alpha=0.5); axes[1,1].set_title('Fill (5-15%)'); axes[1,1].set_ylim(0, 0.3)

axes[1,2].plot(epochs, [h['spill_weight'] for h in trainer.history], 'r-')
axes[1,2].axvline(g_end, color='r', ls='--', alpha=0.5); axes[1,2].set_title('Spill Weight')

plt.tight_layout(); plt.show()

In [None]:
for i in range(3):
    visualize(model, trainer.scene_gen, CONFIG, device, f'Final {i+1}')

In [None]:
trainer.save(f"{PROJECT_ROOT}/step_b/checkpoints/v65_final.pth")
print('Done!')