# Experiment 032 v4.1: Full Architecture Ablation

**Purpose**: Comprehensive ablation testing ALL architectural variants.

**Key Change from v2.3**: NO control signal passed to models.
We replicate the winning 032 setup (pure prediction) and test architecture variations.

**v4.1 Fix**: Proper batch size handling - drop incomplete batches to maintain
consistent tensor shapes for history management. This is important for wormhole
attention which needs persistent history across batches.

## Ablation Dimensions

### 1. Number of Spectral Bands
- **2-band**: Low/High frequency split
- **3-band**: Low/Mid/High (like 033 experiment)
- **7-band**: Original SBM design (log-spaced radial masks)

### 2. Attention Types (from 033 experiment)
- **None**: No attention (pure spectral processing)
- **Temporal**: Per-position Top-K temporal attention (object permanence)
- **Neighbor**: 3x3 local spatial attention (local physics)
- **Wormhole**: Sparse similarity-gated non-local attention (teleportation)
- **Full**: All three attention types

### 3. Baselines
- **Flat ConvNet**: Direct spatial convolutions

## Models Tested (12 total)

| Model | Bands | Temporal | Neighbor | Wormhole |
|-------|-------|----------|----------|----------|
| SBM_2B_None | 2 | - | - | - |
| SBM_3B_None | 3 | - | - | - |
| SBM_7B_None | 7 | - | - | - |
| SBM_3B_Temporal | 3 | Y | - | - |
| SBM_3B_Neighbor | 3 | - | Y | - |
| SBM_3B_Wormhole | 3 | - | - | Y |
| SBM_3B_TempNeigh | 3 | Y | Y | - |
| SBM_3B_Full | 3 | Y | Y | Y |
| SBM_7B_Temporal | 7 | Y | - | - |
| SBM_7B_Full | 7 | Y | Y | Y |
| Flat_Baseline | - | - | - | - |
| Flat_WithAttn | - | Y | Y | - |

**Run on Colab with GPU**: Runtime -> Change runtime type -> A100

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GRID_SIZE = 64
EPOCHS = 50
DIFFICULTY = "turbulent"
PREDICT_DELTA = True
PREDICTION_HORIZON = 3  # Like original winning experiment
NUM_TRAJECTORIES = 100
TRAJECTORY_LENGTH = 100
BATCH_SIZE = 32
LR = 0.001

# Attention config (from 033)
HISTORY_LEN = 8
TOP_K_TEMPORAL = 4
TEMPORAL_DECAY = 0.95
NEIGHBOR_RANGE = 3
WORMHOLE_THRESHOLD = 0.9995
WORMHOLE_MAX_CONN = 4

print(f"Device: {DEVICE}")
print(f"Grid: {GRID_SIZE}x{GRID_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Difficulty: {DIFFICULTY}")
print(f"Prediction: delta t+{PREDICTION_HORIZON}")
print(f"\nNO CONTROL SIGNAL - Pure prediction task (like winning 032)")

## 1. Environment (Same as original 032)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional
import math
import time

@dataclass
class PlasmaConfig:
    height: int = 64
    width: int = 64
    diffusion: float = 0.25
    advection: float = 0.08
    noise_std: float = 0.02
    disturbance_prob: float = 0.1
    disturbance_strength: float = 0.15
    num_vortices: int = 3
    vortex_strength: float = 0.1
    shear_strength: float = 0.05
    multiscale_noise: bool = True
    num_actuators: int = 9
    _base_actuator_sigma: float = 5.0
    device: str = "cpu"
    dtype: torch.dtype = torch.float32
    
    @property
    def actuator_sigma(self) -> float:
        return self._base_actuator_sigma * min(self.height, self.width) / 64.0
    
    @classmethod
    def turbulent(cls, device: str = "cpu", size: int = 64):
        return cls(height=size, width=size, diffusion=0.3, advection=0.12, noise_std=0.03,
                   disturbance_prob=0.25, disturbance_strength=0.3,
                   num_vortices=3, vortex_strength=0.15, shear_strength=0.08,
                   multiscale_noise=True, num_actuators=9, device=device)
    
    @classmethod
    def medium(cls, device: str = "cpu", size: int = 64):
        return cls(height=size, width=size, device=device)


