In [None]:
# fscil_adsp_cfr_full.py
# Fully functional FSCIL (ADSP-CFR style) reference implementation.
# Runs end-to-end on CIFAR-100 by default.
#
# Key modules:
# - Plastic stream (trainable backbone + session prompt)
# - Stable stream (EMA backbone + base prompt) as teacher/memory
# - Prototype Memory + NCM classifier
# - Losses: CE + proto + KD + domain
#
# NOTE: This is a research-grade reference implementation (not SOTA optimized).
# It is designed to be clear, correct, and runnable.

import os
import math
import copy
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

import torchvision
import torchvision.transforms as T
from tqdm import tqdm
import timm


In [None]:
# ---------------------------
# Reproducibility
# ---------------------------
def seed_all(seed: int = 0):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [None]:
# ---------------------------
# Helpers
# ---------------------------
def l2_normalize(x: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor:
    return x / (x.norm(p=2, dim=dim, keepdim=True) + eps)


@torch.no_grad()
def ema_update(teacher: nn.Module, student: nn.Module, alpha: float):
    """teacher = alpha*teacher + (1-alpha)*student"""
    t_params = dict(teacher.named_parameters())
    for name, p_s in student.named_parameters():
        p_t = t_params[name]
        p_t.data.mul_(alpha).add_(p_s.data, alpha=(1.0 - alpha))


def set_requires_grad(module: nn.Module, flag: bool):
    for p in module.parameters():
        p.requires_grad = flag


In [None]:
# ---------------------------
# Dataset: MiniImageNet (optional)
# Folder structure expected:
# root/
#   train/class_x/xxx.png ...
#   test/class_x/xxx.png ...
#   val/class_x/xxx.png ...
# ---------------------------
class MiniImageNetFolder(torchvision.datasets.ImageFolder):
    def __init__(self, root_split: str, transform=None):
        super().__init__(root=root_split, transform=transform)




In [None]:
# ---------------------------
# Split FSCIL sessions
# ---------------------------
def make_fscil_splits(
    targets: List[int],
    base_classes: int = 40,
    ways_per_session: int = 5,
    shots_per_class: int = 5,
    seed: int = 0,
) -> Tuple[List[int], List[List[int]]]:
    """
    Returns:
      base_class_ids: list of class ids in base session
      inc_sessions: list of sessions, each session is list of class ids
    """
    seed_all(seed)
    classes = sorted(list(set(targets)))
    assert len(classes) >= base_classes, "Not enough classes."

    #random.shuffle(classes)
    base_cls = classes[:base_classes]
    remaining = classes[base_classes:]

    inc_sessions = []
    for i in range(0, len(remaining), ways_per_session):
        sess = remaining[i:i + ways_per_session]
        if len(sess) == ways_per_session:
            inc_sessions.append(sess)

    return base_cls, inc_sessions


def indices_for_classes(targets: List[int], class_ids: List[int]) -> List[int]:
    s = set(class_ids)
    return [i for i, y in enumerate(targets) if y in s]


def fewshot_indices_for_classes(
    targets: List[int],
    class_ids: List[int],
    shots_per_class: int,
    seed: int = 0
) -> List[int]:
    seed_all(seed)
    idxs = []
    for c in class_ids:
        all_idx = [i for i, y in enumerate(targets) if y == c]
        random.shuffle(all_idx)
        idxs.extend(all_idx[:shots_per_class])
    return idxs


In [None]:
# ---------------------------
# Memory buffer for replay
# ---------------------------
class ReplayBuffer:
    def __init__(self, max_per_class: int = 20, seed: int = 0):
        self.max_per_class = max_per_class
        self.seed = seed
        self.storage: Dict[int, List[int]] = {}  # class_id -> list of dataset indices

    def add_indices(self, targets: List[int], new_indices: List[int]):
        # store up to max_per_class per class
        for idx in new_indices:
            c = targets[idx]
            if c not in self.storage:
                self.storage[c] = []
            self.storage[c].append(idx)

        # trim
        seed_all(self.seed)
        for c in list(self.storage.keys()):
            lst = self.storage[c]
            random.shuffle(lst)
            self.storage[c] = lst[: self.max_per_class]

    def all_indices(self) -> List[int]:
        out = []
        for c, lst in self.storage.items():
            out.extend(lst)
        return out

In [None]:
# ---------------------------
# Prompt module (learnable tokens)
# ---------------------------
class PromptModule(nn.Module):
    """
    prompts: [m, D]
    """
    def __init__(self, num_prompts: int, dim: int, init_std: float = 0.02):
        super().__init__()
        self.prompts = nn.Parameter(torch.randn(num_prompts, dim) * init_std)

    def forward(self, B: int) -> torch.Tensor:
        return self.prompts.unsqueeze(0).expand(B, -1, -1)  # [B,m,D]

In [None]:
#--------------------------------------------------
# Model 
#--------------------------------------------------

class PromptedTimmViT(nn.Module):
    """
    Prompt injection into timm ViT:
      tokens = [CLS | PROMPTS | PATCHES]
    Works with timm VisionTransformer models (e.g., vit_base_patch16_224).
    """
    def __init__(self, timm_vit: nn.Module, prompt: nn.Module):
        super().__init__()
        self.vit = timm_vit
        self.prompt = prompt

        # Ensure timm vit has these expected parts
        assert hasattr(self.vit, "patch_embed")
        assert hasattr(self.vit, "pos_embed")
        assert hasattr(self.vit, "cls_token")
        assert hasattr(self.vit, "blocks")
        assert hasattr(self.vit, "norm")

        self.embed_dim = self.vit.embed_dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]

        # 1) Patch embedding (timm): [B, N, D]
        x = self.vit.patch_embed(x)

        # 2) CLS token: [B, 1, D]
        cls = self.vit.cls_token.expand(B, -1, -1)

        # 3) Prompt tokens: [B, M, D]
        prompts = self.prompt(B)
        if prompts.shape[-1] != self.embed_dim:
            raise ValueError(f"Prompt dim {prompts.shape[-1]} != ViT embed_dim {self.embed_dim}")

        # 4) Concat: [B, 1+M+N, D]
        x = torch.cat([cls, prompts, x], dim=1)

        # 5) Positional embedding handling
        # timm pos_embed is usually [1, 1+N, D] (CLS + patches).
        # We need [1, 1+M+N, D]. We'll insert M learnable prompt pos tokens.
        pos = self._pos_embed_with_prompts(x)

        x = x + pos
        x = self.vit.pos_drop(x)

        # 6) Transformer blocks
        for blk in self.vit.blocks:
            x = blk(x)

        x = self.vit.norm(x)

        # 7) Use CLS output as feature: [B, D]
        feat = x[:, 0]
        return feat

    def _pos_embed_with_prompts(self, x_tokens: torch.Tensor) -> torch.Tensor:
        """
        Create positional embeddings for [CLS | PROMPTS | PATCHES].
        """
        device = x_tokens.device
        dtype = x_tokens.dtype

        pos_base = self.vit.pos_embed.to(device=device, dtype=dtype)  # [1, 1+N, D]
        base_len = pos_base.shape[1]
        total_len = x_tokens.shape[1]
        if total_len == base_len:
            return pos_base

        # total_len = 1 + M + N, base_len = 1 + N => M = total_len - base_len
        M = total_len - base_len
        assert M > 0

        # create (or resize) learnable prompt position embeddings
        if not hasattr(self, "pos_embed_prompt") or self.pos_embed_prompt.shape[1] != M:
            self.pos_embed_prompt = nn.Parameter(torch.zeros(1, M, self.embed_dim, device=device, dtype=dtype))
            nn.init.trunc_normal_(self.pos_embed_prompt, std=0.02)
        else:
            if self.pos_embed_prompt.device != device or self.pos_embed_prompt.dtype != dtype:
                self.pos_embed_prompt = nn.Parameter(self.pos_embed_prompt.data.to(device=device, dtype=dtype))

        cls_pos = pos_base[:, :1, :]   # [1,1,D]
        patch_pos = pos_base[:, 1:, :] # [1,N,D]
        pos = torch.cat([cls_pos, self.pos_embed_prompt, patch_pos], dim=1)  # [1,1+M+N,D]
        return pos


