# **Generalized Few-Shot Learning (GFSL)**

This demo implements **Dynamic Few-Shot Learning without Forgetting** (DFSLwF) on the **CIFAR-100** dataset under the *generalized setting*, where both base and novel classes are present at test time. The notebook reproduces the two-stage training pipeline (supervised base training followed by episodic meta-training with a generator) and allows readers to evaluate base, novel, and harmonic accuracies in the **GFSL scenario**.

---

## **1. Background**

### **1.1 Generalized Few-Shot Learning (GFSL)**

In GFSL the evaluation setup changes: during testing, queries can belong to either  
- **base classes** (those already seen during training), or  
- **novel classes** (new ones introduced only at test time).  

The learner must recognize queries across this **joint label space**, handling both familiar and unseen categories simultaneously. This setting is more realistic than classical FSL, where queries always belong to novel classes only.

> **Example (GFSL 3-way 2-shot):** suppose the model was trained on 64 base classes. At test time, an episode may include queries from both these base classes and 3 novel ones, each with only 2 support examples. The learner must correctly classify across all base + novel classes.

A challenge arises: the model tends to be **biased toward base classes**, since they are represented by many more examples during training.  
Approaches like **Dynamic Few-Shot Learning without Forgetting (DFSLwF)** mitigate this issue by combining a base classifier (trained on abundant data) with a novel classifier (trained from few supports), and balancing their predictions at inference time.

*(Evaluation tip: in GFSL, report separate accuracy on base and novel classes, plus the **harmonic mean (H-mean)** to capture overall balance.)*

> **Practical scenario:** GFSL is valuable in real-world systems (e.g., an image recognition app) where the model must keep performing well on frequent categories like “dogs” or “cars” while also learning to recognize new, rare categories from just a handful of examples.


### **1.2 Dynamic Few-Shot Learning without Forgetting**

Let $F_\theta(\cdot)$ be a ConvNet feature extractor.  
A classifier is built using a set of **weight vectors** $W = \{ w_k \}$, one per class, and computes scores via a **cosine similarity**:

$$
s_k(x) \;=\; \tau \cdot \cos\!\big( F_\theta(x), w_k \big),
$$

where both features and weights are $\ell_2$-normalized, and $\tau$ is a learnable scale.  
This design ensures that base and novel categories are treated in a **unified space**.

To incorporate novel classes at test time, DFSLwF introduces a **few-shot weight generator** $G(\cdot)$:
- Input: a small support set of feature vectors $Z' = \{z'_i\}$ from a novel class, plus the set of base weights $W_\text{base}$,  
- Output: a novel classification weight $w'$ for that class.

The generator combines two mechanisms:
1. **Feature averaging:**  
   $w'_\text{avg} = \tfrac{1}{N'} \sum_i z'_i$, scaled by learnable parameters.  
2. **Attention over base weights:**  
   Composes $w'_\text{att}$ as a similarity-weighted sum of $W_\text{base}$, exploiting prior knowledge about the visual world.  

The final novel weight is a linear combination:  
$$
w' = \phi_\text{avg} \odot w'_\text{avg} \;+\; \phi_\text{att} \odot w'_\text{att}.
$$

At inference, the classifier uses  
$$
W^* = W_\text{base} \cup W_\text{novel},
$$  
thus predicting across **both base and novel classes** without retraining.

**Training procedure (two stages):**
1. **Stage 1:** Train feature extractor + base classifier on abundant base data.  
2. **Stage 2:** Train the weight generator using “fake” novel tasks sampled from base categories (episodic style).  

**Notes (implementation):**
- Removing the final ReLU in $F_\theta$ helps cosine similarity classification.  
- The $\ell_2$-normalization enforces compact clusters, improving generalization to novel classes.  
- Evaluation reports accuracy on **Base**, **Novel**, and the **harmonic mean (H-mean)**.  

> **Why it matters:** DFSLwF directly addresses the **GFSL setting**: real systems must keep high performance on frequent base categories (e.g., “dogs”, “cars”) while adapting dynamically to rare, unseen ones (e.g., “drone types”) from few examples, without retraining or forgetting.


### **1.3 CIFAR-100 dataset**
We also use [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html), a widely adopted benchmark for image classification.  
It consists of **100 classes** (e.g., animals, vehicles, household objects), each containing **600 color images** of size $32 \times 32$.  
For every class, there are **500 training images** and **100 test images**.

CIFAR-100 is more challenging than Omniglot or MNIST-like datasets, since it involves **natural RGB images** with high intra-class variability and smaller image resolution.

It's included in the `torchvision` package, making it straightforward to download and preprocess for few-shot or generalized few-shot experiments.


---

## **2. Practice**

In [None]:
from __future__ import annotations
import math
import random
import types
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import copy
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torchvision.models import resnet18, ResNet18_Weights
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Seed:

In [None]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

train_rng = np.random.default_rng(SEED + 1)
test_rng = np.random.default_rng(SEED + 2)

Settings:

In [None]:
# Few-shot episode configuration
N_WAY = 5           # Number of classes per task
K_SHOT = 5          # Support examples per class
Q_NOVEL = 15        # Query examples per novel class
Q_BASE_TOTAL = 75   # Query examples of base classes

# CIFAR-100 split sizes
N_BASE = 64         # Number of classes retained as base
N_VALNOVEL = 16     # Number of classes retained as novel (for validation)
N_TESTNOVEL = 20    # Number of classes retained as novel (for test)

# Network params
TAU_INIT = 10.0             # Temperature init
BACKBONE = "cifar_resnet"   # Feature extractor network
assert BACKBONE in ["resnet18", "conv", "cifar_resnet"], f"Unknown backbone '{BACKBONE}'. Use 'conv', 'resnet18' or 'cifar_resnet'."

# Stage 1
STAGE1_EPOCHS = 120
STAGE1_LR = 3e-3
STAGE1_BS = 512
STAGE1_WEIGHT_DECAY = 5e-4
S1_VAL_FRAC   = 0.10   # Fraction of classes retained for validation
S1_VAL_EVERY  = 2      # Validate every N epochs
S1_PATIENCE   = 5      # Early-stopping after N validations without improvements
S1_LOG_EVERY  = 50     # Log loss every N batches

# Stage 2
STAGE2_TASKS = 20_000         # Number of training episodes
STAGE2_LR = 5e-4
# STAGE2_GRAD_CLIP = 1.0
S2_VAL_TASKS      = 1_000     # Number of validation episodes
STAGE2_VAL_EVERY = 500        # Validate every N episodes
S2_PATIENCE       = 5         # Early-stopping after N validations without improvements
S2_SELECT_METRIC  = "hmean"   # validation metric
assert S2_SELECT_METRIC in ["hmean", "base", "novel"], f"Unknown metric '{S2_SELECT_METRIC}'. Use 'hmean', 'base' or 'novel'."

# Test
TEST_TASKS = 1_000    # Number of test episodes

Utility functions:

In [None]:
def split_cifar100_classes(seed: int, n_base=64, n_val=16, n_test=20):
    """Split CIFAR-100 class IDs into base/val-novel/test-novel sets.
    """
    rng = np.random.default_rng(seed)
    classes = np.arange(100); rng.shuffle(classes)
    return classes[:n_base].tolist(), classes[n_base:n_base+n_val].tolist(), classes[n_base+n_val:n_base+n_val+n_test].tolist()


def subset_by_classes(ds, keep):
    """Return a Subset containing only samples whose label is in `keep`.

    Uses vectorized filtering over `ds.targets` to select the indices that
    belong to the provided set of class IDs.
    """
    t = np.array(ds.targets)
    idx = np.nonzero(np.isin(t, keep))[0]
    return Subset(ds, idx)


def class_to_local_indices(subset):
    """Build a mapping class_id -> list of *local* indices within `subset`.

    Iterates over the subset indices and groups them by their original class
    ID (read from `subset.dataset.targets`). Useful for fast episodic sampling
    (e.g., drawing K support and Q query images per class).

    Example:
    ```
    {
        0: [0, 5, 9, ...],
        1: [1, 7, ...],
        ...
    }
    ```
    """
    t = np.array(subset.dataset.targets)
    out = {}
    for j, i in enumerate(subset.indices):
        y = int(t[i])
        (out.setdefault(y, [])).append(j)
    return out


