In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from typing import Optional, Dict, Tuple, List

# ==========================================
# 1. DATASET LOADING & DATALOADER CREATION
# ==========================================
# -- Import your custom composite dataset class
from src.dataset.composite_dataset import CompositeDataset  # Update if your path differs

# -- Define file paths (set these to your actual data locations)
gim_parquet_file = '/path/to/gim_data.parquet'
celestrak_data_file = '/path/to/celestrak_data.parquet'
solar_index_data_file = '/path/to/solar_indices.parquet'
omniweb_dir = '/path/to/omniweb_dir'
date_start = None   # Use datetimes or None for full range
date_end = None
normalize = True

# -- Create the full dataset and split into training/validation sets
full_dataset = CompositeDataset(
    gim_parquet_file=gim_parquet_file,
    celestrak_data_file=celestrak_data_file,
    solar_index_data_file=solar_index_data_file,
    omniweb_dir=omniweb_dir,
    date_start=date_start,
    date_end=date_end,
    normalize=normalize
)
n_total = len(full_dataset)
n_train = int(0.9 * n_total)
n_val = n_total - n_train
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [n_train, n_val])

# -- Collate function: Batches nested dicts to batched tensors
def multitask_collate(batch: List[Dict]):
    out = {}
    for key in batch[0]:
        if isinstance(batch[0][key], dict):
            out[key] = {subkey: torch.stack([item[key][subkey] for item in batch]) for subkey in batch[0][key]}
        else:
            out[key] = torch.stack([item[key] for item in batch])
    return out

# -- Dataloaders
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=multitask_collate)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=multitask_collate)

# -- Optional: Sanity check the loader shapes
for batch in train_loader:
    print("Batch keys:", batch.keys())
    print("GIM input shape:", batch['gim']['input'].shape)
    print("GIM target shape:", batch['gim']['target'].shape)
    print("Solar index input shape:", batch['solar_index']['input'].shape)
    print("OMNI input shape:", batch['omniweb']['input'].shape)
    print("Celestrak input shape:", batch['celestrak']['input'].shape)
    break