In [None]:
# ---------------------------
# Prototype memory + NCM logits
# ---------------------------
class PrototypeMemory(nn.Module):
    def __init__(self, feat_dim: int):
        super().__init__()
        self.feat_dim = feat_dim
        self.register_buffer("prototypes", torch.empty(0, feat_dim))
        self.register_buffer("counts", torch.empty(0))

    @property
    def num_classes(self) -> int:
        return int(self.prototypes.shape[0])

    def ensure_size(self, C: int, device: torch.device):
        if self.num_classes >= C:
            return
        old = self.num_classes
        pad_proto = torch.zeros(C - old, self.feat_dim, device=device)
        pad_cnt = torch.zeros(C - old, device=device)
        if old == 0:
            self.prototypes = pad_proto
            self.counts = pad_cnt
        else:
            self.prototypes = torch.cat([self.prototypes.to(device), pad_proto], dim=0)
            self.counts = torch.cat([self.counts.to(device), pad_cnt], dim=0)

    @torch.no_grad()
    def update_from_batch(self, feats: torch.Tensor, labels: torch.Tensor):
        device = feats.device
        max_label = int(labels.max().item())
        self.ensure_size(max_label + 1, device)

        for c in labels.unique():
            c = int(c.item())
            idx = labels == c
            f_c = feats[idx]
            k = f_c.shape[0]
            if k == 0:
                continue
            mu_old = self.prototypes[c]
            n_old = self.counts[c].clamp(min=0.0)
            mu_new = (n_old * mu_old + f_c.sum(dim=0)) / (n_old + float(k))
            self.prototypes[c] = mu_new
            self.counts[c] = n_old + float(k)

    def logits_ncm(self, feats: torch.Tensor, tau: float = 1.0, normalize: bool = True) -> torch.Tensor:
        if self.num_classes == 0:
            raise RuntimeError("Empty prototype memory.")
        proto = self.prototypes.to(feats.device)
        h = feats
        if normalize:
            proto = l2_normalize(proto, dim=-1)
            h = l2_normalize(h, dim=-1)
        d2 = torch.cdist(h, proto, p=2) ** 2
        return -d2 / tau