def stratified_split_subset(subset: Subset, val_frac: float, seed: int):
    """
    Split a Subset in (train_part, val_part),
    maintaining proportions between each original CIFAR class.
    """
    rng = np.random.default_rng(seed)
    ds_targets = np.array(subset.dataset.targets)
    # map: class_id -> local indexes list in Subset
    cls2locals = {}
    for j, i in enumerate(subset.indices):
        y = int(ds_targets[i])
        (cls2locals.setdefault(y, [])).append(j)

    train_locals, val_locals = [], []
    for y, locals_ in cls2locals.items():
        locals_ = np.array(locals_, dtype=int)
        n_val = max(1, int(round(len(locals_) * val_frac)))
        rng.shuffle(locals_)
        val_locals.append(locals_[:n_val])
        train_locals.append(locals_[n_val:])

    train_locals = np.concatenate(train_locals).tolist()
    val_locals   = np.concatenate(val_locals).tolist()

    train_indices = [subset.indices[i] for i in train_locals]
    val_indices   = [subset.indices[i] for i in val_locals]

    return Subset(subset.dataset, train_indices), Subset(subset.dataset, val_indices)

In [None]:
def l2_normalize(x: torch.Tensor, dim: int = 1, eps: float = 1e-6) -> torch.Tensor:
  """L2-normalize a tensor along a given dimension.
  """
  return x / (x.norm(p=2, dim=dim, keepdim=True).clamp_min(eps))

In [None]:
def set_bn_eval(m: nn.Module):
  """Put BatchNorm2d layers in eval mode and freeze their parameters,
  so it uses stored running statistics and stops updating them,
  and it disables gradient updates for its affine parameters (gamma/beta).
  """
  if isinstance(m, nn.BatchNorm2d):
      m.eval()
      for p in m.parameters():
          p.requires_grad = False

In [None]:
def sliding_avg(xs: List[float], k: int = 20) -> float:
  if not xs:
      return 0.0
  if len(xs) < k:
      return float(sum(xs) / len(xs))
  return float(sum(xs[-k:]) / k)

In [None]:
def mean_ci95(xs: np.ndarray) -> Tuple[float, float]:
    xs = np.asarray(xs, dtype=float)
    n  = xs.size
    if n < 2:
        return float(xs.mean()), 0.0
    std = xs.std(ddof=1)
    stderr = std / np.sqrt(n)
    z = 1.96
    return float(xs.mean()), float(z * stderr)


def gfsl_stats(
    acc_per_ep_base: List[float],
    acc_per_ep_novel: List[float],
) -> Dict[str, Dict[str, float]]:

    if len(acc_per_ep_base) != len(acc_per_ep_novel):
        raise ValueError("base and novel must be of the same length")
    T = len(acc_per_ep_base)

    base = np.asarray(acc_per_ep_base, dtype=float)
    novel = np.asarray(acc_per_ep_novel, dtype=float)

    denom = base + novel
    h_per_ep = np.where(denom > 0, 2.0 * base * novel / denom, 0.0)

    base_mean, base_ci  = mean_ci95(base)
    novel_mean, novel_ci = mean_ci95(novel)
    h_mean, h_ci = mean_ci95(h_per_ep)

    return {
        "base":  {"mean": base_mean,  "conf": base_ci},
        "novel": {"mean": novel_mean, "conf": novel_ci},
        "hmean": {"mean": h_mean,     "conf": h_ci},
    }


def print_stats(T: int, stats: Dict[str, Dict[str, float]], model: str = ""):
  print(f"[test] - {model} (95% CI on {T} tasks)")
  print(f" - [Base]   acc={100*stats['base']['mean']:.2f}% ± {100*stats['base']['conf']:.2f}%")
  print(f" - [Novel]  acc={100*stats['novel']['mean']:.2f}% ± {100*stats['novel']['conf']:.2f}%")
  print(f" - [H-mean] acc={100*stats['hmean']['mean']:.2f}% ± {100*stats['hmean']['conf']:.2f}%")

### **2.1 Environment**

#### **2.1.1 CIFAR100 dataset**

We define **data transformations** separately for training and evaluation.  
On **train**, we apply light stochastic augmentation suited to 32x32 images (random crop with padding and horizontal flip) to improve invariances without distorting small objects.  
On **eval**, we keep a deterministic pipeline to ensure consistent measurement.

> Augmentations reduce overfitting of the feature extractor in Stage-1. A stable eval pipeline is instead crucial when computing Base / Novel / H-mean in GFSL

