In [None]:
"""
Minimal, working PyTorch skeleton for:
- Dual-stream FSCIL (Stable=EMA teacher, Plastic=student)
- Prompt learning (base prompt + session prompt)
- Global prototype memory (NCM classifier)
- Knowledge Distillation (stable -> plastic) in prototype-logit space
- Domain alignment loss (old vs new feature mean)
- Base training + incremental sessions + inference

Assumptions:
- You already have dataloaders that yield (images, labels).
- Labels are global class IDs (e.g., 0..C-1 across all sessions).
- Backbone here is a simple ViT-like "token" encoder for clarity.
  Replace PromptedBackbone with timm ViT (recommended) if you want.
"""

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# -----------------------------
# Utils
# -----------------------------
def l2_normalize(x: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor:
    return x / (x.norm(p=2, dim=dim, keepdim=True) + eps)


@torch.no_grad()
def ema_update(teacher: nn.Module, student: nn.Module, alpha: float = 0.999) -> None:
    """teacher = alpha*teacher + (1-alpha)*student"""
    for (n_t, p_t), (n_s, p_s) in zip(teacher.named_parameters(), student.named_parameters()):
        assert n_t == n_s, f"Param mismatch: {n_t} vs {n_s}"
        p_t.data.mul_(alpha).add_(p_s.data, alpha=(1 - alpha))


# -----------------------------
# Prompt module
# -----------------------------
class PromptModule(nn.Module):
    """
    Prompts are learnable tokens with the SAME dim as backbone embedding dim D.
    Shape: [m, D]  (m prompt tokens)
    """
    def __init__(self, num_prompts: int, dim: int, init_std: float = 0.02):
        super().__init__()
        self.prompts = nn.Parameter(torch.randn(num_prompts, dim) * init_std)

    def forward(self, batch_size: int) -> torch.Tensor:
        # Return prompts expanded to batch: [B, m, D]
        return self.prompts.unsqueeze(0).expand(batch_size, -1, -1)


# -----------------------------
# A small backbone (toy ViT-style token encoder)
# Replace with timm ViT for real runs.
# -----------------------------
class TinyTokenBackbone(nn.Module):
    """
    This is a lightweight placeholder for demonstration.
    For real FSCIL, replace with a real ViT (e.g., timm vit_base_patch16_224).
    """
    def __init__(self, in_ch: int = 3, dim: int = 768):
        super().__init__()
        # crude "patch embedding": global pooling + linear
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.proj = nn.Linear(in_ch, dim)
        # a small MLP to mimic transformer depth
        self.mlp = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim),
        )
        self.out_norm = nn.LayerNorm(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B,3,H,W] -> [B, dim]
        b, c, _, _ = x.shape
        v = self.pool(x).view(b, c)           # [B,3]
        h = self.proj(v)                      # [B,dim]
        h = h + self.mlp(h)                   # residual-ish
        h = self.out_norm(h)
        return h


class PromptedBackbone(nn.Module):
    """
    Wraps a backbone with prompt injection.
    For real ViT:
      - you would concat prompt tokens with patch tokens.
    Here, we simulate prompt influence by projecting pooled prompt summary and adding to features.
    """
    def __init__(self, backbone: nn.Module, prompt: PromptModule, dim: int = 768):
        super().__init__()
        self.backbone = backbone
        self.prompt = prompt
        self.prompt_proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b = x.shape[0]
        h = self.backbone(x)                  # [B, D]
        p = self.prompt(b)                    # [B, m, D]
        p_sum = p.mean(dim=1)                 # [B, D]
        h = h + self.prompt_proj(p_sum)       # prompt influences features
        return h


