# Experiment 032: Mini AKIRA Plasma Controller

**Purpose**: Test the 7+1 Spectral Belief Machine (SBM) architecture against baselines on a plasma field prediction task.

**Key Features**:
- 7 spectral bands (log-spaced FFT decomposition) + 1 temporal band
- Delta prediction mode (predict change, not state)
- Prediction horizon t+3 (harder task that penalizes identity copying)
- Differential learning rates per band

**Run on Colab with GPU**: Runtime -> Change runtime type -> A100

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GRID_SIZE = 128  # 64, 128, or 256
EPOCHS = 100
DIFFICULTY = "medium"  # easy, medium, hard
PREDICT_DELTA = True
PREDICTION_HORIZON = 3
NUM_TRAJECTORIES = 100
TRAJECTORY_LENGTH = 100
BATCH_SIZE = 32
LR = 0.001

print(f"Device: {DEVICE}")
print(f"Grid size: {GRID_SIZE}x{GRID_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Difficulty: {DIFFICULTY}")
print(f"Prediction mode: {'delta' if PREDICT_DELTA else 'absolute'}")
print(f"Prediction horizon: t+{PREDICTION_HORIZON}")

## 1. Environment: Mini Plasma Simulation

In [None]:
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import time
from dataclasses import dataclass, field

@dataclass
class PlasmaConfig:
    """Configuration for the mini plasma environment.
    
    All spatial parameters (actuator_sigma, etc.) are automatically scaled
    with grid size to maintain consistent dynamics across resolutions.
    Base parameters are defined for 64x64 grid.
    """
    height: int = 64
    width: int = 64
    # Physics (scale-independent)
    diffusion: float = 0.25
    advection: float = 0.08
    noise_std: float = 0.02
    disturbance_prob: float = 0.1
    disturbance_strength: float = 0.15
    # Spatial (base values for 64x64, will be scaled)
    num_actuators: int = 9  # 3x3 grid of actuators
    _base_actuator_sigma: float = 5.0  # Base sigma for 64x64
    # Device
    device: str = "cpu"
    dtype: torch.dtype = torch.float32
    
    @property
    def actuator_sigma(self) -> float:
        """Actuator sigma scaled to grid size (maintains ~8% of grid diameter)."""
        scale = min(self.height, self.width) / 64.0
        return self._base_actuator_sigma * scale
    
    @classmethod
    def easy(cls, device: str = "cpu", size: int = 64) -> "PlasmaConfig":
        return cls(height=size, width=size, diffusion=0.12, advection=0.02, noise_std=0.01,
                   disturbance_prob=0.0, disturbance_strength=0.0, device=device)
    
    @classmethod
    def medium(cls, device: str = "cpu", size: int = 64) -> "PlasmaConfig":
        return cls(height=size, width=size, device=device)
    
    @classmethod
    def hard(cls, device: str = "cpu", size: int = 64) -> "PlasmaConfig":
        return cls(height=size, width=size, diffusion=0.4, advection=0.15, noise_std=0.05,
                   disturbance_prob=0.2, disturbance_strength=0.25, device=device)