In [None]:
if BACKBONE == "conv" or BACKBONE == "cifar_resnet":
    IMAGE_SIZE = 32

    train_tf = transforms.Compose([
        transforms.RandomCrop(IMAGE_SIZE, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    eval_tf = transforms.Compose([transforms.ToTensor(),])

else:
    IMNET_MEAN = [0.485, 0.456, 0.406]
    IMNET_STD  = [0.229, 0.224, 0.225]
    IM_RESIZE = 128
    IM_CROP   = 112

    train_tf = transforms.Compose([
        transforms.Resize(IM_RESIZE),
        transforms.RandomCrop(IM_CROP),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(IMNET_MEAN, IMNET_STD),
    ])

    eval_tf = transforms.Compose([
        transforms.Resize(IM_RESIZE),
        transforms.CenterCrop(IM_CROP),
        transforms.ToTensor(),
        transforms.Normalize(IMNET_MEAN, IMNET_STD),
    ])

We instantiate the **GFSL** splits on CIFAR-100 as follows:
- `ds_train` / `ds_test` load the standard CIFAR-100 **training** and **test** partitions with their respective transforms (`train_tf` randomized, `eval_tf` deterministic).
- `split_cifar100_classes(SEED)` yields a **disjoint class split** into:
  - **Base** classes (used for Stage-1 supervised training and as the base pool in GFSL episodes),
  - **Val-Novel** classes (for model selection with GFSL episodes),
  - **Test-Novel** classes (for final GFSL evaluation).

All data live under `./data` and are fetched on demand with `download=True`.

> *Note.* In GFSL, splits must be done **by class**, not by image, to avoid leakage. Label spaces are remapped locally for base vs. novel episodes so that evaluation is over the **joint label space** (Base ∪ Novel) while keeping indices compact inside each block.


In [None]:
ds_train = CIFAR100(root="./data", train=True,  transform=train_tf, download=True)
ds_test  = CIFAR100(root="./data", train=False, transform=eval_tf,  download=True)

base, valn, testn = split_cifar100_classes(SEED)

100%|██████████| 169M/169M [00:03<00:00, 42.7MB/s]


We now **separate** the base training set into a train and validation split.  
This allows us to monitor Stage-1 supervised training on base classes without leaking novel information.  
We also prepare a distinct set of **val-novel classes** (from the train partition) to be used later for episodic validation in GFSL.

In [None]:
# Stage 1: Train + validation (base) / Stage 2: Train (base)
train_base_full = subset_by_classes(ds_train, base)
train_base_tr, train_base_val = stratified_split_subset(train_base_full, S1_VAL_FRAC, SEED)

# Stage 2: Validation (Novel) - from train
train_valnovel = subset_by_classes(ds_train, valn)

# Test (base + novel)
test_base  = subset_by_classes(ds_test,  base)
test_novel = subset_by_classes(ds_test,  testn)

Then we create the **label maps** for each split:  
- local indices for base-train and base-val (Stage-1),  
- mappings for val-novel classes (Stage-2 validation),  
- and indices for base and novel classes in the final test set.  

These compact class-to-indices maps are required for episodic samplers and joint logits during GFSL evaluation.

In [None]:
# Stage-1 (supervised)
cti_train_base_tr  = class_to_local_indices(train_base_tr)
cti_train_base_val = class_to_local_indices(train_base_val)

# Stage-2 Validation (episodic GFSL) - from train
cti_val_base   = cti_train_base_val
cti_val_novel  = class_to_local_indices(train_valnovel)

# For test (episodic GFSL)
cti_test_base  = class_to_local_indices(test_base)
cti_test_novel = class_to_local_indices(test_novel)

**2.1.1.1 DataLoader: Stage 1**

Stage-1 trains the backbone + base classifier with sufficient base data, so we construct a classic **supervised loader** over **base classes only**:

In [None]:
class Stage1TrainDS(torch.utils.data.Dataset):
    def __init__(self, subset: Subset, orig_targets: List[int], order_map: Dict[int, int]):
        self.subset = subset
        self.local_labels = [order_map[int(orig_targets[i])] for i in subset.indices]

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        x, _ = self.subset[idx]
        return x, self.local_labels[idx]


base_order = sorted(base)
order_map  = {cid: i for i, cid in enumerate(base_order)}

stage1_train_ds = Stage1TrainDS(train_base_tr,  ds_train.targets, order_map)  # train
stage1_val_ds   = Stage1TrainDS(train_base_val, ds_train.targets, order_map)  # val (disjoined)

We instantiate them:

In [None]:
train_loader_s1 = DataLoader(
    stage1_train_ds, batch_size=STAGE1_BS, shuffle=True,
    num_workers=2, pin_memory=True
)

val_loader_s1   = DataLoader(
    stage1_val_ds, batch_size=STAGE1_BS*2, shuffle=False,
    num_workers=2, pin_memory=True
)

**2.1.1.2 DataLoader: Stage 2 Train**

For the stage-2 training (but also for the validation and test phases), we need to evaluate the on **few-shot tasks**.
A standard PyTorch `DataLoader` builds mini-batches of images without
considering whether they belong to a support or query set.

However, in a GFSL setting, the dataloader works differently from the classical FSL one, since we need episodes that combine **few-shot novel supports** with **base queries**.

To achieve this, we use a **custom task sampler** and a **custom collate function**.

**Collate function**

It splits the batch into:
  1. `support_novel` images (size `n_way * k_shot`)  
  2. `query_images` = `[novel queries | base queries]`  
  3. `true_novel_ids` (original CIFAR-100 ids for the sampled novel classes)  
  4. `base_query_labels_cifar` (original CIFAR-100 ids for base queries)

In [None]:
def stage2_collate(batch, n_way: int, k_shot: int, q_novel: int, q_base_total: int):
    imgs, labs = list(zip(*batch))
    images = torch.stack(imgs)
    labels = torch.tensor([int(y) for y in labs])

    per_novel = k_shot + q_novel
    total_novel = n_way * per_novel

    novel_block = images[:total_novel].view(n_way, per_novel, *images.shape[1:])
    novel_labels_block = labels[:total_novel].view(n_way, per_novel)

    support_novel = novel_block[:, :k_shot].reshape(-1, *images.shape[1:])
    query_novel = novel_block[:, k_shot:].reshape(-1, *images.shape[1:])
    query_base = images[total_novel:]

    query_images = torch.cat([query_novel, query_base], dim=0)

    true_novel_ids = [int(novel_labels_block[i, 0].item()) for i in range(n_way)]
    base_query_labels_cifar = labels[total_novel:]  # original CIFAR ids
    return support_novel, query_images, true_novel_ids, base_query_labels_cifar


stage2_collate_fn = partial(
    stage2_collate,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL,
)

**Batch sampler**

Samples `n_way` pseudo-novel classes from the base pool, drawing `k_shot + q_novel` examples per class. It then adds `q_base_total` queries from the remaining base classes.


In [None]:
class GFSLTrainEpisodicBatchSampler:
    """
    Stage-2: N_WAY pseudo-novel (from BASE train) with K+Qn each + Qb base queries from BASE (any class).
    Returns indices over the Subset(train_base).
    """
    def __init__(self, class_to_indices: Dict[int, List[int]], n_tasks: int, rng: np.random.Generator,
                 n_way: int = N_WAY, k_shot: int = K_SHOT, q_novel: int = Q_NOVEL, q_base_total: int = Q_BASE_TOTAL):
        self.cti = class_to_indices
        self.all_classes = list(class_to_indices.keys())
        self.n_tasks = n_tasks
        self.rng = rng
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_novel = q_novel
        self.q_base_total = q_base_total

    def __len__(self): return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):

            novel = self.rng.choice(self.all_classes, size=self.n_way, replace=False)

            batch = []
            for c in novel:
                pool = self.cti[c]
                need = self.k_shot + self.q_novel
                if len(pool) < need:
                    raise ValueError(f"Class {c} has {len(pool)} < {need}")
                idx = self.rng.choice(pool, size=need, replace=False)
                batch.append(idx)

            novel_classes = set(novel.tolist())
            base_pool_classes = [c for c in self.all_classes if c not in novel_classes]

            used = set(np.concatenate(batch).tolist())

            base_q = []
            while len(base_q) < self.q_base_total:
                c = int(self.rng.choice(base_pool_classes))
                cand = int(self.rng.choice(self.cti[c]))
                if cand not in used:
                    base_q.append(cand)
                    used.add(cand)

            batch.append(np.array(base_q, dtype=int))
            yield np.concatenate(batch)


stage2_train_batch_sampler = GFSLTrainEpisodicBatchSampler(
    cti_train_base_tr, n_tasks=STAGE2_TASKS, rng=train_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL
)

We can now create our Dataset...

In [None]:
class Stage2TrainDS(torch.utils.data.Dataset):
    def __init__(self, subset: Subset):
        self.subset = subset
    def __len__(self):
      return len(self.subset)
    def __getitem__(self, idx):
      return self.subset[idx]


stage2_train_ds = Stage2TrainDS(train_base_tr)

... and build the DataLoader:

In [None]:
train_loader_s2 = DataLoader(
    stage2_train_ds, batch_sampler=stage2_train_batch_sampler,
    collate_fn=stage2_collate_fn, num_workers=2, pin_memory=True
)

**2.1.1.3 DataLoader: Stage 2 Validation and Test**

For validation and test we need a **different sampler + collate** than in training, because episodes must follow the **GFSL evaluation protocol**:  
- sample supports and queries only from the **novel split**,  
- add a fixed number of queries from the **base split**,  
- keep the layout `[novel block | base block]` so evaluation can compute accuracy on Base, Novel, and H-mean.  

The custom sampler/collate ensure this structure and prevent any leakage between base and novel classes.


In [None]:
def eval_collate(batch, n_way: int, k_shot: int, q_novel: int, q_base_total: int):
    """Reconstruct support/query tensors for a GFSL test episode.

    Input `batch` is a list of (image, label) from `test_concat` where the first
    N_WAY*(K_SHOT+Q_NOVEL) items belong to the novel subset and the remaining Q_BASE_TOTAL
    items belong to the base subset.
    """
    imgs, labs = list(zip(*batch))
    images = torch.stack(imgs)
    labels = torch.tensor([int(y) for y in labs])

    per_novel = k_shot + q_novel
    total_novel = n_way * per_novel

    # reshape novel block into (N, K+Qn, C, H, W) and (N, K+Qn) for labels
    novel_block = images[:total_novel].view(n_way, per_novel, *images.shape[1:])
    novel_labels_block = labels[:total_novel].view(n_way, per_novel)

    # split into support (first K) and query (last Qn)
    support_novel = novel_block[:, :k_shot].reshape(-1, *images.shape[1:])
    query_novel   = novel_block[:, k_shot:].reshape(-1, *images.shape[1:])
    query_base    = images[total_novel:]  # remaining Qb from base subset

    # concatenate all queries [novel | base]
    query_images = torch.cat([query_novel, query_base], dim=0)

    # collect true novel CIFAR IDs (one per class) and per-query GT labels
    true_novel_ids = [int(novel_labels_block[i, 0].item()) for i in range(n_way)]
    gt_novel = novel_labels_block[:, k_shot:].reshape(-1)
    gt_base  = labels[total_novel:]

    return support_novel, query_images, true_novel_ids, gt_novel, gt_base


eval_collate_fn = partial(
    eval_collate,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL,
)

In [None]:
class GFSLEvalEpisodeSampler:
    """Batch sampler for GFSL test episodes over a single ConcatDataset.

    Each yielded batch is a 1D numpy array of indices into `test_concat`
    laid out as:
        [ N_WAY*(K_SHOT+Q_NOVEL) indices from novel part | Q_BASE_TOTAL indices from base part ]
    """

    def __init__(self, base_cti: Dict[int, List[int]], novel_cti: Dict[int, List[int]], offset_base: int,
                 n_tasks: int, rng: np.random.Generator,
                 n_way: int = N_WAY, k_shot: int = K_SHOT, q_novel: int = Q_NOVEL, q_base_total: int = Q_BASE_TOTAL):
        self.base_cti = base_cti
        self.novel_cti = novel_cti
        self.offset_base = offset_base

        self.base_classes  = list(base_cti.keys())
        self.novel_classes = list(novel_cti.keys())

        self.n_tasks = n_tasks
        self.rng = rng
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_novel = q_novel
        self.q_base_total = q_base_total

    def __len__(self):
        return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):
            # ---- NOVEL BLOCK (test_novel) ----
            chosen_novel = self.rng.choice(self.novel_classes, size=self.n_way, replace=False)
            per_novel = self.k_shot + self.q_novel

            novel_chunks = []
            for c in chosen_novel:
                pool = self.novel_cti[c]  # local indices within test_novel
                if len(pool) < per_novel:
                    raise ValueError(f"Novel class {c} has {len(pool)} < {per_novel}")
                idx = self.rng.choice(pool, size=per_novel, replace=False)
                novel_chunks.append(idx)

            novel_block = np.concatenate(novel_chunks).astype(int)  # still in "novel" namespace (no offset)

            # ---- BASE BLOCK (test_base) — NO DUPLICATES, NO REPLACEMENT ----
            # Build a single pool of local indices from all base classes
            base_pool = np.concatenate([self.base_cti[c] for c in self.base_classes]) if self.base_classes else np.array([], dtype=int)
            if len(base_pool) < self.q_base_total:
                raise ValueError(f"Not enough base candidates: have {len(base_pool)} < {self.q_base_total}")

            # Sample Qb distinct local indices, then shift by offset to address ConcatDataset second component
            base_q_local = self.rng.choice(base_pool, size=self.q_base_total, replace=False)
            base_block   = (base_q_local + self.offset_base).astype(int)

            # ---- FINAL EPISODE ----
            full_episode = np.concatenate([novel_block, base_block])
            yield full_episode

Validation (Stage-2) dataLoader:

In [None]:
sampler_val_s2 = GFSLEvalEpisodeSampler(
    cti_val_base, cti_val_novel, len(train_valnovel), n_tasks=S2_VAL_TASKS, rng=test_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL
)

In [None]:
# Build a single test dataset by concatenating the two subsets.
# Order matters: novel part first, base part second.
val_concat_s2 = ConcatDataset([train_valnovel, train_base_val])

In [None]:
val_loader_s2 = DataLoader(
    val_concat_s2, batch_sampler=sampler_val_s2, collate_fn=eval_collate_fn,
    num_workers=2, pin_memory=True
)

Test DataLoader:

In [None]:
sampler_test = GFSLEvalEpisodeSampler(
    cti_test_base, cti_test_novel, len(test_novel), n_tasks=TEST_TASKS, rng=test_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL
)

In [None]:
# Build a single test dataset by concatenating the two subsets.
# Order matters: novel part first, base part second.
test_concat = ConcatDataset([test_novel, test_base])

In [None]:
test_loader = DataLoader(
    test_concat, batch_sampler=sampler_test, collate_fn=eval_collate_fn,
    num_workers=2, pin_memory=True
)

#### **2.1.2 DFSLwF module**

???

**2.1.2.1 Feature extractor**

???

In [None]:
class BasicBlockCIFAR(nn.Module):
    expansion = 1
    def __init__(self, in_ch, out_ch, stride=1, remove_last_relu=False):
        super().__init__()
        self.remove_last_relu = remove_last_relu
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.down  = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.down is not None:
            identity = self.down(x)
        out = out + identity
        if not self.remove_last_relu:
            out = self.relu(out)
        return out


class CIFARResNetSmall(nn.Module):
    """
    ResNet 'cifar-like' per 32x32:
      stem 3x3 s1 -> [64]x2 -> [128]x2 (s2) -> [256]x2 (s2) -> GAP -> 256-D
    """
    def __init__(self, remove_last_relu: bool = True):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.layer1 = nn.Sequential(
            BasicBlockCIFAR(64, 64, 1),
            BasicBlockCIFAR(64, 64, 1),
        )
        self.layer2 = nn.Sequential(
            BasicBlockCIFAR(64, 128, 2),   # downsample 32->16
            BasicBlockCIFAR(128, 128, 1),
        )
        self.layer3 = nn.Sequential(
            BasicBlockCIFAR(128, 256, 2),  # downsample 16->8
            BasicBlockCIFAR(256, 256, 1, remove_last_relu=remove_last_relu),
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.out_dim = 256

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x)
        x = self.gap(x).squeeze(-1).squeeze(-1)  # (B,256)
        return x

In [None]:
class ConvBlock(nn.Module):
    """Conv3x3 -> BN -> ReLU -> (optional MaxPool2d)."""
    def __init__(self, in_ch, out_ch, pool: bool):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
        self.bn   = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2) if pool else nn.Identity()

    def forward(self, x):
        x = self.conv(x); x = self.bn(x); x = self.relu(x); x = self.pool(x)
        return x