In [None]:
# ---------------------------
# Losses
# ---------------------------
def proto_alignment_loss(feats: torch.Tensor, labels: torch.Tensor, proto_mem: PrototypeMemory) -> torch.Tensor:
    proto = proto_mem.prototypes.to(feats.device)
    mu = proto[labels]
    feats_n = l2_normalize(feats, dim=-1)
    mu_n = l2_normalize(mu, dim=-1)
    return ((feats_n - mu_n) ** 2).sum(dim=-1).mean()


def kd_loss_logits(logits_teacher: torch.Tensor, logits_student: torch.Tensor, T: float = 2.0) -> torch.Tensor:
    p_t = F.softmax(logits_teacher / T, dim=-1)
    log_p_s = F.log_softmax(logits_student / T, dim=-1)
    return F.kl_div(log_p_s, p_t, reduction="batchmean") * (T * T)


def domain_shift_loss(feats_old: torch.Tensor, feats_new: torch.Tensor) -> torch.Tensor:
    mu_old = l2_normalize(feats_old.mean(dim=0), dim=-1)
    mu_new = l2_normalize(feats_new.mean(dim=0), dim=-1)
    return ((mu_old - mu_new) ** 2).sum()

In [None]:
# ---------------------------
# System wrapper
# ---------------------------
@dataclass
class FSCILConfig:
    # -----------------------
    # Backbone (timm ViT)
    # -----------------------
    vit_name: str = "vit_base_patch16_224"
    pretrained: bool = True

    # timm vit_base_patch16_224 expects 224x224 and outputs 768-dim features
    img_size: int = 224
    patch_size: int = 16          # fixed by model name (B/16)
    feat_dim: int = 768           # ViT-Base embedding dim
    num_prompts: int = 10

    # -----------------------
    # Optimization
    # -----------------------
    base_epochs: int = 10
    inc_epochs: int = 20
    lr_base: float = 1e-4
    lr_inc: float = 5e-5
    weight_decay: float = 0.05

    batch_size: int = 64

    # -----------------------
    # FSCIL protocol (MiniImageNet/CIFAR100 style)
    # -----------------------
    base_classes: int = 60
    ways_per_session: int = 5
    shots_per_class: int = 5
    replay_max_per_class: int = 20

    # -----------------------
    # Loss / EMA
    # -----------------------
    tau: float = 1.0
    kd_T: float = 2.0
    ema_alpha: float = 0.999

    lam_proto: float = 1.0
    lam_kd: float = 1.0
    lam_domain: float = 0.5

    seed: int = 0

