# Experiment 032 v2.3: Pre-FFT vs Post-FFT Control Conditioning Ablation

**Purpose**: Test whether control signal should be applied BEFORE or AFTER FFT transform.

**Key Insight from v2.2**: Post-FFT control conditioning destroys spatial locality.
When control features are concatenated AFTER the FFT transform, learned actuator maps
become identical concentric ring patterns instead of localized Gaussian blobs at
different spatial positions.

**Ablation Tests (5 models)**:
1. **SBM_PostFFT_Known**: Original flawed design with known maps (for comparison)
2. **SBM_PreFFT_Known**: Correct design with known maps - control applied BEFORE FFT
3. **SBM_PreFFT_Learned**: Correct design with learned maps - tests structure discovery
4. **Flat_Known**: Baseline ConvNet with known actuator physics
5. **Flat_Learned**: Baseline ConvNet learning actuator locations

**Architecture Comparison**:
- Post-FFT (Wrong): `x -> FFT -> bands -> [concat ctrl_feat] -> process -> iFFT -> output`
- Pre-FFT (Correct): `x + ctrl_spatial -> FFT -> bands -> process -> iFFT -> output`

**Key Hypothesis**: Pre-FFT control preserves spatial locality, enabling proper
localized actuator pattern learning.

**Run on Colab with GPU**: Runtime -> Change runtime type -> A100

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GRID_SIZE = 64
EPOCHS = 50
DIFFICULTY = "turbulent"
PREDICT_DELTA = True
PREDICTION_HORIZON = 5
NUM_TRAJECTORIES = 100
TRAJECTORY_LENGTH = 100
BATCH_SIZE = 32
LR = 0.001

# Attention
TOP_K_TEMPORAL = 4
TEMPORAL_DECAY = 0.9

# Control
NUM_ACTUATORS = 9
CONTROL_EMBED_DIM = 32

print(f"Device: {DEVICE}")
print(f"Grid: {GRID_SIZE}x{GRID_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Difficulty: {DIFFICULTY}")
print(f"Prediction: delta t+{PREDICTION_HORIZON}")
print(f"Actuators: {NUM_ACTUATORS} (3x3 grid)")

## 1. Environment

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional
import math
import time

@dataclass
class PlasmaConfig:
    height: int = 64
    width: int = 64
    diffusion: float = 0.25
    advection: float = 0.08
    noise_std: float = 0.02
    disturbance_prob: float = 0.1
    disturbance_strength: float = 0.15
    num_vortices: int = 3
    vortex_strength: float = 0.1
    shear_strength: float = 0.05
    multiscale_noise: bool = True
    num_actuators: int = 9
    _base_actuator_sigma: float = 5.0
    device: str = "cpu"
    dtype: torch.dtype = torch.float32
    
    @property
    def actuator_sigma(self) -> float:
        return self._base_actuator_sigma * min(self.height, self.width) / 64.0
    
    @classmethod
    def turbulent(cls, device: str = "cpu", size: int = 64):
        return cls(height=size, width=size, diffusion=0.3, advection=0.12, noise_std=0.03,
                   disturbance_prob=0.25, disturbance_strength=0.3,
                   num_vortices=3, vortex_strength=0.15, shear_strength=0.08,
                   multiscale_noise=True, num_actuators=9, device=device)