In [None]:
class FeatureExtractor(nn.Module):
    """Feature extractor with selectable backbone: 'conv' (light), 'cifar_resnet' or 'resnet18'.

    Args:
        backbone: 'conv', 'resnet18' or cifar_resnet.
        normalize_out: if True, L2-normalize the output features.
        resnet_pretrained: if True (only for 'resnet18'), load ImageNet pretrained weights.
        remove_last_relu: if True (only for 'resnet18'), remove the last post-add ReLU in the final BasicBlock
                          (useful with cosine classifiers).
    """
    def __init__(
        self,
        backbone: str = "conv",
        normalize_out: bool = True,
        resnet_pretrained: bool = True,
        remove_last_relu: bool = True,
    ):
        super().__init__()
        self.normalize_out = normalize_out

        if backbone.lower() == "conv":
            # Lightweight CIFAR-style CNN trained from scratch: 64-D output
            self.fe = nn.Sequential(
                ConvBlock(3,   64, pool=True),   # 32 -> 16
                ConvBlock(64,  64, pool=True),   # 16 -> 8
                ConvBlock(64,  64, pool=False),
                ConvBlock(64,  64, pool=False),
                nn.AdaptiveAvgPool2d(1),         # -> (B,64,1,1)
            )
            self._mode = "conv"
            self.out_dim = 64

        elif backbone.lower() == "cifar_resnet":
            self.fe = CIFARResNetSmall(remove_last_relu=remove_last_relu)
            self._mode = "cifar_resnet"; self.out_dim = 256

        elif backbone.lower() == "resnet18":
            weights = ResNet18_Weights.IMAGENET1K_V1 if resnet_pretrained else None
            m = resnet18(weights=weights)
            m.fc = nn.Identity()  # we want the 512-D penultimate features

            if remove_last_relu:
                # Patch only the final BasicBlock to skip the post-add ReLU
                last_block = m.layer4[-1]

                def forward_norelu(self_block, x):
                    identity = x
                    out = self_block.conv1(x); out = self_block.bn1(out); out = self_block.relu(out)
                    out = self_block.conv2(out); out = self_block.bn2(out)
                    if self_block.downsample is not None:
                        identity = self_block.downsample(x)
                    out = out + identity
                    return out  # no final ReLU

                last_block.forward = types.MethodType(forward_norelu, last_block)

            self.fe = m
            self._mode = "resnet18"
            self.out_dim = 512

        else:
            raise ValueError(f"Unknown backbone '{backbone}'. Use 'conv', 'resnet18' or 'cifar_resnet'.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self._mode == "conv":
            z = self.fe(x).squeeze(-1).squeeze(-1)    # (B, 64)
        else:  # resnet18
            z = self.fe(x)                            # (B, 512)
        return l2_normalize(z, dim=1) if self.normalize_out else z

**2.1.2.2 Cosine Classifier**

???

In [None]:
class CosineClassifier(nn.Module):
    """Cosine classifier with learnable temperature τ (tau)."""
    def __init__(self, in_dim: int, n_classes: int, init_scale: float = TAU_INIT):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(n_classes, in_dim))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.tau = nn.Parameter(torch.tensor(float(init_scale)))

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        W = l2_normalize(self.weight, dim=1)
        feats = l2_normalize(feats, dim=1)
        logits = feats @ W.t()
        return self.tau * logits