class MiniPlasmaEnv:
    """A toy 2D plasma-like environment."""
    
    def __init__(self, cfg: PlasmaConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.dtype = cfg.dtype
        self._actuator_maps = self._build_actuator_maps()
        
    def _build_actuator_maps(self) -> torch.Tensor:
        h, w = self.cfg.height, self.cfg.width
        yy, xx = torch.meshgrid(
            torch.linspace(0, h - 1, h, device=self.device, dtype=self.dtype),
            torch.linspace(0, w - 1, w, device=self.device, dtype=self.dtype),
            indexing="ij",
        )
        grid_n = int((self.cfg.num_actuators) ** 0.5)
        if grid_n * grid_n < self.cfg.num_actuators:
            grid_n += 1
        centers_y = torch.linspace(h * 0.2, h * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers_x = torch.linspace(w * 0.2, w * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers = torch.cartesian_prod(centers_y, centers_x)[:self.cfg.num_actuators]
        sig2 = self.cfg.actuator_sigma ** 2
        bumps = []
        for cy, cx in centers:
            bump = torch.exp(-((yy - cy) ** 2 + (xx - cx) ** 2) / (2 * sig2))
            bumps.append(bump)
        return torch.stack(bumps, dim=0)
    
    def reset(self, batch_size: int = 1) -> torch.Tensor:
        h, w = self.cfg.height, self.cfg.width
        field = torch.zeros(batch_size, 1, h, w, device=self.device, dtype=self.dtype)
        cy = torch.randint(int(h * 0.3), int(h * 0.7), (batch_size,), device=self.device)
        cx = torch.randint(int(w * 0.3), int(w * 0.7), (batch_size,), device=self.device)
        yy, xx = torch.meshgrid(
            torch.arange(h, device=self.device, dtype=self.dtype),
            torch.arange(w, device=self.device, dtype=self.dtype),
            indexing="ij",
        )
        sig2 = (self.cfg.actuator_sigma * 1.5) ** 2
        for b in range(batch_size):
            blob = torch.exp(-((yy - cy[b].float()) ** 2 + (xx - cx[b].float()) ** 2) / (2 * sig2))
            field[b, 0] = blob
        return field
    
    def step(self, field: torch.Tensor, control: torch.Tensor, noise: bool = True) -> torch.Tensor:
        B = field.shape[0]
        # Diffusion
        lap = (-4 * field
               + F.pad(field, (0, 0, 1, 0))[:, :, :-1, :]
               + F.pad(field, (0, 0, 0, 1))[:, :, 1:, :]
               + F.pad(field, (1, 0, 0, 0))[:, :, :, :-1]
               + F.pad(field, (0, 1, 0, 0))[:, :, :, 1:])
        diffused = field + self.cfg.diffusion * lap
        # Advection
        advected = (torch.roll(diffused, shifts=(1, -1), dims=(2, 3)) * self.cfg.advection
                    + diffused * (1 - self.cfg.advection))
        # Actuator forces
        bumps = self._actuator_maps.unsqueeze(0).expand(B, -1, -1, -1)
        control_expanded = control.unsqueeze(-1).unsqueeze(-1)
        force = torch.sum(control_expanded * bumps, dim=1, keepdim=True)
        next_field = advected + force
        # Noise
        if noise and self.cfg.noise_std > 0:
            next_field = next_field + torch.randn_like(next_field) * self.cfg.noise_std
        # Random disturbances
        if self.cfg.disturbance_prob > 0 and torch.rand(1).item() < self.cfg.disturbance_prob:
            h, w = self.cfg.height, self.cfg.width
            cy = torch.randint(int(h * 0.2), int(h * 0.8), (1,)).item()
            cx = torch.randint(int(w * 0.2), int(w * 0.8), (1,)).item()
            yy, xx = torch.meshgrid(
                torch.arange(h, device=self.device, dtype=self.dtype),
                torch.arange(w, device=self.device, dtype=self.dtype),
                indexing="ij",
            )
            sig2 = (self.cfg.actuator_sigma * 1.5) ** 2
            bump = torch.exp(-((yy - cy) ** 2 + (xx - cx) ** 2) / (2 * sig2))
            sign = 2 * (torch.rand(1).item() > 0.5) - 1
            disturbance = sign * self.cfg.disturbance_strength * bump.unsqueeze(0).unsqueeze(0)
            next_field = next_field + disturbance.expand(B, -1, -1, -1)
        return torch.clamp(next_field, min=-1.0, max=1.0)


def generate_trajectories(
    env: MiniPlasmaEnv,
    num_trajectories: int,
    trajectory_length: int,
    control_scale: float = 0.1,
    prediction_horizon: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate training data with multi-step prediction (BATCHED - much faster).
    
    Generates all trajectories in parallel by using batch_size=num_trajectories.
    This leverages GPU parallelism for massive speedup.
    """
    # Generate all trajectories in parallel
    field = env.reset(batch_size=num_trajectories)  # [N, 1, H, W]
    trajectory = [field.clone()]
    all_controls = []
    
    for _ in range(trajectory_length + prediction_horizon):
        control = torch.randn(num_trajectories, env.cfg.num_actuators, 
                              device=env.device, dtype=env.dtype) * control_scale
        control = torch.clamp(control, -1, 1)
        next_field = env.step(field, control, noise=True)
        trajectory.append(next_field.clone())
        all_controls.append(control)
        field = next_field.detach()
    
    # Extract (input, target) pairs for all trajectories
    fields_list = []
    targets_list = []
    controls_list = []
    
    for t in range(trajectory_length):
        fields_list.append(trajectory[t])  # [N, 1, H, W]
        targets_list.append(trajectory[t + prediction_horizon])  # [N, 1, H, W]
        controls_list.append(all_controls[t])  # [N, num_actuators]
    
    # Stack and reshape: [T, N, ...] -> [T*N, ...]
    fields = torch.cat(fields_list, dim=0)  # [T*N, 1, H, W]
    targets = torch.cat(targets_list, dim=0)  # [T*N, 1, H, W]
    controls = torch.cat(controls_list, dim=0)  # [T*N, num_actuators]
    
    return fields, targets, controls

print("Environment defined.")

## 2. Spectral Belief Machine (7+1 Architecture)

In [None]:
@dataclass
class SpectralConfig:
    """Configuration for the Spectral Belief Machine."""
    height: int = 64
    width: int = 64
    num_spectral_bands: int = 7
    channels_per_band: int = 16
    history_len: int = 4
    num_heads: int = 4
    # Windowing OFF by default - causes edge artifacts in loss
    use_windowing: bool = False
    band_lr_multipliers: List[float] = field(default_factory=lambda: [
        0.001, 0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 0.1
    ])
    device: str = "cpu"
    dtype: torch.dtype = torch.float32


def make_hamming_window(h: int, w: int, device: torch.device) -> torch.Tensor:
    wy = torch.hamming_window(h, device=device)
    wx = torch.hamming_window(w, device=device)
    return wy.unsqueeze(1) * wx.unsqueeze(0)


def make_radial_masks(h: int, w: int, num_bands: int, device: torch.device) -> torch.Tensor:
    yy, xx = torch.meshgrid(
        torch.linspace(-1.0, 1.0, h, device=device),
        torch.linspace(-1.0, 1.0, w, device=device),
        indexing="ij",
    )
    rr = torch.sqrt(yy ** 2 + xx ** 2).clamp(min=1e-6)
    edges = torch.logspace(-3, math.log10(math.sqrt(2)), steps=num_bands + 1, device=device)
    edges[0] = 0.0
    masks = []
    for i in range(num_bands):
        lo, hi = edges[i], edges[i + 1]
        mask = torch.sigmoid((rr - lo) * 20) * torch.sigmoid((hi - rr) * 20)
        masks.append(mask)
    masks = torch.stack(masks, dim=0)
    return masks / (masks.sum(dim=0, keepdim=True) + 1e-8)


class PerBandBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(2, channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(channels, 2, kernel_size=1),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class TemporalBand(nn.Module):
    def __init__(self, dim: int, num_heads: int, max_len: int):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.register_buffer("causal_mask", torch.triu(torch.ones(max_len, max_len), diagonal=1).bool())
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, T, D = x.shape
        Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        scores = scores.masked_fill(self.causal_mask[:T, :T].unsqueeze(0).unsqueeze(0), float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn_avg = attn.mean(dim=1)
        entropy = -(attn_avg * torch.log(attn_avg + 1e-9)).sum(dim=-1).mean(dim=1)
        out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, T, D)
        return self.out_proj(out), entropy


class SpectralBeliefMachine(nn.Module):
    """Complete 7+1 Spectral Belief Machine."""
    
    def __init__(self, cfg: SpectralConfig):
        super().__init__()
        self.cfg = cfg
        # Windowing (optional - OFF by default)
        if cfg.use_windowing:
            self.register_buffer("window", make_hamming_window(cfg.height, cfg.width, torch.device(cfg.device)))
        else:
            self.register_buffer("window", torch.ones(cfg.height, cfg.width, device=torch.device(cfg.device)))
        self.register_buffer("masks", make_radial_masks(cfg.height, cfg.width, cfg.num_spectral_bands, torch.device(cfg.device)))
        self.band_blocks = nn.ModuleList([PerBandBlock(cfg.channels_per_band) for _ in range(cfg.num_spectral_bands)])
        temporal_dim = cfg.num_spectral_bands * 2 * 4
        self.temporal_proj_in = nn.Linear(temporal_dim, cfg.channels_per_band * 8)
        self.temporal_band = TemporalBand(cfg.channels_per_band * 8, cfg.num_heads, cfg.history_len)
        self.temporal_proj_out = nn.Linear(cfg.channels_per_band * 8, temporal_dim)
        self.to(cfg.device)
    
    def _fft_decompose(self, x: torch.Tensor) -> List[torch.Tensor]:
        x_windowed = x.squeeze(1) * self.window
        fft = torch.fft.fft2(x_windowed)
        fft_shifted = torch.fft.fftshift(fft)
        bands = []
        for i in range(self.cfg.num_spectral_bands):
            mask = self.masks[i].unsqueeze(0)
            band_fft = fft_shifted * mask
            band_feat = torch.stack([band_fft.real, band_fft.imag], dim=1)
            bands.append(band_feat)
        return bands
    
    def _fft_reconstruct(self, bands: List[torch.Tensor]) -> torch.Tensor:
        fft_recon = torch.zeros(bands[0].shape[0], self.cfg.height, self.cfg.width,
                                dtype=torch.complex64, device=bands[0].device)
        for i, band in enumerate(bands):
            mask = self.masks[i].unsqueeze(0)
            band_complex = torch.complex(band[:, 0], band[:, 1])
            fft_recon = fft_recon + band_complex * mask
        fft_unshifted = torch.fft.ifftshift(fft_recon)
        return torch.fft.ifft2(fft_unshifted).real.unsqueeze(1)
    
    def forward(self, x: torch.Tensor, history: Optional[List[torch.Tensor]] = None
                ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        B = x.shape[0]
        bands = self._fft_decompose(x)
        processed_bands = [self.band_blocks[i](bands[i]) for i in range(self.cfg.num_spectral_bands)]
        
        # Compute band entropies
        band_entropies = []
        for band in processed_bands:
            mag = torch.sqrt(band[:, 0] ** 2 + band[:, 1] ** 2 + 1e-8)
            mag_norm = mag / (mag.sum(dim=[1, 2], keepdim=True) + 1e-8)
            entropy = -(mag_norm * torch.log(mag_norm + 1e-9)).sum(dim=[1, 2])
            band_entropies.append(entropy)
        
        # Pool band features for temporal processing
        pooled_bands = [F.adaptive_avg_pool2d(band, (2, 2)).flatten(1) for band in processed_bands]
        band_feats = torch.cat(pooled_bands, dim=1)
        
        # Temporal processing
        if history is None:
            history = []
        valid_history = [h for h in history if h.ndim == 2 and h.shape[0] == B and h.shape[1] == band_feats.shape[1]]
        history_seq = valid_history + [band_feats]
        while len(history_seq) < self.cfg.history_len:
            history_seq.insert(0, torch.zeros_like(band_feats))
        history_seq = history_seq[-self.cfg.history_len:]
        
        temporal_seq = torch.stack(history_seq, dim=1)
        temporal_in = self.temporal_proj_in(temporal_seq)
        temporal_out, temporal_entropy = self.temporal_band(temporal_in)
        temporal_out = self.temporal_proj_out(temporal_out)
        
        # Reconstruct
        pred = self._fft_reconstruct(processed_bands)
        
        # Belief state
        entropy_tensor = torch.stack(band_entropies, dim=1)
        padded_entropy = F.pad(entropy_tensor, (0, 1))
        padded_entropy[:, -1] = temporal_entropy
        
        belief_state = {
            "band_entropy": padded_entropy,
            "global_entropy": padded_entropy.sum(dim=1),
            "band_features": band_feats,
        }
        return pred, belief_state
    
    def get_lr_groups(self, base_lr: float) -> List[dict]:
        groups = []
        for i, block in enumerate(self.band_blocks):
            groups.append({"params": block.parameters(), "lr": base_lr * self.cfg.band_lr_multipliers[i]})
        groups.append({"params": list(self.temporal_proj_in.parameters()) +
                                  list(self.temporal_band.parameters()) +
                                  list(self.temporal_proj_out.parameters()),
                       "lr": base_lr * self.cfg.band_lr_multipliers[-1]})
        return groups

print("SpectralBeliefMachine defined.")

## 3. Baseline Models

In [None]:
@dataclass
class BaselineConfig:
    height: int = 64
    width: int = 64
    channels: int = 32
    device: str = "cpu"
    dtype: torch.dtype = torch.float32


class FlatBaseline(nn.Module):
    """Flat ConvNet baseline - NO residual skip (fair comparison)."""
    
    def __init__(self, cfg: BaselineConfig):
        super().__init__()
        self.cfg = cfg
        self.encoder = nn.Sequential(
            nn.Conv2d(1, cfg.channels, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(cfg.channels, cfg.channels, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(cfg.channels, cfg.channels * 2, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(cfg.channels * 2, cfg.channels * 2, kernel_size=3, padding=1), nn.GELU(),
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(cfg.channels * 2, cfg.channels * 2, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(cfg.channels * 2, cfg.channels, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(cfg.channels, cfg.channels, kernel_size=3, padding=1), nn.GELU(),
            nn.Conv2d(cfg.channels, 1, kernel_size=3, padding=1),
        )
        self.to(cfg.device)
    
    def forward(self, x: torch.Tensor, history: Optional[List[torch.Tensor]] = None
                ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        pred = self.decoder(self.encoder(x))
        # NO residual skip - model must learn dynamics
        B = x.shape[0]
        belief_state = {
            "band_entropy": torch.zeros(B, 8, device=x.device),
            "global_entropy": torch.zeros(B, device=x.device),
            "band_features": torch.zeros(B, 56, device=x.device),
        }
        return pred, belief_state


class FourBandBaseline(nn.Module):
    """4-band spectral baseline (tests if 7 bands is better than fewer)."""
    
    def __init__(self, cfg: BaselineConfig):
        super().__init__()
        self.cfg = cfg
        self.num_bands = 4
        self.register_buffer("masks", self._make_masks(cfg.height, cfg.width))
        self.band_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(2, cfg.channels, kernel_size=3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, cfg.channels, kernel_size=3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, 2, kernel_size=1),
            ) for _ in range(self.num_bands)
        ])
        self.to(cfg.device)
    
    def _make_masks(self, h: int, w: int) -> torch.Tensor:
        yy, xx = torch.meshgrid(torch.linspace(-1, 1, h), torch.linspace(-1, 1, w), indexing="ij")
        rr = torch.sqrt(yy ** 2 + xx ** 2).clamp(min=1e-6)
        edges = torch.logspace(-2, math.log10(math.sqrt(2)), steps=5)
        edges[0] = 0
        masks = [torch.sigmoid((rr - edges[i]) * 20) * torch.sigmoid((edges[i+1] - rr) * 20) for i in range(4)]
        masks = torch.stack(masks, dim=0)
        return masks / (masks.sum(dim=0, keepdim=True) + 1e-8)
    
    def forward(self, x: torch.Tensor, history: Optional[List[torch.Tensor]] = None
                ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        B = x.shape[0]
        fft = torch.fft.fft2(x.squeeze(1))
        fft_shifted = torch.fft.fftshift(fft)
        processed_bands, band_entropies = [], []
        for i in range(self.num_bands):
            mask = self.masks[i].unsqueeze(0).to(x.device)
            band_fft = fft_shifted * mask
            band_feat = torch.stack([band_fft.real, band_fft.imag], dim=1)
            processed = self.band_blocks[i](band_feat)
            processed_bands.append(processed)
            mag = torch.sqrt(processed[:, 0] ** 2 + processed[:, 1] ** 2 + 1e-8)
            mag_norm = mag / (mag.sum(dim=[1, 2], keepdim=True) + 1e-8)
            band_entropies.append(-(mag_norm * torch.log(mag_norm + 1e-9)).sum(dim=[1, 2]))
        
        fft_recon = torch.zeros_like(fft_shifted)
        for i, band in enumerate(processed_bands):
            mask = self.masks[i].unsqueeze(0).to(x.device)
            fft_recon = fft_recon + torch.complex(band[:, 0], band[:, 1]) * mask
        pred = torch.fft.ifft2(torch.fft.ifftshift(fft_recon)).real.unsqueeze(1)
        
        entropy_tensor = torch.stack(band_entropies, dim=1)
        belief_state = {
            "band_entropy": F.pad(entropy_tensor, (0, 4)),
            "global_entropy": entropy_tensor.sum(dim=1),
            "band_features": torch.zeros(B, 56, device=x.device),
        }
        return pred, belief_state

print("Baselines defined.")

## 4. Training Function

In [None]:
def train_predictor(
    model: nn.Module,
    train_data: Tuple[torch.Tensor, torch.Tensor],
    epochs: int,
    lr: float,
    device: str,
    use_differential_lr: bool = False,
    predict_delta: bool = True,
    batch_size: int = 32,
) -> Dict[str, List[float]]:
    """Train a predictor on next-frame prediction."""
    fields, next_fields = train_data
    fields = fields.to(device)
    next_fields = next_fields.to(device)
    
    if use_differential_lr and hasattr(model, 'get_lr_groups'):
        optimizer = optim.Adam(model.get_lr_groups(lr))
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)
    
    metrics = {"loss": [], "epoch_time": []}
    num_samples = fields.shape[0]
    
    for epoch in range(epochs):
        epoch_start = time.time()
        epoch_losses = []
        history = []
        perm = torch.randperm(num_samples)
        
        for i in range(0, num_samples, batch_size):
            idx = perm[i:i + batch_size]
            x = fields[idx]
            y = next_fields[idx]
            
            pred, belief = model(x, history if history else None)
            
            if predict_delta:
                delta_target = y - x
                delta_pred = pred - x
                loss = F.mse_loss(delta_pred, delta_target)
            else:
                loss = F.mse_loss(pred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
            
            if "band_features" in belief:
                history = (history + [belief["band_features"].detach()])[-4:]
        
        epoch_time = time.time() - epoch_start
        mean_loss = sum(epoch_losses) / len(epoch_losses)
        metrics["loss"].append(mean_loss)
        metrics["epoch_time"].append(epoch_time)
        
        print(f"  Epoch {epoch + 1}/{epochs}: loss={mean_loss:.6f}, time={epoch_time:.2f}s")
    
    return metrics

print("Training function defined.")

## 5. Run Experiment

In [None]:
# Initialize environment
print("[1] Initializing environment...")
if DIFFICULTY == "easy":
    plasma_cfg = PlasmaConfig.easy(device=DEVICE, size=GRID_SIZE)
elif DIFFICULTY == "hard":
    plasma_cfg = PlasmaConfig.hard(device=DEVICE, size=GRID_SIZE)
else:
    plasma_cfg = PlasmaConfig.medium(device=DEVICE, size=GRID_SIZE)

print(f"    Diffusion: {plasma_cfg.diffusion}, Advection: {plasma_cfg.advection}")
print(f"    Noise: {plasma_cfg.noise_std}, Disturbance: {plasma_cfg.disturbance_prob}@{plasma_cfg.disturbance_strength}")
print(f"    Actuator sigma: {plasma_cfg.actuator_sigma:.1f} (scaled from base {plasma_cfg._base_actuator_sigma} for {GRID_SIZE}x{GRID_SIZE})")

env = MiniPlasmaEnv(plasma_cfg)

In [None]:
# Generate training data
print("[2] Generating training data...")
fields, next_fields, controls = generate_trajectories(
    env, NUM_TRAJECTORIES, TRAJECTORY_LENGTH,
    control_scale=0.1,
    prediction_horizon=PREDICTION_HORIZON
)
print(f"    Generated {fields.shape[0]} samples (predicting t+{PREDICTION_HORIZON})")

train_data = (fields, next_fields)

In [None]:
# Initialize models
print("[3] Initializing models...")

spectral_cfg = SpectralConfig(height=GRID_SIZE, width=GRID_SIZE, device=DEVICE)
sbm = SpectralBeliefMachine(spectral_cfg)
print(f"    SpectralBeliefMachine: {sum(p.numel() for p in sbm.parameters()):,} params")

baseline_cfg = BaselineConfig(height=GRID_SIZE, width=GRID_SIZE, device=DEVICE)
flat = FlatBaseline(baseline_cfg)
four_band = FourBandBaseline(baseline_cfg)

print(f"    FlatBaseline: {sum(p.numel() for p in flat.parameters()):,} params")
print(f"    FourBandBaseline: {sum(p.numel() for p in four_band.parameters()):,} params")

# Verify GPU usage
print(f"\n[GPU CHECK]")
print(f"    Target device: {DEVICE}")
print(f"    SBM on: {next(sbm.parameters()).device}")
print(f"    Flat on: {next(flat.parameters()).device}")
print(f"    Training data on: {fields.device}")
if DEVICE == 'cuda' and torch.cuda.is_available():
    print(f"    GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated")

In [None]:
# Train SpectralBeliefMachine (7+1)
print("\n" + "="*60)
print("Training SpectralBeliefMachine (7+1)...")
print("="*60)

sbm_metrics = train_predictor(
    sbm, train_data, EPOCHS, lr=LR,
    device=DEVICE, use_differential_lr=True,
    predict_delta=PREDICT_DELTA, batch_size=BATCH_SIZE
)

In [None]:
# Train FlatBaseline
print("\n" + "="*60)
print("Training FlatBaseline...")
print("="*60)

flat_metrics = train_predictor(
    flat, train_data, EPOCHS, lr=LR,
    device=DEVICE, use_differential_lr=False,
    predict_delta=PREDICT_DELTA, batch_size=BATCH_SIZE
)

In [None]:
# Train FourBandBaseline
print("\n" + "="*60)
print("Training FourBandBaseline...")
print("="*60)

four_band_metrics = train_predictor(
    four_band, train_data, EPOCHS, lr=LR,
    device=DEVICE, use_differential_lr=False,
    predict_delta=PREDICT_DELTA, batch_size=BATCH_SIZE
)

## 6. Results Comparison

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax = axes[0]
ax.plot(sbm_metrics["loss"], label=f"SBM (7+1) - final: {sbm_metrics['loss'][-1]:.6f}", linewidth=2)
ax.plot(flat_metrics["loss"], label=f"FlatBaseline - final: {flat_metrics['loss'][-1]:.6f}", linewidth=2)
ax.plot(four_band_metrics["loss"], label=f"FourBandBaseline - final: {four_band_metrics['loss'][-1]:.6f}", linewidth=2)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss (MSE)")
ax.set_title("Training Loss Curves")
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

# Time per epoch
ax = axes[1]
models = ["SBM (7+1)", "FlatBaseline", "FourBand"]
times = [sum(sbm_metrics["epoch_time"])/EPOCHS, 
         sum(flat_metrics["epoch_time"])/EPOCHS,
         sum(four_band_metrics["epoch_time"])/EPOCHS]
ax.bar(models, times, color=['tab:blue', 'tab:orange', 'tab:green'])
ax.set_ylabel("Seconds per Epoch")
ax.set_title("Training Speed")
for i, t in enumerate(times):
    ax.text(i, t + 1, f"{t:.1f}s", ha='center')

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("FINAL RESULTS")
print("="*60)
print(f"SpectralBeliefMachine (7+1): {sbm_metrics['loss'][-1]:.6f}")
print(f"FlatBaseline:                {flat_metrics['loss'][-1]:.6f}")
print(f"FourBandBaseline:            {four_band_metrics['loss'][-1]:.6f}")
print()
print(f"SBM improvement over epochs: {sbm_metrics['loss'][0]:.4f} -> {sbm_metrics['loss'][-1]:.4f} ({sbm_metrics['loss'][0]/sbm_metrics['loss'][-1]:.1f}x)")
print(f"Flat improvement over epochs: {flat_metrics['loss'][0]:.4f} -> {flat_metrics['loss'][-1]:.4f} ({flat_metrics['loss'][0]/flat_metrics['loss'][-1]:.1f}x)")

In [None]:
# Visualize predictions
print("\nVisualizing sample predictions...")

# Get a sample
idx = 500
x_sample = fields[idx:idx+1].to(DEVICE)
y_sample = next_fields[idx:idx+1].to(DEVICE)

with torch.no_grad():
    sbm_pred, _ = sbm(x_sample, None)
    flat_pred, _ = flat(x_sample, None)
    four_pred, _ = four_band(x_sample, None)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

# Row 1: Fields
axes[0, 0].imshow(x_sample[0, 0].cpu().numpy(), cmap='viridis')
axes[0, 0].set_title('Input (t)')
axes[0, 1].imshow(y_sample[0, 0].cpu().numpy(), cmap='viridis')
axes[0, 1].set_title(f'Target (t+{PREDICTION_HORIZON})')
axes[0, 2].imshow(sbm_pred[0, 0].cpu().numpy(), cmap='viridis')
axes[0, 2].set_title('SBM Prediction')
axes[0, 3].imshow(flat_pred[0, 0].cpu().numpy(), cmap='viridis')
axes[0, 3].set_title('Flat Prediction')

# Row 2: Errors
axes[1, 0].axis('off')
axes[1, 1].axis('off')
sbm_err = torch.abs(sbm_pred - y_sample)[0, 0].cpu().numpy()
flat_err = torch.abs(flat_pred - y_sample)[0, 0].cpu().numpy()
vmax = max(sbm_err.max(), flat_err.max())
im1 = axes[1, 2].imshow(sbm_err, cmap='hot', vmin=0, vmax=vmax)
axes[1, 2].set_title(f'SBM Error (MSE: {sbm_err.mean():.4f})')
im2 = axes[1, 3].imshow(flat_err, cmap='hot', vmin=0, vmax=vmax)
axes[1, 3].set_title(f'Flat Error (MSE: {flat_err.mean():.4f})')

for ax in axes.flat:
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()

## 7. Summary

Key observations:
1. **SBM continues learning** - loss keeps dropping even after 30 epochs
2. **FlatBaseline plateaus early** - learns quickly but stops improving
3. **SBM has fewer parameters** (99K vs 130K) but competitive performance
4. **Delta prediction mode** prevents identity copying shortcuts
5. **Differential learning rates** allow progressive coarse-to-fine learning