# -----------------------------
# Prototype memory (global across sessions)
# -----------------------------
class PrototypeMemory(nn.Module):
    """
    Stores class prototypes (centroids). Global, persistent.
    - prototypes: [C, D]
    - counts: [C] number of samples accumulated per class
    """
    def __init__(self, feat_dim: int):
        super().__init__()
        self.feat_dim = feat_dim
        self.register_buffer("prototypes", torch.empty(0, feat_dim))
        self.register_buffer("counts", torch.empty(0))

    @property
    def num_classes(self) -> int:
        return int(self.prototypes.shape[0])

    def ensure_size(self, num_classes: int, device: torch.device) -> None:
        if self.num_classes >= num_classes:
            return
        old_C = self.num_classes
        new_C = num_classes
        # expand prototypes and counts
        pad_proto = torch.zeros(new_C - old_C, self.feat_dim, device=device)
        pad_cnt = torch.zeros(new_C - old_C, device=device)
        if old_C == 0:
            self.prototypes = pad_proto
            self.counts = pad_cnt
        else:
            self.prototypes = torch.cat([self.prototypes.to(device), pad_proto], dim=0)
            self.counts = torch.cat([self.counts.to(device), pad_cnt], dim=0)

    @torch.no_grad()
    def update_from_batch(self, feats: torch.Tensor, labels: torch.Tensor) -> None:
        """
        Online mean update:
          mu_new = (n*mu_old + sum(feats))/ (n + k)
        feats: [B,D], labels: [B]
        """
        device = feats.device
        max_label = int(labels.max().item())
        self.ensure_size(max_label + 1, device=device)

        for c in labels.unique():
            c = int(c.item())
            idx = (labels == c)
            f_c = feats[idx]  # [k,D]
            k = f_c.shape[0]
            if k == 0:
                continue
            mu_old = self.prototypes[c]
            n_old = self.counts[c].clamp(min=0.0)
            mu_new = (n_old * mu_old + f_c.sum(dim=0)) / (n_old + float(k))
            self.prototypes[c] = mu_new
            self.counts[c] = n_old + float(k)

    def logits_ncm(self, feats: torch.Tensor, tau: float = 1.0, normalize: bool = True) -> torch.Tensor:
        """
        Prototype-based logits:
          logits_c = -||h - mu_c||^2 / tau
        """
        if self.num_classes == 0:
            raise RuntimeError("PrototypeMemory is empty. Build prototypes first (base session).")
        proto = self.prototypes.to(feats.device)
        h = feats
        if normalize:
            proto = l2_normalize(proto, dim=-1)
            h = l2_normalize(h, dim=-1)

        # squared Euclidean distances
        # dists: [B,C]
        dists = torch.cdist(h, proto, p=2) ** 2
        logits = -dists / tau
        return logits


# -----------------------------
# Losses
# -----------------------------
def proto_alignment_loss(feats: torch.Tensor, labels: torch.Tensor, proto_mem: PrototypeMemory) -> torch.Tensor:
    """
    Pull feature toward its class prototype: ||h - mu_y||^2
    """
    proto = proto_mem.prototypes.to(feats.device)
    mu = proto[labels]  # [B,D]
    feats_n = l2_normalize(feats, dim=-1)
    mu_n = l2_normalize(mu, dim=-1)
    return ((feats_n - mu_n) ** 2).sum(dim=-1).mean()


def kd_loss_logits(logits_teacher: torch.Tensor, logits_student: torch.Tensor, T: float = 2.0) -> torch.Tensor:
    """
    KD on logits:
      KL( softmax(teacher/T) || softmax(student/T) ) * T^2
    """
    p_t = F.softmax(logits_teacher / T, dim=-1)
    log_p_s = F.log_softmax(logits_student / T, dim=-1)
    return F.kl_div(log_p_s, p_t, reduction="batchmean") * (T * T)


def domain_shift_loss(feats_old: torch.Tensor, feats_new: torch.Tensor) -> torch.Tensor:
    """
    Align mean features between old-domain (stable) and new-domain (plastic):
      || mean(old) - mean(new) ||^2
    """
    mu_old = feats_old.mean(dim=0)
    mu_new = feats_new.mean(dim=0)
    mu_old = l2_normalize(mu_old, dim=-1)
    mu_new = l2_normalize(mu_new, dim=-1)
    return ((mu_old - mu_new) ** 2).sum()


# -----------------------------
# FSCIL model wrapper
# -----------------------------
@dataclass
class FSCILConfig:
    feat_dim: int = 768
    num_prompts: int = 10
    tau: float = 1.0
    kd_T: float = 2.0
    ema_alpha: float = 0.999

    # loss weights
    lam_proto: float = 1.0
    lam_kd: float = 1.0
    lam_domain: float = 0.5