**2.1.2.3 Few-Shot Weight Generator**

???

In [None]:
class FewShotWeightGenerator(nn.Module):
    """
    [DFSLwF] Few-shot classification weight generator = Avg + Attention:
      - w_avg  = mean(z_i)                     → φ_avg ⊙ w_avg
      - w_att  = avg_i softmax(γ cos(φ_q z_i, k_b)) · w_b (over base classes b) → φ_att ⊙ w_att
      - w'     = φ_avg ⊙ w_avg + φ_att ⊙ w_att, then L2-normalize
    Includes:
      - learnable keys k_b (one per base class), size (C_base, D)
      - learnable φ_q (Linear D→D, no bias), φ_avg, φ_att (vectors), γ (scalar)
      - optional dropout on features during training (Stage-2)
    """
    def __init__(self, dim: int, num_base: int, p_dropout: float = 0.5):
        super().__init__()
        self.dim = dim
        self.num_base = num_base
        self.phi_q = nn.Linear(dim, dim, bias=False)
        nn.init.kaiming_uniform_(self.phi_q.weight, a=math.sqrt(5))
        self.keys = nn.Parameter(l2_normalize(torch.randn(num_base, dim), dim=1))
        self.phi_avg = nn.Parameter(torch.ones(dim))
        self.phi_att = nn.Parameter(torch.ones(dim))
        self.gamma = nn.Parameter(torch.tensor(10.0))  # attention temperature (like τ)
        self.dropout = nn.Dropout(p=p_dropout)

    def forward(
        self,
        support_feats: torch.Tensor,          # (N*K, D), L2-normalized
        base_weights: torch.Tensor,           # (C_base, D), not necessarily normalized
        shots_per_class: int,
        exclude_mask: Optional[torch.Tensor] = None  # (C_base,) bool; True = keep; False = exclude
    ) -> torch.Tensor:
        D = support_feats.size(1)
        N = support_feats.size(0) // shots_per_class
        z = support_feats.view(N, shots_per_class, D)
        if self.training:
            z = self.dropout(z)
        # w_avg
        w_avg = l2_normalize(z.mean(dim=1), dim=1)  # (N, D)
        # attention over base weights
        Wb = l2_normalize(base_weights, dim=1)      # (C_base, D)
        Kb = l2_normalize(self.keys, dim=1)         # (C_base, D)
        # queries
        q = self.phi_q(z)                           # (N, K, D)
        q = l2_normalize(q, dim=2)                 # normalize across D
        # cosine(q, Kb) => (N, K, C_base)
        att_logits = torch.einsum("nkd,bd->nkb", q, Kb) * self.gamma
        if exclude_mask is not None:
            # set -inf on excluded classes before softmax
            mask = exclude_mask.view(1, 1, -1)  # broadcast
            att_logits = att_logits.masked_fill(~mask, float("-inf"))
        att = torch.softmax(att_logits, dim=2)      # (N, K, C_base)
        # weighted sum of base weights -> (N, K, D), then average over K (shots)
        w_att = torch.einsum("nkb,bd->nkd", att, Wb).mean(dim=1)  # (N, D)
        # combine
        w = self.phi_avg * w_avg + self.phi_att * w_att            # Hadamard
        w = l2_normalize(w, dim=1)  # (N, D)
        return w

**2.1.2.4 Final Network**

???

In [None]:
class DFSLwF(nn.Module):
    def __init__(self, fe: FeatureExtractor, clf_base: CosineClassifier, gen: FewShotWeightGenerator):
        super().__init__()
        self.fe = fe
        self.clf_base = clf_base
        self.gen = gen

    def forward_logits(self, x: torch.Tensor, novel_weights: torch.Tensor | None = None) -> torch.Tensor:
        feats = self.fe(x)
        Wb = l2_normalize(self.clf_base.weight, dim=1)
        logits = self.clf_base.tau * (feats @ Wb.t())
        if novel_weights is not None and novel_weights.numel() > 0:
            Wn = l2_normalize(novel_weights, dim=1)
            logits_n = self.clf_base.tau * (feats @ Wn.t())
            logits = torch.cat([logits, logits_n], dim=1)
        return logits

    @torch.no_grad()
    def build_novel_weights(self, support_imgs: torch.Tensor, k_shot: int) -> torch.Tensor:
        supp = self.fe(support_imgs)  # (N*K, D), already normalized
        Wb = self.clf_base.weight
        return self.gen(supp, Wb, k_shot, exclude_mask=None)

Now we can build our network following the initial settings:

In [None]:
if BACKBONE == "conv":
    fe = FeatureExtractor(backbone=BACKBONE, normalize_out=True)
elif BACKBONE == "cifar_resnet":
    fe = FeatureExtractor(backbone=BACKBONE, normalize_out=True, remove_last_relu=True)
elif BACKBONE == "resnet18":
    fe = FeatureExtractor(backbone=BACKBONE, normalize_out=True, resnet_pretrained=True, remove_last_relu=True)
else:
    raise ValueError(f"Unknown backbone '{BACKBONE}'. Use 'conv', 'resnet18' or 'cifar_resnet'.")

clf = CosineClassifier(in_dim=fe.out_dim, n_classes=len(base_order))
gen = FewShotWeightGenerator(dim=fe.out_dim, num_base=len(base_order), p_dropout=0.5)

model = DFSLwF(fe=fe, clf_base=clf, gen=gen).to(device)

In [None]:
baseline_model = copy.deepcopy(model)

### **2.2 Training**

???

#### **2.2.1 Stage 1: supervised base training**

???

In [None]:
@torch.no_grad()
def evaluate_stage1_base_top1(model: DFSLwF, loader: DataLoader, device) -> float:
    model.fe.eval(); model.clf_base.eval()
    correct, total = 0, 0
    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        logits = model.clf_base(model.fe(xb))
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        total   += yb.numel()
    return correct / max(1, total)

In [None]:
def train_stage1(model: DFSLwF, loader: DataLoader, device: torch.device,
                 epochs: int = STAGE1_EPOCHS, lr: float = STAGE1_LR,
                 weight_decay: float = STAGE1_WEIGHT_DECAY,
                 val_loader: Optional[DataLoader] = None):

    # [DFSLwF] Train feature extractor + base classifier (cosine)
    model.fe.train(); model.clf_base.train(); model.gen.eval()

    # Keep BN learnable here (paper trains a standard classifier in Stage-1)
    params    = list(model.fe.parameters()) + list(model.clf_base.parameters())
    opt       = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    # Params for validation
    best_val = -1.0
    patience = 0
    best_state = None

    with tqdm(range(epochs), desc="[Stage1] Supervised base training") as epbar:
        for ep in epbar:
            model.fe.train(); model.clf_base.train()
            batch_losses = []

            for i, (xb, yb) in enumerate(loader):
                xb = xb.to(device); yb = yb.to(device)
                opt.zero_grad()
                logits = model.clf_base(model.fe(xb))
                loss   = criterion(logits, yb)
                loss.backward()
                opt.step()
                batch_losses.append(float(loss.item()))

                if (i + 1) % S1_LOG_EVERY == 0:
                    epbar.set_postfix(loss=f"{sliding_avg(batch_losses, S1_LOG_EVERY):.4f}")

            # ---- PERIODIC VALIDATION ----
            if val_loader is not None and ((ep + 1) % S1_VAL_EVERY == 0):
                val_acc = evaluate_stage1_base_top1(model, val_loader, device)
                epbar.write(f"\t[stage1/val] epoch {ep+1:03d}: acc_base={100*val_acc:.2f}%")

                if val_acc > best_val + 1e-6:
                    best_val   = val_acc
                    patience   = 0
                    best_state = {
                        "model": copy.deepcopy(model.state_dict()),
                        "opt":   copy.deepcopy(opt.state_dict()),
                        "epoch": ep + 1,
                        "val":   best_val,
                    }
                else:
                    patience += 1
                    if patience >= S1_PATIENCE:
                        epbar.write(f"[stage1] Early stopping (no improvement for {S1_PATIENCE} validations).")
                        break

    # Restore best result
    if best_state is not None:
        model.load_state_dict(best_state["model"])

