In [None]:
# fscil_adsp_cfr_full.py
# Fully functional FSCIL (ADSP-CFR style) reference implementation.
# Runs end-to-end on CIFAR-100 by default.
#
# Key modules:
# - Plastic stream (trainable backbone + session prompt)
# - Stable stream (EMA backbone + base prompt) as teacher/memory
# - Prototype Memory + NCM classifier
# - Losses: CE + proto + KD + domain
#
# NOTE: This is a research-grade reference implementation (not SOTA optimized).
# It is designed to be clear, correct, and runnable.

import os
import math
import copy
import random
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset

import torchvision
import torchvision.transforms as T
from tqdm import tqdm


In [None]:
# ---------------------------
# Reproducibility
# ---------------------------
def seed_all(seed: int = 0):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [None]:
# ---------------------------
# Helpers
# ---------------------------
def l2_normalize(x: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor:
    return x / (x.norm(p=2, dim=dim, keepdim=True) + eps)


@torch.no_grad()
def ema_update(teacher: nn.Module, student: nn.Module, alpha: float):
    """teacher = alpha*teacher + (1-alpha)*student"""
    t_params = dict(teacher.named_parameters())
    for name, p_s in student.named_parameters():
        p_t = t_params[name]
        p_t.data.mul_(alpha).add_(p_s.data, alpha=(1.0 - alpha))


def set_requires_grad(module: nn.Module, flag: bool):
    for p in module.parameters():
        p.requires_grad = flag


In [None]:
# ---------------------------
# Dataset: MiniImageNet (optional)
# Folder structure expected:
# root/
#   train/class_x/xxx.png ...
#   test/class_x/xxx.png ...
#   val/class_x/xxx.png ...
# ---------------------------
class MiniImageNetFolder(torchvision.datasets.ImageFolder):
    def __init__(self, root_split: str, transform=None):
        super().__init__(root=root_split, transform=transform)




In [None]:
# ---------------------------
# Split FSCIL sessions
# ---------------------------
def make_fscil_splits(
    targets: List[int],
    base_classes: int = 60,
    ways_per_session: int = 5,
    shots_per_class: int = 5,
    seed: int = 0,
) -> Tuple[List[int], List[List[int]]]:
    """
    Returns:
      base_class_ids: list of class ids in base session
      inc_sessions: list of sessions, each session is list of class ids
    """
    seed_all(seed)
    classes = sorted(list(set(targets)))
    assert len(classes) >= base_classes, "Not enough classes."

    random.shuffle(classes)
    base_cls = classes[:base_classes]
    remaining = classes[base_classes:]

    inc_sessions = []
    for i in range(0, len(remaining), ways_per_session):
        sess = remaining[i:i + ways_per_session]
        if len(sess) == ways_per_session:
            inc_sessions.append(sess)

    return base_cls, inc_sessions


def indices_for_classes(targets: List[int], class_ids: List[int]) -> List[int]:
    s = set(class_ids)
    return [i for i, y in enumerate(targets) if y in s]


def fewshot_indices_for_classes(
    targets: List[int],
    class_ids: List[int],
    shots_per_class: int,
    seed: int = 0
) -> List[int]:
    seed_all(seed)
    idxs = []
    for c in class_ids:
        all_idx = [i for i, y in enumerate(targets) if y == c]
        random.shuffle(all_idx)
        idxs.extend(all_idx[:shots_per_class])
    return idxs


In [None]:
# ---------------------------
# Memory buffer for replay
# ---------------------------
class ReplayBuffer:
    def __init__(self, max_per_class: int = 20, seed: int = 0):
        self.max_per_class = max_per_class
        self.seed = seed
        self.storage: Dict[int, List[int]] = {}  # class_id -> list of dataset indices

    def add_indices(self, targets: List[int], new_indices: List[int]):
        # store up to max_per_class per class
        for idx in new_indices:
            c = targets[idx]
            if c not in self.storage:
                self.storage[c] = []
            self.storage[c].append(idx)

        # trim
        seed_all(self.seed)
        for c in list(self.storage.keys()):
            lst = self.storage[c]
            random.shuffle(lst)
            self.storage[c] = lst[: self.max_per_class]

    def all_indices(self) -> List[int]:
        out = []
        for c, lst in self.storage.items():
            out.extend(lst)
        return out

In [None]:
# ---------------------------
# Prompt module (learnable tokens)
# ---------------------------
class PromptModule(nn.Module):
    """
    prompts: [m, D]
    """
    def __init__(self, num_prompts: int, dim: int, init_std: float = 0.02):
        super().__init__()
        self.prompts = nn.Parameter(torch.randn(num_prompts, dim) * init_std)

    def forward(self, B: int) -> torch.Tensor:
        return self.prompts.unsqueeze(0).expand(B, -1, -1)  # [B,m,D]

In [None]:
# ---------------------------
# A simple ViT (functional, no external libs)
# Prompt injection is exact: [CLS | PROMPTS | PATCHES]
# ---------------------------
class PatchEmbed(nn.Module):
    def __init__(self, img_size: int, patch_size: int, in_chans: int, embed_dim: int):
        super().__init__()
        assert img_size % patch_size == 0
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid = img_size // patch_size
        self.num_patches = self.grid * self.grid
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B,C,H,W] -> [B, N, D]
        x = self.proj(x)                # [B,D,grid,grid]
        x = x.flatten(2).transpose(1, 2)  # [B,N,D]
        return x