class FSCILSystem(nn.Module):
    """
    Holds:
    - stable backbone + base prompt (teacher)
    - plastic backbone + session prompt (student) during training
    - global prototype memory
    """
    def __init__(self, cfg: FSCILConfig, device: torch.device):
        super().__init__()
        self.cfg = cfg
        self.device = device

        # Base prompt (learned in base session)
        self.base_prompt = PromptModule(cfg.num_prompts, cfg.feat_dim).to(device)

        # Backbones
        base_backbone = TinyTokenBackbone(dim=cfg.feat_dim).to(device)

        # Plastic uses (backbone + session prompt) that will be re-created per session
        self.plastic_prompt = None
        self.plastic = None

        # Stable uses (backbone + base prompt), EMA-updated
        self.stable = PromptedBackbone(copy.deepcopy(base_backbone), self.base_prompt, dim=cfg.feat_dim).to(device)
        for p in self.stable.parameters():
            p.requires_grad = False

        # Prototype memory
        self.proto_mem = PrototypeMemory(cfg.feat_dim).to(device)

        # Keep a reference backbone init for creating new plastic streams
        self._backbone_template = base_backbone

    def new_plastic_for_session(self) -> None:
        """
        Create a fresh plastic stream and session prompt for a new incremental session.
        Plastic initialized from stable backbone weights.
        """
        # Clone stable backbone weights into a trainable backbone
        backbone = copy.deepcopy(self._backbone_template).to(self.device)

        # Load stable's backbone weights (stable.backbone)
        backbone.load_state_dict(self.stable.backbone.state_dict(), strict=True)

        # Create session prompt initialized from base prompt (+ small noise)
        P_base = self.base_prompt.prompts.detach()
        session_prompt = PromptModule(self.cfg.num_prompts, self.cfg.feat_dim).to(self.device)
        with torch.no_grad():
            session_prompt.prompts.copy_(P_base + 0.01 * torch.randn_like(P_base))

        self.plastic_prompt = session_prompt
        self.plastic = PromptedBackbone(backbone, self.plastic_prompt, dim=self.cfg.feat_dim).to(self.device)

        # Plastic trainable
        for p in self.plastic.parameters():
            p.requires_grad = True

    @torch.no_grad()
    def build_prototypes_from_loader(self, loader) -> None:
        """
        Compute/update prototypes using stable stream + base prompt (recommended).
        """
        self.stable.eval()
        for x, y in loader:
            x = x.to(self.device)
            y = y.to(self.device)
            h = self.stable(x)  # [B,D]
            self.proto_mem.update_from_batch(h, y)

    def inference(self, x: torch.Tensor) -> torch.Tensor:
        """
        Final testing:
        x -> stable (base prompt) -> NCM logits -> argmax
        """
        self.stable.eval()
        with torch.no_grad():
            h = self.stable(x.to(self.device))
            logits = self.proto_mem.logits_ncm(h, tau=self.cfg.tau, normalize=True)
            pred = logits.argmax(dim=-1)
        return pred


# -----------------------------
# Training routines
# -----------------------------
def train_base_session(
    sys: FSCILSystem,
    base_loader,
    epochs: int = 5,
    lr: float = 1e-3,
):
    """
    Base session:
    - Train a plastic stream (trainable) + base prompt
    - EMA-update stable stream
    - Build prototypes at end
    """
    cfg = sys.cfg
    device = sys.device

    # Create a trainable plastic that uses the BASE prompt directly in base session
    backbone = copy.deepcopy(sys._backbone_template).to(device)
    plastic = PromptedBackbone(backbone, sys.base_prompt, dim=cfg.feat_dim).to(device)
    for p in plastic.parameters():
        p.requires_grad = True

    # Stable is frozen EMA teacher, initialized already
    sys.stable.eval()

    opt = torch.optim.AdamW(plastic.parameters(), lr=lr)

    plastic.train()
    for ep in range(epochs):
        for x, y in base_loader:
            x = x.to(device)
            y = y.to(device)

            # Forward (student)
            h_p = plastic(x)
            logits_p = sys.proto_mem.logits_ncm(h_p, tau=cfg.tau) if sys.proto_mem.num_classes > 0 else None

            # For base session, if proto memory empty, we can do a temporary linear head
            # Here we use a quick linear head for base training only.
            if logits_p is None or logits_p.shape[1] <= int(y.max().item()):
                # Make a temp head sized to current max label
                C = int(y.max().item()) + 1
                temp_head = nn.Linear(cfg.feat_dim, C).to(device)
                temp_head.train()
                opt_head = torch.optim.AdamW(temp_head.parameters(), lr=lr)

                h_p2 = l2_normalize(h_p)
                logits = temp_head(h_p2)
                loss_cls = F.cross_entropy(logits, y)

                opt.zero_grad(set_to_none=True)
                opt_head.zero_grad(set_to_none=True)
                loss_cls.backward()
                opt.step()
                opt_head.step()

                # EMA update stable from plastic backbone+prompt parameters:
                # stable's prompt is the same base prompt; stable backbone should follow plastic backbone.
                ema_update(sys.stable.backbone, plastic.backbone, alpha=cfg.ema_alpha)

            else:
                # If prototypes exist (rare at start), you can use NCM directly
                loss_cls = F.cross_entropy(logits_p, y)

                opt.zero_grad(set_to_none=True)
                loss_cls.backward()
                opt.step()
                ema_update(sys.stable.backbone, plastic.backbone, alpha=cfg.ema_alpha)

        print(f"[Base] epoch {ep+1}/{epochs} done.")

    # Build base prototypes using stable stream
    sys.proto_mem.prototypes = sys.proto_mem.prototypes[:0]  # reset
    sys.proto_mem.counts = sys.proto_mem.counts[:0]
    sys.build_prototypes_from_loader(base_loader)
    print("[Base] prototypes built:", sys.proto_mem.num_classes)


