# Experiment 032 v4.2: Full Ablation with Fixed Wormhole

**Purpose**: Ablation study of SBM variants with proper wormhole attention.

**v4.2 Fixes over v4.1**:
1. **Pooled Wormhole**: Wormhole operates on global pooled features [B, D], not per-pixel [B, H, W, D]
   - This reduces memory from O(H*W * T*H*W) to O(T) - massive savings
   - Conceptually correct: wormhole finds globally similar states, not per-pixel matches
2. **Lower threshold**: 0.9 instead of 0.9995
3. **Sequence-based training**: Proper temporal history within trajectories

**Models to test**:
- SBM_2B_None, SBM_3B_None, SBM_7B_None (band count ablation)
- SBM_3B_Temporal, SBM_3B_Neighbor, SBM_3B_Wormhole (attention type ablation)
- SBM_3B_All (all attention types combined)
- Flat_None, Flat_Temporal (baselines)

**Run on Colab with GPU**: Runtime -> Change runtime type -> A100

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GRID_SIZE = 64
EPOCHS = 50
PREDICT_DELTA = True
PREDICTION_HORIZON = 5
NUM_TRAJECTORIES = 32  # Batch size = number of parallel trajectories
TRAJECTORY_LENGTH = 100  # Steps per trajectory
LR = 0.001

# Attention config
HISTORY_LEN = 8
TOP_K_TEMPORAL = 4
TEMPORAL_DECAY = 0.9
NEIGHBOR_RANGE = 1  # 3x3 neighborhood

# Wormhole - FIXED
WORMHOLE_THRESHOLD = 0.9  # Was 0.9995
WORMHOLE_MAX_CONN = 4

print(f"Device: {DEVICE}")
print(f"Grid: {GRID_SIZE}x{GRID_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Batch: {NUM_TRAJECTORIES} trajectories x {TRAJECTORY_LENGTH} steps")
print(f"Prediction: delta t+{PREDICTION_HORIZON}")
print(f"\nWormhole: threshold={WORMHOLE_THRESHOLD}, max_conn={WORMHOLE_MAX_CONN}")
print(f"KEY FIX: Pooled wormhole (global features, not per-pixel)")

## 1. Environment

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional
import math
import time

@dataclass
class PlasmaConfig:
    height: int = 64
    width: int = 64
    diffusion: float = 0.25
    advection: float = 0.08
    noise_std: float = 0.02
    disturbance_prob: float = 0.1
    disturbance_strength: float = 0.15
    num_vortices: int = 3
    vortex_strength: float = 0.1
    shear_strength: float = 0.05
    multiscale_noise: bool = True
    num_actuators: int = 9
    _base_actuator_sigma: float = 5.0
    device: str = "cpu"
    dtype: torch.dtype = torch.float32
    
    @property
    def actuator_sigma(self) -> float:
        return self._base_actuator_sigma * min(self.height, self.width) / 64.0
    
    @classmethod
    def turbulent(cls, device: str = "cpu", size: int = 64):
        return cls(height=size, width=size, diffusion=0.3, advection=0.12, noise_std=0.03,
                   disturbance_prob=0.25, disturbance_strength=0.3,
                   num_vortices=3, vortex_strength=0.15, shear_strength=0.08,
                   multiscale_noise=True, num_actuators=9, device=device)