In [None]:
import copy
import timm
import torch
import torch.nn as nn
from typing import Optional
from torch.utils.data import DataLoader

class FSCILSystem(nn.Module):
    def __init__(self, cfg: FSCILConfig, device: torch.device):
        super().__init__()
        self.cfg = cfg
        self.device = device

        # IMPORTANT: cfg.feat_dim must match timm vit embed_dim (vit_base_patch16_224 -> 768)
        # If cfg.feat_dim is not 768, set it to vit_template.embed_dim after model creation.

        vit_template = timm.create_model(
            cfg.vit_name,
            pretrained=True,
            num_classes=0  # feature extractor
        ).to(device)

        # Ensure cfg.feat_dim matches actual ViT dim
        if cfg.feat_dim != vit_template.embed_dim:
            raise ValueError(f"cfg.feat_dim={cfg.feat_dim} but vit.embed_dim={vit_template.embed_dim}. "
                             f"Set cfg.feat_dim={vit_template.embed_dim} (768 for vit_base).")

        # base prompt (persistent)
        self.base_prompt = PromptModule(cfg.num_prompts, cfg.feat_dim).to(device)

        # stable = EMA teacher (persistent)
        self.stable = PromptedTimmViT(copy.deepcopy(vit_template), self.base_prompt).to(device)
        set_requires_grad(self.stable, False)  # stable frozen (EMA only)

        # plastic/session prompt are created per session (temporary)
        self.plastic: Optional[PromptedTimmViT] = None
        self.session_prompt: Optional[PromptModule] = None

        # prototype memory (persistent)
        self.proto_mem = PrototypeMemory(cfg.feat_dim).to(device)

        # keep vit template for plastic recreation
        self._vit_template = vit_template

    def new_plastic_for_session(self):
        cfg = self.cfg

        # clone ViT and init from stable vit weights
        vit = copy.deepcopy(self._vit_template).to(self.device)

        # stable.vit is the timm ViT inside PromptedTimmViT
        vit.load_state_dict(self.stable.vit.state_dict(), strict=True)

        # session prompt initialized from base prompt (+ noise)
        P_base = self.base_prompt.prompts.detach()
        sp = PromptModule(cfg.num_prompts, cfg.feat_dim).to(self.device)
        with torch.no_grad():
            sp.prompts.copy_(P_base + 0.01 * torch.randn_like(P_base))

        self.session_prompt = sp
        self.plastic = PromptedTimmViT(vit, self.session_prompt).to(self.device)
        set_requires_grad(self.plastic, True)

    @torch.no_grad()
    def build_prototypes(self, loader: DataLoader):
        self.stable.eval()
        for x, y in loader:
            x, y = x.to(self.device), y.to(self.device)
            h = self.stable(x)  # [B, 768]
            self.proto_mem.update_from_batch(h, y)

    @torch.no_grad()
    def predict(self, x: torch.Tensor) -> torch.Tensor:
        self.stable.eval()
        h = self.stable(x.to(self.device))
        logits = self.proto_mem.logits_ncm(h, tau=self.cfg.tau, normalize=True)
        return logits.argmax(dim=-1)


In [None]:
# ---------------------------
# Training / Evaluation
# ---------------------------
def accuracy(model: FSCILSystem, loader: DataLoader) -> float:
    correct = 0
    total = 0
    for x, y in loader:
        x, y = x.to(model.device), y.to(model.device)
        pred = model.predict(x)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

In [None]:
import os
import json
import torch
from dataclasses import asdict

def save_checkpoint(
    model: "FSCILSystem",
    path: str,
    session_id: int,
    replay_buffer: Optional[object] = None,  # your ReplayBuffer instance
    extra: Optional[dict] = None,
):
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)

    ckpt = {
        "session_id": session_id,
        "cfg": asdict(model.cfg) if hasattr(model, "cfg") else None,

        # Persistent model parts
        "stable": model.stable.state_dict(),
        "base_prompt": model.base_prompt.state_dict(),

        # Prototype memory
        "proto_prototypes": model.proto_mem.prototypes.detach().cpu(),
        "proto_counts": model.proto_mem.counts.detach().cpu(),

        # Optional: base linear head (only used in base stage)
        "base_head": getattr(train_base, "head", None).state_dict() if hasattr(train_base, "head") else None,

        # Optional: replay buffer indices
        "replay": replay_buffer.storage if replay_buffer is not None else None,

        # Optional: RNG for exact reproducibility
        "rng_torch": torch.get_rng_state(),
        "rng_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
    }

    if extra is not None:
        ckpt["extra"] = extra

    torch.save(ckpt, path)
    print(f"[CKPT] Saved -> {path}")