class SimpleViT(nn.Module):
    def __init__(
        self,
        img_size: int = 32,
        patch_size: int = 4,
        in_chans: int = 3,
        embed_dim: int = 256,
        depth: int = 6,
        num_heads: int = 8,
        mlp_ratio: float = 4.0,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.patch = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # NOTE: pos_embed must match token count: 1 + M + N (but M changes by prompt count)
        # We'll handle pos_embed in PromptedViT wrapper (since prompts are external).
        self.pos_drop = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim * mlp_ratio),
            dropout=dropout,
            batch_first=True,
            activation="gelu",
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)

        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward_tokens(self, x_tokens: torch.Tensor, pos_embed: torch.Tensor) -> torch.Tensor:
        # x_tokens: [B, 1+M+N, D]
        x_tokens = x_tokens + pos_embed
        x_tokens = self.pos_drop(x_tokens)
        x_tokens = self.encoder(x_tokens)
        x_tokens = self.norm(x_tokens)
        return x_tokens

    def patch_tokens(self, x: torch.Tensor) -> torch.Tensor:
        return self.patch(x)

In [None]:
class PromptedViT(nn.Module):
    """
    Exact prompt injection:
      tokens = [CLS | PROMPTS | PATCHES]
    """
    def __init__(self, vit: SimpleViT, prompt: PromptModule, img_size: int, patch_size: int):
        super().__init__()
        self.vit = vit
        self.prompt = prompt
        self.img_size = img_size
        self.patch_size = patch_size

        # pos embed for max tokens will be built dynamically at runtime (safe).
        # We register a parameter for base length = 1 + N (no prompts) and extend if needed.
        N = vit.patch.num_patches
        self.pos_embed_base = nn.Parameter(torch.zeros(1, 1 + N, vit.embed_dim))
        nn.init.trunc_normal_(self.pos_embed_base, std=0.02)

    def _build_pos_embed(self, total_tokens: int) -> torch.Tensor:
        base = self.pos_embed_base                      # [1, 1+N, D]
        base_len = base.shape[1]
        D = base.shape[2]
        device = base.device
        dtype = base.dtype

        if total_tokens == base_len:
            return base

        # total_tokens = 1 + M + N, base_len = 1 + N  => M = total_tokens - base_len
        M = total_tokens - base_len
        assert M > 0

        # Create or resize prompt positional embeddings on SAME device/dtype
        if not hasattr(self, "pos_embed_prompt") or self.pos_embed_prompt.shape[1] != M:
            self.pos_embed_prompt = nn.Parameter(
                torch.zeros(1, M, D, device=device, dtype=dtype)
            )
            nn.init.trunc_normal_(self.pos_embed_prompt, std=0.02)
        else:
            # If it exists but on CPU, move it to GPU (or vice versa)
            if self.pos_embed_prompt.device != device or self.pos_embed_prompt.dtype != dtype:
                self.pos_embed_prompt = nn.Parameter(self.pos_embed_prompt.data.to(device=device, dtype=dtype))

        cls_pos = base[:, :1, :]         # [1,1,D] on device
        patch_pos = base[:, 1:, :]       # [1,N,D] on device

        pos = torch.cat([cls_pos, self.pos_embed_prompt, patch_pos], dim=1)  # [1,1+M+N,D]
        return pos


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B = x.shape[0]
        patches = self.vit.patch_tokens(x)              # [B,N,D]
        cls = self.vit.cls_token.expand(B, -1, -1)      # [B,1,D]
        prompts = self.prompt(B)                        # [B,M,D]

        tokens = torch.cat([cls, prompts, patches], dim=1)  # [B,1+M+N,D]
        pos = self._build_pos_embed(tokens.shape[1]).to(tokens.device)
        out = self.vit.forward_tokens(tokens, pos)          # [B,1+M+N,D]
        feat = out[:, 0, :]                                 # CLS feature [B,D]
        return feat