class TurbulentPlasmaEnv:
    def __init__(self, cfg: PlasmaConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.dtype = cfg.dtype
        self._actuator_maps = self._build_actuator_maps()
        self._vortex_flow = self._build_vortex_flow()
        self._shear_flow = self._build_shear_flow()
    
    def _build_actuator_maps(self) -> torch.Tensor:
        h, w = self.cfg.height, self.cfg.width
        grid_n = int(math.ceil(math.sqrt(self.cfg.num_actuators)))
        centers_y = torch.linspace(h * 0.2, h * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers_x = torch.linspace(w * 0.2, w * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers = torch.cartesian_prod(centers_y, centers_x)[:self.cfg.num_actuators]
        sig2 = self.cfg.actuator_sigma ** 2
        yy, xx = torch.meshgrid(
            torch.arange(h, device=self.device, dtype=self.dtype),
            torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        bumps = [torch.exp(-((yy - cy) ** 2 + (xx - cx) ** 2) / (2 * sig2)) for cy, cx in centers]
        return torch.stack(bumps, dim=0)
    
    def _build_vortex_flow(self):
        if self.cfg.num_vortices == 0: return None, None
        h, w = self.cfg.height, self.cfg.width
        torch.manual_seed(42)  # Reproducible vortex positions
        centers_y = torch.rand(self.cfg.num_vortices, device=self.device) * (h * 0.6) + (h * 0.2)
        centers_x = torch.rand(self.cfg.num_vortices, device=self.device) * (w * 0.6) + (w * 0.2)
        yy, xx = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                                torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        vy, vx = torch.zeros_like(yy), torch.zeros_like(xx)
        scale = min(h, w) / 64.0
        for i in range(self.cfg.num_vortices):
            dy, dx = yy - centers_y[i], xx - centers_x[i]
            r2 = dy**2 + dx**2 + 1e-6
            decay = torch.exp(-r2 / (2 * (10 * scale)**2))
            sign = 1 if i % 2 == 0 else -1
            vy += sign * self.cfg.vortex_strength * (-dx) / torch.sqrt(r2) * decay
            vx += sign * self.cfg.vortex_strength * dy / torch.sqrt(r2) * decay
        return vy, vx
    
    def _build_shear_flow(self):
        if self.cfg.shear_strength == 0: return None, None
        h, w = self.cfg.height, self.cfg.width
        yy, _ = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                               torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        return torch.zeros_like(yy), self.cfg.shear_strength * torch.sin(2 * math.pi * yy / h)
    
    def _apply_flow(self, field, vy, vx):
        B, C, H, W = field.shape
        yy, xx = torch.meshgrid(torch.linspace(-1, 1, H, device=self.device),
                                torch.linspace(-1, 1, W, device=self.device), indexing="ij")
        grid = torch.stack([xx - vx/(W/2), yy - vy/(H/2)], dim=-1).unsqueeze(0).expand(B, -1, -1, -1)
        return F.grid_sample(field, grid, mode='bilinear', padding_mode='border', align_corners=True)
    
    def _multiscale_noise(self, shape):
        B, C, H, W = shape
        noise = torch.zeros(shape, device=self.device, dtype=self.dtype)
        for scale in [1, 2, 4, 8]:
            hs, ws = H // scale, W // scale
            if hs < 4: continue
            coarse = torch.randn(B, C, hs, ws, device=self.device, dtype=self.dtype)
            noise += F.interpolate(coarse, size=(H, W), mode='bilinear', align_corners=False) * (self.cfg.noise_std / scale)
        return noise
    
    def reset(self, batch_size=1):
        h, w = self.cfg.height, self.cfg.width
        cx = torch.randint(int(w*0.3), int(w*0.7), (batch_size,), device=self.device)
        cy = torch.randint(int(h*0.3), int(h*0.7), (batch_size,), device=self.device)
        yy, xx = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                                torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        sig2 = (self.cfg.actuator_sigma * 1.5) ** 2
        field = torch.zeros(batch_size, 1, h, w, device=self.device, dtype=self.dtype)
        for b in range(batch_size):
            field[b, 0] = torch.exp(-((yy - cy[b].float())**2 + (xx - cx[b].float())**2) / (2*sig2))
        return field
    
    def step(self, field, control, noise=True):
        B = field.shape[0]
        # Diffusion
        lap = (F.pad(field, (0,0,1,0))[:,:,:-1,:] + F.pad(field, (0,0,0,1))[:,:,1:,:] +
               F.pad(field, (1,0,0,0))[:,:,:,:-1] + F.pad(field, (0,1,0,0))[:,:,:,1:]) - 4*field
        diffused = field + self.cfg.diffusion * lap
        # Advection
        advected = torch.roll(diffused, shifts=(1,-1), dims=(2,3)) * self.cfg.advection + diffused * (1 - self.cfg.advection)
        if self._vortex_flow[0] is not None:
            advected = self._apply_flow(advected, self._vortex_flow[0], self._vortex_flow[1])
        if self._shear_flow[0] is not None:
            advected = self._apply_flow(advected, self._shear_flow[0], self._shear_flow[1])
        # Actuators
        force = torch.einsum('ba,ahw->bhw', control, self._actuator_maps).unsqueeze(1)
        next_field = advected + force
        # Noise
        if noise:
            next_field = next_field + (self._multiscale_noise(next_field.shape) if self.cfg.multiscale_noise 
                                       else torch.randn_like(next_field) * self.cfg.noise_std)
        return torch.clamp(next_field, -1, 1)


def generate_trajectories_no_control(env, num_traj, traj_len, control_scale=0.1, horizon=1):
    """Generate data WITHOUT returning control - pure prediction task."""
    all_fields, all_next = [], []
    for _ in range(num_traj):
        field = env.reset(1)
        traj = [field.clone()]
        for _ in range(traj_len + horizon):
            # Control still affects physics, but NOT passed to model
            ctrl = torch.clamp(torch.randn(1, env.cfg.num_actuators, device=env.device) * control_scale, -1, 1)
            traj.append(env.step(field, ctrl).clone())
            field = traj[-1].detach()
        for t in range(traj_len):
            all_fields.append(traj[t])
            all_next.append(traj[t + horizon])
    return torch.cat(all_fields), torch.cat(all_next)

print("Environment defined.")

## 2. Attention Modules (from 033)

In [None]:
# ============================================================================
# TEMPORAL ATTENTION: Per-position Top-K (object permanence)
# Each (i,j) attends to its OWN history at (i,j,t'<t)
# ============================================================================

class PerPositionTemporalAttention(nn.Module):
    """Per-position temporal attention with Top-K selection.
    
    Key insight: Each spatial position tracks its own history,
    NOT a pooled summary. This enables proper object tracking.
    """
    
    def __init__(self, feature_dim: int, attn_dim: int, top_k: int = 4, 
                 decay_rate: float = 0.95, device: str = 'cpu'):
        super().__init__()
        self.feature_dim = feature_dim
        self.attn_dim = attn_dim
        self.top_k = top_k
        self.decay_rate = decay_rate
        self.device = device
        
        self.W_q = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_k = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_v = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_o = nn.Linear(attn_dim, feature_dim, bias=False)
        self.to(device)
    
    def forward(self, query: torch.Tensor, history: torch.Tensor) -> torch.Tensor:
        """
        Args:
            query: Current features [B, H, W, D]
            history: Past features [B, T, H, W, D]
        Returns:
            output: Attended features [B, H, W, D]
        """
        B, H, W, D = query.shape
        T = history.shape[1]
        
        if T == 0:
            return torch.zeros_like(query)
        
        # Project
        Q = self.W_q(query)  # [B, H, W, attn_dim]
        K = self.W_k(history)  # [B, T, H, W, attn_dim]
        V = self.W_v(history)  # [B, T, H, W, attn_dim]
        
        # Reshape for per-position attention
        Q_flat = Q.reshape(B * H * W, self.attn_dim)  # [BHW, attn_dim]
        K_flat = K.permute(0, 2, 3, 1, 4).reshape(B * H * W, T, self.attn_dim)  # [BHW, T, attn_dim]
        V_flat = V.permute(0, 2, 3, 1, 4).reshape(B * H * W, T, self.attn_dim)  # [BHW, T, attn_dim]
        
        # Attention scores
        scores = torch.bmm(Q_flat.unsqueeze(1), K_flat.transpose(1, 2)).squeeze(1)  # [BHW, T]
        scores = scores / math.sqrt(self.attn_dim)
        
        # Top-K selection
        if self.top_k < T:
            _, topk_idx = torch.topk(scores, self.top_k, dim=-1)
            mask = torch.ones_like(scores, dtype=torch.bool)
            mask.scatter_(-1, topk_idx, False)
            scores = scores.masked_fill(mask, float('-inf'))
        
        # Temporal decay
        time_offsets = torch.arange(T, device=query.device, dtype=torch.float32)
        decay_weights = self.decay_rate ** (T - time_offsets)
        scores = scores + torch.log(decay_weights.unsqueeze(0) + 1e-10)
        
        # Softmax and attend
        attn = F.softmax(scores, dim=-1)  # [BHW, T]
        output = torch.bmm(attn.unsqueeze(1), V_flat).squeeze(1)  # [BHW, attn_dim]
        output = self.W_o(output)
        
        return output.reshape(B, H, W, D)


# ============================================================================
# NEIGHBOR ATTENTION: 3x3 local spatial (local physics)
# ============================================================================

class NeighborAttention(nn.Module):
    """8-connected neighbor attention for local physics.
    
    Models diffusion, collision, wavefront propagation.
    Like a learned convolution but with attention weights.
    """
    
    def __init__(self, feature_dim: int, attn_dim: int, layer_range: int = 3,
                 device: str = 'cpu'):
        super().__init__()
        self.feature_dim = feature_dim
        self.attn_dim = attn_dim
        self.layer_range = layer_range
        self.device = device
        
        self.W_q = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_k = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_v = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_o = nn.Linear(attn_dim, feature_dim, bias=False)
        self.to(device)
    
    def forward(self, query: torch.Tensor, history: torch.Tensor) -> torch.Tensor:
        """
        Args:
            query: Current features [B, H, W, D]
            history: Past features [B, T, H, W, D]
        Returns:
            output: Attended features [B, H, W, D]
        """
        B, H, W, D = query.shape
        T = history.shape[1]
        
        if T == 0:
            return torch.zeros_like(query)
        
        # Use most recent frame for neighbor attention
        key_frame = history[:, -1]  # [B, H, W, D]
        
        Q = self.W_q(query)  # [B, H, W, attn_dim]
        K = self.W_k(key_frame)  # [B, H, W, attn_dim]
        V = self.W_v(key_frame)  # [B, H, W, attn_dim]
        
        # Pad for boundary handling
        K_padded = F.pad(K.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='replicate').permute(0, 2, 3, 1)
        V_padded = F.pad(V.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='replicate').permute(0, 2, 3, 1)
        
        # 9 neighbors (self + 8-connected)
        neighbor_offsets = [
            (-1, -1), (-1, 0), (-1, 1),
            (0, -1),  (0, 0),  (0, 1),
            (1, -1),  (1, 0),  (1, 1)
        ]
        
        K_neighbors = torch.zeros(B, H, W, 9, self.attn_dim, device=query.device)
        V_neighbors = torch.zeros(B, H, W, 9, self.attn_dim, device=query.device)
        
        for n_idx, (di, dj) in enumerate(neighbor_offsets):
            K_neighbors[:, :, :, n_idx, :] = K_padded[:, 1+di:H+1+di, 1+dj:W+1+dj, :]
            V_neighbors[:, :, :, n_idx, :] = V_padded[:, 1+di:H+1+di, 1+dj:W+1+dj, :]
        
        # Compute attention
        scores = torch.einsum('bhwd,bhwnd->bhwn', Q, K_neighbors) / math.sqrt(self.attn_dim)
        attn = F.softmax(scores, dim=-1)  # [B, H, W, 9]
        output = torch.einsum('bhwn,bhwnd->bhwd', attn, V_neighbors)
        
        return self.W_o(output)


# ============================================================================
# WORMHOLE ATTENTION: Sparse similarity-gated non-local (teleportation)
# ============================================================================

class WormholeAttention(nn.Module):
    """Sparse non-local attention via cosine similarity gating.
    
    Connects distant tokens based on content similarity.
    Enables instant global synchronization and resonance.
    """
    
    def __init__(self, feature_dim: int, attn_dim: int, threshold: float = 0.9995,
                 max_connections: int = 4, device: str = 'cpu'):
        super().__init__()
        self.feature_dim = feature_dim
        self.attn_dim = attn_dim
        self.threshold = threshold
        self.max_connections = max_connections
        self.device = device
        
        self.W_q = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_k = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_v = nn.Linear(feature_dim, attn_dim, bias=False)
        self.W_o = nn.Linear(attn_dim, feature_dim, bias=False)
        self.to(device)
    
    def forward(self, query: torch.Tensor, history: torch.Tensor) -> torch.Tensor:
        """
        Args:
            query: Current features [B, H, W, D]
            history: Past features [B, T, H, W, D]
        Returns:
            output: Attended features [B, H, W, D]
        """
        B, H, W, D = query.shape
        T = history.shape[1]
        
        if T == 0:
            return torch.zeros_like(query)
        
        # Flatten spatial dimensions
        Q = self.W_q(query).reshape(B, H * W, self.attn_dim)  # [B, HW, attn_dim]
        K = self.W_k(history).reshape(B, T * H * W, self.attn_dim)  # [B, THW, attn_dim]
        V = self.W_v(history).reshape(B, T * H * W, self.attn_dim)  # [B, THW, attn_dim]
        
        # Normalize for cosine similarity
        Q_norm = F.normalize(Q, p=2, dim=-1)
        K_norm = F.normalize(K, p=2, dim=-1)
        
        # Cosine similarity
        sim = torch.bmm(Q_norm, K_norm.transpose(1, 2))  # [B, HW, THW]
        
        # Top-K selection
        K_conn = min(self.max_connections, T * H * W)
        topk_sim, topk_idx = torch.topk(sim, K_conn, dim=-1)  # [B, HW, K]
        
        # Threshold mask
        mask = topk_sim > self.threshold
        
        # Gather values
        V_gathered = torch.gather(
            V.unsqueeze(1).expand(-1, H * W, -1, -1),
            2,
            topk_idx.unsqueeze(-1).expand(-1, -1, -1, self.attn_dim)
        )  # [B, HW, K, attn_dim]
        
        # Attention scores with threshold mask
        scores = topk_sim / math.sqrt(self.attn_dim)
        scores = scores.masked_fill(~mask, float('-inf'))
        attn = F.softmax(scores, dim=-1)  # [B, HW, K]
        
        # Handle positions with no valid connections
        valid_rows = mask.any(dim=-1)  # [B, HW]
        attn = torch.where(
            valid_rows.unsqueeze(-1),
            attn,
            torch.zeros_like(attn)
        )
        
        output = torch.einsum('bhk,bhkd->bhd', attn, V_gathered)  # [B, HW, attn_dim]
        output = self.W_o(output)
        
        return output.reshape(B, H, W, D)

print("Attention modules defined.")

## 3. Spectral Band Masks

In [None]:
def make_radial_masks(h: int, w: int, num_bands: int, device: str) -> torch.Tensor:
    """Create radial frequency masks for spectral decomposition.
    
    Args:
        h, w: Grid dimensions
        num_bands: Number of frequency bands (2, 3, or 7)
        device: Torch device
    Returns:
        masks: [num_bands, H, W] normalized masks
    """
    yy, xx = torch.meshgrid(
        torch.linspace(-1, 1, h, device=device),
        torch.linspace(-1, 1, w, device=device), indexing="ij")
    rr = torch.sqrt(yy**2 + xx**2).clamp(min=1e-6)
    
    # Different edge spacing for different band counts
    if num_bands == 2:
        # Simple low/high split at 0.5
        edges = torch.tensor([0.0, 0.5, math.sqrt(2)], device=device)
    elif num_bands == 3:
        # Low/mid/high (like 033 experiment)
        edges = torch.tensor([0.0, 0.3, 0.7, math.sqrt(2)], device=device)
    else:
        # Log-spaced for 7 bands (original SBM)
        edges = torch.logspace(-3, math.log10(math.sqrt(2)), num_bands + 1, device=device)
        edges[0] = 0.0
    
    masks = []
    for i in range(num_bands):
        lo, hi = edges[i], edges[i + 1]
        mask = torch.sigmoid((rr - lo) * 20) * torch.sigmoid((hi - rr) * 20)
        masks.append(mask)
    
    masks = torch.stack(masks)
    return masks / (masks.sum(0, keepdim=True) + 1e-8)


# Visualize masks
import matplotlib.pyplot as plt

fig, axes = plt.subplots(3, 7, figsize=(14, 6))

for row, (num_bands, name) in enumerate([(2, '2-band'), (3, '3-band'), (7, '7-band')]):
    masks = make_radial_masks(64, 64, num_bands, 'cpu')
    for i in range(7):
        if i < num_bands:
            axes[row, i].imshow(masks[i].numpy(), cmap='viridis')
            axes[row, i].set_title(f'{name} band {i}')
        else:
            axes[row, i].axis('off')
        axes[row, i].set_xticks([])
        axes[row, i].set_yticks([])

plt.suptitle('Spectral Band Masks Comparison')
plt.tight_layout()
plt.show()

print("Mask generation defined.")

## 4. SBM Models with Configurable Attention

In [None]:
@dataclass
class SBMConfig:
    height: int = 64
    width: int = 64
    num_bands: int = 3
    channels: int = 16
    attn_dim: int = 32
    history_len: int = 8
    # Attention flags
    use_temporal: bool = False
    use_neighbor: bool = False
    use_wormhole: bool = False
    # Temporal config
    top_k: int = 4
    decay: float = 0.95
    # Neighbor config
    neighbor_range: int = 3
    # Wormhole config
    wormhole_threshold: float = 0.9995
    wormhole_max_conn: int = 4
    device: str = "cpu"


class SBMWithAttention(nn.Module):
    """Spectral Belief Machine with configurable attention types.
    
    NO CONTROL SIGNAL - pure prediction like winning 032 experiment.
    
    Supports:
    - Variable band count (2, 3, or 7)
    - Per-position temporal attention (from 033)
    - Neighbor attention (from 033)
    - Wormhole attention (from 033)
    """
    
    def __init__(self, cfg: SBMConfig):
        super().__init__()
        self.cfg = cfg
        
        # Spectral masks
        self.register_buffer("masks", make_radial_masks(
            cfg.height, cfg.width, cfg.num_bands, cfg.device))
        self.register_buffer("window", torch.ones(cfg.height, cfg.width, device=cfg.device))
        
        # Band processing blocks: input = 2 (real/imag)
        self.bands = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(2, cfg.channels, 3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, cfg.channels, 3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, 2, 1),
            ) for _ in range(cfg.num_bands)
        ])
        
        # Feature dimension for attention (2 channels per band -> flatten)
        self.feat_dim = cfg.num_bands * 2 * 4 * 4  # 2x2 pool per band, 2 channels
        
        # Optional attention modules
        if cfg.use_temporal:
            self.temporal_attn = PerPositionTemporalAttention(
                feature_dim=2,  # Per band: real/imag
                attn_dim=cfg.attn_dim,
                top_k=cfg.top_k,
                decay_rate=cfg.decay,
                device=cfg.device
            )
        
        if cfg.use_neighbor:
            self.neighbor_attn = NeighborAttention(
                feature_dim=2,
                attn_dim=cfg.attn_dim,
                layer_range=cfg.neighbor_range,
                device=cfg.device
            )
        
        if cfg.use_wormhole:
            self.wormhole_attn = WormholeAttention(
                feature_dim=2,
                attn_dim=cfg.attn_dim,
                threshold=cfg.wormhole_threshold,
                max_connections=cfg.wormhole_max_conn,
                device=cfg.device
            )
        
        # Fusion layer if using attention
        num_attn = sum([cfg.use_temporal, cfg.use_neighbor, cfg.use_wormhole])
        if num_attn > 0:
            self.fusion = nn.Sequential(
                nn.Linear(2 * (1 + num_attn), cfg.attn_dim),
                nn.GELU(),
                nn.Linear(cfg.attn_dim, 2)
            )
        
        self.to(cfg.device)
    
    def forward(self, x: torch.Tensor, history: Optional[Dict] = None) -> Tuple[torch.Tensor, Dict]:
        """
        Args:
            x: Input field [B, 1, H, W]
            history: Dict with 'bands': [B, T, num_bands, H, W, 2]
        Returns:
            pred: Predicted field [B, 1, H, W]
            info: Dict with features for history
        """
        B, _, H, W = x.shape
        
        # FFT decompose
        fft = torch.fft.fftshift(torch.fft.fft2(x.squeeze(1) * self.window))
        
        proc_bands = []
        current_band_feats = []  # For history
        
        for i in range(self.cfg.num_bands):
            band = fft * self.masks[i].unsqueeze(0)
            band_feat = torch.stack([band.real, band.imag], dim=1)  # [B, 2, H, W]
            
            # Apply band processing
            proc = self.bands[i](band_feat)
            processed = band_feat + proc  # Residual
            
            # Apply attention if enabled and history available
            if history is not None and 'bands' in history and history['bands'].shape[1] > 0:
                band_history = history['bands'][:, :, i]  # [B, T, H, W, 2]
                current_as_feat = processed.permute(0, 2, 3, 1)  # [B, H, W, 2]
                
                attn_outputs = [current_as_feat]
                
                if self.cfg.use_temporal:
                    t_out = self.temporal_attn(current_as_feat, band_history)
                    attn_outputs.append(t_out)
                
                if self.cfg.use_neighbor:
                    n_out = self.neighbor_attn(current_as_feat, band_history)
                    attn_outputs.append(n_out)
                
                if self.cfg.use_wormhole:
                    w_out = self.wormhole_attn(current_as_feat, band_history)
                    attn_outputs.append(w_out)
                
                if len(attn_outputs) > 1:
                    # Fuse attention outputs
                    combined = torch.cat(attn_outputs, dim=-1)  # [B, H, W, 2*(1+num_attn)]
                    fused = self.fusion(combined)  # [B, H, W, 2]
                    processed = processed + 0.1 * fused.permute(0, 3, 1, 2)  # Add attention contribution
            
            proc_bands.append(processed)
            current_band_feats.append(processed.permute(0, 2, 3, 1))  # [B, H, W, 2]
        
        # Stack band features for history
        band_feats = torch.stack(current_band_feats, dim=1)  # [B, num_bands, H, W, 2]
        
        # Reconstruct
        recon = sum(
            torch.complex(b[:, 0], b[:, 1]) * self.masks[i].unsqueeze(0)
            for i, b in enumerate(proc_bands)
        )
        pred = torch.fft.ifft2(torch.fft.ifftshift(recon)).real.unsqueeze(1)
        
        return pred, {'bands': band_feats.detach()}


class FlatBaseline(nn.Module):
    """Flat ConvNet baseline - no spectral decomposition."""
    
    def __init__(self, height: int, width: int, channels: int = 32, device: str = 'cpu'):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels, channels*2, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels*2, channels*2, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels*2, channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels, 1, 3, padding=1),
        )
        self.to(device)
    
    def forward(self, x, history=None):
        return self.net(x), {}