def load_checkpoint(
    model: "FSCILSystem",
    path: str,
    replay_buffer: Optional[object] = None,
    map_location: str = "cpu",
):
    ckpt = torch.load(path, map_location=map_location)

    model.stable.load_state_dict(ckpt["stable"])
    model.base_prompt.load_state_dict(ckpt["base_prompt"])

    # Restore prototypes
    with torch.no_grad():
        model.proto_mem.prototypes = ckpt["proto_prototypes"].to(model.device)
        model.proto_mem.counts = ckpt["proto_counts"].to(model.device)

    # Restore base head if present
    if ckpt.get("base_head", None) is not None:
        # You must create the head with correct out_features before loading
        C = ckpt["base_head"]["weight"].shape[0]
        train_base.head = nn.Linear(model.cfg.feat_dim, C).to(model.device)
        train_base.head.load_state_dict(ckpt["base_head"])

    # Restore replay buffer if provided
    if replay_buffer is not None and ckpt.get("replay", None) is not None:
        replay_buffer.storage = ckpt["replay"]

    # Restore RNG (optional)
    if ckpt.get("rng_torch", None) is not None:
        torch.set_rng_state(ckpt["rng_torch"])
    if torch.cuda.is_available() and ckpt.get("rng_cuda", None) is not None:
        torch.cuda.set_rng_state_all(ckpt["rng_cuda"])

    session_id = ckpt.get("session_id", 0)
    print(f"[CKPT] Loaded <- {path} | session_id={session_id} | prototypes={model.proto_mem.num_classes}")
    return session_id, ckpt.get("extra", None)


# Report

In [None]:
from sklearn.metrics import accuracy_score, classification_report
import torch

@torch.no_grad()
def collect_preds(
    model: FSCILSystem,
    loader: DataLoader,
    allowed_max_class: int,     # evaluate only y < allowed_max_class
):
    model.stable.eval()
    y_true, y_pred = [], []

    for x, y in loader:
        # filter by seen classes
        mask = (y < allowed_max_class)
        if mask.sum().item() == 0:
            continue

        x = x[mask].to(model.device)
        y = y[mask].to(model.device)

        h = model.stable(x)
        logits = model.proto_mem.logits_ncm(h, tau=model.cfg.tau, normalize=True)
        pred = logits.argmax(dim=-1)

        y_true.extend(y.cpu().tolist())
        y_pred.extend(pred.cpu().tolist())

    return y_true, y_pred


In [None]:
def report_base(model: FSCILSystem, test_loader: DataLoader, class_names=None):
    base_C = model.cfg.base_classes
    y_true, y_pred = collect_preds(model, test_loader, allowed_max_class=base_C)

    acc = accuracy_score(y_true, y_pred)
    print(f"\n========== BASE REPORT (classes 0 to {base_C-1}) ==========")
    print(f"[BASE] Accuracy: {acc*100:.2f}%")

    # optional names: must match only base classes
    names = None
    if class_names is not None:
        names = class_names[:base_C]

    print(classification_report(y_true, y_pred, target_names=names, digits=4))
    return acc


In [None]:
def report_session(model: FSCILSystem, test_loader: DataLoader, session_id: int, class_names=None):
    seen_C = model.cfg.base_classes + session_id * model.cfg.ways_per_session
    y_true, y_pred = collect_preds(model, test_loader, allowed_max_class=seen_C)

    acc = accuracy_score(y_true, y_pred)
    print(f"\n========== SESSION {session_id} REPORT (classes 0 to {seen_C-1}) ==========")
    print(f"[SESSION {session_id}] Accuracy: {acc*100:.2f}%")

    names = None
    if class_names is not None:
        names = class_names[:seen_C]

    print(classification_report(y_true, y_pred, target_names=names, digits=4))
    return acc


