In [None]:
"""
Minimal PyTorch skeleton for variable-length, multi-sensor, multitemporal change segmentation
(Per-sensor encoders → temporal transformer aggregation → U-Net style decoder)

Key ideas:
- Accepts a list of frames per sample; each frame has (image tensor, sensor id, Δdays to event, optional per-pixel valid mask).
- Per-sensor encoders produce two scales of features (H/2, H/4).  
- We add **time** (Δdays) and **sensor** embeddings to each frame's features.
- A **temporal Transformer** aggregates across time **per spatial location** at each scale, with masking for missing/invalid pixels.
- A light U-Net decoder upsamples to full resolution to output a binary change mask.

This is a scaffold you can adapt/extend (e.g., deeper backbones, more scales, SAR-specific preprocessing, etc.).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------------
# Utility: small building blocks
# -----------------------------

class ConvBNAct(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, k: int = 3, s: int = 1, p: Optional[int] = None):
        super().__init__()
        if p is None:
            p = k // 2
        self.conv = nn.Conv2d(in_ch, out_ch, k, s, p, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.act(self.bn(self.conv(x)))


class ResidualBlock(nn.Module):
    def __init__(self, ch: int):
        super().__init__()
        self.conv1 = ConvBNAct(ch, ch)
        self.conv2 = ConvBNAct(ch, ch)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.conv2(self.conv1(x))


class Downsample(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.conv = ConvBNAct(in_ch, out_ch, k=3, s=2, p=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


# ---------------------------------
# Per-sensor lightweight encoders
# ---------------------------------

class SensorEncoder(nn.Module):
    """A tiny 2-stage encoder that returns two scales: (H/2), (H/4).

    Replace with timm/SegFormer backbones if desired. Keep channel dims same across sensors
    so fusion is easy.
    """
    def __init__(self, in_ch: int, base_ch: int = 48):
        super().__init__()
        # Stage 1 → H/2
        self.stem = ConvBNAct(in_ch, base_ch, k=3, s=2, p=1)
        self.res1 = ResidualBlock(base_ch)
        # Stage 2 → H/4
        self.down2 = Downsample(base_ch, base_ch * 2)
        self.res2 = ResidualBlock(base_ch * 2)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        f1 = self.res1(self.stem(x))          # (B, C1, H/2, W/2)
        f2 = self.res2(self.down2(f1))        # (B, C2, H/4, W/4)
        return f1, f2


# ---------------------------------
# Time & sensor embeddings (FiLM-ish)
# ---------------------------------

class TimeEmbedding(nn.Module):
    """Encode Δdays (float) → feature bias + scale.
    We output a vector of size C and simply **add** it to feature maps (could FiLM with scale, too).
    """
    def __init__(self, d_model: int):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(16, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, d_model),
        )

    @staticmethod
    def _fourier_features(x: torch.Tensor, B: int = 8) -> torch.Tensor:
        # x: (N,) in days. Map to sin/cos features.
        device = x.device
        freqs = 2 ** torch.arange(B, device=device).float() * 2 * math.pi / 365.0
        xf = x[:, None] * freqs[None, :]
        return torch.cat([torch.sin(xf), torch.cos(xf)], dim=-1)  # (N, 2B)

    def forward(self, delta_days: torch.Tensor) -> torch.Tensor:
        ff = self._fourier_features(delta_days)  # (N, 16)
        return self.proj(ff)                     # (N, d_model)


class SensorEmbedding(nn.Module):
    def __init__(self, sensor_vocab: List[str], d_model: int):
        super().__init__()
        self.sensor_to_idx: Dict[str, int] = {s: i for i, s in enumerate(sensor_vocab)}
        self.emb = nn.Embedding(len(sensor_vocab), d_model)

    def forward(self, sensors: List[str]) -> torch.Tensor:
        idx = torch.tensor([self.sensor_to_idx[s] for s in sensors], dtype=torch.long, device=self.emb.weight.device)
        return self.emb(idx)  # (N, d_model)


# ---------------------------------
# Temporal aggregation: Transformer per spatial location
# ---------------------------------

class TemporalAggregator(nn.Module):
    """Aggregate a variable-length list of frame features (C, H, W) using a Transformer along time.

    We process one **sample** at a time to keep the padding/masking simple. This is fine as a template; 
    for speed, you can pack multiple samples by padding T to the batch max and using attn masks.
    """
    def __init__(self, channels: int, nhead: int = 8, num_layers: int = 2, dropout: float = 0.1):
        super().__init__()
        enc_layer = nn.TransformerEncoderLayer(d_model=channels, nhead=nhead, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

    def forward(
        self,
        feats: List[torch.Tensor],          # List[T] of (C, H, W)
        valid_masks: Optional[List[torch.Tensor]] = None,  # List[T] of (1, H, W) or None
    ) -> torch.Tensor:
        assert len(feats) > 0, "No frames provided"
        C, H, W = feats[0].shape
        T = len(feats)

        # Stack to (T, H*W, C)
        x = torch.stack(feats, dim=0)                  # (T, C, H, W)
        x = x.permute(0, 2, 3, 1).contiguous().view(T, H * W, C)  # (T, HW, C)

        # Build padding mask over time if per-pixel valid masks provided
        # Padding mask: (N, S) where N=batch(HW), S=seq_len(T). True = PAD
        if valid_masks is not None and valid_masks[0] is not None:
            m = torch.stack(valid_masks, dim=0)  # (T, 1, H, W)
            m = (m > 0.5).float()
            m = m.permute(0, 2, 3, 1).contiguous().view(T, H * W, 1)  # (T, HW, 1)
            # if all zeros at a pixel across time, we'll treat all as valid (avoid NaNs)
            pix_has_valid = (m.sum(dim=0, keepdim=False) > 0)  # (HW, 1)
            # Create time-wise mask: invalid → pad=True
            pad_mask = (~(m.bool())).permute(1, 0, 2).squeeze(-1)  # (HW, T)
            # For pixels with no valid frames at all, mark all as not padded so the transformer sees something
            pad_mask = torch.where(pix_has_valid.permute(1, 0), pad_mask, torch.zeros_like(pad_mask))
        else:
            pad_mask = None

        # Transformer expects (batch=HW, seq=T, d=C) when batch_first=True
        x = x.permute(1, 0, 2).contiguous()  # (HW, T, C)
        x = self.encoder(x, src_key_padding_mask=pad_mask)  # (HW, T, C)

        # Aggregate over time: masked mean
        if pad_mask is not None:
            valid = (~pad_mask).float()  # (HW, T)
            denom = valid.sum(dim=1, keepdim=True).clamp_min(1.0)
            agg = (x * valid.unsqueeze(-1)).sum(dim=1) / denom  # (HW, C)
        else:
            agg = x.mean(dim=1)  # (HW, C)

        agg = agg.view(H, W, C).permute(2, 0, 1).contiguous()  # (C, H, W)
        return agg


# ---------------------------------
# Decoder (2-scale U-Net style)
# ---------------------------------

class UNetDecoder2Scale(nn.Module):
    def __init__(self, ch_low: int, ch_high: int, out_ch: int = 1):
        super().__init__()
        self.up1 = nn.ConvTranspose2d(ch_high, ch_low, kernel_size=2, stride=2)
        self.fuse1 = nn.Sequential(
            ConvBNAct(ch_low + ch_low, ch_low),
            ResidualBlock(ch_low),
        )
        self.up2 = nn.ConvTranspose2d(ch_low, ch_low // 2, kernel_size=2, stride=2)
        self.head = nn.Sequential(
            ConvBNAct(ch_low // 2, ch_low // 4),
            nn.Conv2d(ch_low // 4, out_ch, kernel_size=1)
        )

    def forward(self, f_low: torch.Tensor, f_high: torch.Tensor) -> torch.Tensor:
        # f_low: (B, C1, H/2, W/2); f_high: (B, C2, H/4, W/4)
        x = self.up1(f_high)  # → (B, C1, H/2, W/2)
        x = torch.cat([x, f_low], dim=1)
        x = self.fuse1(x)
        x = self.up2(x)       # → (B, C1//2, H, W)
        logits = self.head(x) # → (B, 1, H, W)
        return logits


# ---------------------------------
# Full model
# ---------------------------------

@dataclass
class Frame:
    image: torch.Tensor          # (C, H, W)
    sensor: str                  # e.g., 'S2', 'S1', 'Landsat', 'BR'
    delta_days: float            # days relative to event date (negative=pre, positive=post)
    valid_mask: Optional[torch.Tensor] = None  # (1, H, W) 1=valid, 0=invalid (e.g., clouds)


class MultiSensorTemporalChangeSeg(nn.Module):
    def __init__(
        self,
        sensor_specs: Dict[str, int],   # sensor_id → in_channels
        base_ch: int = 48,
        t_nhead: int = 8,
        t_layers: int = 2,
    ):
        super().__init__()
        self.sensors = list(sensor_specs.keys())
        # Build per-sensor encoders with the same channel dims
        self.encoders = nn.ModuleDict({s: SensorEncoder(in_ch=cin, base_ch=base_ch) for s, cin in sensor_specs.items()})

        ch_low = base_ch
        ch_high = base_ch * 2

        # Embeddings to add to features
        self.time_emb_low = TimeEmbedding(ch_low)
        self.time_emb_high = TimeEmbedding(ch_high)
        self.sensor_emb_low = SensorEmbedding(self.sensors, ch_low)
        self.sensor_emb_high = SensorEmbedding(self.sensors, ch_high)

        # Temporal aggregators (per scale)
        self.tagg_low = TemporalAggregator(channels=ch_low, nhead=t_nhead, num_layers=t_layers)
        self.tagg_high = TemporalAggregator(channels=ch_high, nhead=t_nhead, num_layers=t_layers)

        # Decoder
        self.decoder = UNetDecoder2Scale(ch_low=ch_low, ch_high=ch_high, out_ch=1)

    def _apply_emb(self, feat: torch.Tensor, t_vec: torch.Tensor, s_vec: torch.Tensor) -> torch.Tensor:
        """Add time and sensor embeddings to a feature map (C, H, W)."""
        C, H, W = feat.shape
        t = t_vec.view(C, 1, 1)
        s = s_vec.view(C, 1, 1)
        return feat + t + s

    def forward(self, batch: List[List[Frame]]) -> torch.Tensor:
        """
        Args:
            batch: list of samples; each sample is a list of Frame objects (variable length)
        Returns:
            logits: (B, 1, H, W) binary change logits
        Notes:
            - All frames within a sample must share the same (H, W). Different samples in the same batch
              must also share (H, W) for this simple implementation.
        """
        B = len(batch)
        assert B > 0, "Empty batch"
        # Determine spatial size from first frame of first sample
        H, W = batch[0][0].image.shape[-2:]

        f_low_agg_list = []
        f_high_agg_list = []

        device = next(self.parameters()).device

        for sample in batch:
            # Per-sample lists of per-frame features (C, H', W')
            low_feats: List[torch.Tensor] = []
            high_feats: List[torch.Tensor] = []
            low_masks: List[Optional[torch.Tensor]] = []
            high_masks: List[Optional[torch.Tensor]] = []
            time_vecs_low: List[torch.Tensor] = []
            time_vecs_high: List[torch.Tensor] = []
            sens_vecs_low: List[torch.Tensor] = []
            sens_vecs_high: List[torch.Tensor] = []

            if len(sample) == 0:
                raise ValueError("A sample in the batch has zero frames.")

            for fr in sample:
                assert fr.image.shape[-2:] == (H, W), "All frames in a batch must share H, W"
                x = fr.image.unsqueeze(0).to(device)  # (1, C, H, W)
                enc = self.encoders[fr.sensor]
                f1, f2 = enc(x)  # shapes: (1, C1, H/2, W/2), (1, C2, H/4, W/4)
                f1, f2 = f1.squeeze(0), f2.squeeze(0)

                # Embeddings
                dd = torch.tensor([fr.delta_days], device=device, dtype=torch.float32)
                t1 = self.time_emb_low(dd)[0]   # (C1,)
                t2 = self.time_emb_high(dd)[0]  # (C2,)
                s1 = self.sensor_emb_low([fr.sensor])[0]  # (C1,)
                s2 = self.sensor_emb_high([fr.sensor])[0] # (C2,)

                f1 = self._apply_emb(f1, t1, s1)
                f2 = self._apply_emb(f2, t2, s2)

                low_feats.append(f1)
                high_feats.append(f2)

                if fr.valid_mask is not None:
                    # Downsample valid mask to the two scales (nearest keeps binary nature)
                    vm = fr.valid_mask.to(device)
                    vm1 = F.interpolate(vm.unsqueeze(0), scale_factor=0.5, mode="nearest").squeeze(0)   # (1, H/2, W/2)
                    vm2 = F.interpolate(vm.unsqueeze(0), scale_factor=0.25, mode="nearest").squeeze(0)  # (1, H/4, W/4)
                else:
                    vm1 = None
                    vm2 = None
                low_masks.append(vm1)
                high_masks.append(vm2)

            # Temporal aggregation per scale
            f1_agg = self.tagg_low(low_feats, low_masks)    # (C1, H/2, W/2)
            f2_agg = self.tagg_high(high_feats, high_masks) # (C2, H/4, W/4)

            f_low_agg_list.append(f1_agg.unsqueeze(0))
            f_high_agg_list.append(f2_agg.unsqueeze(0))

        f_low_agg = torch.cat(f_low_agg_list, dim=0)   # (B, C1, H/2, W/2)
        f_high_agg = torch.cat(f_high_agg_list, dim=0) # (B, C2, H/4, W/4)

        logits = self.decoder(f_low_agg, f_high_agg)   # (B, 1, H, W)
        return logits


# ---------------------------------
# Example usage & simple sanity test
# ---------------------------------

def _dummy_batch(
    B: int = 2,
    T_list: List[int] = (4, 7),
    H: int = 256,
    W: int = 256,
    device: str = "cpu",
) -> List[List[Frame]]:
    torch.manual_seed(0)
    sensors = ["S2", "S1", "Landsat", "BR"]
    batch: List[List[Frame]] = []
    for b in range(B):
        frames: List[Frame] = []
        T = T_list[b % len(T_list)]
        for t in range(T):
            sensor = sensors[t % len(sensors)]
            C = {"S2": 4, "S1": 2, "Landsat": 6, "BR": 4}[sensor]
            img = torch.randn(C, H, W, device=device)
            delta_days = float(-20 + 5 * t)  # toy
            # 20% chance to have a cloud mask with 80% valid pixels
            if torch.rand(1).item() < 0.5:
                vm = (torch.rand(1, H, W, device=device) > 0.2).float()
            else:
                vm = None
            frames.append(Frame(image=img, sensor=sensor, delta_days=delta_days, valid_mask=vm))
        batch.append(frames)
    return batch


def _sanity_run():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sensor_specs = {"S2": 4, "S1": 2, "Landsat": 6, "BR": 4}
    model = MultiSensorTemporalChangeSeg(sensor_specs, base_ch=32, t_nhead=8, t_layers=2).to(device)
    batch = _dummy_batch(B=2, T_list=[5, 3], H=128, W=128, device=device)
    logits = model(batch)  # (B, 1, H, W)
    print("Logits shape:", logits.shape)


if __name__ == "__main__":
    _sanity_run()