In [None]:
model = copy.deepcopy(baseline_model)
train_stage1(model, train_loader_s1, device=device, val_loader=val_loader_s1)

[Stage1] Supervised base training:   2%|▏         | 2/120 [00:36<36:11, 18.40s/it, loss=3.2044]

	[stage1/val] epoch 002: acc_base=12.91%


[Stage1] Supervised base training:   3%|▎         | 4/120 [01:13<35:44, 18.48s/it, loss=2.5435]

	[stage1/val] epoch 004: acc_base=19.34%


[Stage1] Supervised base training:   5%|▌         | 6/120 [01:49<34:45, 18.29s/it, loss=2.1755]

	[stage1/val] epoch 006: acc_base=30.34%


[Stage1] Supervised base training:   7%|▋         | 8/120 [02:26<34:10, 18.31s/it, loss=1.8601]

	[stage1/val] epoch 008: acc_base=35.31%


[Stage1] Supervised base training:   8%|▊         | 10/120 [03:02<33:44, 18.40s/it, loss=1.6402]

	[stage1/val] epoch 010: acc_base=45.06%


[Stage1] Supervised base training:  10%|█         | 12/120 [03:40<33:28, 18.60s/it, loss=1.4634]

	[stage1/val] epoch 012: acc_base=35.97%


[Stage1] Supervised base training:  12%|█▏        | 14/120 [04:18<33:47, 19.13s/it, loss=1.3379]

	[stage1/val] epoch 014: acc_base=45.06%


[Stage1] Supervised base training:  13%|█▎        | 16/120 [04:56<32:59, 19.03s/it, loss=1.2211]

	[stage1/val] epoch 016: acc_base=47.94%


[Stage1] Supervised base training:  15%|█▌        | 18/120 [05:34<32:14, 18.96s/it, loss=1.1496]

	[stage1/val] epoch 018: acc_base=32.06%


[Stage1] Supervised base training:  17%|█▋        | 20/120 [06:11<31:42, 19.03s/it, loss=1.0355]

	[stage1/val] epoch 020: acc_base=50.12%


[Stage1] Supervised base training:  18%|█▊        | 22/120 [06:49<30:54, 18.93s/it, loss=0.9855]

	[stage1/val] epoch 022: acc_base=57.66%


[Stage1] Supervised base training:  20%|██        | 24/120 [07:27<30:18, 18.94s/it, loss=0.9150]

	[stage1/val] epoch 024: acc_base=49.62%


[Stage1] Supervised base training:  22%|██▏       | 26/120 [08:04<29:37, 18.91s/it, loss=0.8782]

	[stage1/val] epoch 026: acc_base=53.06%


[Stage1] Supervised base training:  23%|██▎       | 28/120 [08:41<28:50, 18.81s/it, loss=0.8353]

	[stage1/val] epoch 028: acc_base=56.41%


[Stage1] Supervised base training:  25%|██▌       | 30/120 [09:19<28:07, 18.75s/it, loss=0.7878]

	[stage1/val] epoch 030: acc_base=53.84%


[Stage1] Supervised base training:  27%|██▋       | 32/120 [09:56<27:25, 18.70s/it, loss=0.7713]

	[stage1/val] epoch 032: acc_base=63.19%


[Stage1] Supervised base training:  28%|██▊       | 34/120 [10:33<26:43, 18.64s/it, loss=0.7545]

	[stage1/val] epoch 034: acc_base=53.84%


[Stage1] Supervised base training:  30%|███       | 36/120 [11:10<26:05, 18.63s/it, loss=0.7252]

	[stage1/val] epoch 036: acc_base=61.66%


[Stage1] Supervised base training:  32%|███▏      | 38/120 [11:47<25:28, 18.64s/it, loss=0.7003]

	[stage1/val] epoch 038: acc_base=60.09%


[Stage1] Supervised base training:  33%|███▎      | 40/120 [12:24<24:51, 18.65s/it, loss=0.6819]

	[stage1/val] epoch 040: acc_base=62.25%


[Stage1] Supervised base training:  34%|███▍      | 41/120 [13:01<25:06, 19.07s/it, loss=0.6441]

	[stage1/val] epoch 042: acc_base=62.06%
	[stage1] Early stopping (no improvement for 5 validations).





In [None]:
stage1_model = copy.deepcopy(model)

#### **2.2.2 Stage 2: episodic training**

???

In [None]:
@torch.no_grad()
def evaluate_gfsl(model: DFSLwF, test_loader: DataLoader, cifar_targets_all: List[int],
                  base_classes: List[int], device: torch.device) -> Tuple[float, float, float]:
    model.fe.eval(); model.clf_base.eval(); model.gen.eval()
    acc_per_episode_base, acc_per_episode_novel = [], []
    base_order = sorted(base_classes)
    b2local = {cid: i for i, cid in enumerate(base_order)}
    Cb = model.clf_base.weight.size(0)

    for (support_novel, query_images, true_novel_ids, gt_novel, gt_base) in test_loader:
        support_novel = support_novel.to(device)
        query_images = query_images.to(device)
        gt_novel = gt_novel.to(device)
        gt_base = gt_base.to(device)

        novel_weights = model.build_novel_weights(support_novel, K_SHOT)  # (N, D)
        logits = model.forward_logits(query_images, novel_weights)
        preds = logits.argmax(dim=1)

        N = novel_weights.size(0)
        pred_novel = preds[:N * Q_NOVEL] - Cb
        pred_base = preds[N * Q_NOVEL:]

        id2local = {cid: i for i, cid in enumerate(true_novel_ids)}
        gt_novel_local = torch.tensor([id2local[int(y.item())] for y in gt_novel], device=device)
        gt_base_local = torch.tensor([b2local[int(y.item())] for y in gt_base], device=device)

        acc_b = (pred_base == gt_base_local).float().mean().item()
        acc_n = (pred_novel == gt_novel_local).float().mean().item()
        acc_per_episode_base.append(acc_b); acc_per_episode_novel.append(acc_n)

    return len(acc_per_episode_base), gfsl_stats(acc_per_episode_base, acc_per_episode_novel)