In [None]:
def report_final(model: FSCILSystem, test_loader: DataLoader, total_classes: int, class_names=None):
    y_true, y_pred = collect_preds(model, test_loader, allowed_max_class=total_classes)

    acc = accuracy_score(y_true, y_pred)
    print(f"\n========== FINAL REPORT (classes 0 to {total_classes-1}) ==========")
    print(f"[FINAL] Accuracy: {acc*100:.2f}%")

    names = class_names if class_names is not None else None
    print(classification_report(y_true, y_pred, target_names=names, digits=4))
    return acc


In [None]:
acc_log = {"base": None, "sessions": [], "final": None}

In [None]:
# ---------------------------
# Training Base 
# ---------------------------

def evaluate_stable_accuracy(model: FSCILSystem, loader: DataLoader) -> float:
    """
    Base-stage evaluation using stable stream + temporary linear head.
    FILTERED: evaluates only base classes [0 .. base_classes-1].
    """
    model.stable.eval()
    head = getattr(train_base, "head", None)
    if head is None:
        return 0.0

    correct, total = 0, 0
    base_C = model.cfg.base_classes  # e.g., 60

    with torch.no_grad():
        for x, y in loader:
            # ---- FILTER: keep only base-class samples ----
            mask = (y < base_C)
            if mask.sum().item() == 0:
                continue

            x = x[mask].to(model.device)
            y = y[mask].to(model.device)

            h = model.stable(x)  # [B, D]
            logits = head(l2_normalize(h))
            pred = logits.argmax(dim=-1)

            correct += (pred == y).sum().item()
            total += y.numel()

    return correct / max(total, 1)



# ---------------------------
# Training Base (tqdm + val + early stopping)
# ---------------------------
def train_base(
    model: FSCILSystem,
    base_loader: DataLoader,
    val_loader: DataLoader,
    patience: int = 35,
    min_delta: float = 1e-4,
):
    cfg = model.cfg
    device = model.device

    # train a temporary plastic that uses BASE prompt (same as stable prompt)
    plastic = PromptedTimmViT(
        copy.deepcopy(model._vit_template).to(device),
        model.base_prompt
    ).to(device)
    set_requires_grad(plastic, True)

    # stable is EMA-updated only
    model.stable.eval()

    opt = torch.optim.AdamW(
        plastic.parameters(),
        lr=cfg.lr_base,
        weight_decay=cfg.weight_decay
    )

    best_acc = -1.0
    best_state = None
    no_improve = 0

    for ep in range(cfg.base_epochs):
        plastic.train()

        epoch_bar = tqdm(
            base_loader,
            desc=f"[Base] Epoch {ep+1}/{cfg.base_epochs}",
            leave=True
        )

        for x, y in epoch_bar:
            x, y = x.to(device), y.to(device)
            h = plastic(x)  # [B, D]

            # temporary linear head for base training only
            if (not hasattr(train_base, "head")) or (train_base.head.out_features <= int(y.max().item())):
                C = int(y.max().item()) + 1
                train_base.head = nn.Linear(cfg.feat_dim, C).to(device)
                train_base.head_opt = torch.optim.AdamW(
                    train_base.head.parameters(),
                    lr=cfg.lr_base,
                    weight_decay=cfg.weight_decay
                )

            logits = train_base.head(l2_normalize(h))
            loss = F.cross_entropy(logits, y)

            opt.zero_grad(set_to_none=True)
            train_base.head_opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            train_base.head_opt.step()

            # EMA: stable <- plastic (vit weights only)
            ema_update(model.stable.vit, plastic.vit, alpha=cfg.ema_alpha)

            epoch_bar.set_postfix(loss=f"{loss.item():.4f}")

        epoch_bar.close()

        # -----------------------
        # Validation after epoch
        # -----------------------
        val_acc = evaluate_stable_accuracy(model, val_loader)
        print(f"[Base] Epoch {ep+1}: val_acc={val_acc*100:.2f}%")

        # -----------------------
        # Early stopping logic
        # -----------------------
        if val_acc > best_acc + min_delta:
            best_acc = val_acc
            no_improve = 0

            # Save best stable model weights (+ head if you want)
            best_state = {
                "stable": copy.deepcopy(model.stable.state_dict()),
                "head": copy.deepcopy(train_base.head.state_dict()),
            }
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"[Base] Early stopping triggered at epoch {ep+1}. Best val_acc={best_acc*100:.2f}%")
                break
    

    
    # -----------------------
    # Restore best weights
    # -----------------------
    if best_state is not None:
        model.stable.load_state_dict(best_state["stable"])
        train_base.head.load_state_dict(best_state["head"])
        print(f"[Base] Restored best model. Best val_acc={best_acc*100:.2f}%")

    # -----------------------
    # Build prototypes using stable (and base prompt)
    # -----------------------
    model.proto_mem.prototypes = model.proto_mem.prototypes[:0]
    model.proto_mem.counts = model.proto_mem.counts[:0]
    model.build_prototypes(base_loader)

    save_checkpoint(
        model,
        path="checkpoints/base.ckpt",
        session_id=0,
        replay_buffer=None  # if you use replay during base, pass it
    )

    


    print(f"[Base] prototypes built: {model.proto_mem.num_classes}")

    acc_log["base"] = report_base(model, val_loader)