In [None]:
# ---------------------------
# Prototype memory + NCM logits
# ---------------------------
class PrototypeMemory(nn.Module):
    def __init__(self, feat_dim: int):
        super().__init__()
        self.feat_dim = feat_dim
        self.register_buffer("prototypes", torch.empty(0, feat_dim))
        self.register_buffer("counts", torch.empty(0))

    @property
    def num_classes(self) -> int:
        return int(self.prototypes.shape[0])

    def ensure_size(self, C: int, device: torch.device):
        if self.num_classes >= C:
            return
        old = self.num_classes
        pad_proto = torch.zeros(C - old, self.feat_dim, device=device)
        pad_cnt = torch.zeros(C - old, device=device)
        if old == 0:
            self.prototypes = pad_proto
            self.counts = pad_cnt
        else:
            self.prototypes = torch.cat([self.prototypes.to(device), pad_proto], dim=0)
            self.counts = torch.cat([self.counts.to(device), pad_cnt], dim=0)

    @torch.no_grad()
    def update_from_batch(self, feats: torch.Tensor, labels: torch.Tensor):
        device = feats.device
        max_label = int(labels.max().item())
        self.ensure_size(max_label + 1, device)

        for c in labels.unique():
            c = int(c.item())
            idx = labels == c
            f_c = feats[idx]
            k = f_c.shape[0]
            if k == 0:
                continue
            mu_old = self.prototypes[c]
            n_old = self.counts[c].clamp(min=0.0)
            mu_new = (n_old * mu_old + f_c.sum(dim=0)) / (n_old + float(k))
            self.prototypes[c] = mu_new
            self.counts[c] = n_old + float(k)

    def logits_ncm(self, feats: torch.Tensor, tau: float = 1.0, normalize: bool = True) -> torch.Tensor:
        if self.num_classes == 0:
            raise RuntimeError("Empty prototype memory.")
        proto = self.prototypes.to(feats.device)
        h = feats
        if normalize:
            proto = l2_normalize(proto, dim=-1)
            h = l2_normalize(h, dim=-1)
        d2 = torch.cdist(h, proto, p=2) ** 2
        return -d2 / tau

In [None]:
# ---------------------------
# Losses
# ---------------------------
def proto_alignment_loss(feats: torch.Tensor, labels: torch.Tensor, proto_mem: PrototypeMemory) -> torch.Tensor:
    proto = proto_mem.prototypes.to(feats.device)
    mu = proto[labels]
    feats_n = l2_normalize(feats, dim=-1)
    mu_n = l2_normalize(mu, dim=-1)
    return ((feats_n - mu_n) ** 2).sum(dim=-1).mean()


def kd_loss_logits(logits_teacher: torch.Tensor, logits_student: torch.Tensor, T: float = 2.0) -> torch.Tensor:
    p_t = F.softmax(logits_teacher / T, dim=-1)
    log_p_s = F.log_softmax(logits_student / T, dim=-1)
    return F.kl_div(log_p_s, p_t, reduction="batchmean") * (T * T)


def domain_shift_loss(feats_old: torch.Tensor, feats_new: torch.Tensor) -> torch.Tensor:
    mu_old = l2_normalize(feats_old.mean(dim=0), dim=-1)
    mu_new = l2_normalize(feats_new.mean(dim=0), dim=-1)
    return ((mu_old - mu_new) ** 2).sum()

