# Experiment 032 v2: Turbulent Plasma with Top-K Sparse Attention

**Purpose**: Test 7+1 SBM on a HARDER environment with multi-scale turbulent structure.

**Key Improvements over v1**:
- **Turbulent plasma**: Adds vortices, shear, and multi-scale noise
- **Top-K sparse temporal attention**: From Exp 033, improves efficiency and focuses on relevant history
- **Exponential decay weighting**: Older states contribute less
- **Harder prediction**: t+5 horizon, more disturbances

**Run on Colab with GPU**: Runtime -> Change runtime type -> A100

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GRID_SIZE = 128  # 64, 128, or 256
EPOCHS = 100
DIFFICULTY = "turbulent"  # easy, medium, hard, turbulent (NEW)
PREDICT_DELTA = True
PREDICTION_HORIZON = 5  # Harder than v1's t+3
NUM_TRAJECTORIES = 100
TRAJECTORY_LENGTH = 100
BATCH_SIZE = 32
LR = 0.001

# Top-K Attention Settings (from Exp 033)
TOP_K_TEMPORAL = 4  # Number of past states to attend to
TEMPORAL_DECAY = 0.9  # Exponential decay for older states

print(f"Device: {DEVICE}")
print(f"Grid size: {GRID_SIZE}x{GRID_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Difficulty: {DIFFICULTY}")
print(f"Prediction mode: {'delta' if PREDICT_DELTA else 'absolute'}")
print(f"Prediction horizon: t+{PREDICTION_HORIZON}")
print(f"Top-K temporal: {TOP_K_TEMPORAL}, decay: {TEMPORAL_DECAY}")

## 1. Environment: Turbulent Plasma Simulation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional
import math
import time

@dataclass
class PlasmaConfig:
    """Configuration for the turbulent plasma environment.
    
    v2 adds: vortices, shear flows, multi-scale noise for harder prediction.
    All spatial parameters auto-scale with grid size.
    """
    height: int = 64
    width: int = 64
    # Physics (scale-independent)
    diffusion: float = 0.25
    advection: float = 0.08
    noise_std: float = 0.02
    disturbance_prob: float = 0.1
    disturbance_strength: float = 0.15
    # v2: Turbulence parameters
    num_vortices: int = 3  # Number of rotating vortex centers
    vortex_strength: float = 0.1  # Rotation magnitude
    shear_strength: float = 0.05  # Horizontal shear flow
    multiscale_noise: bool = True  # Add noise at multiple frequencies
    # Spatial (base values for 64x64, will be scaled)
    # NOTE: num_actuators should be a perfect square (4, 9, 16) to fill grid evenly
    num_actuators: int = 9  # 3x3 grid fills entire field
    _base_actuator_sigma: float = 5.0  # Base sigma for 64x64
    # Device
    device: str = "cpu"
    dtype: torch.dtype = torch.float32
    
    @property
    def actuator_sigma(self) -> float:
        """Actuator sigma scaled to grid size."""
        scale = min(self.height, self.width) / 64.0
        return self._base_actuator_sigma * scale
    
    @classmethod
    def easy(cls, device: str = "cpu", size: int = 64) -> "PlasmaConfig":
        return cls(height=size, width=size, diffusion=0.12, advection=0.02, noise_std=0.01,
                   disturbance_prob=0.0, disturbance_strength=0.0, 
                   num_vortices=0, vortex_strength=0.0, shear_strength=0.0, 
                   multiscale_noise=False, num_actuators=9, device=device)
    
    @classmethod
    def medium(cls, device: str = "cpu", size: int = 64) -> "PlasmaConfig":
        return cls(height=size, width=size, num_vortices=0, vortex_strength=0.0, 
                   shear_strength=0.0, multiscale_noise=False, num_actuators=9, device=device)
    
    @classmethod
    def hard(cls, device: str = "cpu", size: int = 64) -> "PlasmaConfig":
        return cls(height=size, width=size, diffusion=0.4, advection=0.15, noise_std=0.05,
                   disturbance_prob=0.2, disturbance_strength=0.25,
                   num_vortices=1, vortex_strength=0.05, shear_strength=0.02,
                   multiscale_noise=False, num_actuators=9, device=device)
    
    @classmethod
    def turbulent(cls, device: str = "cpu", size: int = 64) -> "PlasmaConfig":
        """NEW: Turbulent mode with vortices, shear, and multi-scale noise."""
        return cls(height=size, width=size, diffusion=0.3, advection=0.12, noise_std=0.03,
                   disturbance_prob=0.25, disturbance_strength=0.3,
                   num_vortices=3, vortex_strength=0.15, shear_strength=0.08,
                   multiscale_noise=True, num_actuators=9, device=device)