class TurbulentPlasmaEnv:
    def __init__(self, cfg: PlasmaConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.dtype = cfg.dtype
        self._actuator_maps = self._build_actuator_maps()
        self._vortex_flow = self._build_vortex_flow()
        self._shear_flow = self._build_shear_flow()
    
    def _build_actuator_maps(self) -> torch.Tensor:
        h, w = self.cfg.height, self.cfg.width
        grid_n = int(math.ceil(math.sqrt(self.cfg.num_actuators)))
        centers_y = torch.linspace(h * 0.2, h * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers_x = torch.linspace(w * 0.2, w * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers = torch.cartesian_prod(centers_y, centers_x)[:self.cfg.num_actuators]
        sig2 = self.cfg.actuator_sigma ** 2
        yy, xx = torch.meshgrid(
            torch.arange(h, device=self.device, dtype=self.dtype),
            torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        bumps = [torch.exp(-((yy - cy) ** 2 + (xx - cx) ** 2) / (2 * sig2)) for cy, cx in centers]
        return torch.stack(bumps, dim=0)
    
    def _build_vortex_flow(self):
        if self.cfg.num_vortices == 0: return None, None
        h, w = self.cfg.height, self.cfg.width
        torch.manual_seed(42)
        centers_y = torch.rand(self.cfg.num_vortices, device=self.device) * (h * 0.6) + (h * 0.2)
        centers_x = torch.rand(self.cfg.num_vortices, device=self.device) * (w * 0.6) + (w * 0.2)
        yy, xx = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                                torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        vy, vx = torch.zeros_like(yy), torch.zeros_like(xx)
        scale = min(h, w) / 64.0
        for i in range(self.cfg.num_vortices):
            dy, dx = yy - centers_y[i], xx - centers_x[i]
            r2 = dy**2 + dx**2 + 1e-6
            decay = torch.exp(-r2 / (2 * (10 * scale)**2))
            sign = 1 if i % 2 == 0 else -1
            vy += sign * self.cfg.vortex_strength * (-dx) / torch.sqrt(r2) * decay
            vx += sign * self.cfg.vortex_strength * dy / torch.sqrt(r2) * decay
        return vy, vx
    
    def _build_shear_flow(self):
        if self.cfg.shear_strength == 0: return None, None
        h, w = self.cfg.height, self.cfg.width
        yy, _ = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                               torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        return torch.zeros_like(yy), self.cfg.shear_strength * torch.sin(2 * math.pi * yy / h)
    
    def _apply_flow(self, field, vy, vx):
        B, C, H, W = field.shape
        yy, xx = torch.meshgrid(torch.linspace(-1, 1, H, device=self.device),
                                torch.linspace(-1, 1, W, device=self.device), indexing="ij")
        grid = torch.stack([xx - vx/(W/2), yy - vy/(H/2)], dim=-1).unsqueeze(0).expand(B, -1, -1, -1)
        return F.grid_sample(field, grid, mode='bilinear', padding_mode='border', align_corners=True)
    
    def _multiscale_noise(self, shape):
        B, C, H, W = shape
        noise = torch.zeros(shape, device=self.device, dtype=self.dtype)
        for scale in [1, 2, 4, 8]:
            hs, ws = H // scale, W // scale
            if hs < 4: continue
            coarse = torch.randn(B, C, hs, ws, device=self.device, dtype=self.dtype)
            noise += F.interpolate(coarse, size=(H, W), mode='bilinear', align_corners=False) * (self.cfg.noise_std / scale)
        return noise
    
    def reset(self, batch_size=1):
        h, w = self.cfg.height, self.cfg.width
        cx = torch.randint(int(w*0.3), int(w*0.7), (batch_size,), device=self.device)
        cy = torch.randint(int(h*0.3), int(h*0.7), (batch_size,), device=self.device)
        yy, xx = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                                torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        sig2 = (self.cfg.actuator_sigma * 1.5) ** 2
        field = torch.zeros(batch_size, 1, h, w, device=self.device, dtype=self.dtype)
        for b in range(batch_size):
            field[b, 0] = torch.exp(-((yy - cy[b].float())**2 + (xx - cx[b].float())**2) / (2*sig2))
        return field
    
    def step(self, field, control, noise=True):
        B = field.shape[0]
        lap = (F.pad(field, (0,0,1,0))[:,:,:-1,:] + F.pad(field, (0,0,0,1))[:,:,1:,:] +
               F.pad(field, (1,0,0,0))[:,:,:,:-1] + F.pad(field, (0,1,0,0))[:,:,:,1:]) - 4*field
        diffused = field + self.cfg.diffusion * lap
        advected = torch.roll(diffused, shifts=(1,-1), dims=(2,3)) * self.cfg.advection + diffused * (1 - self.cfg.advection)
        if self._vortex_flow[0] is not None:
            advected = self._apply_flow(advected, self._vortex_flow[0], self._vortex_flow[1])
        if self._shear_flow[0] is not None:
            advected = self._apply_flow(advected, self._shear_flow[0], self._shear_flow[1])
        force = torch.einsum('ba,ahw->bhw', control, self._actuator_maps).unsqueeze(1)
        next_field = advected + force
        if noise:
            next_field = next_field + (self._multiscale_noise(next_field.shape) if self.cfg.multiscale_noise 
                                       else torch.randn_like(next_field) * self.cfg.noise_std)
        return torch.clamp(next_field, -1, 1)


def generate_trajectory_batch(env, batch_size: int, traj_len: int, control_scale: float = 0.1, horizon: int = 1):
    """Generate batch of trajectories for sequence-based training.
    
    Returns:
        fields: [T, B, 1, H, W] - time-first indexing
        targets: [T, B, 1, H, W]
    """
    all_fields = []
    all_targets = []
    
    field = env.reset(batch_size)
    trajectory = [field.clone()]
    
    for _ in range(traj_len + horizon):
        ctrl = torch.clamp(torch.randn(batch_size, env.cfg.num_actuators, device=env.device) * control_scale, -1, 1)
        field = env.step(field, ctrl)
        trajectory.append(field.clone())
    
    for t in range(traj_len):
        all_fields.append(trajectory[t])
        all_targets.append(trajectory[t + horizon])
    
    return torch.stack(all_fields, dim=0), torch.stack(all_targets, dim=0)

print("Environment defined.")

## 2. Attention Modules

In [None]:
# ============================================================================
# TEMPORAL ATTENTION: Top-K causal attention over time
# ============================================================================

class TemporalAttention(nn.Module):
    """Per-band temporal attention with Top-K selection."""
    
    def __init__(self, feature_dim: int, num_heads: int = 4, max_len: int = 16,
                 top_k: int = 4, decay: float = 0.9, device: str = 'cpu'):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads
        self.top_k = top_k
        self.decay = decay
        
        self.qkv = nn.Linear(feature_dim, 3 * feature_dim)
        self.out = nn.Linear(feature_dim, feature_dim)
        self.register_buffer("causal_mask", torch.triu(torch.ones(max_len, max_len), 1).bool())
        self.to(device)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Args: x: [B, T, D]. Returns: [B, T, D]"""
        B, T, D = x.shape
        qkv = self.qkv(x).view(B, T, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        scores = scores.masked_fill(self.causal_mask[:T, :T].unsqueeze(0).unsqueeze(0), float('-inf'))
        
        if T > self.top_k:
            _, topk_idx = torch.topk(scores, self.top_k, dim=-1)
            mask = torch.ones_like(scores, dtype=torch.bool)
            mask.scatter_(-1, topk_idx, False)
            scores = scores.masked_fill(mask, float('-inf'))
        
        if self.decay < 1.0:
            time_offsets = torch.arange(T, device=x.device, dtype=x.dtype)
            time_diff = time_offsets.unsqueeze(0) - time_offsets.unsqueeze(1)
            decay_weights = torch.pow(self.decay, time_diff.clamp(min=0).float())
            decay_weights = torch.tril(decay_weights)
            scores = scores + torch.log(decay_weights + 1e-10).unsqueeze(0).unsqueeze(0)
        
        attn = F.softmax(scores, dim=-1)
        out = (attn @ V).transpose(1, 2).reshape(B, T, D)
        return self.out(out)


# ============================================================================
# NEIGHBOR ATTENTION: Local 3x3 spatial attention
# ============================================================================

class NeighborAttention(nn.Module):
    """Local spatial attention within 3x3 neighborhood."""
    
    def __init__(self, feature_dim: int, num_heads: int = 4, layer_range: int = 1, device: str = 'cpu'):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_heads = num_heads
        self.head_dim = feature_dim // num_heads
        self.layer_range = layer_range
        self.window_size = 2 * layer_range + 1
        
        self.qkv = nn.Linear(feature_dim, 3 * feature_dim)
        self.out = nn.Linear(feature_dim, feature_dim)
        self.to(device)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Args: x: [B, H, W, D]. Returns: [B, H, W, D]"""
        B, H, W, D = x.shape
        
        # Pad for neighborhood extraction
        x_pad = F.pad(x.permute(0, 3, 1, 2), (self.layer_range,) * 4, mode='replicate')
        x_pad = x_pad.permute(0, 2, 3, 1)  # [B, H+2r, W+2r, D]
        
        # Extract neighborhoods
        neighbors = x_pad.unfold(1, self.window_size, 1).unfold(2, self.window_size, 1)
        neighbors = neighbors.reshape(B, H, W, D, -1).permute(0, 1, 2, 4, 3)  # [B, H, W, K, D]
        K_size = neighbors.shape[3]
        
        # QKV
        Q = self.qkv(x)[:, :, :, :D].reshape(B, H, W, self.num_heads, self.head_dim)
        Q = Q.unsqueeze(4)  # [B, H, W, heads, 1, head_dim]
        
        kv = self.qkv(neighbors)  # [B, H, W, K, 3D]
        K = kv[:, :, :, :, D:2*D].reshape(B, H, W, K_size, self.num_heads, self.head_dim)
        V = kv[:, :, :, :, 2*D:].reshape(B, H, W, K_size, self.num_heads, self.head_dim)
        K = K.permute(0, 1, 2, 4, 3, 5)  # [B, H, W, heads, K, head_dim]
        V = V.permute(0, 1, 2, 4, 3, 5)
        
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.head_dim)  # [B, H, W, heads, 1, K]
        attn = F.softmax(scores, dim=-1)
        out = (attn @ V).squeeze(4)  # [B, H, W, heads, head_dim]
        out = out.reshape(B, H, W, D)
        return self.out(out)


# ============================================================================
# WORMHOLE ATTENTION: Global pooled similarity (FIXED - no per-pixel)
# ============================================================================

class WormholeAttention(nn.Module):
    """Sparse non-local attention via cosine similarity on POOLED features.
    
    v4.2 FIX: Operates on global pooled features [B, D], not per-pixel.
    This is conceptually correct (wormhole finds globally similar states)
    and memory efficient (O(T) instead of O(H*W * T*H*W)).
    """
    
    def __init__(self, feature_dim: int, attn_dim: int, threshold: float = 0.9,
                 max_connections: int = 4, device: str = 'cpu'):
        super().__init__()
        self.feature_dim = feature_dim
        self.attn_dim = attn_dim
        self.threshold = threshold
        self.max_connections = max_connections
        
        self.W_q = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_k = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_v = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_o = nn.Linear(attn_dim, feature_dim, bias=False)
        self.to(device)
    
    def forward(self, query: torch.Tensor, history: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
        """Args:
            query: Current pooled features [B, D]
            history: Past pooled features [B, T, D]
        Returns:
            output: [B, D]
            stats: dict with connection info
        """
        B, D = query.shape
        T = history.shape[1] if history is not None else 0
        
        if T == 0:
            return torch.zeros_like(query), {'num_connections': 0, 'max_sim': 0.0}
        
        # Normalize for cosine similarity
        Q_norm = F.normalize(query, p=2, dim=-1)  # [B, D]
        K_norm = F.normalize(history, p=2, dim=-1)  # [B, T, D]
        
        # Cosine similarity: [B, T]
        sim = torch.bmm(Q_norm.unsqueeze(1), K_norm.transpose(1, 2)).squeeze(1)
        
        # Top-K selection
        K_conn = min(self.max_connections, T)
        topk_sim, topk_idx = torch.topk(sim, K_conn, dim=-1)  # [B, K]
        
        # Threshold mask
        mask = topk_sim > self.threshold  # [B, K]
        num_valid = mask.sum().item()
        
        if num_valid == 0:
            return torch.zeros_like(query), {
                'num_connections': 0, 
                'max_sim': sim.max().item()
            }
        
        # Gather selected history
        selected_hist = torch.gather(history, 1, topk_idx.unsqueeze(-1).expand(-1, -1, D))  # [B, K, D]
        
        # Attention
        Q = self.W_q(query).unsqueeze(1)  # [B, 1, attn_dim]
        K = self.W_k(selected_hist)  # [B, K, attn_dim]
        V = self.W_v(selected_hist)  # [B, K, attn_dim]
        
        scores = torch.bmm(Q, K.transpose(1, 2)).squeeze(1) / math.sqrt(self.attn_dim)  # [B, K]
        scores = scores.masked_fill(~mask, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = torch.where(mask, attn, torch.zeros_like(attn))
        
        out = torch.bmm(attn.unsqueeze(1), V).squeeze(1)  # [B, attn_dim]
        out = self.W_o(out)
        
        return out, {
            'num_connections': num_valid,
            'max_sim': topk_sim.max().item(),
            'mean_sim': topk_sim[mask].mean().item() if num_valid > 0 else 0.0,
            'conn_per_sample': mask.sum(dim=1).float().mean().item(),
        }

print("Attention modules defined.")

## 3. SBM and Flat Models

In [None]:
@dataclass
class SBMConfig:
    height: int = 64
    width: int = 64
    num_bands: int = 3
    channels: int = 16
    # Attention flags
    use_temporal: bool = False
    use_neighbor: bool = False
    use_wormhole: bool = False
    # Attention params
    attn_dim: int = 32
    num_heads: int = 4
    history_len: int = 8
    top_k: int = 4
    temporal_decay: float = 0.9
    neighbor_range: int = 1
    wormhole_threshold: float = 0.9
    wormhole_max_conn: int = 4
    device: str = 'cpu'


def make_radial_masks(h, w, num_bands, device):
    yy, xx = torch.meshgrid(
        torch.linspace(-1.0, 1.0, h, device=device),
        torch.linspace(-1.0, 1.0, w, device=device), indexing="ij")
    rr = torch.sqrt(yy ** 2 + xx ** 2).clamp(min=1e-6)
    edges = torch.logspace(-3, math.log10(math.sqrt(2)), steps=num_bands + 1, device=device)
    edges[0] = 0.0
    masks = [torch.sigmoid((rr - edges[i]) * 20) * torch.sigmoid((edges[i+1] - rr) * 20) for i in range(num_bands)]
    masks = torch.stack(masks, dim=0)
    return masks / (masks.sum(dim=0, keepdim=True) + 1e-8)


class SBMWithAttention(nn.Module):
    """Spectral Belief Machine with optional attention modules."""
    
    def __init__(self, cfg: SBMConfig):
        super().__init__()
        self.cfg = cfg
        
        self.register_buffer("window", torch.ones(cfg.height, cfg.width, device=cfg.device))
        self.register_buffer("masks", make_radial_masks(cfg.height, cfg.width, cfg.num_bands, cfg.device))
        
        # Per-band processing
        self.bands = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(2, cfg.channels, 3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, cfg.channels, 3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, 2, 1),
            ) for _ in range(cfg.num_bands)
        ])
        
        # Pooled feature dimension for wormhole
        self.pool_dim = cfg.num_bands * 2 * 4  # bands * (real,imag) * 2x2_pool
        
        # Attention modules
        if cfg.use_temporal:
            self.temporal_attn = TemporalAttention(
                feature_dim=2, num_heads=cfg.num_heads, max_len=cfg.history_len + 1,
                top_k=cfg.top_k, decay=cfg.temporal_decay, device=cfg.device
            )
        
        if cfg.use_neighbor:
            self.neighbor_attn = NeighborAttention(
                feature_dim=2, num_heads=cfg.num_heads,
                layer_range=cfg.neighbor_range, device=cfg.device
            )
        
        if cfg.use_wormhole:
            self.wormhole_attn = WormholeAttention(
                feature_dim=self.pool_dim, attn_dim=cfg.attn_dim,
                threshold=cfg.wormhole_threshold, max_connections=cfg.wormhole_max_conn,
                device=cfg.device
            )
            # Project wormhole output back to spatial
            self.wormhole_proj = nn.Linear(self.pool_dim, cfg.num_bands * 2)
        
        # Fusion if using spatial attention
        num_spatial_attn = sum([cfg.use_temporal, cfg.use_neighbor])
        if num_spatial_attn > 0:
            self.fusion = nn.Sequential(
                nn.Linear(2 * (1 + num_spatial_attn), cfg.attn_dim),
                nn.GELU(),
                nn.Linear(cfg.attn_dim, 2)
            )
        
        self.to(cfg.device)
    
    def forward(self, x: torch.Tensor, history: Optional[Dict] = None) -> Tuple[torch.Tensor, Dict]:
        """Args:
            x: [B, 1, H, W]
            history: Dict with 'bands': [B, T, num_bands, H, W, 2] and 'pooled': [B, T, pool_dim]
        Returns:
            pred: [B, 1, H, W]
            info: Dict with features for history
        """
        B, _, H, W = x.shape
        
        # FFT decompose
        fft = torch.fft.fftshift(torch.fft.fft2(x.squeeze(1) * self.window))
        
        proc_bands = []
        current_band_feats = []
        
        for i in range(self.cfg.num_bands):
            band = fft * self.masks[i].unsqueeze(0)
            band_feat = torch.stack([band.real, band.imag], dim=1)  # [B, 2, H, W]
            proc = self.bands[i](band_feat)
            processed = band_feat + proc  # Residual
            
            # Apply spatial attention if enabled and history available
            if history is not None and 'bands' in history and history['bands'].shape[1] > 0:
                band_history = history['bands'][:, :, i]  # [B, T, H, W, 2]
                current_hw = processed.permute(0, 2, 3, 1)  # [B, H, W, 2]
                
                attn_outputs = [current_hw]
                
                if self.cfg.use_temporal:
                    # Build temporal sequence
                    seq = torch.cat([band_history, current_hw.unsqueeze(1)], dim=1)  # [B, T+1, H, W, 2]
                    seq_flat = seq.reshape(B * H * W, -1, 2)  # [B*H*W, T+1, 2]
                    temp_out = self.temporal_attn(seq_flat)[:, -1]  # [B*H*W, 2]
                    attn_outputs.append(temp_out.reshape(B, H, W, 2))
                
                if self.cfg.use_neighbor:
                    neigh_out = self.neighbor_attn(current_hw)
                    attn_outputs.append(neigh_out)
                
                if len(attn_outputs) > 1:
                    fused = self.fusion(torch.cat(attn_outputs, dim=-1))
                    processed = processed + 0.1 * fused.permute(0, 3, 1, 2)
            
            proc_bands.append(processed)
            current_band_feats.append(processed.permute(0, 2, 3, 1))  # [B, H, W, 2]
        
        # Pooled features for wormhole
        pooled_bands = [F.adaptive_avg_pool2d(b, (2, 2)).flatten(1) for b in proc_bands]
        current_pooled = torch.cat(pooled_bands, dim=1)  # [B, pool_dim]
        
        # Wormhole attention on pooled features
        wormhole_stats = {'num_connections': 0, 'max_sim': 0.0}
        if self.cfg.use_wormhole and history is not None and 'pooled' in history:
            wh_out, wormhole_stats = self.wormhole_attn(current_pooled, history['pooled'])
            # Project back and add to bands
            wh_delta = self.wormhole_proj(wh_out)  # [B, num_bands * 2]
            for i in range(self.cfg.num_bands):
                delta_i = wh_delta[:, i*2:(i+1)*2].unsqueeze(-1).unsqueeze(-1)  # [B, 2, 1, 1]
                proc_bands[i] = proc_bands[i] + 0.1 * delta_i
        
        # Reconstruct
        fft_recon = sum(
            torch.complex(b[:, 0], b[:, 1]) * self.masks[i].unsqueeze(0)
            for i, b in enumerate(proc_bands)
        )
        pred = torch.fft.ifft2(torch.fft.ifftshift(fft_recon)).real.unsqueeze(1)
        
        # Pack current features for history
        current_bands = torch.stack(current_band_feats, dim=1)  # [B, num_bands, H, W, 2]
        
        info = {
            'bands': current_bands.detach(),
            'pooled': current_pooled.detach(),
            'wormhole_conn': wormhole_stats['num_connections'],
            'wormhole_max_sim': wormhole_stats['max_sim'],
        }
        return pred, info


class FlatBaseline(nn.Module):
    """Simple ConvNet baseline."""
    
    def __init__(self, height, width, channels=32, use_temporal=False, device='cpu'):
        super().__init__()
        self.use_temporal = use_temporal
        
        self.net = nn.Sequential(
            nn.Conv2d(1, channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels, channels*2, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels*2, channels*2, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels*2, channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels, 1, 3, padding=1),
        )
        
        if use_temporal:
            self.temporal = TemporalAttention(
                feature_dim=channels, num_heads=4, max_len=16,
                top_k=4, decay=0.9, device=device
            )
            self.feat_extract = nn.Conv2d(1, channels, 3, padding=1)
            self.feat_combine = nn.Conv2d(channels * 2, channels, 1)
        
        self.to(device)
    
    def forward(self, x, history=None):
        if self.use_temporal and history is not None and 'feat' in history and history['feat'].shape[1] > 0:
            feat = self.feat_extract(x)  # [B, C, H, W]
            B, C, H, W = feat.shape
            
            current_hw = feat.permute(0, 2, 3, 1)  # [B, H, W, C]
            hist = history['feat']  # [B, T, H, W, C]
            seq = torch.cat([hist, current_hw.unsqueeze(1)], dim=1)  # [B, T+1, H, W, C]
            seq_flat = seq.reshape(B * H * W, -1, C)
            temp_out = self.temporal(seq_flat)[:, -1].reshape(B, H, W, C).permute(0, 3, 1, 2)
            
            feat = self.feat_combine(torch.cat([feat, temp_out], dim=1))
            
            # Continue through net (skip first conv)
            out = feat
            for layer in list(self.net.children())[2:]:
                out = layer(out)
            
            info = {'feat': current_hw.detach()}
        else:
            out = self.net(x)
            if self.use_temporal:
                feat = self.feat_extract(x).permute(0, 2, 3, 1)  # [B, H, W, C]
                info = {'feat': feat.detach()}
            else:
                info = {}
        
        return out, info

print("Models defined.")

## 4. Training with Sequence-Based Batching

In [None]:
def train_sequential(model, env, num_traj, traj_len, epochs, lr, device,
                     predict_delta=True, horizon=5, history_len=8):
    """Train with sequence-based batching for proper temporal history."""
    opt = optim.Adam(model.parameters(), lr=lr)
    losses = []
    wh_stats = []
    
    for epoch in range(epochs):
        t0 = time.time()
        
        fields, targets = generate_trajectory_batch(env, num_traj, traj_len, 0.1, horizon)
        fields = fields.to(device)
        targets = targets.to(device)
        
        T, B = fields.shape[0], fields.shape[1]
        ep_losses = []
        ep_wh_conn = []
        
        # Initialize history
        history = None
        
        for t in range(T):
            x = fields[t]
            y = targets[t]
            
            pred, info = model(x, history)
            
            if predict_delta:
                loss = F.mse_loss(pred - x, y - x)
            else:
                loss = F.mse_loss(pred, y)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            ep_losses.append(loss.item())
            if 'wormhole_conn' in info:
                ep_wh_conn.append(info['wormhole_conn'])
            
            # Update history
            if 'bands' in info:
                new_bands = info['bands'].unsqueeze(1)  # [B, 1, num_bands, H, W, 2]
                new_pooled = info['pooled'].unsqueeze(1)  # [B, 1, pool_dim]
                
                if history is None:
                    history = {'bands': new_bands, 'pooled': new_pooled}
                else:
                    history['bands'] = torch.cat([history['bands'], new_bands], dim=1)
                    history['pooled'] = torch.cat([history['pooled'], new_pooled], dim=1)
                    if history['bands'].shape[1] > history_len:
                        history['bands'] = history['bands'][:, -history_len:]
                        history['pooled'] = history['pooled'][:, -history_len:]
            elif 'feat' in info:
                new_feat = info['feat'].unsqueeze(1)  # [B, 1, H, W, C]
                if history is None:
                    history = {'feat': new_feat}
                else:
                    history['feat'] = torch.cat([history['feat'], new_feat], dim=1)
                    if history['feat'].shape[1] > history_len:
                        history['feat'] = history['feat'][:, -history_len:]
        
        avg_loss = sum(ep_losses) / len(ep_losses)
        avg_wh = sum(ep_wh_conn) / len(ep_wh_conn) if ep_wh_conn else 0
        losses.append(avg_loss)
        wh_stats.append(avg_wh)
        
        elapsed = time.time() - t0
        wh_str = f", wh_conn={avg_wh:.1f}" if ep_wh_conn else ""
        print(f"  Ep {epoch+1}/{epochs}: loss={avg_loss:.6f}{wh_str}, time={elapsed:.1f}s")
    
    return {'loss': losses, 'wormhole_conn': wh_stats}

print("Training function defined.")

## 5. Run Ablation

In [None]:
# Initialize environment
print("[1] Initializing environment...")
plasma_cfg = PlasmaConfig.turbulent(device=DEVICE, size=GRID_SIZE)
env = TurbulentPlasmaEnv(plasma_cfg)
print(f"    Turbulent plasma environment ready.")

In [None]:
# Define all models to test
print("[2] Creating models...")

def make_sbm(num_bands, use_temporal=False, use_neighbor=False, use_wormhole=False):
    cfg = SBMConfig(
        height=GRID_SIZE, width=GRID_SIZE, num_bands=num_bands,
        use_temporal=use_temporal, use_neighbor=use_neighbor, use_wormhole=use_wormhole,
        history_len=HISTORY_LEN, top_k=TOP_K_TEMPORAL, temporal_decay=TEMPORAL_DECAY,
        neighbor_range=NEIGHBOR_RANGE, wormhole_threshold=WORMHOLE_THRESHOLD,
        wormhole_max_conn=WORMHOLE_MAX_CONN, device=DEVICE
    )
    return SBMWithAttention(cfg)

models = {
    # Band count ablation (no attention)
    'SBM_2B_None': make_sbm(2),
    'SBM_3B_None': make_sbm(3),
    'SBM_7B_None': make_sbm(7),
    # Attention type ablation (3 bands)
    'SBM_3B_Temporal': make_sbm(3, use_temporal=True),
    'SBM_3B_Neighbor': make_sbm(3, use_neighbor=True),
    'SBM_3B_Wormhole': make_sbm(3, use_wormhole=True),
    # Combined
    'SBM_3B_All': make_sbm(3, use_temporal=True, use_neighbor=True, use_wormhole=True),
    # Baselines
    'Flat_None': FlatBaseline(GRID_SIZE, GRID_SIZE, 32, use_temporal=False, device=DEVICE),
    'Flat_Temporal': FlatBaseline(GRID_SIZE, GRID_SIZE, 32, use_temporal=True, device=DEVICE),
}

for name, model in models.items():
    n_params = sum(p.numel() for p in model.parameters())
    print(f"    {name}: {n_params:,} params")

In [None]:
# Train all models
results = {}

for name, model in models.items():
    print(f"\n{'='*60}")
    print(f"Training {name}...")
    print(f"{'='*60}")
    
    results[name] = train_sequential(
        model, env, NUM_TRAJECTORIES, TRAJECTORY_LENGTH, EPOCHS, LR, DEVICE,
        predict_delta=PREDICT_DELTA, horizon=PREDICTION_HORIZON, history_len=HISTORY_LEN
    )

## 6. Results

In [None]:
import matplotlib.pyplot as plt

# Loss curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Band count comparison
ax = axes[0]
for name in ['SBM_2B_None', 'SBM_3B_None', 'SBM_7B_None']:
    ax.semilogy(results[name]['loss'], label=f"{name}: {results[name]['loss'][-1]:.6f}")
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('Band Count Ablation')
ax.legend()
ax.grid(True, alpha=0.3)

# Attention type comparison
ax = axes[1]
for name in ['SBM_3B_None', 'SBM_3B_Temporal', 'SBM_3B_Neighbor', 'SBM_3B_Wormhole', 'SBM_3B_All']:
    ax.semilogy(results[name]['loss'], label=f"{name.replace('SBM_3B_', '')}: {results[name]['loss'][-1]:.6f}")
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('Attention Type Ablation (3 bands)')
ax.legend()
ax.grid(True, alpha=0.3)

# SBM vs Flat
ax = axes[2]
ax.semilogy(results['SBM_3B_None']['loss'], label=f"SBM_3B: {results['SBM_3B_None']['loss'][-1]:.6f}")
ax.semilogy(results['SBM_3B_All']['loss'], label=f"SBM_3B_All: {results['SBM_3B_All']['loss'][-1]:.6f}")
ax.semilogy(results['Flat_None']['loss'], label=f"Flat: {results['Flat_None']['loss'][-1]:.6f}")
ax.semilogy(results['Flat_Temporal']['loss'], label=f"Flat_Temp: {results['Flat_Temporal']['loss'][-1]:.6f}")
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('SBM vs Flat')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Wormhole connections plot
wh_models = [n for n in results if 'Wormhole' in n or 'All' in n]
if wh_models:
    fig, ax = plt.subplots(figsize=(10, 4))
    for name in wh_models:
        if results[name]['wormhole_conn']:
            ax.plot(results[name]['wormhole_conn'], label=name)
    ax.axhline(y=0, color='r', linestyle='--', alpha=0.5)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Avg Wormhole Connections')
    ax.set_title(f'Wormhole Activity (threshold={WORMHOLE_THRESHOLD})')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
# Summary table
print("\n" + "="*70)
print("FINAL RESULTS (v4.2 - Pooled Wormhole)")
print("="*70)

sorted_results = sorted(results.items(), key=lambda x: x[1]['loss'][-1])

print(f"{'Model':<25} {'Final Loss':<15} {'vs Best':<15}")
print("-"*55)
best_loss = sorted_results[0][1]['loss'][-1]
for name, res in sorted_results:
    final = res['loss'][-1]
    ratio = final / best_loss
    print(f"{name:<25} {final:<15.6f} {ratio:<15.2f}x")

print("\n" + "="*70)
print("KEY COMPARISONS")
print("="*70)

# Band count
print("\n1. BAND COUNT (no attention):")
for n in ['SBM_2B_None', 'SBM_3B_None', 'SBM_7B_None']:
    print(f"   {n}: {results[n]['loss'][-1]:.6f}")

# Attention type
print("\n2. ATTENTION TYPE (3 bands):")
for n in ['SBM_3B_None', 'SBM_3B_Temporal', 'SBM_3B_Neighbor', 'SBM_3B_Wormhole', 'SBM_3B_All']:
    print(f"   {n}: {results[n]['loss'][-1]:.6f}")

# SBM vs Flat
print("\n3. SBM vs FLAT:")
sbm_best = min(results['SBM_3B_None']['loss'][-1], results['SBM_3B_All']['loss'][-1])
flat_best = min(results['Flat_None']['loss'][-1], results['Flat_Temporal']['loss'][-1])
print(f"   Best SBM: {sbm_best:.6f}")
print(f"   Best Flat: {flat_best:.6f}")
if sbm_best < flat_best:
    print(f"   SBM wins by {(flat_best/sbm_best - 1)*100:.1f}%")
else:
    print(f"   Flat wins by {(sbm_best/flat_best - 1)*100:.1f}%")

## 7. Summary

### v4.2 Key Fix: Pooled Wormhole

The wormhole now operates on **globally pooled features** [B, D] instead of per-pixel [B, H, W, D].

This is:
1. **Memory efficient**: O(T) instead of O(H*W * T*H*W) - no more OOM
2. **Conceptually correct**: Wormhole finds globally similar states, not pixel-to-pixel matches
3. **What 033 actually did**: The original design pooled before wormhole

### Sequence-Based Training

History is now meaningful - each trajectory's history comes from its own past states,
not random samples from different trajectories.