In [None]:
# ---------------------------
# System wrapper
# ---------------------------
@dataclass
class FSCILConfig:
    # model
    img_size: int = 32
    patch_size: int = 4
    feat_dim: int = 256
    depth: int = 6
    heads: int = 8
    num_prompts: int = 10

    # optimization
    base_epochs: int = 5
    inc_epochs: int = 5
    lr_base: float = 3e-4
    lr_inc: float = 3e-4
    weight_decay: float = 0.05

    # FSCIL split
    base_classes: int = 60
    ways_per_session: int = 5
    shots_per_class: int = 5
    replay_max_per_class: int = 20
    batch_size: int = 64

    # losses
    tau: float = 1.0
    kd_T: float = 2.0
    ema_alpha: float = 0.999
    lam_proto: float = 1.0
    lam_kd: float = 1.0
    lam_domain: float = 0.5

    seed: int = 0

In [None]:
class FSCILSystem(nn.Module):
    def __init__(self, cfg: FSCILConfig, device: torch.device):
        super().__init__()
        self.cfg = cfg
        self.device = device

        # base prompt (persistent)
        self.base_prompt = PromptModule(cfg.num_prompts, cfg.feat_dim).to(device)

        # template ViT
        vit_template = SimpleViT(
            img_size=cfg.img_size,
            patch_size=cfg.patch_size,
            embed_dim=cfg.feat_dim,
            depth=cfg.depth,
            num_heads=cfg.heads,
        ).to(device)

        # stable = EMA teacher (persistent)
        self.stable = PromptedViT(copy.deepcopy(vit_template), self.base_prompt, cfg.img_size, cfg.patch_size).to(device)
        set_requires_grad(self.stable, False)

        # plastic/session prompt are created per session (temporary)
        self.plastic: Optional[PromptedViT] = None
        self.session_prompt: Optional[PromptModule] = None

        # prototype memory (persistent)
        self.proto_mem = PrototypeMemory(cfg.feat_dim).to(device)

        # keep vit template for plastic recreation
        self._vit_template = vit_template

    def new_plastic_for_session(self):
        cfg = self.cfg
        # plastic vit initialized from stable vit weights
        vit = copy.deepcopy(self._vit_template).to(self.device)
        vit.load_state_dict(self.stable.vit.state_dict(), strict=True)

        # session prompt initialized from base prompt + small noise
        P_base = self.base_prompt.prompts.detach()
        sp = PromptModule(cfg.num_prompts, cfg.feat_dim).to(self.device)
        with torch.no_grad():
            sp.prompts.copy_(P_base + 0.01 * torch.randn_like(P_base))

        self.session_prompt = sp
        self.plastic = PromptedViT(vit, self.session_prompt, cfg.img_size, cfg.patch_size).to(self.device)
        set_requires_grad(self.plastic, True)

    @torch.no_grad()
    def build_prototypes(self, loader: DataLoader):
        self.stable.eval()
        for x, y in loader:
            x, y = x.to(self.device), y.to(self.device)
            h = self.stable(x)
            self.proto_mem.update_from_batch(h, y)

    @torch.no_grad()
    def predict(self, x: torch.Tensor) -> torch.Tensor:
        self.stable.eval()
        h = self.stable(x.to(self.device))
        logits = self.proto_mem.logits_ncm(h, tau=self.cfg.tau, normalize=True)
        return logits.argmax(dim=-1)

In [None]:
# ---------------------------
# Training / Evaluation
# ---------------------------
def accuracy(model: FSCILSystem, loader: DataLoader) -> float:
    correct = 0
    total = 0
    for x, y in loader:
        x, y = x.to(model.device), y.to(model.device)
        pred = model.predict(x)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