In [None]:
# ---------------------------
# Training Incremental
# ---------------------------


def train_incremental(
    model: FSCILSystem,
    new_loader: DataLoader,
    replay_loader: Optional[DataLoader],
    session_id: int,
    replay_buffer=None
):
    cfg = model.cfg
    device = model.device

    model.new_plastic_for_session()
    assert model.plastic is not None
    plastic = model.plastic

    model.stable.eval()
    plastic.train()

    opt = torch.optim.AdamW(
        plastic.parameters(),
        lr=cfg.lr_inc,
        weight_decay=cfg.weight_decay
    )

    replay_iter = iter(replay_loader) if replay_loader is not None else None

    for ep in range(cfg.inc_epochs):

        epoch_bar = tqdm(
            new_loader,
            desc=f"[Inc S{session_id}] Epoch {ep+1}/{cfg.inc_epochs}",
            leave=True
        )

        for x_new, y_new in epoch_bar:
            x_new, y_new = x_new.to(device), y_new.to(device)


            # IMPORTANT: allocate prototype slots for new classes
            model.proto_mem.ensure_size(int(y_new.max().item()) + 1, device=device)

            # ---- new-class forward (plastic) ----
            h_new = plastic(x_new)
            logits_new = model.proto_mem.logits_ncm(
                h_new, tau=cfg.tau, normalize=True
            )

            loss_cls = F.cross_entropy(logits_new, y_new)
            loss_proto = proto_alignment_loss(h_new, y_new, model.proto_mem)

            # ---- old-class forward (stable teacher + plastic student) ----
            loss_kd = torch.tensor(0.0, device=device)
            loss_dom = torch.tensor(0.0, device=device)

            if replay_iter is not None:
                try:
                    x_old, y_old = next(replay_iter)
                except StopIteration:
                    replay_iter = iter(replay_loader)
                    x_old, y_old = next(replay_iter)

                x_old, y_old = x_old.to(device), y_old.to(device)

                with torch.no_grad():
                    h_old_s = model.stable(x_old)
                    logits_old_s = model.proto_mem.logits_ncm(
                        h_old_s, tau=cfg.tau, normalize=True
                    )

                h_old_p = plastic(x_old)
                logits_old_p = model.proto_mem.logits_ncm(
                    h_old_p, tau=cfg.tau, normalize=True
                )

                loss_kd = kd_loss_logits(
                    logits_old_s, logits_old_p, T=cfg.kd_T
                )

                loss_dom = domain_shift_loss(h_old_s, h_new)

            # ---- total loss ----
            loss = (
                loss_cls
                + cfg.lam_proto * loss_proto
                + cfg.lam_kd * loss_kd
                + cfg.lam_domain * loss_dom
            )

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            # EMA merge plastic -> stable (ViT weights only)
            ema_update(
                model.stable.vit,
                plastic.vit,
                alpha=cfg.ema_alpha
            )

            # Update prototypes for new classes
            with torch.no_grad():
                model.proto_mem.update_from_batch(
                    h_new.detach(),
                    y_new.detach()
                )

            # update tqdm bar
            epoch_bar.set_postfix(loss=f"{loss.item():.4f}")

        epoch_bar.close()

        print(
            f"[Inc S{session_id}] Epoch {ep+1}/{cfg.inc_epochs} completed | "
            f"last_loss={loss.item():.4f}"
        )

        
    # discard plastic/session prompt
    model.plastic = None
    model.session_prompt = None

    save_checkpoint(
        model,
        path=f"checkpoints/session_{session_id}.ckpt",
        session_id=session_id,
        replay_buffer=replay_buffer
    )

    print(
        f"[Inc] session {session_id} done. "
        f"Total prototypes now: {model.proto_mem.num_classes}"
    )

    