class FlatWithAttention(nn.Module):
    """Flat ConvNet with temporal and neighbor attention."""
    
    def __init__(self, height: int, width: int, channels: int = 32, 
                 attn_dim: int = 32, history_len: int = 8, device: str = 'cpu'):
        super().__init__()
        self.device = device
        self.channels = channels
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels, channels, 3, padding=1), nn.GELU(),
        )
        
        # Attention
        self.temporal_attn = PerPositionTemporalAttention(
            feature_dim=channels, attn_dim=attn_dim, device=device)
        self.neighbor_attn = NeighborAttention(
            feature_dim=channels, attn_dim=attn_dim, device=device)
        
        # Fusion
        self.fusion = nn.Sequential(
            nn.Linear(channels * 3, channels),
            nn.GELU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels, 1, 3, padding=1),
        )
        
        self.to(device)
    
    def forward(self, x, history=None):
        B, _, H, W = x.shape
        
        # Encode
        feat = self.encoder(x)  # [B, C, H, W]
        feat_hwc = feat.permute(0, 2, 3, 1)  # [B, H, W, C]
        
        # Apply attention if history available
        if history is not None and 'feat' in history and history['feat'].shape[1] > 0:
            hist_feat = history['feat']  # [B, T, H, W, C]
            
            t_out = self.temporal_attn(feat_hwc, hist_feat)
            n_out = self.neighbor_attn(feat_hwc, hist_feat)
            
            combined = torch.cat([feat_hwc, t_out, n_out], dim=-1)
            fused = self.fusion(combined)  # [B, H, W, C]
            feat = feat + 0.1 * fused.permute(0, 3, 1, 2)
        
        # Decode
        pred = self.decoder(feat)
        
        return pred, {'feat': feat.permute(0, 2, 3, 1).detach()}