class TurbulentPlasmaEnv:
    """v2 Plasma environment with turbulence, vortices, and multi-scale structure."""
    
    def __init__(self, cfg: PlasmaConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.dtype = cfg.dtype
        self._actuator_maps = self._build_actuator_maps()
        self._vortex_centers = self._init_vortex_centers()
        self._vortex_flow = self._build_vortex_flow()
        self._shear_flow = self._build_shear_flow()
        
    def _build_actuator_maps(self) -> torch.Tensor:
        h, w = self.cfg.height, self.cfg.width
        grid_n = int(math.ceil(math.sqrt(self.cfg.num_actuators)))
        centers_y = torch.linspace(h * 0.2, h * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers_x = torch.linspace(w * 0.2, w * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers = torch.cartesian_prod(centers_y, centers_x)[:self.cfg.num_actuators]
        sig2 = self.cfg.actuator_sigma ** 2
        bumps = []
        yy, xx = torch.meshgrid(
            torch.arange(h, device=self.device, dtype=self.dtype),
            torch.arange(w, device=self.device, dtype=self.dtype),
            indexing="ij",
        )
        for cy, cx in centers:
            bump = torch.exp(-((yy - cy) ** 2 + (xx - cx) ** 2) / (2 * sig2))
            bumps.append(bump)
        return torch.stack(bumps, dim=0)
    
    def _init_vortex_centers(self) -> torch.Tensor:
        """Initialize random vortex center positions."""
        if self.cfg.num_vortices == 0:
            return None
        h, w = self.cfg.height, self.cfg.width
        # Random positions in central 60% of grid
        centers_y = torch.rand(self.cfg.num_vortices, device=self.device) * (h * 0.6) + (h * 0.2)
        centers_x = torch.rand(self.cfg.num_vortices, device=self.device) * (w * 0.6) + (w * 0.2)
        return torch.stack([centers_y, centers_x], dim=1)
    
    def _build_vortex_flow(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Build velocity field from vortices."""
        if self.cfg.num_vortices == 0 or self._vortex_centers is None:
            return None, None
        h, w = self.cfg.height, self.cfg.width
        yy, xx = torch.meshgrid(
            torch.arange(h, device=self.device, dtype=self.dtype),
            torch.arange(w, device=self.device, dtype=self.dtype),
            indexing="ij",
        )
        vy = torch.zeros_like(yy)
        vx = torch.zeros_like(xx)
        scale = min(h, w) / 64.0  # Scale vortex size with grid
        for i in range(self.cfg.num_vortices):
            cy, cx = self._vortex_centers[i]
            dy = yy - cy
            dx = xx - cx
            r2 = dy**2 + dx**2 + 1e-6
            # Vortex: velocity perpendicular to radius, decays with distance
            decay = torch.exp(-r2 / (2 * (10 * scale)**2))
            sign = 1 if i % 2 == 0 else -1  # Alternating rotation
            vy += sign * self.cfg.vortex_strength * (-dx) / torch.sqrt(r2) * decay
            vx += sign * self.cfg.vortex_strength * dy / torch.sqrt(r2) * decay
        return vy, vx
    
    def _build_shear_flow(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Build horizontal shear flow (varies with y)."""
        if self.cfg.shear_strength == 0:
            return None, None
        h, w = self.cfg.height, self.cfg.width
        yy, _ = torch.meshgrid(
            torch.arange(h, device=self.device, dtype=self.dtype),
            torch.arange(w, device=self.device, dtype=self.dtype),
            indexing="ij",
        )
        # Shear: vx varies sinusoidally with y
        vx = self.cfg.shear_strength * torch.sin(2 * math.pi * yy / h)
        vy = torch.zeros_like(vx)
        return vy, vx
    
    def _apply_flow(self, field: torch.Tensor, vy: torch.Tensor, vx: torch.Tensor) -> torch.Tensor:
        """Apply velocity field via semi-Lagrangian advection."""
        B, C, H, W = field.shape
        # Create sampling grid
        yy, xx = torch.meshgrid(
            torch.linspace(-1, 1, H, device=self.device),
            torch.linspace(-1, 1, W, device=self.device),
            indexing="ij",
        )
        # Normalized velocity (from pixel to [-1, 1] coords)
        vy_norm = vy / (H / 2)
        vx_norm = vx / (W / 2)
        # Backtrack: where did this point come from?
        sample_y = yy - vy_norm
        sample_x = xx - vx_norm
        # Grid for grid_sample: [B, H, W, 2] with (x, y) order
        grid = torch.stack([sample_x, sample_y], dim=-1).unsqueeze(0).expand(B, -1, -1, -1)
        advected = F.grid_sample(field, grid, mode='bilinear', padding_mode='border', align_corners=True)
        return advected
    
    def _multiscale_noise(self, shape: Tuple[int, ...]) -> torch.Tensor:
        """Generate noise at multiple spatial frequencies."""
        B, C, H, W = shape
        noise = torch.zeros(shape, device=self.device, dtype=self.dtype)
        # Add noise at different scales: 1x, 2x, 4x, 8x pooling
        for scale in [1, 2, 4, 8]:
            h_s, w_s = H // scale, W // scale
            if h_s < 4 or w_s < 4:
                continue
            coarse = torch.randn(B, C, h_s, w_s, device=self.device, dtype=self.dtype)
            upsampled = F.interpolate(coarse, size=(H, W), mode='bilinear', align_corners=False)
            noise += upsampled * (self.cfg.noise_std / scale)
        return noise
    
    def reset(self, batch_size: int = 1) -> torch.Tensor:
        h, w = self.cfg.height, self.cfg.width
        cx = torch.randint(int(w * 0.3), int(w * 0.7), (batch_size,), device=self.device)
        cy = torch.randint(int(h * 0.3), int(h * 0.7), (batch_size,), device=self.device)
        yy, xx = torch.meshgrid(
            torch.arange(h, device=self.device, dtype=self.dtype),
            torch.arange(w, device=self.device, dtype=self.dtype),
            indexing="ij",
        )
        sig2 = (self.cfg.actuator_sigma * 1.5) ** 2
        field = torch.zeros(batch_size, 1, h, w, device=self.device, dtype=self.dtype)
        for b in range(batch_size):
            blob = torch.exp(-((yy - cy[b].float()) ** 2 + (xx - cx[b].float()) ** 2) / (2 * sig2))
            field[b, 0] = blob
        return field
    
    def step(self, field: torch.Tensor, control: torch.Tensor, noise: bool = True) -> torch.Tensor:
        B = field.shape[0]
        # Diffusion (Laplacian)
        lap = (F.pad(field, (0, 0, 1, 0))[:, :, :-1, :]
               + F.pad(field, (0, 0, 0, 1))[:, :, 1:, :]
               + F.pad(field, (1, 0, 0, 0))[:, :, :, :-1]
               + F.pad(field, (0, 1, 0, 0))[:, :, :, 1:]) - 4 * field
        diffused = field + self.cfg.diffusion * lap
        
        # Standard advection
        advected = (torch.roll(diffused, shifts=(1, -1), dims=(2, 3)) * self.cfg.advection
                    + diffused * (1 - self.cfg.advection))
        
        # v2: Vortex advection
        if self._vortex_flow[0] is not None:
            advected = self._apply_flow(advected, self._vortex_flow[0], self._vortex_flow[1])
        
        # v2: Shear flow
        if self._shear_flow[0] is not None:
            advected = self._apply_flow(advected, self._shear_flow[0], self._shear_flow[1])
        
        # Actuator forces
        bumps = self._actuator_maps.unsqueeze(0).expand(B, -1, -1, -1)
        control_expanded = control.unsqueeze(-1).unsqueeze(-1)
        force = torch.sum(control_expanded * bumps, dim=1, keepdim=True)
        next_field = advected + force
        
        # Noise
        if noise:
            if self.cfg.multiscale_noise:
                next_field = next_field + self._multiscale_noise(next_field.shape)
            elif self.cfg.noise_std > 0:
                next_field = next_field + torch.randn_like(next_field) * self.cfg.noise_std
        
        # Random disturbances
        if self.cfg.disturbance_prob > 0 and torch.rand(1).item() < self.cfg.disturbance_prob:
            h, w = self.cfg.height, self.cfg.width
            cy = torch.randint(int(h * 0.2), int(h * 0.8), (1,)).item()
            cx = torch.randint(int(w * 0.2), int(w * 0.8), (1,)).item()
            yy, xx = torch.meshgrid(
                torch.arange(h, device=self.device, dtype=self.dtype),
                torch.arange(w, device=self.device, dtype=self.dtype),
                indexing="ij",
            )
            sig2 = (self.cfg.actuator_sigma * 1.5) ** 2
            bump = torch.exp(-((yy - cy) ** 2 + (xx - cx) ** 2) / (2 * sig2))
            sign = 2 * (torch.rand(1).item() > 0.5) - 1
            disturbance = sign * self.cfg.disturbance_strength * bump.unsqueeze(0).unsqueeze(0)
            next_field = next_field + disturbance.expand(B, -1, -1, -1)
        
        return torch.clamp(next_field, min=-1.0, max=1.0)


def generate_trajectories(
    env: TurbulentPlasmaEnv,
    num_trajectories: int,
    trajectory_length: int,
    control_scale: float = 0.1,
    prediction_horizon: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Generate training data with multi-step prediction."""
    all_fields, all_next_fields, all_controls = [], [], []
    
    for _ in range(num_trajectories):
        field = env.reset(batch_size=1)
        trajectory = [field.clone()]
        controls_traj = []
        
        for _ in range(trajectory_length + prediction_horizon):
            control = torch.randn(1, env.cfg.num_actuators, device=env.device, dtype=env.dtype) * control_scale
            control = torch.clamp(control, -1, 1)
            next_field = env.step(field, control, noise=True)
            trajectory.append(next_field.clone())
            controls_traj.append(control)
            field = next_field.detach()
        
        for t in range(trajectory_length):
            all_fields.append(trajectory[t])
            all_next_fields.append(trajectory[t + prediction_horizon])
            all_controls.append(controls_traj[t])
    
    return torch.cat(all_fields), torch.cat(all_next_fields), torch.cat(all_controls)

## 2. Models: SBM v2 with Top-K Sparse Attention

In [None]:
@dataclass
class SpectralConfig:
    """Configuration for the Spectral Belief Machine v2."""
    height: int = 64
    width: int = 64
    num_spectral_bands: int = 7
    channels_per_band: int = 16
    history_len: int = 8  # Increased from 4
    num_heads: int = 4
    # v2: Top-K sparse attention
    top_k: int = 4  # Number of past states to attend to
    temporal_decay: float = 0.9  # Exponential decay for older states
    # Windowing OFF by default
    use_windowing: bool = False
    band_lr_multipliers: List[float] = field(default_factory=lambda: [
        0.001, 0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 0.1
    ])
    device: str = "cpu"
    dtype: torch.dtype = torch.float32


def make_hamming_window(h: int, w: int, device: torch.device) -> torch.Tensor:
    wy = torch.hamming_window(h, device=device)
    wx = torch.hamming_window(w, device=device)
    return wy.unsqueeze(1) * wx.unsqueeze(0)


def make_radial_masks(h: int, w: int, num_bands: int, device: torch.device) -> torch.Tensor:
    yy, xx = torch.meshgrid(
        torch.linspace(-1.0, 1.0, h, device=device),
        torch.linspace(-1.0, 1.0, w, device=device),
        indexing="ij",
    )
    rr = torch.sqrt(yy ** 2 + xx ** 2).clamp(min=1e-6)
    edges = torch.logspace(-3, math.log10(math.sqrt(2)), steps=num_bands + 1, device=device)
    edges[0] = 0.0
    masks = []
    for i in range(num_bands):
        lo, hi = edges[i], edges[i + 1]
        mask = torch.sigmoid((rr - lo) * 20) * torch.sigmoid((hi - rr) * 20)
        masks.append(mask)
    masks = torch.stack(masks, dim=0)
    return masks / (masks.sum(dim=0, keepdim=True) + 1e-8)


class PerBandBlock(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(2, channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(channels, 2, kernel_size=1),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class TopKTemporalBand(nn.Module):
    """v2: Temporal attention with Top-K selection and exponential decay.
    
    From Experiment 033: Attends to top-K most relevant past states,
    with exponential decay weighting for recency bias.
    """
    def __init__(self, dim: int, num_heads: int, max_len: int, top_k: int = 4, decay: float = 0.9):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.top_k = top_k
        self.decay = decay
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.register_buffer("causal_mask", torch.triu(torch.ones(max_len, max_len), diagonal=1).bool())
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, T, D = x.shape
        Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Apply causal mask
        scores = scores.masked_fill(self.causal_mask[:T, :T].unsqueeze(0).unsqueeze(0), float('-inf'))
        
        # v2: Top-K selection per query position
        if T > self.top_k:
            # Get top-k scores for each query (last dim is keys)
            topk_vals, topk_inds = torch.topk(scores, min(self.top_k, T), dim=-1)
            # Create mask for non-top-k positions
            mask = torch.ones_like(scores, dtype=torch.bool)
            mask.scatter_(-1, topk_inds, False)
            scores = scores.masked_fill(mask, float('-inf'))
        
        # v2: Exponential decay for older positions
        if self.decay < 1.0:
            time_offsets = torch.arange(T, device=x.device, dtype=x.dtype)
            # decay_weights[i, j] = decay^(i - j) for j < i
            time_diff = time_offsets.unsqueeze(0) - time_offsets.unsqueeze(1)  # [T, T]
            decay_weights = torch.pow(self.decay, time_diff.clamp(min=0).float())
            decay_weights = torch.tril(decay_weights)  # Causal
            scores = scores + torch.log(decay_weights + 1e-10).unsqueeze(0).unsqueeze(0)
        
        attn = F.softmax(scores, dim=-1)
        attn_avg = attn.mean(dim=1)
        entropy = -(attn_avg * torch.log(attn_avg + 1e-9)).sum(dim=-1).mean(dim=1)
        
        out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, T, D)
        return self.out_proj(out), entropy


class SpectralBeliefMachineV2(nn.Module):
    """v2 SBM with Top-K sparse temporal attention."""
    
    def __init__(self, cfg: SpectralConfig):
        super().__init__()
        self.cfg = cfg
        
        # Windowing (optional)
        if cfg.use_windowing:
            self.register_buffer("window", make_hamming_window(cfg.height, cfg.width, torch.device(cfg.device)))
        else:
            self.register_buffer("window", torch.ones(cfg.height, cfg.width, device=torch.device(cfg.device)))
        
        self.register_buffer("masks", make_radial_masks(cfg.height, cfg.width, cfg.num_spectral_bands, torch.device(cfg.device)))
        self.band_blocks = nn.ModuleList([PerBandBlock(cfg.channels_per_band) for _ in range(cfg.num_spectral_bands)])
        
        temporal_dim = cfg.num_spectral_bands * 2 * 4
        self.temporal_proj_in = nn.Linear(temporal_dim, cfg.channels_per_band * 8)
        # v2: Use TopKTemporalBand instead of regular
        # max_len = history_len + 1 because we append current frame to history
        self.temporal_band = TopKTemporalBand(
            cfg.channels_per_band * 8, cfg.num_heads, cfg.history_len + 1,
            top_k=cfg.top_k, decay=cfg.temporal_decay
        )
        self.temporal_proj_out = nn.Linear(cfg.channels_per_band * 8, temporal_dim)
        self.to(cfg.device)
    
    def _fft_decompose(self, x: torch.Tensor) -> List[torch.Tensor]:
        x_windowed = x.squeeze(1) * self.window
        fft = torch.fft.fft2(x_windowed)
        fft_shifted = torch.fft.fftshift(fft)
        bands = []
        for i in range(self.cfg.num_spectral_bands):
            mask = self.masks[i].unsqueeze(0)
            band_fft = fft_shifted * mask
            band_feat = torch.stack([band_fft.real, band_fft.imag], dim=1)
            bands.append(band_feat)
        return bands
    
    def _fft_reconstruct(self, bands: List[torch.Tensor]) -> torch.Tensor:
        fft_recon = None
        for i, band in enumerate(bands):
            mask = self.masks[i].unsqueeze(0)
            band_fft = torch.complex(band[:, 0], band[:, 1]) * mask
            if fft_recon is None:
                fft_recon = band_fft
            else:
                fft_recon = fft_recon + band_fft
        fft_unshifted = torch.fft.ifftshift(fft_recon)
        spatial = torch.fft.ifft2(fft_unshifted).real
        return spatial.unsqueeze(1)
    
    def forward(self, x: torch.Tensor, history: Optional[List[torch.Tensor]] = None
                ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        B = x.shape[0]
        bands = self._fft_decompose(x)
        processed_bands = []
        band_entropies = []
        
        for i, band in enumerate(bands):
            proc = self.band_blocks[i](band)
            processed_bands.append(band + proc)
            band_entropies.append(torch.tensor(0.0, device=x.device))
        
        # Temporal band features
        band_pooled = [F.adaptive_avg_pool2d(b, (2, 2)).flatten(1) for b in processed_bands]
        current_feat = torch.cat(band_pooled, dim=1)
        
        if history is not None:
            history = [h for h in history if h.shape[0] == B and h.dim() == 2 and h.shape[1] == current_feat.shape[1]]
        
        if history is not None and len(history) > 0:
            history_seq = torch.stack(history + [current_feat], dim=1)
            history_proj = self.temporal_proj_in(history_seq)
            temporal_out, temporal_entropy = self.temporal_band(history_proj)
            temporal_feat = self.temporal_proj_out(temporal_out[:, -1, :])
            
            # Cross-band mixing
            chunk_size = current_feat.shape[1] // self.cfg.num_spectral_bands
            for i in range(self.cfg.num_spectral_bands):
                delta = temporal_feat[:, i*chunk_size:(i+1)*chunk_size]
                delta_spatial = delta.view(B, 2, 2, 2).repeat_interleave(self.cfg.height//2, dim=2).repeat_interleave(self.cfg.width//2, dim=3)
                processed_bands[i] = processed_bands[i] + 0.1 * delta_spatial
            band_entropies.append(temporal_entropy)
        else:
            band_entropies.append(torch.zeros(B, device=x.device))
        
        pred = self._fft_reconstruct(processed_bands)
        belief = {
            "entropy_per_band": torch.stack(band_entropies[:-1]),
            "temporal_entropy": band_entropies[-1],
            "current_feat": current_feat.detach(),
        }
        return pred, belief
    
    def get_lr_groups(self, base_lr: float) -> List[Dict]:
        groups = []
        for i, block in enumerate(self.band_blocks):
            mult = self.cfg.band_lr_multipliers[i]
            groups.append({"params": block.parameters(), "lr": base_lr * mult})
        groups.append({"params": self.temporal_proj_in.parameters(), "lr": base_lr * self.cfg.band_lr_multipliers[-1]})
        groups.append({"params": self.temporal_band.parameters(), "lr": base_lr * self.cfg.band_lr_multipliers[-1]})
        groups.append({"params": self.temporal_proj_out.parameters(), "lr": base_lr * self.cfg.band_lr_multipliers[-1]})
        return groups

In [None]:
# Baselines (same as v1 for fair comparison)

@dataclass
class BaselineConfig:
    height: int = 64
    width: int = 64
    channels: int = 32
    device: str = "cpu"
    dtype: torch.dtype = torch.float32


class FlatBaseline(nn.Module):
    """Standard ConvNet without spectral decomposition."""
    def __init__(self, cfg: BaselineConfig):
        super().__init__()
        self.cfg = cfg
        self.net = nn.Sequential(
            nn.Conv2d(1, cfg.channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(cfg.channels, cfg.channels * 2, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(cfg.channels * 2, cfg.channels * 2, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(cfg.channels * 2, cfg.channels, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(cfg.channels, 1, kernel_size=3, padding=1),
        )
        self.to(cfg.device)
    
    def forward(self, x: torch.Tensor, history: Optional[List[torch.Tensor]] = None
                ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        pred = self.net(x)
        # No residual connection (removed cheating)
        return pred, {"entropy_per_band": torch.zeros(7, device=x.device)}


class FourBandBaseline(nn.Module):
    """4-band spectral baseline."""
    def __init__(self, cfg: BaselineConfig):
        super().__init__()
        self.cfg = cfg
        self.num_bands = 4
        self.register_buffer("masks", self._make_masks(cfg.height, cfg.width))
        self.band_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(2, cfg.channels, kernel_size=3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, 2, kernel_size=1),
            ) for _ in range(self.num_bands)
        ])
        self.to(cfg.device)
    
    def _make_masks(self, h: int, w: int) -> torch.Tensor:
        yy, xx = torch.meshgrid(torch.linspace(-1, 1, h), torch.linspace(-1, 1, w), indexing="ij")
        rr = torch.sqrt(yy ** 2 + xx ** 2).clamp(min=1e-6)
        edges = torch.logspace(-2, math.log10(math.sqrt(2)), steps=5)
        edges[0] = 0
        masks = [torch.sigmoid((rr - edges[i]) * 20) * torch.sigmoid((edges[i+1] - rr) * 20) for i in range(4)]
        masks = torch.stack(masks, dim=0)
        return masks / (masks.sum(dim=0, keepdim=True) + 1e-8)
    
    def forward(self, x: torch.Tensor, history: Optional[List[torch.Tensor]] = None
                ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        B = x.shape[0]
        fft = torch.fft.fft2(x.squeeze(1))
        fft_shifted = torch.fft.fftshift(fft)
        processed_bands, band_entropies = [], []
        for i in range(self.num_bands):
            mask = self.masks[i].unsqueeze(0).to(x.device)
            band_fft = fft_shifted * mask
            band_feat = torch.stack([band_fft.real, band_fft.imag], dim=1)
            proc = self.band_blocks[i](band_feat)
            processed_bands.append(band_feat + proc)
            band_entropies.append(torch.tensor(0.0, device=x.device))
        fft_recon = torch.zeros_like(fft_shifted)
        for i, band in enumerate(processed_bands):
            mask = self.masks[i].unsqueeze(0).to(x.device)
            fft_recon = fft_recon + torch.complex(band[:, 0], band[:, 1]) * mask
        fft_unshifted = torch.fft.ifftshift(fft_recon)
        spatial = torch.fft.ifft2(fft_unshifted).real
        return spatial.unsqueeze(1), {"entropy_per_band": torch.stack(band_entropies)}

## 3. Training

In [None]:
def train_predictor(
    model: nn.Module,
    train_data: Tuple[torch.Tensor, torch.Tensor],
    epochs: int,
    lr: float,
    device: str,
    use_differential_lr: bool = False,
    predict_delta: bool = True,
    batch_size: int = 32,
) -> Dict[str, List[float]]:
    """Train a predictor on next-frame prediction."""
    fields, next_fields = train_data
    fields = fields.to(device)
    next_fields = next_fields.to(device)
    
    if use_differential_lr and hasattr(model, 'get_lr_groups'):
        optimizer = optim.Adam(model.get_lr_groups(lr))
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr)
    
    metrics = {"loss": [], "epoch_time": []}
    num_samples = fields.shape[0]
    
    for epoch in range(epochs):
        epoch_start = time.time()
        epoch_losses = []
        history = []
        perm = torch.randperm(num_samples)
        
        for i in range(0, num_samples, batch_size):
            idx = perm[i:i + batch_size]
            x = fields[idx]
            y = next_fields[idx]
            
            pred, belief = model(x, history if history else None)
            
            if predict_delta:
                delta_target = y - x
                delta_pred = pred - x
                loss = F.mse_loss(delta_pred, delta_target)
            else:
                loss = F.mse_loss(pred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
            
            if "current_feat" in belief:
                history.append(belief["current_feat"])
                if len(history) > 8:
                    history = history[-8:]
        
        epoch_time = time.time() - epoch_start
        avg_loss = sum(epoch_losses) / len(epoch_losses)
        metrics["loss"].append(avg_loss)
        metrics["epoch_time"].append(epoch_time)
        print(f"  Epoch {epoch+1}/{epochs}: loss={avg_loss:.6f}, time={epoch_time:.2f}s")
    
    return metrics

## 4. Run Experiment

In [None]:
# Initialize environment
print("[1] Initializing turbulent plasma environment...")

if DIFFICULTY == "easy":
    plasma_cfg = PlasmaConfig.easy(device=DEVICE, size=GRID_SIZE)
elif DIFFICULTY == "hard":
    plasma_cfg = PlasmaConfig.hard(device=DEVICE, size=GRID_SIZE)
elif DIFFICULTY == "turbulent":
    plasma_cfg = PlasmaConfig.turbulent(device=DEVICE, size=GRID_SIZE)
else:
    plasma_cfg = PlasmaConfig.medium(device=DEVICE, size=GRID_SIZE)

print(f"    Diffusion: {plasma_cfg.diffusion}, Advection: {plasma_cfg.advection}")
print(f"    Noise: {plasma_cfg.noise_std}, Disturbance: {plasma_cfg.disturbance_prob}@{plasma_cfg.disturbance_strength}")
print(f"    Actuator sigma: {plasma_cfg.actuator_sigma:.1f} (scaled for {GRID_SIZE}x{GRID_SIZE})")
print(f"    Vortices: {plasma_cfg.num_vortices}, strength: {plasma_cfg.vortex_strength}")
print(f"    Shear: {plasma_cfg.shear_strength}, Multiscale noise: {plasma_cfg.multiscale_noise}")

env = TurbulentPlasmaEnv(plasma_cfg)

In [None]:
# Generate training data
print("[2] Generating training data...")
fields, next_fields, controls = generate_trajectories(
    env, NUM_TRAJECTORIES, TRAJECTORY_LENGTH,
    control_scale=0.1,
    prediction_horizon=PREDICTION_HORIZON
)
print(f"    Generated {fields.shape[0]} samples (predicting t+{PREDICTION_HORIZON})")

train_data = (fields, next_fields)

In [None]:
# Initialize models
print("[3] Initializing models...")

spectral_cfg = SpectralConfig(
    height=GRID_SIZE, width=GRID_SIZE, device=DEVICE,
    top_k=TOP_K_TEMPORAL, temporal_decay=TEMPORAL_DECAY
)
sbm = SpectralBeliefMachineV2(spectral_cfg)
print(f"    SpectralBeliefMachine v2: {sum(p.numel() for p in sbm.parameters()):,} params")
print(f"      Top-K: {spectral_cfg.top_k}, Decay: {spectral_cfg.temporal_decay}")

baseline_cfg = BaselineConfig(height=GRID_SIZE, width=GRID_SIZE, device=DEVICE)
flat = FlatBaseline(baseline_cfg)
four_band = FourBandBaseline(baseline_cfg)

print(f"    FlatBaseline: {sum(p.numel() for p in flat.parameters()):,} params")
print(f"    FourBandBaseline: {sum(p.numel() for p in four_band.parameters()):,} params")

# Verify GPU usage
print(f"\n[GPU CHECK]")
print(f"    Target device: {DEVICE}")
print(f"    SBM on: {next(sbm.parameters()).device}")
print(f"    Flat on: {next(flat.parameters()).device}")
print(f"    Training data on: {fields.device}")
if DEVICE == 'cuda' and torch.cuda.is_available():
    print(f"    GPU Memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated")

In [None]:
# Train SBM v2
print("\n" + "="*60)
print("Training SpectralBeliefMachine v2 (7+1 with Top-K)...")
print("="*60)

sbm_metrics = train_predictor(
    sbm, train_data, EPOCHS, lr=LR,
    device=DEVICE, use_differential_lr=True,
    predict_delta=PREDICT_DELTA, batch_size=BATCH_SIZE
)

In [None]:
# Train FlatBaseline
print("\n" + "="*60)
print("Training FlatBaseline...")
print("="*60)

flat_metrics = train_predictor(
    flat, train_data, EPOCHS, lr=LR,
    device=DEVICE, use_differential_lr=False,
    predict_delta=PREDICT_DELTA, batch_size=BATCH_SIZE
)

In [None]:
# Train FourBandBaseline
print("\n" + "="*60)
print("Training FourBandBaseline...")
print("="*60)

four_band_metrics = train_predictor(
    four_band, train_data, EPOCHS, lr=LR,
    device=DEVICE, use_differential_lr=False,
    predict_delta=PREDICT_DELTA, batch_size=BATCH_SIZE
)

## 5. Results

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax = axes[0]
ax.semilogy(sbm_metrics['loss'], label=f'SBM v2 (7+1 Top-K) - final: {sbm_metrics["loss"][-1]:.6f}')
ax.semilogy(flat_metrics['loss'], label=f'FlatBaseline - final: {flat_metrics["loss"][-1]:.6f}')
ax.semilogy(four_band_metrics['loss'], label=f'FourBandBaseline - final: {four_band_metrics["loss"][-1]:.6f}')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('Training Loss Curves (Turbulent Environment)')
ax.legend()
ax.grid(True, alpha=0.3)

# Training speed
ax = axes[1]
avg_times = [
    sum(sbm_metrics['epoch_time']) / len(sbm_metrics['epoch_time']),
    sum(flat_metrics['epoch_time']) / len(flat_metrics['epoch_time']),
    sum(four_band_metrics['epoch_time']) / len(four_band_metrics['epoch_time']),
]
bars = ax.bar(['SBM v2 (7+1)', 'FlatBaseline', 'FourBand'], avg_times)
ax.set_ylabel('Seconds per Epoch')
ax.set_title('Training Speed')
for bar, t in zip(bars, avg_times):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{t:.1f}s', 
            ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Final results
print("\n" + "="*60)
print("FINAL RESULTS (Turbulent Environment)")
print("="*60)
print(f"SpectralBeliefMachine v2 (7+1 Top-K): {sbm_metrics['loss'][-1]:.6f}")
print(f"FlatBaseline:                        {flat_metrics['loss'][-1]:.6f}")
print(f"FourBandBaseline:                    {four_band_metrics['loss'][-1]:.6f}")
print(f"")
print(f"SBM improvement over epochs: {sbm_metrics['loss'][0]:.4f} -> {sbm_metrics['loss'][-1]:.4f} ({sbm_metrics['loss'][0]/sbm_metrics['loss'][-1]:.1f}x)")
print(f"Flat improvement over epochs: {flat_metrics['loss'][0]:.4f} -> {flat_metrics['loss'][-1]:.4f} ({flat_metrics['loss'][0]/flat_metrics['loss'][-1]:.1f}x)")

In [None]:
# Visualize predictions
sbm.eval()
flat.eval()

with torch.no_grad():
    idx = 500
    x_sample = fields[idx:idx+1].to(DEVICE)
    y_sample = next_fields[idx:idx+1].to(DEVICE)
    
    sbm_pred, _ = sbm(x_sample, None)
    flat_pred, _ = flat(x_sample, None)
    
    sbm_error = (sbm_pred - y_sample).abs()
    flat_error = (flat_pred - y_sample).abs()

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

axes[0, 0].imshow(x_sample[0, 0].cpu(), cmap='viridis')
axes[0, 0].set_title('Input (t)')
axes[0, 0].axis('off')

axes[0, 1].imshow(y_sample[0, 0].cpu(), cmap='viridis')
axes[0, 1].set_title(f'Target (t+{PREDICTION_HORIZON})')
axes[0, 1].axis('off')

axes[0, 2].imshow(sbm_pred[0, 0].cpu(), cmap='viridis')
axes[0, 2].set_title('SBM v2 Prediction')
axes[0, 2].axis('off')

axes[0, 3].imshow(flat_pred[0, 0].cpu(), cmap='viridis')
axes[0, 3].set_title('Flat Prediction')
axes[0, 3].axis('off')

axes[1, 0].axis('off')
axes[1, 1].axis('off')

im1 = axes[1, 2].imshow(sbm_error[0, 0].cpu(), cmap='hot', vmin=0)
axes[1, 2].set_title(f'SBM Error (MSE: {F.mse_loss(sbm_pred, y_sample).item():.4f})')
axes[1, 2].axis('off')

im2 = axes[1, 3].imshow(flat_error[0, 0].cpu(), cmap='hot', vmin=0)
axes[1, 3].set_title(f'Flat Error (MSE: {F.mse_loss(flat_pred, y_sample).item():.4f})')
axes[1, 3].axis('off')

plt.tight_layout()
plt.show()