class TurbulentPlasmaEnv:
    def __init__(self, cfg: PlasmaConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        self.dtype = cfg.dtype
        self._actuator_maps = self._build_actuator_maps()
        self._vortex_flow = self._build_vortex_flow()
        self._shear_flow = self._build_shear_flow()
    
    def _build_actuator_maps(self) -> torch.Tensor:
        h, w = self.cfg.height, self.cfg.width
        grid_n = int(math.ceil(math.sqrt(self.cfg.num_actuators)))
        centers_y = torch.linspace(h * 0.2, h * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers_x = torch.linspace(w * 0.2, w * 0.8, grid_n, device=self.device, dtype=self.dtype)
        centers = torch.cartesian_prod(centers_y, centers_x)[:self.cfg.num_actuators]
        sig2 = self.cfg.actuator_sigma ** 2
        yy, xx = torch.meshgrid(
            torch.arange(h, device=self.device, dtype=self.dtype),
            torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        bumps = [torch.exp(-((yy - cy) ** 2 + (xx - cx) ** 2) / (2 * sig2)) for cy, cx in centers]
        return torch.stack(bumps, dim=0)
    
    def _build_vortex_flow(self):
        if self.cfg.num_vortices == 0: return None, None
        h, w = self.cfg.height, self.cfg.width
        centers_y = torch.rand(self.cfg.num_vortices, device=self.device) * (h * 0.6) + (h * 0.2)
        centers_x = torch.rand(self.cfg.num_vortices, device=self.device) * (w * 0.6) + (w * 0.2)
        yy, xx = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                                torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        vy, vx = torch.zeros_like(yy), torch.zeros_like(xx)
        scale = min(h, w) / 64.0
        for i in range(self.cfg.num_vortices):
            dy, dx = yy - centers_y[i], xx - centers_x[i]
            r2 = dy**2 + dx**2 + 1e-6
            decay = torch.exp(-r2 / (2 * (10 * scale)**2))
            sign = 1 if i % 2 == 0 else -1
            vy += sign * self.cfg.vortex_strength * (-dx) / torch.sqrt(r2) * decay
            vx += sign * self.cfg.vortex_strength * dy / torch.sqrt(r2) * decay
        return vy, vx
    
    def _build_shear_flow(self):
        if self.cfg.shear_strength == 0: return None, None
        h, w = self.cfg.height, self.cfg.width
        yy, _ = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                               torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        return torch.zeros_like(yy), self.cfg.shear_strength * torch.sin(2 * math.pi * yy / h)
    
    def _apply_flow(self, field, vy, vx):
        B, C, H, W = field.shape
        yy, xx = torch.meshgrid(torch.linspace(-1, 1, H, device=self.device),
                                torch.linspace(-1, 1, W, device=self.device), indexing="ij")
        grid = torch.stack([xx - vx/(W/2), yy - vy/(H/2)], dim=-1).unsqueeze(0).expand(B, -1, -1, -1)
        return F.grid_sample(field, grid, mode='bilinear', padding_mode='border', align_corners=True)
    
    def _multiscale_noise(self, shape):
        B, C, H, W = shape
        noise = torch.zeros(shape, device=self.device, dtype=self.dtype)
        for scale in [1, 2, 4, 8]:
            hs, ws = H // scale, W // scale
            if hs < 4: continue
            coarse = torch.randn(B, C, hs, ws, device=self.device, dtype=self.dtype)
            noise += F.interpolate(coarse, size=(H, W), mode='bilinear', align_corners=False) * (self.cfg.noise_std / scale)
        return noise
    
    def reset(self, batch_size=1):
        h, w = self.cfg.height, self.cfg.width
        cx = torch.randint(int(w*0.3), int(w*0.7), (batch_size,), device=self.device)
        cy = torch.randint(int(h*0.3), int(h*0.7), (batch_size,), device=self.device)
        yy, xx = torch.meshgrid(torch.arange(h, device=self.device, dtype=self.dtype),
                                torch.arange(w, device=self.device, dtype=self.dtype), indexing="ij")
        sig2 = (self.cfg.actuator_sigma * 1.5) ** 2
        field = torch.zeros(batch_size, 1, h, w, device=self.device, dtype=self.dtype)
        for b in range(batch_size):
            field[b, 0] = torch.exp(-((yy - cy[b].float())**2 + (xx - cx[b].float())**2) / (2*sig2))
        return field
    
    def step(self, field, control, noise=True):
        B = field.shape[0]
        # Diffusion
        lap = (F.pad(field, (0,0,1,0))[:,:,:-1,:] + F.pad(field, (0,0,0,1))[:,:,1:,:] +
               F.pad(field, (1,0,0,0))[:,:,:,:-1] + F.pad(field, (0,1,0,0))[:,:,:,1:]) - 4*field
        diffused = field + self.cfg.diffusion * lap
        # Advection
        advected = torch.roll(diffused, shifts=(1,-1), dims=(2,3)) * self.cfg.advection + diffused * (1 - self.cfg.advection)
        if self._vortex_flow[0] is not None:
            advected = self._apply_flow(advected, self._vortex_flow[0], self._vortex_flow[1])
        if self._shear_flow[0] is not None:
            advected = self._apply_flow(advected, self._shear_flow[0], self._shear_flow[1])
        # Actuators
        force = torch.einsum('ba,ahw->bhw', control, self._actuator_maps).unsqueeze(1)
        next_field = advected + force
        # Noise
        if noise:
            next_field = next_field + (self._multiscale_noise(next_field.shape) if self.cfg.multiscale_noise 
                                       else torch.randn_like(next_field) * self.cfg.noise_std)
        return torch.clamp(next_field, -1, 1)
    
    def get_actuator_maps(self):
        """Return actuator Gaussian maps for known control encoding."""
        return self._actuator_maps.clone()


def generate_trajectories(env, num_traj, traj_len, control_scale=0.1, horizon=1):
    """Generate data with control sequences."""
    all_fields, all_next, all_ctrl = [], [], []
    for _ in range(num_traj):
        field = env.reset(1)
        traj, ctrls = [field.clone()], []
        for _ in range(traj_len + horizon):
            ctrl = torch.clamp(torch.randn(1, env.cfg.num_actuators, device=env.device) * control_scale, -1, 1)
            traj.append(env.step(field, ctrl).clone())
            ctrls.append(ctrl)
            field = traj[-1].detach()
        for t in range(traj_len):
            all_fields.append(traj[t])
            all_next.append(traj[t + horizon])
            all_ctrl.append(torch.stack([ctrls[t+k] for k in range(horizon)], dim=1).squeeze(0))
    return torch.cat(all_fields), torch.cat(all_next), torch.stack(all_ctrl)

print("Environment defined.")

## 2. Control Encoders

Two types of control encoders:
- **Post-FFT encoders**: Output `embed_dim` channels for concatenation with FFT bands
- **Pre-FFT (Spatial) encoders**: Output 1 channel for spatial addition before FFT

In [None]:
# ============================================================================
# POST-FFT CONTROL ENCODERS (for SBMPostFFT - the flawed design)
# Output: [B, embed_dim, H, W] for concatenation with FFT bands
# ============================================================================

class ControlEncoderKnownPostFFT(nn.Module):
    """Known actuator maps, outputs embed_dim channels for Post-FFT concat."""
    def __init__(self, actuator_maps: torch.Tensor, horizon: int, embed_dim: int):
        super().__init__()
        self.register_buffer('actuator_maps', actuator_maps)
        self.temporal_enc = nn.Sequential(
            nn.Linear(horizon, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim),
        )
    
    def forward(self, ctrl_seq):
        # ctrl_seq: [B, horizon, A] -> [B, embed_dim, H, W]
        ctrl_per_act = ctrl_seq.transpose(1, 2)
        ctrl_enc = self.temporal_enc(ctrl_per_act)
        return torch.einsum('bae,ahw->behw', ctrl_enc, self.actuator_maps)


# ============================================================================
# PRE-FFT CONTROL ENCODERS (for SBMPreFFT - the correct design)
# Output: [B, 1, H, W] for spatial addition before FFT
# ============================================================================

class ControlEncoderKnownPreFFT(nn.Module):
    """Known actuator maps, outputs 1 channel for Pre-FFT spatial addition.
    
    Preserves spatial locality by applying control in spatial domain.
    """
    def __init__(self, actuator_maps: torch.Tensor, horizon: int):
        super().__init__()
        self.register_buffer('actuator_maps', actuator_maps)
        self.temporal_enc = nn.Sequential(
            nn.Linear(horizon, 16),
            nn.GELU(),
            nn.Linear(16, 1),
        )
    
    def forward(self, ctrl_seq):
        # ctrl_seq: [B, horizon, A] -> [B, 1, H, W]
        ctrl_per_act = ctrl_seq.transpose(1, 2)  # [B, A, horizon]
        weights = self.temporal_enc(ctrl_per_act).squeeze(-1)  # [B, A]
        spatial = torch.einsum('ba,ahw->bhw', weights, self.actuator_maps)
        return spatial.unsqueeze(1)


class ControlEncoderLearnedPreFFT(nn.Module):
    """Learned actuator maps, outputs 1 channel for Pre-FFT spatial addition.
    
    Model discovers actuator locations while preserving spatial locality.
    """
    def __init__(self, num_actuators: int, horizon: int, height: int, width: int):
        super().__init__()
        self.actuator_patterns = nn.Parameter(torch.randn(num_actuators, height, width) * 0.01)
        self.temporal_enc = nn.Sequential(
            nn.Linear(horizon, 16),
            nn.GELU(),
            nn.Linear(16, 1),
        )
    
    def forward(self, ctrl_seq):
        ctrl_per_act = ctrl_seq.transpose(1, 2)
        weights = self.temporal_enc(ctrl_per_act).squeeze(-1)
        spatial = torch.einsum('ba,ahw->bhw', weights, self.actuator_patterns)
        return spatial.unsqueeze(1)
    
    def get_learned_maps(self):
        return self.actuator_patterns.detach()


# ============================================================================
# FLAT MODEL CONTROL ENCODERS (multi-channel for ConvNet input)
# ============================================================================

class ControlEncoderKnownFlat(nn.Module):
    """Known actuator maps for Flat ConvNet baseline."""
    def __init__(self, actuator_maps: torch.Tensor, horizon: int, embed_dim: int):
        super().__init__()
        self.register_buffer('actuator_maps', actuator_maps)
        self.temporal_enc = nn.Sequential(
            nn.Linear(horizon, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim),
        )
    
    def forward(self, ctrl_seq):
        ctrl_per_act = ctrl_seq.transpose(1, 2)
        ctrl_enc = self.temporal_enc(ctrl_per_act)
        return torch.einsum('bae,ahw->behw', ctrl_enc, self.actuator_maps)


class ControlEncoderLearnedFlat(nn.Module):
    """Learned actuator maps for Flat ConvNet baseline."""
    def __init__(self, num_actuators: int, horizon: int, embed_dim: int, height: int, width: int):
        super().__init__()
        self.actuator_patterns = nn.Parameter(torch.randn(num_actuators, height, width) * 0.01)
        self.temporal_enc = nn.Sequential(
            nn.Linear(horizon, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim),
        )
    
    def forward(self, ctrl_seq):
        ctrl_per_act = ctrl_seq.transpose(1, 2)
        ctrl_enc = self.temporal_enc(ctrl_per_act)
        return torch.einsum('bae,ahw->behw', ctrl_enc, self.actuator_patterns)
    
    def get_learned_maps(self):
        return self.actuator_patterns.detach()

print("Control encoders defined.")

## 3. Models

In [None]:
@dataclass
class SpectralConfig:
    height: int = 64
    width: int = 64
    num_bands: int = 7
    channels: int = 16
    history_len: int = 8
    num_heads: int = 4
    top_k: int = 4
    decay: float = 0.9
    num_actuators: int = 9
    horizon: int = 5
    ctrl_embed: int = 32
    device: str = "cpu"


def make_radial_masks(h, w, num_bands, device):
    yy, xx = torch.meshgrid(torch.linspace(-1, 1, h, device=device),
                            torch.linspace(-1, 1, w, device=device), indexing="ij")
    rr = torch.sqrt(yy**2 + xx**2).clamp(min=1e-6)
    edges = torch.logspace(-3, math.log10(math.sqrt(2)), num_bands+1, device=device)
    edges[0] = 0
    masks = [torch.sigmoid((rr - edges[i])*20) * torch.sigmoid((edges[i+1] - rr)*20) for i in range(num_bands)]
    masks = torch.stack(masks)
    return masks / (masks.sum(0, keepdim=True) + 1e-8)


class TopKTemporal(nn.Module):
    def __init__(self, dim, heads, max_len, top_k=4, decay=0.9):
        super().__init__()
        self.heads, self.head_dim, self.top_k, self.decay = heads, dim//heads, top_k, decay
        self.qkv = nn.Linear(dim, 3*dim)
        self.out = nn.Linear(dim, dim)
        self.register_buffer("mask", torch.triu(torch.ones(max_len, max_len), 1).bool())
    
    def forward(self, x):
        B, T, D = x.shape
        qkv = self.qkv(x).view(B, T, 3, self.heads, self.head_dim).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        scores = (Q @ K.transpose(-2,-1)) / math.sqrt(self.head_dim)
        scores = scores.masked_fill(self.mask[:T,:T].unsqueeze(0).unsqueeze(0), float('-inf'))
        if T > self.top_k:
            _, topk_idx = torch.topk(scores, self.top_k, dim=-1)
            mask = torch.ones_like(scores, dtype=torch.bool).scatter_(-1, topk_idx, False)
            scores = scores.masked_fill(mask, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        out = (attn @ V).transpose(1, 2).reshape(B, T, D)
        return self.out(out), attn.mean(1).sum(-1).mean(-1)


# ============================================================================
# SBMPostFFT: FLAWED DESIGN - Control applied AFTER FFT (kept for comparison)
# ============================================================================

class SBMPostFFT(nn.Module):
    """Spectral Belief Machine with POST-FFT control conditioning.
    
    FLAWED DESIGN - kept for ablation comparison.
    
    Problem: Control features concatenated AFTER FFT destroys spatial locality.
    Learned actuator maps become identical concentric rings instead of
    localized Gaussian blobs at different positions.
    
    Architecture: x -> FFT -> bands -> [concat ctrl_feat] -> process -> iFFT -> output
    """
    
    def __init__(self, cfg: SpectralConfig, actuator_maps: torch.Tensor):
        super().__init__()
        self.cfg = cfg
        
        # Post-FFT control encoder (multi-channel output)
        self.ctrl_enc = ControlEncoderKnownPostFFT(actuator_maps, cfg.horizon, cfg.ctrl_embed)
        
        # Spectral
        self.register_buffer("masks", make_radial_masks(cfg.height, cfg.width, cfg.num_bands, torch.device(cfg.device)))
        self.register_buffer("window", torch.ones(cfg.height, cfg.width, device=cfg.device))
        
        # Band blocks: input = 2 (real/imag) + ctrl_embed
        self.bands = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(2 + cfg.ctrl_embed, cfg.channels, 3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, cfg.channels, 3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, 2, 1),
            ) for _ in range(cfg.num_bands)
        ])
        
        # Temporal
        tdim = cfg.num_bands * 2 * 4
        self.t_in = nn.Linear(tdim, cfg.channels * 8)
        self.temporal = TopKTemporal(cfg.channels * 8, cfg.num_heads, cfg.history_len + 1, cfg.top_k, cfg.decay)
        self.t_out = nn.Linear(cfg.channels * 8, tdim)
        self.to(cfg.device)
    
    def forward(self, x, ctrl, history=None):
        B = x.shape[0]
        ctrl_feat = self.ctrl_enc(ctrl)  # [B, embed, H, W]
        
        # FFT decompose
        fft = torch.fft.fftshift(torch.fft.fft2(x.squeeze(1) * self.window))
        proc_bands = []
        for i in range(self.cfg.num_bands):
            band = fft * self.masks[i].unsqueeze(0)
            band_feat = torch.stack([band.real, band.imag], dim=1)
            # POST-FFT: concat control features with band features
            band_with_ctrl = torch.cat([band_feat, ctrl_feat], dim=1)
            proc = self.bands[i](band_with_ctrl)
            proc_bands.append(band_feat + proc)
        
        # Pool for temporal
        pooled = torch.cat([F.adaptive_avg_pool2d(b, (2,2)).flatten(1) for b in proc_bands], dim=1)
        
        # Temporal attention
        if history and len(history) > 0:
            history = [h for h in history if h.shape[0] == B and h.shape[1] == pooled.shape[1]]
            if history:
                seq = torch.stack(history + [pooled], dim=1)
                t_out, _ = self.temporal(self.t_in(seq))
                t_feat = self.t_out(t_out[:, -1, :])
                chunk = pooled.shape[1] // self.cfg.num_bands
                for i in range(self.cfg.num_bands):
                    delta = t_feat[:, i*chunk:(i+1)*chunk].view(B, 2, 2, 2)
                    delta_up = delta.repeat_interleave(self.cfg.height//2, 2).repeat_interleave(self.cfg.width//2, 3)
                    proc_bands[i] = proc_bands[i] + 0.1 * delta_up
        
        # Reconstruct
        recon = sum(torch.complex(b[:,0], b[:,1]) * self.masks[i].unsqueeze(0) for i, b in enumerate(proc_bands))
        pred = torch.fft.ifft2(torch.fft.ifftshift(recon)).real.unsqueeze(1)
        
        return pred, {"feat": pooled.detach()}


# ============================================================================
# SBMPreFFT: CORRECT DESIGN - Control applied BEFORE FFT
# ============================================================================

class SBMPreFFT(nn.Module):
    """Spectral Belief Machine with PRE-FFT control conditioning.
    
    CORRECT DESIGN - preserves spatial locality.
    
    Control is applied spatially BEFORE the FFT transform.
    This allows proper localized actuator pattern learning because
    spatial locality is preserved before the frequency transform.
    
    Architecture: x + ctrl_spatial -> FFT -> bands -> process -> iFFT -> output
    """
    
    def __init__(self, cfg: SpectralConfig, actuator_maps=None, mode='known'):
        super().__init__()
        self.cfg = cfg
        self.mode = mode
        
        # Pre-FFT control encoder (1-channel spatial output)
        if mode == 'known':
            assert actuator_maps is not None, "actuator_maps required for known mode"
            self.ctrl_enc = ControlEncoderKnownPreFFT(actuator_maps, cfg.horizon)
        else:
            self.ctrl_enc = ControlEncoderLearnedPreFFT(cfg.num_actuators, cfg.horizon, cfg.height, cfg.width)
        
        # Spectral
        self.register_buffer("masks", make_radial_masks(cfg.height, cfg.width, cfg.num_bands, torch.device(cfg.device)))
        self.register_buffer("window", torch.ones(cfg.height, cfg.width, device=cfg.device))
        
        # Band blocks: input = 2 (real/imag) only - NO ctrl_embed
        self.bands = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(2, cfg.channels, 3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, cfg.channels, 3, padding=1), nn.GELU(),
                nn.Conv2d(cfg.channels, 2, 1),
            ) for _ in range(cfg.num_bands)
        ])
        
        # Temporal
        tdim = cfg.num_bands * 2 * 4
        self.t_in = nn.Linear(tdim, cfg.channels * 8)
        self.temporal = TopKTemporal(cfg.channels * 8, cfg.num_heads, cfg.history_len + 1, cfg.top_k, cfg.decay)
        self.t_out = nn.Linear(cfg.channels * 8, tdim)
        self.to(cfg.device)
    
    def forward(self, x, ctrl, history=None):
        B = x.shape[0]
        
        # PRE-FFT: Apply control spatially BEFORE FFT
        ctrl_spatial = self.ctrl_enc(ctrl)  # [B, 1, H, W]
        x_conditioned = x + ctrl_spatial
        
        # FFT decompose on conditioned input
        fft = torch.fft.fftshift(torch.fft.fft2(x_conditioned.squeeze(1) * self.window))
        proc_bands = []
        for i in range(self.cfg.num_bands):
            band = fft * self.masks[i].unsqueeze(0)
            band_feat = torch.stack([band.real, band.imag], dim=1)
            # No control concat here - just process the band
            proc = self.bands[i](band_feat)
            proc_bands.append(band_feat + proc)
        
        # Pool for temporal
        pooled = torch.cat([F.adaptive_avg_pool2d(b, (2,2)).flatten(1) for b in proc_bands], dim=1)
        
        # Temporal attention
        if history and len(history) > 0:
            history = [h for h in history if h.shape[0] == B and h.shape[1] == pooled.shape[1]]
            if history:
                seq = torch.stack(history + [pooled], dim=1)
                t_out, _ = self.temporal(self.t_in(seq))
                t_feat = self.t_out(t_out[:, -1, :])
                chunk = pooled.shape[1] // self.cfg.num_bands
                for i in range(self.cfg.num_bands):
                    delta = t_feat[:, i*chunk:(i+1)*chunk].view(B, 2, 2, 2)
                    delta_up = delta.repeat_interleave(self.cfg.height//2, 2).repeat_interleave(self.cfg.width//2, 3)
                    proc_bands[i] = proc_bands[i] + 0.1 * delta_up
        
        # Reconstruct
        recon = sum(torch.complex(b[:,0], b[:,1]) * self.masks[i].unsqueeze(0) for i, b in enumerate(proc_bands))
        pred = torch.fft.ifft2(torch.fft.ifftshift(recon)).real.unsqueeze(1)
        
        return pred, {"feat": pooled.detach()}


# ============================================================================
# FlatWithControl: Baseline ConvNet
# ============================================================================

class FlatWithControl(nn.Module):
    """Flat ConvNet baseline with control conditioning."""
    
    def __init__(self, height, width, channels, num_actuators, horizon, ctrl_embed, actuator_maps=None, mode='known', device='cpu'):
        super().__init__()
        self.mode = mode
        
        if mode == 'known':
            self.ctrl_enc = ControlEncoderKnownFlat(actuator_maps, horizon, ctrl_embed)
        else:
            self.ctrl_enc = ControlEncoderLearnedFlat(num_actuators, horizon, ctrl_embed, height, width)
        
        self.net = nn.Sequential(
            nn.Conv2d(1 + ctrl_embed, channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels, channels*2, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels*2, channels*2, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels*2, channels, 3, padding=1), nn.GELU(),
            nn.Conv2d(channels, 1, 3, padding=1),
        )
        self.to(device)
    
    def forward(self, x, ctrl, history=None):
        ctrl_feat = self.ctrl_enc(ctrl)
        return self.net(torch.cat([x, ctrl_feat], dim=1)), {}

print("Models defined.")

## 4. Training

In [None]:
def train(model, data, epochs, lr, device, delta=True, batch=32):
    fields, targets, ctrls = [d.to(device) for d in data]
    opt = optim.Adam(model.parameters(), lr=lr)
    losses = []
    
    for ep in range(epochs):
        t0 = time.time()
        ep_loss, history = [], []
        perm = torch.randperm(len(fields))
        
        for i in range(0, len(fields), batch):
            idx = perm[i:i+batch]
            x, y, c = fields[idx], targets[idx], ctrls[idx]
            
            pred, info = model(x, c, history if history else None)
            
            loss = F.mse_loss(pred - x, y - x) if delta else F.mse_loss(pred, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            ep_loss.append(loss.item())
            
            if "feat" in info:
                history.append(info["feat"])
                history = history[-8:]
        
        avg = sum(ep_loss) / len(ep_loss)
        losses.append(avg)
        print(f"  Ep {ep+1}/{epochs}: loss={avg:.6f}, time={time.time()-t0:.1f}s")
    
    return losses

print("Training function defined.")

## 5. Run Ablation

In [None]:
# Setup
print("[1] Creating environment...")
plasma_cfg = PlasmaConfig.turbulent(device=DEVICE, size=GRID_SIZE)
env = TurbulentPlasmaEnv(plasma_cfg)
actuator_maps = env.get_actuator_maps()
print(f"    Actuator maps shape: {actuator_maps.shape}")

print("\n[2] Generating data...")
data = generate_trajectories(env, NUM_TRAJECTORIES, TRAJECTORY_LENGTH, 0.1, PREDICTION_HORIZON)
print(f"    Fields: {data[0].shape}, Controls: {data[2].shape}")

In [None]:
# Create all 5 models for ablation
print("\n[3] Creating models...")

cfg = SpectralConfig(height=GRID_SIZE, width=GRID_SIZE, horizon=PREDICTION_HORIZON, 
                     ctrl_embed=CONTROL_EMBED_DIM, device=DEVICE)

models = {
    # Post-FFT (flawed design) - for comparison
    'SBM_PostFFT_Known': SBMPostFFT(cfg, actuator_maps),
    
    # Pre-FFT (correct design) - known maps
    'SBM_PreFFT_Known': SBMPreFFT(cfg, actuator_maps, mode='known'),
    
    # Pre-FFT (correct design) - learned maps
    'SBM_PreFFT_Learned': SBMPreFFT(cfg, None, mode='learned'),
    
    # Flat baselines
    'Flat_Known': FlatWithControl(GRID_SIZE, GRID_SIZE, 32, NUM_ACTUATORS, PREDICTION_HORIZON, 
                                   CONTROL_EMBED_DIM, actuator_maps, 'known', DEVICE),
    'Flat_Learned': FlatWithControl(GRID_SIZE, GRID_SIZE, 32, NUM_ACTUATORS, PREDICTION_HORIZON,
                                     CONTROL_EMBED_DIM, None, 'learned', DEVICE),
}

for name, model in models.items():
    params = sum(p.numel() for p in model.parameters())
    print(f"    {name}: {params:,} params")

In [None]:
# Train all models
results = {}

for name, model in models.items():
    print(f"\n{'='*60}")
    print(f"Training {name}...")
    print(f"{'='*60}")
    results[name] = train(model, data, EPOCHS, LR, DEVICE, PREDICT_DELTA, BATCH_SIZE)

## 6. Results

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax = axes[0]
colors = {'SBM_PostFFT_Known': 'red', 'SBM_PreFFT_Known': 'blue', 
          'SBM_PreFFT_Learned': 'lightblue', 'Flat_Known': 'green', 'Flat_Learned': 'lightgreen'}
for name, losses in results.items():
    ax.semilogy(losses, label=f"{name}: {losses[-1]:.6f}", color=colors.get(name, 'gray'))
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss (MSE)')
ax.set_title('Pre-FFT vs Post-FFT Control Conditioning Ablation')
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)

# Final comparison bar chart
ax = axes[1]
names = list(results.keys())
finals = [results[n][-1] for n in names]
bar_colors = [colors.get(n, 'gray') for n in names]
bars = ax.bar(range(len(names)), finals, color=bar_colors)
ax.set_xticks(range(len(names)))
ax.set_xticklabels([n.replace('_', '\n') for n in names], fontsize=8)
ax.set_ylabel('Final Loss')
ax.set_title('Final Loss Comparison')
for bar, val in zip(bars, finals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(), f'{val:.5f}', 
            ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("ABLATION RESULTS: Pre-FFT vs Post-FFT Control Conditioning")
print("="*70)
for name, losses in results.items():
    improvement = losses[0] / losses[-1]
    print(f"{name:20s}: {losses[-1]:.6f} ({improvement:.1f}x improvement)")

In [None]:
# Analysis
print("\n" + "="*70)
print("KEY COMPARISONS")
print("="*70)

post_known = results['SBM_PostFFT_Known'][-1]
pre_known = results['SBM_PreFFT_Known'][-1]
pre_learned = results['SBM_PreFFT_Learned'][-1]
flat_known = results['Flat_Known'][-1]
flat_learned = results['Flat_Learned'][-1]

print(f"\n1. POST-FFT vs PRE-FFT (both with known maps):")
print(f"   PostFFT: {post_known:.6f}")
print(f"   PreFFT:  {pre_known:.6f}")
diff = (post_known - pre_known) / post_known * 100
print(f"   Difference: {diff:+.1f}% {'(PreFFT better)' if diff > 0 else '(PostFFT better)'}")

print(f"\n2. KNOWN vs LEARNED (Pre-FFT design):")
print(f"   Known:   {pre_known:.6f}")
print(f"   Learned: {pre_learned:.6f}")
diff = (pre_learned - pre_known) / pre_known * 100
print(f"   Difference: {diff:+.1f}%")

print(f"\n3. SBM vs FLAT (both known):")
print(f"   SBM PreFFT: {pre_known:.6f}")
print(f"   Flat:       {flat_known:.6f}")
diff = (pre_known - flat_known) / flat_known * 100
print(f"   Difference: {diff:+.1f}%")

print(f"\n4. SBM vs FLAT (both learned):")
print(f"   SBM PreFFT: {pre_learned:.6f}")
print(f"   Flat:       {flat_learned:.6f}")
diff = (pre_learned - flat_learned) / flat_learned * 100
print(f"   Difference: {diff:+.1f}%")

In [None]:
# Visualize learned actuator maps
print("\nComparing Known vs Learned Actuator Maps (Pre-FFT SBM):")

fig, axes = plt.subplots(2, 9, figsize=(18, 4))

# Known maps (from environment)
for i in range(9):
    axes[0, i].imshow(actuator_maps[i].cpu(), cmap='viridis')
    axes[0, i].set_title(f'Known {i}')
    axes[0, i].axis('off')

# Learned maps (from SBM_PreFFT_Learned)
learned_maps = models['SBM_PreFFT_Learned'].ctrl_enc.get_learned_maps().cpu()
for i in range(9):
    axes[1, i].imshow(learned_maps[i], cmap='viridis')
    axes[1, i].set_title(f'Learned {i}')
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('Known\n(Physics)', fontsize=12)
axes[1, 0].set_ylabel('Learned\n(PreFFT SBM)', fontsize=12)

plt.suptitle('Pre-FFT Control: Spatial Locality Preserved?', fontsize=14)
plt.tight_layout()
plt.show()

print("\nKey question: Are learned maps localized Gaussians at different positions?")
print("(vs identical concentric rings as seen with Post-FFT conditioning)")

In [None]:
# Also show Flat_Learned maps for comparison
print("\nFlat Baseline Learned Maps (for comparison):")

fig, axes = plt.subplots(1, 9, figsize=(18, 2))

flat_learned_maps = models['Flat_Learned'].ctrl_enc.get_learned_maps().cpu()
for i in range(9):
    axes[i].imshow(flat_learned_maps[i], cmap='viridis')
    axes[i].set_title(f'Flat Learned {i}')
    axes[i].axis('off')

plt.suptitle('Flat ConvNet Learned Actuator Maps', fontsize=14)
plt.tight_layout()
plt.show()