In [None]:
def train_stage2(model: DFSLwF, meta_loader: DataLoader, device: torch.device,
                 base_order: List[int], n_tasks: int = STAGE2_TASKS, val_every: int = STAGE2_VAL_EVERY,
                 lr: float = STAGE2_LR, val_loader: Optional[DataLoader] = None):
    """
    [DFSLwF] Freeze F, train generator + continue training W_base (and τ). Exclude pseudo-novel from attention memory.
    Use *true* base labels for base queries; novel queries target indices are (Cb .. Cb+N-1).
    """
    # Freeze feature extractor; freeze BN stats
    model.fe.eval(); model.fe.apply(set_bn_eval)
    for p in model.fe.parameters(): p.requires_grad = False

    # Train generator, base weights, and tau
    for p in model.clf_base.parameters(): p.requires_grad = True
    params = list(model.gen.parameters()) + [model.clf_base.tau, model.clf_base.weight]
    opt = torch.optim.Adam(params, lr=lr)
    criterion = nn.CrossEntropyLoss()

    # mapping CIFAR id -> local base index
    b2local = {cid: i for i, cid in enumerate(sorted(base_order))}
    Cb = model.clf_base.weight.size(0)

    # Validation params
    best_score = -1.0
    patience   = 0
    best_state = None

    model.gen.train(); model.clf_base.train()
    with tqdm(enumerate(meta_loader), total=len(meta_loader), desc="[Stage2] Episodic Training") as pbar:
        for step, (support_novel, query_images, true_novel_ids, base_q_labels_cifar) in pbar:
            support_novel       = support_novel.to(device)
            query_images        = query_images.to(device)
            base_q_labels_cifar = base_q_labels_cifar.to(device)

            # Mask: escludi pseudo-novel dall'attenzione
            exclude_mask = torch.ones(Cb, dtype=torch.bool, device=device)
            for cid in true_novel_ids:
                if cid in b2local:
                    exclude_mask[b2local[cid]] = False

            # Novel weights (support -> gen), FE è congelato
            with torch.no_grad():
                supp_feats = model.fe(support_novel)
            novel_weights = model.gen(supp_feats, model.clf_base.weight, K_SHOT, exclude_mask=exclude_mask)

            # Logits e target
            logits = model.forward_logits(query_images, novel_weights)  # [Cb | N]
            N = novel_weights.size(0)

            y_novel = torch.arange(N, device=device).repeat_interleave(Q_NOVEL)
            targets = torch.empty(logits.size(0), dtype=torch.long, device=device)
            targets[:N * Q_NOVEL] = Cb + y_novel
            base_local = torch.tensor([b2local[int(y.item())] for y in base_q_labels_cifar], device=device)
            targets[N * Q_NOVEL:] = base_local

            opt.zero_grad()
            loss = criterion(logits, targets)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(params, STAGE2_GRAD_CLIP)
            opt.step()
            pbar.set_postfix(loss=f"{float(loss.item()):.4f}")

            # ---- PERIODIC VALIDATION ----
            if val_loader is not None and ((step + 1) % val_every == 0):
                Tval, vstats = evaluate_gfsl(model, val_loader, ds_train.targets, base_order, device)
                v_base  = vstats["base"]["mean"];  v_base_ci  = vstats["base"]["conf"]
                v_novel = vstats["novel"]["mean"]; v_novel_ci = vstats["novel"]["conf"]
                v_h     = vstats["hmean"]["mean"]; v_h_ci     = vstats["hmean"]["conf"]

                pbar.write(f"\t[stage2/val] step {step+1:05d}: "
                          f"base={100*v_base:.2f}%±{100*v_base_ci:.2f}  "
                          f"novel={100*v_novel:.2f}%±{100*v_novel_ci:.2f}  "
                          f"h-mean={100*v_h:.2f}%±{100*v_h_ci:.2f}  (T={Tval})")

                sel = {"base": v_base, "novel": v_novel, "hmean": v_h}[S2_SELECT_METRIC]
                if sel > best_score + 1e-6:
                    best_score = sel
                    patience   = 0
                    best_state = {
                        "model": copy.deepcopy(model.state_dict()),
                        "opt":   copy.deepcopy(opt.state_dict()),
                        "step":  step + 1,
                        "score": best_score,
                    }
                else:
                    patience += 1
                    if patience >= S2_PATIENCE:
                        pbar.write(f"[stage2] Early stopping (no improvement for {S2_PATIENCE} validations).")
                        break

                model.gen.train(); model.clf_base.train()

    # Restore the best result
    if best_state is not None:
        model.load_state_dict(best_state["model"])

In [None]:
model = copy.deepcopy(stage1_model)
train_stage2(model, train_loader_s2, device=device, base_order=base_order, val_loader=val_loader_s2)

[Stage2] Episodic Training:   3%|▎         | 502/20000 [01:32<37:31:25,  6.93s/it, loss=0.9299]

	[stage2/val] step 00500: base=47.98%±0.38  novel=72.39%±0.52  h-mean=57.26%±0.31  (T=1000)


[Stage2] Episodic Training:   5%|▌         | 1003/20000 [03:05<31:42:53,  6.01s/it, loss=0.6792]

	[stage2/val] step 01000: base=47.37%±0.38  novel=73.81%±0.50  h-mean=57.28%±0.32  (T=1000)


[Stage2] Episodic Training:   8%|▊         | 1502/20000 [04:37<30:48:50,  6.00s/it, loss=0.7881]

	[stage2/val] step 01500: base=47.01%±0.39  novel=74.52%±0.52  h-mean=57.20%±0.33  (T=1000)


[Stage2] Episodic Training:  10%|█         | 2003/20000 [06:10<29:27:38,  5.89s/it, loss=0.8059]

	[stage2/val] step 02000: base=47.03%±0.35  novel=74.77%±0.51  h-mean=57.37%±0.31  (T=1000)


[Stage2] Episodic Training:  13%|█▎        | 2502/20000 [07:40<29:32:09,  6.08s/it, loss=0.8211]

	[stage2/val] step 02500: base=46.98%±0.38  novel=75.33%±0.51  h-mean=57.44%±0.32  (T=1000)


[Stage2] Episodic Training:  15%|█▌        | 3002/20000 [09:10<28:28:10,  6.03s/it, loss=0.8468]

	[stage2/val] step 03000: base=46.83%±0.38  novel=74.90%±0.51  h-mean=57.20%±0.32  (T=1000)


[Stage2] Episodic Training:  18%|█▊        | 3503/20000 [10:48<29:20:39,  6.40s/it, loss=0.8594]

	[stage2/val] step 03500: base=46.86%±0.37  novel=75.61%±0.50  h-mean=57.45%±0.31  (T=1000)


[Stage2] Episodic Training:  20%|██        | 4002/20000 [12:34<30:46:59,  6.93s/it, loss=0.7984]

	[stage2/val] step 04000: base=46.54%±0.38  novel=75.35%±0.50  h-mean=57.12%±0.32  (T=1000)


[Stage2] Episodic Training:  23%|██▎       | 4502/20000 [14:10<29:33:17,  6.87s/it, loss=0.7996]

	[stage2/val] step 04500: base=47.38%±0.39  novel=74.88%±0.51  h-mean=57.57%±0.32  (T=1000)


[Stage2] Episodic Training:  25%|██▌       | 5003/20000 [15:44<26:09:56,  6.28s/it, loss=0.7873]

	[stage2/val] step 05000: base=47.52%±0.38  novel=74.71%±0.51  h-mean=57.66%±0.32  (T=1000)


[Stage2] Episodic Training:  28%|██▊       | 5503/20000 [17:21<23:25:10,  5.82s/it, loss=0.7184]

	[stage2/val] step 05500: base=47.81%±0.37  novel=74.71%±0.52  h-mean=57.88%±0.31  (T=1000)


[Stage2] Episodic Training:  30%|███       | 6003/20000 [18:51<22:55:24,  5.90s/it, loss=0.7107]

	[stage2/val] step 06000: base=47.81%±0.38  novel=75.24%±0.51  h-mean=58.05%±0.32  (T=1000)


[Stage2] Episodic Training:  33%|███▎      | 6502/20000 [20:22<23:02:21,  6.14s/it, loss=0.6854]

	[stage2/val] step 06500: base=48.09%±0.39  novel=75.18%±0.53  h-mean=58.21%±0.33  (T=1000)


[Stage2] Episodic Training:  35%|███▌      | 7003/20000 [21:52<19:11:12,  5.31s/it, loss=0.6599]

	[stage2/val] step 07000: base=47.56%±0.39  novel=75.01%±0.52  h-mean=57.74%±0.32  (T=1000)


[Stage2] Episodic Training:  38%|███▊      | 7503/20000 [23:23<20:13:13,  5.82s/it, loss=0.7589]

	[stage2/val] step 07500: base=47.00%±0.39  novel=75.37%±0.50  h-mean=57.46%±0.33  (T=1000)