def train_incremental_session(
    sys: FSCILSystem,
    new_loader,
    old_loader: Optional[object],  # can be None if you do not have replay
    epochs: int = 5,
    lr: float = 5e-4,
):
    """
    Incremental session t:
    - Create new plastic stream & session prompt
    - Stable stream is frozen teacher
    - Loss = cls + proto + kd + domain
    - Update prototypes for new classes
    - EMA merge plastic -> stable
    """
    cfg = sys.cfg
    device = sys.device

    sys.new_plastic_for_session()
    assert sys.plastic is not None

    sys.stable.eval()
    sys.plastic.train()

    opt = torch.optim.AdamW(sys.plastic.parameters(), lr=lr)

    # If you have replay, we iterate both loaders; otherwise KD/domain become weaker.
    old_iter = iter(old_loader) if old_loader is not None else None

    for ep in range(epochs):
        for x_new, y_new in new_loader:
            x_new = x_new.to(device)
            y_new = y_new.to(device)

            # -------- new-class forward (plastic) --------
            h_new = sys.plastic(x_new)  # [B,D]
            logits_new = sys.proto_mem.logits_ncm(h_new, tau=cfg.tau, normalize=True)
            loss_cls = F.cross_entropy(logits_new, y_new)

            # Prototype alignment (only if prototype exists for those classes; after first update it will)
            # For the first mini-batch of a brand-new class, prototypes might be zero -> still ok but weak.
            loss_proto = proto_alignment_loss(h_new, y_new, sys.proto_mem)

            # -------- old-class forward (stable teacher + plastic student) --------
            loss_kd = torch.tensor(0.0, device=device)
            loss_dom = torch.tensor(0.0, device=device)

            if old_iter is not None:
                try:
                    x_old, y_old = next(old_iter)
                except StopIteration:
                    old_iter = iter(old_loader)
                    x_old, y_old = next(old_iter)

                x_old = x_old.to(device)
                y_old = y_old.to(device)

                with torch.no_grad():
                    h_old_s = sys.stable(x_old)  # stable features
                    logits_old_s = sys.proto_mem.logits_ncm(h_old_s, tau=cfg.tau, normalize=True)

                h_old_p = sys.plastic(x_old)  # plastic features
                logits_old_p = sys.proto_mem.logits_ncm(h_old_p, tau=cfg.tau, normalize=True)

                # KD on prototype-logits (unique aspect)
                loss_kd = kd_loss_logits(logits_old_s, logits_old_p, T=cfg.kd_T)

                # Domain alignment: align mean(old stable) with mean(new plastic)
                loss_dom = domain_shift_loss(h_old_s, h_new)

            # -------- total --------
            loss = loss_cls + cfg.lam_proto * loss_proto + cfg.lam_kd * loss_kd + cfg.lam_domain * loss_dom

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            # EMA merge plastic -> stable backbone ONLY (stable prompt is base prompt, stays fixed)
            ema_update(sys.stable.backbone, sys.plastic.backbone, alpha=cfg.ema_alpha)

            # Update prototypes online using plastic features for new classes
            with torch.no_grad():
                sys.proto_mem.update_from_batch(h_new.detach(), y_new.detach())

        print(f"[Inc] epoch {ep+1}/{epochs} done.")

    # Discard plastic stream and session prompt to keep inference light
    sys.plastic = None
    sys.plastic_prompt = None
    print("[Inc] session complete. total prototypes:", sys.proto_mem.num_classes)


# -----------------------------
# Example usage (you plug your loaders)
# -----------------------------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    cfg = FSCILConfig(
        feat_dim=768,          # ViT-B style
        num_prompts=10,
        tau=1.0,
        kd_T=2.0,
        ema_alpha=0.999,
        lam_proto=1.0,
        lam_kd=1.0,
        lam_domain=0.5
    )

    sys = FSCILSystem(cfg, device=device)

    # You must provide:
    # base_loader: contains many images for base classes
    # inc_loaders: list of loaders for each incremental session (few-shot new classes)
    # old_replay_loader: optional, could be memory buffer loader

    # train_base_session(sys, base_loader, epochs=50, lr=1e-4)
    # for t, new_loader in enumerate(inc_loaders, start=1):
    #     train_incremental_session(sys, new_loader, old_replay_loader, epochs=20, lr=5e-5)
    #
    # # Testing
    # for x_test, y_test in test_loader:
    #     pred = sys.inference(x_test)
    #     ...

    print("Skeleton ready. Plug your MiniImageNet loaders into the commented section.")