In [None]:
# ---------------------------
# Training Base (with tqdm)
# ---------------------------
def train_base(model: FSCILSystem, base_loader: DataLoader):
    cfg = model.cfg
    device = model.device

    # train a temporary plastic that uses BASE prompt (same as stable prompt)
    plastic = PromptedViT(
        copy.deepcopy(model._vit_template),
        model.base_prompt,
        cfg.img_size,
        cfg.patch_size
    ).to(device)
    set_requires_grad(plastic, True)

    # stable is EMA-updated only
    model.stable.eval()

    opt = torch.optim.AdamW(
        plastic.parameters(),
        lr=cfg.lr_base,
        weight_decay=cfg.weight_decay
    )

    plastic.train()
    for ep in range(cfg.base_epochs):

        epoch_bar = tqdm(
            base_loader,
            desc=f"[Base] Epoch {ep+1}/{cfg.base_epochs}",
            leave=True
        )

        for x, y in epoch_bar:
            x, y = x.to(device), y.to(device)
            h = plastic(x)  # [B,D]

            # temporary linear head for base training only
            if not hasattr(train_base, "head") or train_base.head.out_features <= int(y.max().item()):
                C = int(y.max().item()) + 1
                train_base.head = nn.Linear(cfg.feat_dim, C).to(device)
                train_base.head_opt = torch.optim.AdamW(
                    train_base.head.parameters(),
                    lr=cfg.lr_base,
                    weight_decay=cfg.weight_decay
                )

            logits = train_base.head(l2_normalize(h))
            loss = F.cross_entropy(logits, y)

            opt.zero_grad(set_to_none=True)
            train_base.head_opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            train_base.head_opt.step()

            # EMA: stable <- plastic (vit weights only)
            ema_update(model.stable.vit, plastic.vit, alpha=cfg.ema_alpha)

            # update tqdm bar
            epoch_bar.set_postfix(loss=f"{loss.item():.4f}")

        epoch_bar.close()

    # Build prototypes using stable (and base prompt)
    model.proto_mem.prototypes = model.proto_mem.prototypes[:0]
    model.proto_mem.counts = model.proto_mem.counts[:0]
    model.build_prototypes(base_loader)

    print(f"[Base] prototypes built: {model.proto_mem.num_classes}")


In [None]:
# ---------------------------
# Training Incremental
# ---------------------------

def train_incremental(
    model: FSCILSystem,
    new_loader: DataLoader,
    replay_loader: Optional[DataLoader],
):
    cfg = model.cfg
    device = model.device

    model.new_plastic_for_session()
    assert model.plastic is not None
    plastic = model.plastic

    model.stable.eval()
    plastic.train()

    opt = torch.optim.AdamW(plastic.parameters(), lr=cfg.lr_inc, weight_decay=cfg.weight_decay)

    replay_iter = iter(replay_loader) if replay_loader is not None else None

    for ep in range(cfg.inc_epochs):
        for x_new, y_new in new_loader:
            x_new, y_new = x_new.to(device), y_new.to(device)

            # ---- new-class forward (plastic) ----
            h_new = plastic(x_new)
            logits_new = model.proto_mem.logits_ncm(h_new, tau=cfg.tau, normalize=True)
            loss_cls = F.cross_entropy(logits_new, y_new)
            loss_proto = proto_alignment_loss(h_new, y_new, model.proto_mem)

            # ---- old-class forward (stable teacher + plastic student) ----
            loss_kd = torch.tensor(0.0, device=device)
            loss_dom = torch.tensor(0.0, device=device)

            if replay_iter is not None:
                try:
                    x_old, y_old = next(replay_iter)
                except StopIteration:
                    replay_iter = iter(replay_loader)
                    x_old, y_old = next(replay_iter)

                x_old, y_old = x_old.to(device), y_old.to(device)

                with torch.no_grad():
                    h_old_s = model.stable(x_old)
                    logits_old_s = model.proto_mem.logits_ncm(h_old_s, tau=cfg.tau, normalize=True)

                h_old_p = plastic(x_old)
                logits_old_p = model.proto_mem.logits_ncm(h_old_p, tau=cfg.tau, normalize=True)

                loss_kd = kd_loss_logits(logits_old_s, logits_old_p, T=cfg.kd_T)
                loss_dom = domain_shift_loss(h_old_s, h_new)

            loss = loss_cls + cfg.lam_proto * loss_proto + cfg.lam_kd * loss_kd + cfg.lam_domain * loss_dom

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            # EMA merge plastic->stable (vit weights only)
            ema_update(model.stable.vit, plastic.vit, alpha=cfg.ema_alpha)

            # Update prototypes for new classes
            with torch.no_grad():
                model.proto_mem.update_from_batch(h_new.detach(), y_new.detach())

        print(f"[Inc] epoch {ep+1}/{cfg.inc_epochs} | loss={loss.item():.4f}")

    # discard plastic/session prompt
    model.plastic = None
    model.session_prompt = None
    print(f"[Inc] session done. prototypes now: {model.proto_mem.num_classes}")