[Stage2] Episodic Training:  40%|████      | 8003/20000 [24:54<17:47:43,  5.34s/it, loss=0.7400]

	[stage2/val] step 08000: base=47.62%±0.37  novel=74.94%±0.49  h-mean=57.86%±0.31  (T=1000)


[Stage2] Episodic Training:  43%|████▎     | 8502/20000 [26:25<21:31:23,  6.74s/it, loss=0.8193]

	[stage2/val] step 08500: base=48.09%±0.38  novel=75.09%±0.50  h-mean=58.21%±0.32  (T=1000)


[Stage2] Episodic Training:  45%|████▌     | 9002/20000 [27:55<20:40:33,  6.77s/it, loss=0.8550]

	[stage2/val] step 09000: base=48.49%±0.38  novel=75.03%±0.48  h-mean=58.50%±0.31  (T=1000)


[Stage2] Episodic Training:  48%|████▊     | 9502/20000 [29:25<17:43:55,  6.08s/it, loss=0.6644]

	[stage2/val] step 09500: base=48.21%±0.38  novel=75.41%±0.52  h-mean=58.38%±0.32  (T=1000)


[Stage2] Episodic Training:  50%|█████     | 10002/20000 [30:55<16:40:15,  6.00s/it, loss=0.6931]

	[stage2/val] step 10000: base=47.86%±0.37  novel=75.93%±0.52  h-mean=58.28%±0.31  (T=1000)


[Stage2] Episodic Training:  53%|█████▎    | 10502/20000 [32:25<16:11:31,  6.14s/it, loss=0.8324]

	[stage2/val] step 10500: base=48.72%±0.36  novel=74.75%±0.51  h-mean=58.59%±0.31  (T=1000)


[Stage2] Episodic Training:  55%|█████▌    | 11003/20000 [33:56<13:23:33,  5.36s/it, loss=0.5748]

	[stage2/val] step 11000: base=48.53%±0.38  novel=74.77%±0.51  h-mean=58.44%±0.32  (T=1000)


[Stage2] Episodic Training:  58%|█████▊    | 11502/20000 [35:25<14:25:20,  6.11s/it, loss=0.7183]

	[stage2/val] step 11500: base=48.58%±0.38  novel=74.98%±0.51  h-mean=58.51%±0.31  (T=1000)


[Stage2] Episodic Training:  60%|██████    | 12002/20000 [36:55<13:32:00,  6.09s/it, loss=0.5841]

	[stage2/val] step 12000: base=48.41%±0.38  novel=74.92%±0.49  h-mean=58.41%±0.32  (T=1000)


[Stage2] Episodic Training:  63%|██████▎   | 12502/20000 [38:28<13:21:54,  6.42s/it, loss=0.6418]

	[stage2/val] step 12500: base=48.59%±0.39  novel=75.53%±0.50  h-mean=58.70%±0.32  (T=1000)


[Stage2] Episodic Training:  65%|██████▌   | 13003/20000 [40:00<10:39:05,  5.48s/it, loss=0.7062]

	[stage2/val] step 13000: base=48.93%±0.39  novel=75.07%±0.51  h-mean=58.78%±0.31  (T=1000)


[Stage2] Episodic Training:  68%|██████▊   | 13502/20000 [41:30<12:10:17,  6.74s/it, loss=0.6671]

	[stage2/val] step 13500: base=48.70%±0.38  novel=74.49%±0.51  h-mean=58.44%±0.30  (T=1000)


[Stage2] Episodic Training:  70%|███████   | 14002/20000 [42:59<10:05:16,  6.05s/it, loss=0.5749]

	[stage2/val] step 14000: base=49.00%±0.37  novel=74.23%±0.51  h-mean=58.63%±0.31  (T=1000)


[Stage2] Episodic Training:  73%|███████▎  | 14502/20000 [44:29<9:20:44,  6.12s/it, loss=0.6961] 

	[stage2/val] step 14500: base=49.02%±0.38  novel=75.15%±0.50  h-mean=58.91%±0.31  (T=1000)


[Stage2] Episodic Training:  75%|███████▌  | 15002/20000 [45:59<9:24:10,  6.77s/it, loss=0.6869]

	[stage2/val] step 15000: base=49.37%±0.37  novel=74.86%±0.52  h-mean=59.10%±0.31  (T=1000)


[Stage2] Episodic Training:  78%|███████▊  | 15502/20000 [47:28<8:26:16,  6.75s/it, loss=0.5240]

	[stage2/val] step 15500: base=49.56%±0.37  novel=74.81%±0.52  h-mean=59.21%±0.32  (T=1000)


[Stage2] Episodic Training:  80%|████████  | 16002/20000 [48:57<6:43:32,  6.06s/it, loss=0.5591]

	[stage2/val] step 16000: base=49.37%±0.38  novel=74.59%±0.51  h-mean=58.98%±0.31  (T=1000)


[Stage2] Episodic Training:  83%|████████▎ | 16503/20000 [50:30<5:57:31,  6.13s/it, loss=0.5547] 

	[stage2/val] step 16500: base=49.65%±0.37  novel=75.05%±0.50  h-mean=59.36%±0.31  (T=1000)


[Stage2] Episodic Training:  85%|████████▌ | 17002/20000 [52:05<6:10:02,  7.41s/it, loss=0.5431]

	[stage2/val] step 17000: base=49.45%±0.36  novel=74.81%±0.50  h-mean=59.16%±0.30  (T=1000)


[Stage2] Episodic Training:  88%|████████▊ | 17502/20000 [53:40<5:07:23,  7.38s/it, loss=0.5720]

	[stage2/val] step 17500: base=48.93%±0.37  novel=74.53%±0.52  h-mean=58.66%±0.32  (T=1000)


[Stage2] Episodic Training:  90%|█████████ | 18003/20000 [55:15<2:56:14,  5.30s/it, loss=0.6514]

	[stage2/val] step 18000: base=49.26%±0.37  novel=74.45%±0.52  h-mean=58.89%±0.32  (T=1000)


[Stage2] Episodic Training:  93%|█████████▎| 18503/20000 [56:45<2:07:49,  5.12s/it, loss=0.7167]

	[stage2/val] step 18500: base=49.51%±0.38  novel=74.49%±0.51  h-mean=59.07%±0.32  (T=1000)


[Stage2] Episodic Training:  95%|█████████▌| 19002/20000 [58:14<1:39:09,  5.96s/it, loss=0.5008]

	[stage2/val] step 19000: base=49.95%±0.37  novel=74.40%±0.51  h-mean=59.37%±0.31  (T=1000)


[Stage2] Episodic Training:  98%|█████████▊| 19502/20000 [59:43<55:28,  6.68s/it, loss=0.6629]

	[stage2/val] step 19500: base=49.70%±0.38  novel=74.49%±0.51  h-mean=59.22%±0.32  (T=1000)


[Stage2] Episodic Training: 100%|██████████| 20000/20000 [1:01:12<00:00,  5.45it/s, loss=0.6459]

	[stage2/val] step 20000: base=49.89%±0.37  novel=74.80%±0.52  h-mean=59.46%±0.32  (T=1000)





### **2.3 Evaluation**

???

In [None]:
T, stats = evaluate_gfsl(stage1_model, test_loader, ds_test.targets, base, device=device)
print_stats(T, stats, model="After stage 1 training")

[test] - After stage 1 training (95% CI on 1000 tasks)
 - [Base]   acc=57.54% ± 0.36%
 - [Novel]  acc=51.71% ± 0.51%
 - [H-mean] acc=53.97% ± 0.33%


In [None]:
T, stats = evaluate_gfsl(model, test_loader, ds_test.targets, base, device=device)
print_stats(T, stats, model="After stage 2 training")

[test] - After stage 2 training (95% CI on 1000 tasks)
 - [Base]   acc=49.93% ± 0.38%
 - [Novel]  acc=71.31% ± 0.53%
 - [H-mean] acc=58.28% ± 0.31%
