# NB02: Minimal-Volume Connectivity v5.0

Goal: connect access points with the smallest voxel volume.
- Context buildings + access points same as previous notebook.
- Bulk is allowed near access points.
- Elsewhere, volume should be minimal.
- No other constraints.

## 1. Setup

In [None]:

import os, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from typing import Tuple, List

PROJECT_ROOT = os.getcwd()
print('Project root:', PROJECT_ROOT)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)


In [None]:
from pathlib import Path
# Resolve config_step_b.json in common locations
candidates = [
    Path(PROJECT_ROOT) / 'config_step_b.json',
    Path(PROJECT_ROOT) / 'Constraint-Based-Architectural-NCA' / 'config_step_b.json',
]

config_path = None
for c in candidates:
    if c.exists():
        config_path = c
        break

if config_path is None:
    # fallback: search up to 3 levels
    here = Path(PROJECT_ROOT).resolve()
    for _ in range(4):
        test = here / 'config_step_b.json'
        if test.exists():
            config_path = test
            break
        here = here.parent

if config_path is None:
    raise FileNotFoundError('config_step_b.json not found; set PROJECT_ROOT to repo root')

with open(config_path, 'r') as f:
    CONFIG = json.load(f)


## 2. Foundation Components

In [None]:

class Perceive3D(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.n_channels = n_channels
        self.kernel = torch.tensor([[[0,0,0],[0,1,0],[0,0,0]],
                                    [[0,1,0],[1,-6,1],[0,1,0]],
                                    [[0,0,0],[0,1,0],[0,0,0]]], dtype=torch.float32)

    def forward(self, x):
        B, C, D, H, W = x.shape
        k = self.kernel.to(x.device).view(1,1,3,3,3)
        lap = F.conv3d(x, k.repeat(C,1,1,1,1), padding=1, groups=C)
        gradx = F.conv3d(x, torch.tensor([[[[-1,0,1]]]], device=x.device).view(1,1,1,1,3).repeat(C,1,1,1,1), padding=(0,0,1), groups=C)
        grady = F.conv3d(x, torch.tensor([[[[-1],[0],[1]]]], device=x.device).view(1,1,1,3,1).repeat(C,1,1,1,1), padding=(0,1,0), groups=C)
        return torch.cat([x, lap, gradx, grady], dim=1)


In [None]:

class UrbanPavilionNCA(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.config = config

        n_channels = config['n_channels']
        hidden_dim = config['hidden_dim']
        perception_dim = n_channels * 4
        n_grown = config['n_grown']

        self.perceive = Perceive3D(n_channels)

        self.update_net = nn.Sequential(
            nn.Conv3d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv3d(hidden_dim, n_grown, 1),
        )

        self._init_weights()

    def _init_weights(self):
        gain = self.config['xavier_gain']
        for m in self.update_net:
            if isinstance(m, nn.Conv3d):
                nn.init.xavier_uniform_(m.weight, gain=gain)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        last_layer = self.update_net[-1]
        with torch.no_grad():
            last_layer.bias[0] = self.config['structure_bias']
            last_layer.bias[1] = self.config['surface_bias']

    def forward(self, state: torch.Tensor, steps: int = 1) -> torch.Tensor:
        for _ in range(steps):
            state = self._step(state)
        return state

    def _step(self, state: torch.Tensor) -> torch.Tensor:
        B, C, D, H, W = state.shape
        cfg = self.config

        perception = self.perceive(state)
        delta = self.update_net(perception)

        if self.training:
            fire_mask = (torch.rand(B, 1, D, H, W, device=state.device) < cfg['fire_rate']).float()
            delta = delta * fire_mask

        grown_start = cfg['n_frozen']
        grown_new = state[:, grown_start:] + cfg['update_scale'] * delta
        grown_new = torch.clamp(grown_new, 0.0, 1.0)

        existing = state[:, cfg['ch_existing']:cfg['ch_existing']+1]
        struct_new = grown_new[:, 0:1] * (1.0 - existing)

        grown_masked = torch.cat([struct_new, grown_new[:, 1:]], dim=1)
        new_state = torch.cat([state[:, :grown_start], grown_masked], dim=1)
        return new_state

    def grow(self, seed: torch.Tensor, steps: int = 50) -> torch.Tensor:
        self.eval()
        with torch.no_grad():
            return self.forward(seed, steps)


In [None]:

class UrbanSceneGenerator:
    def __init__(self, config: dict):
        self.config = config
        self.G = config['grid_size']
        self.C = config['n_channels']

    def generate(self, difficulty: str = 'easy', device: str = 'cuda') -> Tuple[torch.Tensor, dict]:
        G = self.G
        cfg = self.config

        state = torch.zeros(1, self.C, G, G, G, device=device)
        state[:, cfg['ch_ground'], 0, :, :] = 1.0

        params = self._get_difficulty_params(difficulty)
        building_info = self._place_buildings(state, params)
        access_info = self._place_access_points(state, params, building_info)

        metadata = {
            'difficulty': difficulty,
            'buildings': building_info,
            'access_points': access_info,
            'gap_width': params['gap_width'],
        }

        return state, metadata

    def _get_difficulty_params(self, difficulty: str) -> dict:
        G = self.G
        hmin, hmax = self.config.get('height_range_override', (12, 16))
        if difficulty == 'easy':
            return {
                'n_buildings': 2, 'height_range': (hmin, hmax), 'height_variance': False,
                'width_range': (8, 12), 'gap_width': random.randint(14, 18),
                'n_ground_access': 1, 'n_elevated_access': 1,
            }
        elif difficulty == 'medium':
            return {
                'n_buildings': 2, 'height_range': (hmin, hmax), 'height_variance': True,
                'width_range': (6, 10), 'gap_width': random.randint(10, 14),
                'n_ground_access': random.randint(1, 2), 'n_elevated_access': random.randint(1, 2),
            }
        else:
            return {
                'n_buildings': random.randint(2, 4), 'height_range': (hmin, hmax), 'height_variance': True,
                'width_range': (5, 8), 'gap_width': random.randint(6, 10),
                'n_ground_access': random.randint(2, 3), 'n_elevated_access': random.randint(2, 3),
            }

    def _place_buildings(self, state: torch.Tensor, params: dict) -> list:
        G = self.G
        ch = self.config['ch_existing']
        buildings = []
        gap_width = params['gap_width']
        gap_center = G // 2

        w1 = random.randint(*params['width_range'])
        d1 = random.randint(G//2, G-2)
        h1 = random.randint(*params['height_range'])
        x1_end = gap_center - gap_width // 2
        x1_start = max(0, x1_end - w1)
        state[:, ch, :h1, :d1, x1_start:x1_end] = 1.0
        buildings.append({'x': (x1_start, x1_end), 'y': (0, d1), 'z': (0, h1),
                         'gap_facing_x': x1_end, 'side': 'left'})

        w2 = random.randint(*params['width_range'])
        d2 = random.randint(G//2, G-2)
        h2 = h1 if not params['height_variance'] else random.randint(*params['height_range'])
        x2_start = gap_center + gap_width // 2
        x2_end = min(G, x2_start + w2)
        state[:, ch, :h2, :d2, x2_start:x2_end] = 1.0
        buildings.append({'x': (x2_start, x2_end), 'y': (0, d2), 'z': (0, h2),
                         'gap_facing_x': x2_start, 'side': 'right'})

        return buildings

    def _place_access_points(self, state: torch.Tensor, params: dict, buildings: list) -> list:
        G = self.G
        ch = self.config['ch_access']
        access_points = []

        left_buildings = [b for b in buildings if b['side'] == 'left']
        right_buildings = [b for b in buildings if b['side'] == 'right']
        gap_x_min = max(b['gap_facing_x'] for b in left_buildings) if left_buildings else 0
        gap_x_max = min(b['gap_facing_x'] for b in right_buildings) if right_buildings else G

        for _ in range(params.get('n_ground_access', 1)):
            building = random.choice(buildings)
            is_left = building['side'] == 'left'
            facade_x = building['gap_facing_x'] if is_left else building['gap_facing_x'] - 1
            dist = random.randint(1, 3)
            x = facade_x + dist if is_left else facade_x - dist
            x = max(gap_x_min + 1, min(gap_x_max - 3, x))

            y = random.randint(building['y'][0], min(building['y'][1] - 2, building['y'][0] + G // 2))
            y = max(0, min(G - 3, y))

            state[:, ch, 0:2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': 0, 'type': 'ground'})

        for _ in range(params.get('n_elevated_access', 1)):
            building = random.choice(buildings)
            bz_max = building['z'][1]
            is_left = building['side'] == 'left'

            z = random.randint(3, max(4, bz_max - 2))
            y = random.randint(building['y'][0], min(building['y'][1] - 2, building['y'][0] + G//3))
            x = building['x'][1] if is_left else building['x'][0] - 2
            x = max(0, min(G - 2, x))

            state[:, ch, z:z+2, y:y+2, x:x+2] = 1.0
            access_points.append({'x': x, 'y': y, 'z': z, 'type': 'elevated'})

        return access_points

    def batch(self, difficulty: str, batch_size: int, device: str) -> torch.Tensor:
        scenes = [self.generate(difficulty, device)[0] for _ in range(batch_size)]
        return torch.cat(scenes, dim=0)


## 3. Corridor Target

In [None]:

def _extract_access_centroids(access_mask):
    D, H, W = access_mask.shape
    visited = set()
    centroids = []

    def neighbors(z, y, x):
        for dz, dy, dx in [(1,0,0),(-1,0,0),(0,1,0),(0,-1,0),(0,0,1),(0,0,-1)]:
            nz, ny, nx = z+dz, y+dy, x+dx
            if 0 <= nz < D and 0 <= ny < H and 0 <= nx < W:
                yield nz, ny, nx

    for z in range(D):
        for y in range(H):
            for x in range(W):
                if access_mask[z, y, x] < 0.5:
                    continue
                if (z, y, x) in visited:
                    continue
                stack = [(z, y, x)]
                vox = []
                visited.add((z, y, x))
                while stack:
                    cz, cy, cx = stack.pop()
                    vox.append((cz, cy, cx))
                    for nz, ny, nx in neighbors(cz, cy, cx):
                        if access_mask[nz, ny, nx] < 0.5:
                            continue
                        if (nz, ny, nx) in visited:
                            continue
                        visited.add((nz, ny, nx))
                        stack.append((nz, ny, nx))
                if vox:
                    zz = sum(v[0] for v in vox) / len(vox)
                    yy = sum(v[1] for v in vox) / len(vox)
                    xx = sum(v[2] for v in vox) / len(vox)
                    centroids.append((int(round(zz)), int(round(yy)), int(round(xx))))
    return centroids


def compute_distance_field_3d(start_points, legal_mask, max_iters=64):
    D, H, W = legal_mask.shape
    device = legal_mask.device
    distance = torch.full((D, H, W), float('inf'), device=device)

    for z, y, x in start_points:
        if 0 <= z < D and 0 <= y < H and 0 <= x < W:
            distance[z, y, x] = 0

    for _ in range(max_iters):
        dist_4d = distance.unsqueeze(0).unsqueeze(0)
        expanded = -F.max_pool3d(-dist_4d, 3, 1, 1).squeeze(0).squeeze(0) + 1
        new_distance = torch.where(legal_mask > 0.5, torch.min(distance, expanded), distance)
        if torch.allclose(distance, new_distance, atol=1e-5):
            break
        distance = new_distance

    return distance


def compute_corridor_target_v5(seed_state: torch.Tensor, config: dict,
                               corridor_width: int = 1, vertical_envelope: int = 1) -> torch.Tensor:
    cfg = config
    G = cfg['grid_size']
    device = seed_state.device
    B = seed_state.shape[0]

    corridors = torch.zeros(B, G, G, G, device=device)

    for b in range(B):
        access = seed_state[b, cfg['ch_access']]
        existing = seed_state[b, cfg['ch_existing']]
        legal_mask = 1.0 - existing

        centroids = _extract_access_centroids(access.detach().cpu())
        if len(centroids) < 2:
            dilated = F.max_pool3d(access.unsqueeze(0).unsqueeze(0),
                                   2*corridor_width+1, 1, corridor_width)
            corridors[b] = dilated.squeeze() * legal_mask
            continue

        corridor_mask = torch.zeros(G, G, G, device=device)
        for i in range(len(centroids)):
            for j in range(i+1, len(centroids)):
                start = centroids[i]
                end = centroids[j]

                dist_from_start = compute_distance_field_3d([start], legal_mask)
                dist_from_end = compute_distance_field_3d([end], legal_mask)
                total_dist = dist_from_start[end[0], end[1], end[2]]
                if total_dist == float('inf'):
                    continue

                path_cost = dist_from_start + dist_from_end
                slack = corridor_width
                on_path = (path_cost <= total_dist + slack).float()
                corridor_mask = torch.max(corridor_mask, on_path)

        corridor_4d = corridor_mask.unsqueeze(0).unsqueeze(0)
        dilated = F.max_pool3d(corridor_4d, 2*corridor_width+1, 1, corridor_width)
        corridor_dilated = dilated.squeeze()

        if vertical_envelope > 0:
            corridor_4d = corridor_dilated.unsqueeze(0).unsqueeze(0)
            corridor_dilated = F.max_pool3d(
                corridor_4d,
                kernel_size=(2 * vertical_envelope + 1, 1, 1),
                stride=1,
                padding=(vertical_envelope, 0, 0)
            ).squeeze()

        corridors[b] = corridor_dilated * legal_mask

    return corridors


## 4. Losses (Minimal)

In [None]:

class CorridorCoverageLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, structure: torch.Tensor, corridor_target: torch.Tensor) -> torch.Tensor:
        filled = (structure * corridor_target).sum()
        return 1.0 - (filled / (corridor_target.sum() + 1e-8))


class CorridorSpillLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, structure: torch.Tensor, corridor_target: torch.Tensor) -> torch.Tensor:
        outside = structure * (1 - corridor_target)
        return outside.sum() / (structure.sum() + 1e-8)


class AccessProximitySparsity(nn.Module):
    def __init__(self, config: dict):
        super().__init__()
        self.config = config

    def forward(self, structure: torch.Tensor, access: torch.Tensor, existing: torch.Tensor) -> torch.Tensor:
        B, D, H, W = structure.shape
        device = structure.device
        loss = 0.0
        for b in range(B):
            centroids = _extract_access_centroids(access[b].detach().cpu())
            if not centroids:
                continue
            z_idx = torch.arange(D, device=device).view(D,1,1)
            y_idx = torch.arange(H, device=device).view(1,H,1)
            x_idx = torch.arange(W, device=device).view(1,1,W)

            dists = []
            for cz, cy, cx in centroids:
                d = (z_idx - cz).abs() + (y_idx - cy).abs() + (x_idx - cx).abs()
                dists.append(d)
            dist = torch.stack(dists).min(dim=0)[0]

            near = dist <= self.config.get('access_near_radius', 3)
            far = ~near
            legal = (1.0 - existing[b]) > 0.5

            far_weight = self.config.get('access_far_weight', 2.0)
            if (far & legal).any():
                loss = loss + (structure[b][far & legal].mean() * far_weight)
        return loss / max(1, B)


## 5. Trainer

In [None]:

class MinimalConnectivityTrainer:
    def __init__(self, model: nn.Module, config: dict, device: str):
        self.model = model
        self.config = config
        self.device = device

        self.coverage_loss = CorridorCoverageLoss()
        self.spill_loss = CorridorSpillLoss()
        self.sparsity_loss = AccessProximitySparsity(config)

        self.weights = {
            'coverage': 50.0,
            'spill': 20.0,
            'sparsity': 40.0,
        }

        self.optimizer = torch.optim.Adam(model.parameters(), lr=config['lr_initial'])
        self.scene_gen = UrbanSceneGenerator(config)
        self.history = []

    def train_epoch(self, epoch: int) -> dict:
        self.model.train()
        cfg = self.config

        seeds = self.scene_gen.batch(cfg.get('difficulty', 'easy'), cfg['batch_size'], self.device)

        with torch.no_grad():
            corridor_target = compute_corridor_target_v5(
                seeds, cfg,
                corridor_width=cfg.get('corridor_width', 1),
                vertical_envelope=cfg.get('vertical_envelope', 1)
            )

        seed_scale = cfg.get('corridor_seed_scale', 0.0)
        if seed_scale > 0:
            struct_idx = cfg['ch_structure']
            seeds[:, struct_idx] = torch.clamp(seeds[:, struct_idx] + seed_scale * corridor_target, 0.0, 1.0)

        steps = random.randint(cfg['steps_min'], cfg['steps_max'])
        final = self.model(seeds, steps=steps)

        structure = final[:, cfg['ch_structure']]
        access = final[:, cfg['ch_access']]
        existing = final[:, cfg['ch_existing']]

        L_cov = self.coverage_loss(structure, corridor_target)
        L_spill = self.spill_loss(structure, corridor_target)
        L_sparse = self.sparsity_loss(structure, access, existing)

        total_loss = (
            self.weights['coverage'] * L_cov +
            self.weights['spill'] * L_spill +
            self.weights['sparsity'] * L_sparse
        )

        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg['grad_clip'])
        self.optimizer.step()

        with torch.no_grad():
            metrics = {
                'epoch': epoch,
                'total_loss': total_loss.item(),
                'coverage': 1.0 - L_cov.item(),
                'spill': L_spill.item(),
                'fill_ratio': structure.sum().item() / ((1.0 - existing).sum().item() + 1e-8),
                'steps': steps,
            }

        self.history.append(metrics)
        return metrics


## 6. Visualization

In [None]:

def visualize_v5(model, scene_gen, config, device, title='Result'):
    model.eval()
    scene, meta = scene_gen.generate(config.get('difficulty', 'easy'), device)

    with torch.no_grad():
        corridor_target = compute_corridor_target_v5(
            scene, config,
            corridor_width=config.get('corridor_width', 1),
            vertical_envelope=config.get('vertical_envelope', 1)
        )
        seed_scale = config.get('corridor_seed_scale', 0.0)
        if seed_scale > 0:
            struct_idx = config['ch_structure']
            scene[:, struct_idx] = torch.clamp(scene[:, struct_idx] + seed_scale * corridor_target, 0.0, 1.0)
        grown = model.grow(scene, steps=50)

    cfg = config
    s = grown[0].cpu().numpy()
    G = s.shape[1]

    existing = s[cfg['ch_existing']] > 0.5
    access = s[cfg['ch_access']] > 0.5
    structure = s[cfg['ch_structure']] > 0.5
    corridor = corridor_target[0].cpu().numpy() > 0.5

    fig = plt.figure(figsize=(18, 5))

    ax1 = fig.add_subplot(131, projection='3d')
    if existing.any():
        ax1.voxels(existing.transpose(1,2,0), facecolors='gray', alpha=0.3)
    if corridor.any():
        ax1.voxels(corridor.transpose(1,2,0), facecolors='lightblue', alpha=0.2)
    if access.any():
        ax1.voxels(access.transpose(1,2,0), facecolors='green', alpha=0.9)
    if structure.any():
        ax1.voxels(structure.transpose(1,2,0), facecolors='royalblue', alpha=0.6)
    ax1.set_title(f'{title} (3D)')

    ax2 = fig.add_subplot(132)
    plan = np.zeros((G,G,3))
    plan[existing[0,:,:]] = [0.5,0.5,0.5]
    plan[corridor[0,:,:] & ~existing[0,:,:]] = [0.8,0.9,1.0]
    plan[structure.max(axis=0)] = [0.2,0.4,0.8]
    plan[access.max(axis=0)] = [0.2,0.8,0.2]
    ax2.imshow(plan.transpose(1,0,2), origin='lower')
    ax2.set_title('Ground Level')

    ax3 = fig.add_subplot(133)
    elev = np.zeros((G,G,3))
    elev[existing.max(axis=1)] = [0.5,0.5,0.5]
    elev[corridor.max(axis=1) & ~existing.max(axis=1)] = [0.8,0.9,1.0]
    elev[structure.max(axis=1)] = [0.2,0.4,0.8]
    elev[access.max(axis=1)] = [0.2,0.8,0.2]
    ax3.imshow(elev.transpose(1,0,2), origin='lower')
    ax3.set_title('Elevation')

    plt.tight_layout()
    plt.show()

    return grown, corridor_target, meta


## 7. Training

In [None]:

model = UrbanPavilionNCA(CONFIG).to(device)
trainer = MinimalConnectivityTrainer(model, CONFIG, device)

print('Model parameters:', sum(p.numel() for p in model.parameters()))
print('Weights:', trainer.weights)


In [None]:

print('Before training:')
visualize_v5(model, trainer.scene_gen, CONFIG, device, 'Before Training')


In [None]:

print('
' + '='*70)
print('STEP B TRAINING v5.0: MINIMAL VOLUME CONNECTIVITY')
print('='*70)

for epoch in range(CONFIG['epochs']):
    metrics = trainer.train_epoch(epoch)

    if epoch % CONFIG['log_every'] == 0:
        print(
            f"Epoch {epoch:4d} | Loss: {metrics['total_loss']:.2f} | "
            f"Cov: {metrics['coverage']*100:.0f}% | "
            f"Spill: {metrics['spill']*100:.0f}% | "
            f"Fill: {metrics['fill_ratio']*100:.1f}%"
        )

    if epoch > 0 and epoch % CONFIG['viz_every'] == 0:
        visualize_v5(model, trainer.scene_gen, CONFIG, device, f'Epoch {epoch}')


## 8. Evaluation

In [None]:

print('
Final Results:')
for i in range(3):
    visualize_v5(model, trainer.scene_gen, CONFIG, device, f'Final {i+1}')