# ==========================================
# 2. MODEL DEFINITION: SPHERICAL FOURIER FNO
# ==========================================
class SphericalFourierLayer(nn.Module):
    """
    Spectral convolution over 2D lat/lon grid.
    """
    def __init__(self, in_channels, out_channels, modes_lat, modes_lon):
        super().__init__()
        self.weight = nn.Parameter(
            torch.randn(in_channels, out_channels, modes_lat, modes_lon, 2) * 0.01
        )

    def compl_mul2d(self, x, w):
        w = torch.view_as_complex(w)
        return torch.einsum("bchw,cohw->bohw", x, w)

    def forward(self, x):
        B, C, H, W = x.shape
        x_ft = torch.fft.rfft2(x, norm="ortho")
        x_ft = x_ft[:, :, :self.weight.shape[2], :self.weight.shape[3]]
        out_ft = self.compl_mul2d(x_ft, self.weight)
        out_ft_full = torch.zeros(B, self.weight.shape[1], H, W//2+1, dtype=torch.cfloat, device=x.device)
        out_ft_full[:, :, :self.weight.shape[2], :self.weight.shape[3]] = out_ft
        return torch.fft.irfft2(out_ft_full, s=(H, W), norm="ortho")

class MCDropout(nn.Dropout):
    """
    Dropout that can be forced ON at eval time for MC uncertainty.
    """
    def __init__(self, p=0.1):
        super().__init__(p)
        self.mc = False

    def forward(self, x):
        return F.dropout(x, self.p, training=self.training or self.mc)

class SFNOBlock(nn.Module):
    """
    A single block: Fourier layer + local 1x1 conv + MC Dropout + LayerNorm.
    """
    def __init__(self, width, modes_lat, modes_lon, dropout=0.0, mc_dropout=False):
        super().__init__()
        self.fourier = SphericalFourierLayer(width, width, modes_lat, modes_lon)
        self.linear = nn.Conv2d(width, width, 1)
        self.dropout = MCDropout(dropout)
        self.dropout.mc = mc_dropout
        self.norm = nn.LayerNorm([width])

    def set_mc_dropout(self, mc=True):
        self.dropout.mc = mc

    def forward(self, x):
        x1 = self.fourier(x)
        x2 = self.linear(x)
        out = x1 + x2
        out = self.dropout(out)
        # LayerNorm over channel
        out = out.permute(0, 2, 3, 1)
        out = self.norm(out)
        out = out.permute(0, 3, 1, 2)
        return out

class SFNOMultiTask(nn.Module):
    """
    Full Spherical FNO, multi-task, multi-step, probabilistic.
    """
    def __init__(
        self, 
        in_channels: int, trunk_width: int = 64, trunk_depth: int = 6,
        modes_lat: int = 32, modes_lon: int = 64,
        aux_dim: int = 0, tasks: Tuple[str] = ("vtec",),
        out_shapes: Dict[str, Tuple[int, str]] = {"vtec": (1, "grid")},
        probabilistic: bool = True, dropout: float = 0.1, mc_dropout: bool = False
    ):
        super().__init__()
        self.in_proj = nn.Conv2d(in_channels, trunk_width, 1)
        self.blocks = nn.ModuleList([
            SFNOBlock(trunk_width, modes_lat, modes_lon, dropout=dropout, mc_dropout=mc_dropout)
            for _ in range(trunk_depth)
        ])
        self.aux_proj = nn.Linear(aux_dim, trunk_width) if aux_dim > 0 else None
        self.tasks = tasks
        self.out_shapes = out_shapes
        self.probabilistic = probabilistic

        # One head per task
        self.heads = nn.ModuleDict()
        for task in tasks:
            out_dim, out_type = out_shapes[task]
            out_channels = out_dim * (2 if probabilistic else 1)
            if out_type == "grid":
                self.heads[task] = nn.Conv2d(trunk_width, out_channels, 1)
            elif out_type == "scalar":
                self.heads[task] = nn.Sequential(
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Flatten(),
                    nn.Linear(trunk_width, out_channels)
                )
            else:
                raise ValueError(f"Unknown head type: {out_type}")

    def set_mc_dropout(self, mc=True):
        for b in self.blocks:
            b.set_mc_dropout(mc)

    def forward(self, x, aux: Optional[torch.Tensor] = None):
        x = self.in_proj(x)
        for block in self.blocks:
            x = block(x)
        if self.aux_proj is not None and aux is not None:
            aux_emb = self.aux_proj(aux).unsqueeze(-1).unsqueeze(-1)
            x = x + aux_emb
        outputs = {}
        for task in self.tasks:
            head = self.heads[task]
            out_dim, out_type = self.out_shapes[task]
            y = head(x)
            if self.probabilistic:
                split = out_dim
                mean, logvar = y[:, :split], y[:, split:]
                outputs[task] = (mean, logvar)
            else:
                outputs[task] = y
        return outputs

# ==========================================
# 3. TRAINING/VALIDATION ROUTINES
# ==========================================
def nll_gaussian(pred, logvar, target, reduce=True):
    """Negative log-likelihood for Gaussian regression."""
    loss = 0.5 * (logvar + ((pred - target) ** 2) / logvar.exp())
    return loss.mean() if reduce else loss

def train_epoch(model, loader, optimizer, device="cuda", tasks=("vtec",), reduce=True):
    model.train()
    total_loss = 0
    for batch in loader:
        optimizer.zero_grad()
        x = batch['gim']['input'].to(device)      # (B, Cin, H, W)
        y = batch['gim']['target'].to(device)     # (B, Cout, H, W)
        aux = torch.cat([
            batch['solar_index']['input'],
            batch['omniweb']['input'],
            batch['celestrak']['input']
        ], dim=-1).to(device)
        targets = {'vtec': y}  # Add more as needed
        outputs = model(x, aux)
        loss = 0.0
        for task in tasks:
            if model.probabilistic:
                pred, logvar = outputs[task]
                tgt = targets[task]
                loss += nll_gaussian(pred, logvar, tgt, reduce=reduce)
            else:
                loss += F.mse_loss(outputs[task], targets[task])
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.shape[0]
    return total_loss / len(loader.dataset)

@torch.no_grad()
def eval_epoch(model, loader, device="cuda", tasks=("vtec",), mc_dropout_samples: int = 1):
    model.eval()
    orig_mc = getattr(model.blocks[0].dropout, "mc", False)
    all_mc_outputs = []
    for _ in range(mc_dropout_samples):
        model.set_mc_dropout(mc_dropout_samples > 1)
        batch_outputs = []
        for batch in loader:
            x = batch['gim']['input'].to(device)
            y = batch['gim']['target'].to(device)
            aux = torch.cat([
                batch['solar_index']['input'],
                batch['omniweb']['input'],
                batch['celestrak']['input']
            ], dim=-1).to(device)
            out = model(x, aux)
            batch_outputs.append({k: (v[0].cpu(), v[1].cpu()) if isinstance(v, tuple) else v.cpu() for k, v in out.items()})
        all_mc_outputs.append(batch_outputs)
    model.set_mc_dropout(orig_mc)
    return all_mc_outputs

def compute_total_uncertainty(mc_outputs, task="vtec"):
    """Aggregate MC dropout samples for full uncertainty."""
    preds = torch.stack([b[0][task][0] for b in mc_outputs], dim=0)
    logvars = torch.stack([b[0][task][1] for b in mc_outputs], dim=0)
    epistemic = preds.std(dim=0)
    aleatoric = logvars.exp().mean(dim=0)
    total_var = epistemic ** 2 + aleatoric
    total_std = total_var.sqrt()
    return preds.mean(dim=0), epistemic, aleatoric, total_std

# ==========================================
# 4. TRAINING LOOP & MODEL INSTANTIATION
# ==========================================
tasks = ("vtec", "kp", "dst")
out_shapes = {
    "vtec": (3, "grid"),  # Predict 3 time steps of VTEC on grid
    "kp": (2, "scalar"),  # 2-step Kp
    "dst": (2, "scalar"), # 2-step Dst
}
aux_dim = 20  # Sum of auxiliary feature dims (edit if needed)

model = SFNOMultiTask(
    in_channels=3, trunk_width=64, trunk_depth=8, modes_lat=32, modes_lon=64,
    aux_dim=aux_dim, tasks=tasks, out_shapes=out_shapes,
    probabilistic=True, dropout=0.2, mc_dropout=True
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# -- Main Training Loop
for epoch in range(30):
    train_loss = train_epoch(model, train_loader, optimizer, device, tasks)
    val_outputs = eval_epoch(model, val_loader, device, tasks, mc_dropout_samples=10)
    print(f"Epoch {epoch:02d}: Train Loss={train_loss:.4f} | MC samples={len(val_outputs)}")

# ==========================================
# 5. AUTOREGRESSIVE ROLLOUT FOR FORECASTING
# ==========================================
def autoregressive_rollout(model, init_input, aux_seq, horizon=3, device="cuda"):
    """
    Roll forward for multiple timesteps (autoregressive inference).
    Args:
        model: Trained SFNO model
        init_input: (B, Cin, H, W) initial input window
        aux_seq: (B, horizon, aux_dim) sequence of aux features
    Returns:
        (B, horizon, Cout, H, W): predictions for each future step
    """
    model.eval()
    preds = []
    cur_input = init_input.to(device)
    for t in range(horizon):
        aux_t = aux_seq[:, t].to(device)
        with torch.no_grad():
            out = model(cur_input, aux_t)
            pred_mean, _ = out['vtec']  # (B, Cout, H, W)
            preds.append(pred_mean.unsqueeze(1))
            # Slide window: drop oldest, append new
            cur_input = torch.cat([cur_input[:, 1:], pred_mean], dim=1)
    return torch.cat(preds, dim=1)  # (B, horizon, Cout, H, W)