In [None]:
# ---------------------------
# Main: CIFAR-100 FSCIL
# ---------------------------
def main_cifar100(cfg: FSCILConfig):
    seed_all(cfg.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # transforms (match cfg.img_size)
    tf_train = T.Compose([
        T.Resize((cfg.img_size, cfg.img_size)),
        T.RandomCrop(cfg.img_size, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])
    tf_test = T.Compose([
        T.Resize((cfg.img_size, cfg.img_size)),
        T.ToTensor(),
        T.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    trainset = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=tf_train)
    testset  = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tf_test)

    train_targets = list(trainset.targets)
    test_targets = list(testset.targets)

    base_cls, inc_sessions = make_fscil_splits(
        targets=train_targets,
        base_classes=cfg.base_classes,
        ways_per_session=cfg.ways_per_session,
        shots_per_class=cfg.shots_per_class,
        seed=cfg.seed
    )

    # Build loaders
    base_idx = indices_for_classes(train_targets, base_cls)
    base_loader = DataLoader(Subset(trainset, base_idx), batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)

    # Test loader: we evaluate on full test set (all classes)
    test_loader = DataLoader(testset, batch_size=cfg.batch_size, shuffle=False, num_workers=2, pin_memory=True)

    # Replay buffer
    replay = ReplayBuffer(max_per_class=cfg.replay_max_per_class, seed=cfg.seed)

    # System
    model = FSCILSystem(cfg, device=device).to(device)

    print("\n=== BASE TRAINING ===")
    train_base(model, base_loader)

    # Add some base exemplars to replay
    # (sample few indices from base set for replay memory)
    replay.add_indices(train_targets, random.sample(base_idx, min(len(base_idx), cfg.replay_max_per_class * len(base_cls))))

    acc0 = accuracy(model, test_loader)
    print(f"[Eval] After base: Acc(all test classes) = {acc0*100:.2f}%")

    # Incremental sessions
    for s, sess_classes in enumerate(inc_sessions, start=1):
        print(f"\n=== INCREMENTAL SESSION {s}/{len(inc_sessions)} | classes={sess_classes} ===")

        # new few-shot indices (shots per class)
        new_idx = fewshot_indices_for_classes(train_targets, sess_classes, cfg.shots_per_class, seed=cfg.seed + s)
        new_loader = DataLoader(Subset(trainset, new_idx), batch_size=min(cfg.batch_size, len(new_idx)), shuffle=True, num_workers=2, pin_memory=True)

        # replay loader from buffer (old classes)
        replay_idx = replay.all_indices()
        replay_loader = None
        if len(replay_idx) > 0:
            replay_loader = DataLoader(Subset(trainset, replay_idx), batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)

        # Train session
        train_incremental(model, new_loader, replay_loader)

        # Update replay with new indices
        replay.add_indices(train_targets, new_idx)

        acc = accuracy(model, test_loader)
        print(f"[Eval] After session {s}: Acc(all test classes) = {acc*100:.2f}%")

    print("\nDone.")

In [None]:
if __name__ == "__main__":
    cfg = FSCILConfig(
        # CIFAR-100 recommended fast settings
        img_size=32,
        patch_size=4,     # 32/4=8 -> 64 patches
        feat_dim=256,
        depth=6,
        heads=8,
        num_prompts=10,

        base_epochs=20,    # increase for better
        inc_epochs=20,     # increase for better
        lr_base=3e-4,
        lr_inc=3e-4,

        base_classes=60,
        ways_per_session=5,
        shots_per_class=5,
        replay_max_per_class=20,
        batch_size=128,

        tau=1.0,
        kd_T=2.0,
        ema_alpha=0.999,
        lam_proto=1.0,
        lam_kd=1.0,
        lam_domain=0.5,

        seed=0
    )
    main_cifar100(cfg)