print("Models defined.")

## 5. Training

**Key v4.1 Fix**: Drop last incomplete batch to maintain consistent tensor shapes.
This is critical for wormhole attention which needs persistent history.

In [None]:
def train(model, fields, targets, epochs, lr, device, delta=True, batch=32, history_len=8):
    """Train model with proper history management.
    
    v4.1: Drop incomplete batches to maintain consistent tensor shapes.
    This is important for wormhole attention which builds history across batches.
    """
    fields = fields.to(device)
    targets = targets.to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    losses = []
    
    # Calculate number of complete batches (drop incomplete last batch)
    num_samples = len(fields)
    num_complete_batches = num_samples // batch
    usable_samples = num_complete_batches * batch
    
    print(f"    Total samples: {num_samples}, Using: {usable_samples} ({num_complete_batches} batches of {batch})")
    
    for ep in range(epochs):
        t0 = time.time()
        ep_loss = []
        
        # Shuffle only the usable samples
        perm = torch.randperm(usable_samples)
        
        # Reset history at start of each epoch (fresh start)
        history = None
        
        for batch_idx in range(num_complete_batches):
            i = batch_idx * batch
            idx = perm[i:i+batch]
            x, y = fields[idx], targets[idx]
            
            pred, info = model(x, history)
            
            loss = F.mse_loss(pred - x, y - x) if delta else F.mse_loss(pred, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            ep_loss.append(loss.item())
            
            # Update history (rolling buffer)
            if 'bands' in info and info['bands'] is not None:
                new_feat = info['bands'].unsqueeze(1)  # [B, 1, num_bands, H, W, 2]
                if history is None or 'bands' not in history:
                    history = {'bands': new_feat}
                else:
                    # Append and keep last history_len entries
                    history['bands'] = torch.cat([history['bands'], new_feat], dim=1)[:, -history_len:]
            elif 'feat' in info and info['feat'] is not None:
                new_feat = info['feat'].unsqueeze(1)  # [B, 1, H, W, C]
                if history is None or 'feat' not in history:
                    history = {'feat': new_feat}
                else:
                    # Append and keep last history_len entries
                    history['feat'] = torch.cat([history['feat'], new_feat], dim=1)[:, -history_len:]
        
        avg = sum(ep_loss) / len(ep_loss)
        losses.append(avg)
        print(f"  Ep {ep+1}/{epochs}: loss={avg:.6f}, time={time.time()-t0:.1f}s")
    
    return losses

print("Training function defined.")

## 6. Run Full Ablation

In [None]:
# Setup
print("[1] Creating environment...")
plasma_cfg = PlasmaConfig.turbulent(device=DEVICE, size=GRID_SIZE)
env = TurbulentPlasmaEnv(plasma_cfg)

print("\n[2] Generating data (NO CONTROL SIGNAL)...")
fields, targets = generate_trajectories_no_control(
    env, NUM_TRAJECTORIES, TRAJECTORY_LENGTH, 0.1, PREDICTION_HORIZON)
print(f"    Fields: {fields.shape}, Targets: {targets.shape}")

In [None]:
# Create all models for ablation
print("\n[3] Creating models...")

models = {}

# ============================================================================
# BAND COUNT ABLATION (no attention)
# ============================================================================
for num_bands in [2, 3, 7]:
    cfg = SBMConfig(
        height=GRID_SIZE, width=GRID_SIZE, num_bands=num_bands,
        use_temporal=False, use_neighbor=False, use_wormhole=False,
        device=DEVICE
    )
    models[f'SBM_{num_bands}B_None'] = SBMWithAttention(cfg)

# ============================================================================
# ATTENTION TYPE ABLATION (3-band base)
# ============================================================================
attention_configs = [
    ('Temporal', True, False, False),
    ('Neighbor', False, True, False),
    ('Wormhole', False, False, True),
    ('TempNeigh', True, True, False),
    ('Full', True, True, True),
]

for name, use_t, use_n, use_w in attention_configs:
    cfg = SBMConfig(
        height=GRID_SIZE, width=GRID_SIZE, num_bands=3,
        use_temporal=use_t, use_neighbor=use_n, use_wormhole=use_w,
        history_len=HISTORY_LEN, top_k=TOP_K_TEMPORAL, decay=TEMPORAL_DECAY,
        neighbor_range=NEIGHBOR_RANGE, wormhole_threshold=WORMHOLE_THRESHOLD,
        wormhole_max_conn=WORMHOLE_MAX_CONN, device=DEVICE
    )
    models[f'SBM_3B_{name}'] = SBMWithAttention(cfg)

# ============================================================================
# 7-BAND WITH ATTENTION
# ============================================================================
cfg = SBMConfig(
    height=GRID_SIZE, width=GRID_SIZE, num_bands=7,
    use_temporal=True, use_neighbor=False, use_wormhole=False,
    history_len=HISTORY_LEN, device=DEVICE
)
models['SBM_7B_Temporal'] = SBMWithAttention(cfg)

cfg = SBMConfig(
    height=GRID_SIZE, width=GRID_SIZE, num_bands=7,
    use_temporal=True, use_neighbor=True, use_wormhole=True,
    history_len=HISTORY_LEN, device=DEVICE
)
models['SBM_7B_Full'] = SBMWithAttention(cfg)

# ============================================================================
# BASELINES
# ============================================================================
models['Flat_Baseline'] = FlatBaseline(GRID_SIZE, GRID_SIZE, 32, DEVICE)
models['Flat_WithAttn'] = FlatWithAttention(GRID_SIZE, GRID_SIZE, 32, 32, HISTORY_LEN, DEVICE)

# Print model info
print("\nModels created:")
for name, model in models.items():
    params = sum(p.numel() for p in model.parameters())
    print(f"    {name:20s}: {params:>8,} params")

In [None]:
# Train all models
results = {}

for name, model in models.items():
    print(f"\n{'='*60}")
    print(f"Training {name}...")
    print(f"{'='*60}")
    results[name] = train(model, fields, targets, EPOCHS, LR, DEVICE, 
                          PREDICT_DELTA, BATCH_SIZE, HISTORY_LEN)

## 7. Results Analysis

In [None]:
import matplotlib.pyplot as plt

# Group models by category
band_models = ['SBM_2B_None', 'SBM_3B_None', 'SBM_7B_None']
attn_models = ['SBM_3B_None', 'SBM_3B_Temporal', 'SBM_3B_Neighbor', 'SBM_3B_Wormhole', 'SBM_3B_TempNeigh', 'SBM_3B_Full']
seven_band_models = ['SBM_7B_None', 'SBM_7B_Temporal', 'SBM_7B_Full']
flat_models = ['Flat_Baseline', 'Flat_WithAttn']

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Band count comparison
ax = axes[0, 0]
for name in band_models:
    ax.semilogy(results[name], label=f"{name}: {results[name][-1]:.6f}")
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('Band Count Ablation (No Attention)')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 2. Attention type comparison (3-band)
ax = axes[0, 1]
for name in attn_models:
    ax.semilogy(results[name], label=f"{name.replace('SBM_3B_', '')}: {results[name][-1]:.6f}")
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('Attention Type Ablation (3-Band)')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 3. 7-band comparison
ax = axes[1, 0]
for name in seven_band_models:
    ax.semilogy(results[name], label=f"{name.replace('SBM_7B_', '7B_')}: {results[name][-1]:.6f}")
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('7-Band with Attention')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# 4. Best SBM vs Flat
ax = axes[1, 1]
# Find best SBM
sbm_final = {k: v[-1] for k, v in results.items() if k.startswith('SBM')}
best_sbm = min(sbm_final, key=sbm_final.get)
compare = [best_sbm] + flat_models
for name in compare:
    ax.semilogy(results[name], label=f"{name}: {results[name][-1]:.6f}", linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title(f'Best SBM ({best_sbm}) vs Flat Baselines')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Final results table
print("\n" + "="*70)
print("FULL ABLATION RESULTS (No Control Signal - Pure Prediction)")
print("="*70)

# Sort by final loss
sorted_results = sorted(results.items(), key=lambda x: x[1][-1])

print("\n{:25s} {:>12s} {:>12s} {:>10s}".format('Model', 'Final Loss', 'Improvement', 'Rank'))
print("-"*60)
for rank, (name, losses) in enumerate(sorted_results, 1):
    improvement = losses[0] / losses[-1]
    print(f"{name:25s} {losses[-1]:12.6f} {improvement:10.1f}x {rank:>6}")

In [None]:
# Detailed analysis
print("\n" + "="*70)
print("KEY COMPARISONS")
print("="*70)

# 1. Band count effect
print("\n1. BAND COUNT EFFECT (No Attention):")
for name in band_models:
    print(f"   {name}: {results[name][-1]:.6f}")
best_bands = min(band_models, key=lambda x: results[x][-1])
print(f"   Winner: {best_bands}")

# 2. Attention type effect
print("\n2. ATTENTION TYPE EFFECT (3-Band Base):")
base = results['SBM_3B_None'][-1]
for name in attn_models:
    loss = results[name][-1]
    diff = (base - loss) / base * 100
    print(f"   {name.replace('SBM_3B_', ''):12s}: {loss:.6f} ({diff:+.1f}% vs None)")

# 3. SBM vs Flat
print("\n3. ARCHITECTURE COMPARISON:")
sbm_final = {k: v[-1] for k, v in results.items() if k.startswith('SBM')}
flat_final = {k: v[-1] for k, v in results.items() if k.startswith('Flat')}
best_sbm = min(sbm_final, key=sbm_final.get)
best_flat = min(flat_final, key=flat_final.get)
print(f"   Best SBM:  {best_sbm} = {sbm_final[best_sbm]:.6f}")
print(f"   Best Flat: {best_flat} = {flat_final[best_flat]:.6f}")
diff = (sbm_final[best_sbm] - flat_final[best_flat]) / flat_final[best_flat] * 100
winner = "SBM" if diff < 0 else "Flat"
print(f"   Difference: {abs(diff):.1f}% ({winner} wins)")

In [None]:
# Bar chart of all results
fig, ax = plt.subplots(figsize=(14, 6))

names = [x[0] for x in sorted_results]
finals = [x[1][-1] for x in sorted_results]

# Color by category
colors = []
for n in names:
    if 'Flat' in n:
        colors.append('green')
    elif '7B' in n:
        colors.append('darkblue')
    elif '3B' in n:
        colors.append('blue')
    else:
        colors.append('lightblue')

bars = ax.bar(range(len(names)), finals, color=colors)
ax.set_xticks(range(len(names)))
ax.set_xticklabels([n.replace('SBM_', '').replace('_', '\n') for n in names], fontsize=8, rotation=45, ha='right')
ax.set_ylabel('Final Loss (MSE)')
ax.set_title('Full Ablation Results - Sorted by Performance')

for bar, val in zip(bars, finals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{val:.4f}', 
            ha='center', va='bottom', fontsize=7, rotation=90)

# Legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='lightblue', label='2-Band SBM'),
    Patch(facecolor='blue', label='3-Band SBM'),
    Patch(facecolor='darkblue', label='7-Band SBM'),
    Patch(facecolor='green', label='Flat Baseline'),
]
ax.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.show()

In [None]:
# Visualize predictions from best models
print("\nVisualizing predictions from top models...")

# Get sample
idx = 500
x_sample = fields[idx:idx+1].to(DEVICE)
y_sample = targets[idx:idx+1].to(DEVICE)

# Get predictions from top 3 models
top_models = [x[0] for x in sorted_results[:3]]

fig, axes = plt.subplots(2, len(top_models) + 2, figsize=(16, 6))

# Input and target
axes[0, 0].imshow(x_sample[0, 0].cpu().numpy(), cmap='viridis')
axes[0, 0].set_title('Input')
axes[0, 1].imshow(y_sample[0, 0].cpu().numpy(), cmap='viridis')
axes[0, 1].set_title(f'Target (t+{PREDICTION_HORIZON})')
axes[1, 0].axis('off')
axes[1, 1].axis('off')

# Predictions
with torch.no_grad():
    for i, name in enumerate(top_models):
        pred, _ = models[name](x_sample, None)
        err = torch.abs(pred - y_sample)
        
        axes[0, i+2].imshow(pred[0, 0].cpu().numpy(), cmap='viridis')
        axes[0, i+2].set_title(f'{name}\nPrediction')
        
        axes[1, i+2].imshow(err[0, 0].cpu().numpy(), cmap='hot')
        axes[1, i+2].set_title(f'Error\n(MSE: {err.mean():.4f})')

for ax in axes.flat:
    ax.set_xticks([])
    ax.set_yticks([])

plt.suptitle('Top 3 Model Predictions')
plt.tight_layout()
plt.show()

## 8. Summary

### Ablation Dimensions Tested:

1. **Band Count**: 2, 3, 7 bands
2. **Attention Types**: None, Temporal, Neighbor, Wormhole, TempNeigh, Full
3. **Architecture**: SBM vs Flat ConvNet

### Key Questions Answered:

1. Does band count matter for prediction quality?
2. Does per-position temporal attention help (vs pooled)?
3. Does neighbor attention capture local physics better than convolutions?
4. Does wormhole attention provide useful non-local information?
5. Does spectral decomposition (SBM) outperform flat spatial processing?

### Task: Pure Prediction

This experiment uses NO control signal - models must predict future state from
current state alone, exactly like the original winning 032 experiment.

The actuators still affect the physics (control is applied in simulation),
but models don't have access to control signals. They must learn to predict
despite actuator-induced uncertainty.

### v4.1 Fix

Dropped incomplete last batch to maintain consistent tensor shapes for history
management. This ensures wormhole attention can build proper temporal context
across batches within an epoch.