In [None]:
# ---------------------------
# Main: CIFAR-100 FSCIL
# ---------------------------
def main_cifar100(cfg: FSCILConfig):
    seed_all(cfg.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # transforms (match cfg.img_size)
    tf_train = T.Compose([
        T.Resize((cfg.img_size, cfg.img_size)),
        T.RandomCrop(cfg.img_size, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])
    tf_test = T.Compose([
        T.Resize((cfg.img_size, cfg.img_size)),
        T.ToTensor(),
        T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    trainset = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=tf_train)
    testset  = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tf_test)

    train_targets = list(trainset.targets)
    test_targets = list(testset.targets)

    base_cls, inc_sessions = make_fscil_splits(
        targets=train_targets,
        base_classes=cfg.base_classes,
        ways_per_session=cfg.ways_per_session,
        shots_per_class=cfg.shots_per_class,
        seed=cfg.seed
    )

    # Build loaders
    print(base_cls)
    base_idx = indices_for_classes(train_targets, base_cls)
    
    base_loader = DataLoader(Subset(trainset, base_idx), batch_size=cfg.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    # Test loader: we evaluate on full test set (all classes)
    test_loader = DataLoader(testset, batch_size=cfg.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Replay buffer
    replay = ReplayBuffer(max_per_class=cfg.replay_max_per_class, seed=cfg.seed)

    # System
    model = FSCILSystem(cfg, device=device).to(device)

    print("\n=== BASE TRAINING ===")
    train_base(model, base_loader,test_loader)

    # Add some base exemplars to replay
    # (sample few indices from base set for replay memory)
    replay.add_indices(train_targets, random.sample(base_idx, min(len(base_idx), cfg.replay_max_per_class * len(base_cls))))

    acc0 = accuracy(model, test_loader)
    print(f"[Eval] After base: Acc(all test classes) = {acc0*100:.2f}%")

    # Incremental sessions
    for s, sess_classes in enumerate(inc_sessions, start=1):
        print(f"\n=== INCREMENTAL SESSION {s}/{len(inc_sessions)} | classes={sess_classes} ===")

        # new few-shot indices (shots per class)
        new_idx = fewshot_indices_for_classes(train_targets, sess_classes, cfg.shots_per_class, seed=cfg.seed + s)
        new_loader = DataLoader(Subset(trainset, new_idx), batch_size=min(cfg.batch_size, len(new_idx)), shuffle=True, num_workers=2, pin_memory=True)

        # replay loader from buffer (old classes)
        replay_idx = replay.all_indices()
        replay_loader = None
        if len(replay_idx) > 0:
            replay_loader = DataLoader(Subset(trainset, replay_idx), batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)

        # Train session
        train_incremental(model, new_loader, replay_loader, session_id=s, replay_buffer=replay)

        # Update replay with new indices
        replay.add_indices(train_targets, new_idx)

        acc = accuracy(model, test_loader)
        acc_log["sessions"].append(report_session(model, test_loader, s))
        print(f"[Eval] After session {s}: Acc(all test classes) = {acc*100:.2f}%")
    
    acc_log["final"] = report_final(model, test_loader, total_classes=100)

    print("\nDone.")

In [None]:
if __name__ == "__main__":
    cfg = FSCILConfig(
        vit_name="vit_small_patch16_224",
        pretrained=True,

        img_size=224,
        patch_size=16,   # fixed by vit
        feat_dim=384,    # fixed by ViT
        num_prompts=10,

        base_epochs=35,
        inc_epochs=20,
        lr_base=1e-4,
        lr_inc=5e-5,
        weight_decay=0.05,

        base_classes=60,
        ways_per_session=5,
        shots_per_class=5,
        replay_max_per_class=20,
        batch_size=64,   # 128 may OOM on some GPUs with ViT-B/16

        tau=1.0,
        kd_T=2.0,
        ema_alpha=0.999,
        lam_proto=1.0,
        lam_kd=1.0,
        lam_domain=0.5,

        seed=0
    )

    main_cifar100(cfg)


In [None]:
acc_log

In [None]:
with open("acc_log.txt", "w") as f:
    for item in acc_log:
        f.write(str(item) + "\n")
