# **Generalized Few-Shot Learning (GFSL)**

This demo implements **Dynamic Few-Shot Learning without Forgetting** (DFSLwF) on the **CIFAR-100** dataset under the *generalized setting*, where both base and novel classes are present at test time. The notebook reproduces the two-stage training pipeline (supervised base training followed by episodic meta-training with a generator) and allows readers to evaluate base, novel, and harmonic accuracies in the **GFSL scenario**.

---

## **1. Background**

### **1.1 Generalized Few-Shot Learning (GFSL)**

In GFSL the evaluation setup changes: during testing, queries can belong to either  
- **base classes** (those already seen during training), or  
- **novel classes** (new ones introduced only at test time).  

The learner must recognize queries across this **joint label space**, handling both familiar and unseen categories simultaneously. This setting is more realistic than classical FSL, where queries always belong to novel classes only.

> **Example (GFSL 3-way 2-shot):** suppose the model was trained on 64 base classes. At test time, an episode may include queries from both these base classes and 3 novel ones, each with only 2 support examples. The learner must correctly classify across all base + novel classes.

A challenge arises: the model tends to be **biased toward base classes**, since they are represented by many more examples during training.  
Approaches like **Dynamic Few-Shot Learning without Forgetting (DFSLwF)** mitigate this issue by combining a base classifier (trained on abundant data) with a novel classifier (trained from few supports), and balancing their predictions at inference time.

*(Evaluation tip: in GFSL, report separate accuracy on base and novel classes, plus the **harmonic mean (H-mean)** to capture overall balance.)*

> **Practical scenario:** GFSL is valuable in real-world systems (e.g., an image recognition app) where the model must keep performing well on frequent categories like “dogs” or “cars” while also learning to recognize new, rare categories from just a handful of examples.


### **1.2 Dynamic Few-Shot Learning without Forgetting**

Let $F_\theta(\cdot)$ be a ConvNet feature extractor.  
A classifier is built using a set of **weight vectors** $W = \{ w_k \}$, one per class, and computes scores via a **cosine similarity**:

$$
s_k(x) \;=\; \tau \cdot \cos\!\big( F_\theta(x), w_k \big),
$$

where both features and weights are $\ell_2$-normalized, and $\tau$ is a learnable scale.  
This design ensures that base and novel categories are treated in a **unified space**.

To incorporate novel classes at test time, DFSLwF introduces a **few-shot weight generator** $G(\cdot)$:
- Input: a small support set of feature vectors $Z' = \{z'_i\}$ from a novel class, plus the set of base weights $W_\text{base}$,  
- Output: a novel classification weight $w'$ for that class.

The generator combines two mechanisms:
1. **Feature averaging:**  
   $w'_\text{avg} = \tfrac{1}{N'} \sum_i z'_i$, scaled by learnable parameters.  
2. **Attention over base weights:**  
   Composes $w'_\text{att}$ as a similarity-weighted sum of $W_\text{base}$, exploiting prior knowledge about the visual world.  

The final novel weight is a linear combination:  
$$
w' = \phi_\text{avg} \odot w'_\text{avg} \;+\; \phi_\text{att} \odot w'_\text{att}.
$$

At inference, the classifier uses  
$$
W^* = W_\text{base} \cup W_\text{novel},
$$  
thus predicting across **both base and novel classes** without retraining.

**Training procedure (two stages):**
1. **Stage 1:** Train feature extractor + base classifier on abundant base data.  
2. **Stage 2:** Train the weight generator using “fake” novel tasks sampled from base categories (episodic style).  

**Notes (implementation):**
- Removing the final ReLU in $F_\theta$ helps cosine similarity classification.  
- The $\ell_2$-normalization enforces compact clusters, improving generalization to novel classes.  
- Evaluation reports accuracy on **Base**, **Novel**, and the **harmonic mean (H-mean)**.  

> **Why it matters:** DFSLwF directly addresses the **GFSL setting**: real systems must keep high performance on frequent base categories (e.g., “dogs”, “cars”) while adapting dynamically to rare, unseen ones (e.g., “drone types”) from few examples, without retraining or forgetting.


### **1.3 CIFAR-100 dataset**
We also use [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html), a widely adopted benchmark for image classification.  
It consists of **100 classes** (e.g., animals, vehicles, household objects), each containing **600 color images** of size $32 \times 32$.  
For every class, there are **500 training images** and **100 test images**.

CIFAR-100 is more challenging than Omniglot or MNIST-like datasets, since it involves **natural RGB images** with high intra-class variability and smaller image resolution.

It's included in the `torchvision` package, making it straightforward to download and preprocess for few-shot or generalized few-shot experiments.


---

## **2. Practice**

In [None]:
from __future__ import annotations
import math
import random
import types
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import copy
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torchvision.models import resnet18, ResNet18_Weights
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Seed:

In [None]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

train_rng = np.random.default_rng(SEED + 1)
test_rng = np.random.default_rng(SEED + 2)

Settings:

In [None]:
# Few-shot episode configuration
N_WAY = 5           # Number of classes per task
K_SHOT = 5          # Support examples per class
Q_NOVEL = 15        # Query examples per novel class
Q_BASE_TOTAL = 75   # Query examples of base classes

# CIFAR-100 split sizes
N_BASE = 64         # Number of classes retained as base
N_VALNOVEL = 16     # Number of classes retained as novel (for validation)
N_TESTNOVEL = 20    # Number of classes retained as novel (for test)

# Network params
TAU_INIT = 10.0             # Temperature init
BACKBONE = "cifar_resnet"   # Feature extractor network
assert BACKBONE in ["resnet18", "conv", "cifar_resnet"], f"Unknown backbone '{BACKBONE}'. Use 'conv', 'resnet18' or 'cifar_resnet'."

# Stage 1
STAGE1_EPOCHS = 120
STAGE1_LR = 3e-3
STAGE1_BS = 512
STAGE1_WEIGHT_DECAY = 5e-4
S1_VAL_FRAC   = 0.10   # Fraction of examples of base classes retained for validation
S1_VAL_EVERY  = 2      # Validate every N epochs
S1_PATIENCE   = 5      # Early-stopping after N validations without improvements
S1_LOG_EVERY  = 50     # Log loss every N batches

# Stage 2
STAGE2_TASKS = 20_000         # Number of training episodes
STAGE2_LR = 5e-4
STAGE2_GRAD_CLIP  = 1.0
S2_VAL_TASKS      = 1_000     # Number of validation episodes
STAGE2_VAL_EVERY  = 500       # Validate every N episodes
S2_PATIENCE       = 5         # Early-stopping after N validations without improvements
S2_SELECT_METRIC  = "hmean"   # validation metric
assert S2_SELECT_METRIC in ["hmean", "base", "novel"], f"Unknown metric '{S2_SELECT_METRIC}'. Use 'hmean', 'base' or 'novel'."

# Test
TEST_TASKS = 1_000    # Number of test episodes

Utility functions:

In [None]:
def split_cifar100_classes(seed: int, n_base=64, n_val=16, n_test=20):
    """Split CIFAR-100 class IDs into base/val-novel/test-novel sets.
    """
    rng = np.random.default_rng(seed)
    classes = np.arange(100); rng.shuffle(classes)
    return classes[:n_base].tolist(), classes[n_base:n_base+n_val].tolist(), classes[n_base+n_val:n_base+n_val+n_test].tolist()


def subset_by_classes(ds, keep):
    """Return a Subset containing only samples whose label is in `keep`.

    Uses vectorized filtering over `ds.targets` to select the indices that
    belong to the provided set of class IDs.
    """
    t = np.array(ds.targets)
    idx = np.nonzero(np.isin(t, keep))[0]
    return Subset(ds, idx)


def class_to_local_indices(subset):
    """Build a mapping class_id -> list of *local* indices within `subset`.

    Iterates over the subset indices and groups them by their original class
    ID (read from `subset.dataset.targets`). Useful for fast episodic sampling
    (e.g., drawing K support and Q query images per class).

    Example:
    ```
    {
        0: [0, 5, 9, ...],
        1: [1, 7, ...],
        ...
    }
    ```
    """
    t = np.array(subset.dataset.targets)
    out = {}
    for j, i in enumerate(subset.indices):
        y = int(t[i])
        (out.setdefault(y, [])).append(j)
    return out


def stratified_split_subset(subset: Subset, val_frac: float, seed: int):
    """
    Split a Subset in (train_part, val_part),
    maintaining proportions between each original CIFAR class.
    """
    rng = np.random.default_rng(seed)
    ds_targets = np.array(subset.dataset.targets)
    # map: class_id -> local indexes list in Subset
    cls2locals = {}
    for j, i in enumerate(subset.indices):
        y = int(ds_targets[i])
        (cls2locals.setdefault(y, [])).append(j)

    train_locals, val_locals = [], []
    for y, locals_ in cls2locals.items():
        locals_ = np.array(locals_, dtype=int)
        n_val = max(1, int(round(len(locals_) * val_frac)))
        rng.shuffle(locals_)
        val_locals.append(locals_[:n_val])
        train_locals.append(locals_[n_val:])

    train_locals = np.concatenate(train_locals).tolist()
    val_locals   = np.concatenate(val_locals).tolist()

    train_indices = [subset.indices[i] for i in train_locals]
    val_indices   = [subset.indices[i] for i in val_locals]

    return Subset(subset.dataset, train_indices), Subset(subset.dataset, val_indices)

In [None]:
def l2_normalize(x: torch.Tensor, dim: int = 1, eps: float = 1e-6) -> torch.Tensor:
  """L2-normalize a tensor along a given dimension.
  """
  return x / (x.norm(p=2, dim=dim, keepdim=True).clamp_min(eps))

In [None]:
def set_bn_eval(m: nn.Module):
  """Put BatchNorm2d layers in eval mode and freeze their parameters,
  so it uses stored running statistics and stops updating them,
  and it disables gradient updates for its affine parameters (gamma/beta).
  """
  if isinstance(m, nn.BatchNorm2d):
      m.eval()
      for p in m.parameters():
          p.requires_grad = False

In [None]:
def sliding_avg(xs: List[float], k: int = 20) -> float:
  if not xs:
      return 0.0
  if len(xs) < k:
      return float(sum(xs) / len(xs))
  return float(sum(xs[-k:]) / k)

In [None]:
def mean_ci95(xs: np.ndarray) -> Tuple[float, float]:
    xs = np.asarray(xs, dtype=float)
    n  = xs.size
    if n < 2:
        return float(xs.mean()), 0.0
    std = xs.std(ddof=1)
    stderr = std / np.sqrt(n)
    z = 1.96
    return float(xs.mean()), float(z * stderr)


def gfsl_stats(
    acc_per_ep_base: List[float],
    acc_per_ep_novel: List[float],
) -> Dict[str, Dict[str, float]]:

    if len(acc_per_ep_base) != len(acc_per_ep_novel):
        raise ValueError("base and novel must be of the same length")
    T = len(acc_per_ep_base)

    base = np.asarray(acc_per_ep_base, dtype=float)
    novel = np.asarray(acc_per_ep_novel, dtype=float)

    denom = base + novel
    h_per_ep = np.where(denom > 0, 2.0 * base * novel / denom, 0.0)

    base_mean, base_ci  = mean_ci95(base)
    novel_mean, novel_ci = mean_ci95(novel)
    h_mean, h_ci = mean_ci95(h_per_ep)

    return {
        "base":  {"mean": base_mean,  "conf": base_ci},
        "novel": {"mean": novel_mean, "conf": novel_ci},
        "hmean": {"mean": h_mean,     "conf": h_ci},
    }


def print_stats(T: int, stats: Dict[str, Dict[str, float]], model: str = ""):
  print(f"[test] - {model} (95% CI on {T} tasks)")
  print(f" - [Base]   acc={100*stats['base']['mean']:.2f}% ± {100*stats['base']['conf']:.2f}%")
  print(f" - [Novel]  acc={100*stats['novel']['mean']:.2f}% ± {100*stats['novel']['conf']:.2f}%")
  print(f" - [H-mean] acc={100*stats['hmean']['mean']:.2f}% ± {100*stats['hmean']['conf']:.2f}%")

### **2.1 Environment**

#### **2.1.1 CIFAR100 dataset**

We define **data transformations** separately for training and evaluation.  
On **train**, we apply light stochastic augmentation suited to 32x32 images (random crop with padding and horizontal flip) to improve invariances without distorting small objects.  
On **eval**, we keep a deterministic pipeline to ensure consistent measurement.

> Augmentations reduce overfitting of the feature extractor in Stage-1. A stable eval pipeline is instead crucial when computing Base / Novel / H-mean in GFSL

In [None]:
if BACKBONE == "conv" or BACKBONE == "cifar_resnet":
    IMAGE_SIZE = 32

    train_tf = transforms.Compose([
        transforms.RandomCrop(IMAGE_SIZE, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    eval_tf = transforms.Compose([transforms.ToTensor(),])

else:
    IMNET_MEAN = [0.485, 0.456, 0.406]
    IMNET_STD  = [0.229, 0.224, 0.225]
    IM_RESIZE = 128
    IM_CROP   = 112

    train_tf = transforms.Compose([
        transforms.Resize(IM_RESIZE),
        transforms.RandomCrop(IM_CROP),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(IMNET_MEAN, IMNET_STD),
    ])

    eval_tf = transforms.Compose([
        transforms.Resize(IM_RESIZE),
        transforms.CenterCrop(IM_CROP),
        transforms.ToTensor(),
        transforms.Normalize(IMNET_MEAN, IMNET_STD),
    ])

We instantiate the **GFSL** splits on CIFAR-100 as follows:
- `ds_train` / `ds_test` load the standard CIFAR-100 **training** and **test** partitions with their respective transforms (`train_tf` randomized, `eval_tf` deterministic).
- `split_cifar100_classes(SEED)` yields a **disjoint class split** into:
  - **Base** classes (used for Stage-1 supervised training and as the base pool in GFSL episodes),
  - **Val-Novel** classes (for model selection with GFSL episodes),
  - **Test-Novel** classes (for final GFSL evaluation).

All data live under `./data` and are fetched on demand with `download=True`.

> *Note.* In GFSL, splits must be done **by class**, not by image, to avoid leakage. Label spaces are remapped locally for base vs. novel episodes so that evaluation is over the **joint label space** (Base ∪ Novel) while keeping indices compact inside each block.


In [None]:
ds_train = CIFAR100(root="./data", train=True,  transform=train_tf, download=True)
ds_test  = CIFAR100(root="./data", train=False, transform=eval_tf,  download=True)

base, valn, testn = split_cifar100_classes(SEED)

100%|██████████| 169M/169M [00:13<00:00, 12.5MB/s]


We now **separate** the base training set into a train and validation split.  
This allows us to monitor Stage-1 supervised training on base classes without leaking novel information.  
We also prepare a distinct set of **val-novel classes** (from the train partition) to be used later for episodic validation in GFSL.

In [None]:
# Stage 1: Train + validation (base) / Stage 2: Train (base and pseudo-novel) + validation (base) - from train
train_base_full = subset_by_classes(ds_train, base)
train_base_tr, train_base_val = stratified_split_subset(train_base_full, S1_VAL_FRAC, SEED)

# Stage 2: Validation (Novel) - from train
train_valnovel = subset_by_classes(ds_train, valn)

# Test (base + novel) - from test
test_base  = subset_by_classes(ds_test,  base)
test_novel = subset_by_classes(ds_test,  testn)

Then we create the **label maps** for each split:  
- local indices for base-train and base-val (Stage-1),  
- mappings for val-novel classes (Stage-2 validation),  
- and indices for base and novel classes in the final test set.  

These compact class-to-indices maps are required for episodic samplers and joint logits during GFSL evaluation.

In [None]:
# Stage-1 (supervised)
cti_train_base_tr  = class_to_local_indices(train_base_tr)
cti_train_base_val = class_to_local_indices(train_base_val)

# Stage-2 Validation (episodic GFSL) - from train
cti_val_base   = cti_train_base_val
cti_val_novel  = class_to_local_indices(train_valnovel)

# For test (episodic GFSL)
cti_test_base  = class_to_local_indices(test_base)
cti_test_novel = class_to_local_indices(test_novel)

**2.1.1.1 DataLoader: Stage 1**

Stage-1 trains the backbone + base classifier with sufficient base data, so we construct a classic **supervised loader** over **base classes only**:

In [None]:
class Stage1TrainDS(torch.utils.data.Dataset):
    def __init__(self, subset: Subset, orig_targets: List[int], order_map: Dict[int, int]):
        self.subset = subset
        self.local_labels = [order_map[int(orig_targets[i])] for i in subset.indices]

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        x, _ = self.subset[idx]
        return x, self.local_labels[idx]


base_order = sorted(base)
order_map  = {cid: i for i, cid in enumerate(base_order)}

stage1_train_ds = Stage1TrainDS(train_base_tr,  ds_train.targets, order_map)  # train
stage1_val_ds   = Stage1TrainDS(train_base_val, ds_train.targets, order_map)  # val (disjoined)

We instantiate them:

In [None]:
train_loader_s1 = DataLoader(
    stage1_train_ds, batch_size=STAGE1_BS, shuffle=True,
    num_workers=2, pin_memory=True
)

val_loader_s1   = DataLoader(
    stage1_val_ds, batch_size=STAGE1_BS*2, shuffle=False,
    num_workers=2, pin_memory=True
)

**2.1.1.2 DataLoader: Stage 2 Train**

For the stage-2 training (but also for the validation and test phases), we need to evaluate the on **few-shot tasks**.
A standard PyTorch `DataLoader` builds mini-batches of images without
considering whether they belong to a support or query set.

However, in a GFSL setting, the dataloader works differently from the classical FSL one, since we need episodes that combine **few-shot novel supports** with **base queries**.

To achieve this, we use a **custom task sampler** and a **custom collate function**.

**Collate function**

It splits the batch into:
  1. `support_novel` images (size `n_way * k_shot`)  
  2. `query_images` = `[novel queries | base queries]`  
  3. `true_novel_ids` (original CIFAR-100 ids for the sampled novel classes)  
  4. `base_query_labels_cifar` (original CIFAR-100 ids for base queries)

In [None]:
def stage2_collate(batch, n_way: int, k_shot: int, q_novel: int, q_base_total: int):
    imgs, labs = list(zip(*batch))
    images = torch.stack(imgs)
    labels = torch.tensor([int(y) for y in labs])

    per_novel = k_shot + q_novel
    total_novel = n_way * per_novel

    novel_block = images[:total_novel].view(n_way, per_novel, *images.shape[1:])
    novel_labels_block = labels[:total_novel].view(n_way, per_novel)

    support_novel = novel_block[:, :k_shot].reshape(-1, *images.shape[1:])
    query_novel = novel_block[:, k_shot:].reshape(-1, *images.shape[1:])
    query_base = images[total_novel:]

    query_images = torch.cat([query_novel, query_base], dim=0)

    true_novel_ids = [int(novel_labels_block[i, 0].item()) for i in range(n_way)]
    base_query_labels_cifar = labels[total_novel:]  # original CIFAR ids
    return support_novel, query_images, true_novel_ids, base_query_labels_cifar


stage2_collate_fn = partial(
    stage2_collate,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL,
)

**Batch sampler**

Samples `n_way` pseudo-novel classes from the base pool, drawing `k_shot + q_novel` examples per class. It then adds `q_base_total` queries from the remaining base classes.


In [None]:
class GFSLTrainEpisodicBatchSampler:
    """
    Stage-2: N_WAY pseudo-novel (from BASE train) with K+Qn each + Qb base queries from BASE (any class).
    Returns indices over the Subset(train_base).
    """
    def __init__(self, class_to_indices: Dict[int, List[int]], n_tasks: int, rng: np.random.Generator,
                 n_way: int = N_WAY, k_shot: int = K_SHOT, q_novel: int = Q_NOVEL, q_base_total: int = Q_BASE_TOTAL):
        self.cti = class_to_indices
        self.all_classes = list(class_to_indices.keys())
        self.n_tasks = n_tasks
        self.rng = rng
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_novel = q_novel
        self.q_base_total = q_base_total

    def __len__(self): return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):

            novel = self.rng.choice(self.all_classes, size=self.n_way, replace=False)

            batch = []
            for c in novel:
                pool = self.cti[c]
                need = self.k_shot + self.q_novel
                if len(pool) < need:
                    raise ValueError(f"Class {c} has {len(pool)} < {need}")
                idx = self.rng.choice(pool, size=need, replace=False)
                batch.append(idx)

            novel_classes = set(novel.tolist())
            base_pool_classes = [c for c in self.all_classes if c not in novel_classes]

            used = set(np.concatenate(batch).tolist())

            base_q = []
            while len(base_q) < self.q_base_total:
                c = int(self.rng.choice(base_pool_classes))
                cand = int(self.rng.choice(self.cti[c]))
                if cand not in used:
                    base_q.append(cand)
                    used.add(cand)

            batch.append(np.array(base_q, dtype=int))
            yield np.concatenate(batch)


stage2_train_batch_sampler = GFSLTrainEpisodicBatchSampler(
    cti_train_base_tr, n_tasks=STAGE2_TASKS, rng=train_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL
)

We can now create our Dataset...

In [None]:
class Stage2TrainDS(torch.utils.data.Dataset):
    def __init__(self, subset: Subset):
        self.subset = subset
    def __len__(self):
      return len(self.subset)
    def __getitem__(self, idx):
      return self.subset[idx]


stage2_train_ds = Stage2TrainDS(train_base_tr)

... and build the DataLoader:

In [None]:
train_loader_s2 = DataLoader(
    stage2_train_ds, batch_sampler=stage2_train_batch_sampler,
    collate_fn=stage2_collate_fn, num_workers=2, pin_memory=True
)

**2.1.1.3 DataLoader: Stage 2 Validation and Test**

For validation and test we need a **different sampler + collate** than in training, because episodes must follow the **GFSL evaluation protocol**:  
- sample supports and queries only from the **novel split**,  
- add a fixed number of queries from the **base split**,  
- keep the layout `[novel block | base block]` so evaluation can compute accuracy on Base, Novel, and H-mean.  

The custom sampler/collate ensure this structure and prevent any leakage between base and novel classes.


In [None]:
def eval_collate(batch, n_way: int, k_shot: int, q_novel: int, q_base_total: int):
    """Reconstruct support/query tensors for a GFSL test episode.

    Input `batch` is a list of (image, label) from `test_concat` where the first
    N_WAY*(K_SHOT+Q_NOVEL) items belong to the novel subset and the remaining Q_BASE_TOTAL
    items belong to the base subset.
    """
    imgs, labs = list(zip(*batch))
    images = torch.stack(imgs)
    labels = torch.tensor([int(y) for y in labs])

    per_novel = k_shot + q_novel
    total_novel = n_way * per_novel

    # reshape novel block into (N, K+Qn, C, H, W) and (N, K+Qn) for labels
    novel_block = images[:total_novel].view(n_way, per_novel, *images.shape[1:])
    novel_labels_block = labels[:total_novel].view(n_way, per_novel)

    # split into support (first K) and query (last Qn)
    support_novel = novel_block[:, :k_shot].reshape(-1, *images.shape[1:])
    query_novel   = novel_block[:, k_shot:].reshape(-1, *images.shape[1:])
    query_base    = images[total_novel:]  # remaining Qb from base subset

    # concatenate all queries [novel | base]
    query_images = torch.cat([query_novel, query_base], dim=0)

    # collect true novel CIFAR IDs (one per class) and per-query GT labels
    true_novel_ids = [int(novel_labels_block[i, 0].item()) for i in range(n_way)]
    gt_novel = novel_labels_block[:, k_shot:].reshape(-1)
    gt_base  = labels[total_novel:]

    return support_novel, query_images, true_novel_ids, gt_novel, gt_base


eval_collate_fn = partial(
    eval_collate,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL,
)

In [None]:
class GFSLEvalEpisodeSampler:
    """Batch sampler for GFSL test episodes over a single ConcatDataset.

    Each yielded batch is a 1D numpy array of indices into `test_concat`
    laid out as:
        [ N_WAY*(K_SHOT+Q_NOVEL) indices from novel part | Q_BASE_TOTAL indices from base part ]
    """

    def __init__(self, base_cti: Dict[int, List[int]], novel_cti: Dict[int, List[int]], offset_base: int,
                 n_tasks: int, rng: np.random.Generator,
                 n_way: int = N_WAY, k_shot: int = K_SHOT, q_novel: int = Q_NOVEL, q_base_total: int = Q_BASE_TOTAL):
        self.base_cti = base_cti
        self.novel_cti = novel_cti
        self.offset_base = offset_base

        self.base_classes  = list(base_cti.keys())
        self.novel_classes = list(novel_cti.keys())

        self.n_tasks = n_tasks
        self.rng = rng
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_novel = q_novel
        self.q_base_total = q_base_total

    def __len__(self):
        return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):
            # ---- NOVEL BLOCK (test_novel) ----
            chosen_novel = self.rng.choice(self.novel_classes, size=self.n_way, replace=False)
            per_novel = self.k_shot + self.q_novel

            novel_chunks = []
            for c in chosen_novel:
                pool = self.novel_cti[c]  # local indices within test_novel
                if len(pool) < per_novel:
                    raise ValueError(f"Novel class {c} has {len(pool)} < {per_novel}")
                idx = self.rng.choice(pool, size=per_novel, replace=False)
                novel_chunks.append(idx)

            novel_block = np.concatenate(novel_chunks).astype(int)  # still in "novel" namespace (no offset)

            # ---- BASE BLOCK (test_base) - NO DUPLICATES, NO REPLACEMENT ----
            # Build a single pool of local indices from all base classes
            base_pool = np.concatenate([self.base_cti[c] for c in self.base_classes]) if self.base_classes else np.array([], dtype=int)
            if len(base_pool) < self.q_base_total:
                raise ValueError(f"Not enough base candidates: have {len(base_pool)} < {self.q_base_total}")

            # Sample Qb distinct local indices, then shift by offset to address ConcatDataset second component
            base_q_local = self.rng.choice(base_pool, size=self.q_base_total, replace=False)
            base_block   = (base_q_local + self.offset_base).astype(int)

            # ---- FINAL EPISODE ----
            full_episode = np.concatenate([novel_block, base_block])
            yield full_episode

Validation (Stage-2) dataLoader:

In [None]:
sampler_val_s2 = GFSLEvalEpisodeSampler(
    cti_val_base, cti_val_novel, len(train_valnovel), n_tasks=S2_VAL_TASKS, rng=test_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL
)

In [None]:
# Build a single test dataset by concatenating the two subsets.
# Order matters: novel part first, base part second.
val_concat_s2 = ConcatDataset([train_valnovel, train_base_val])

In [None]:
val_loader_s2 = DataLoader(
    val_concat_s2, batch_sampler=sampler_val_s2, collate_fn=eval_collate_fn,
    num_workers=2, pin_memory=True
)

Test DataLoader:

In [None]:
sampler_test = GFSLEvalEpisodeSampler(
    cti_test_base, cti_test_novel, len(test_novel), n_tasks=TEST_TASKS, rng=test_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL
)

In [None]:
# Build a single test dataset by concatenating the two subsets.
# Order matters: novel part first, base part second.
test_concat = ConcatDataset([test_novel, test_base])

In [None]:
test_loader = DataLoader(
    test_concat, batch_sampler=sampler_test, collate_fn=eval_collate_fn,
    num_workers=2, pin_memory=True
)

#### **2.1.2 DFSLwF module**

We define our network as a torch module composed by three main modules: a **feature extractor**, a **cosine classifier** and a **Few-shot weight generator**.

**2.1.2.1 Feature extractor**

It's made up by a convolutional backbone that maps each input image to a feature vector.

In our code we can choose between 3 different backbones:

- **A small Residual Network**:

In [None]:
class BasicBlockCIFAR(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, remove_last_relu=False):
        super().__init__()
        self.remove_last_relu = remove_last_relu
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.down  = None
        if stride != 1 or in_ch != out_ch:
            self.down = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.down is not None:
            identity = self.down(x)
        out = out + identity
        if not self.remove_last_relu:
            out = self.relu(out)
        return out


class CIFARResNetSmall(nn.Module):
    """
    ResNet 'cifar-like' per 32x32:
      stem 3x3 s1 -> [64]x2 -> [128]x2 (s2) -> [256]x2 (s2) -> GAP -> 256-D
    """
    def __init__(self, remove_last_relu: bool = True):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.layer1 = nn.Sequential(
            BasicBlockCIFAR(64, 64, 1),
            BasicBlockCIFAR(64, 64, 1),
        )
        self.layer2 = nn.Sequential(
            BasicBlockCIFAR(64, 128, 2),   # downsample 32->16
            BasicBlockCIFAR(128, 128, 1),
        )
        self.layer3 = nn.Sequential(
            BasicBlockCIFAR(128, 256, 2),  # downsample 16->8
            BasicBlockCIFAR(256, 256, 1, remove_last_relu=remove_last_relu),
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.out_dim = 256

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x)
        x = self.gap(x).squeeze(-1).squeeze(-1)  # (B,256)
        return x

- A simple **Convolutional Neural Network**:

In [None]:
class ConvBlock(nn.Module):
    """Conv3x3 -> BN -> ReLU -> (optional MaxPool2d)."""
    def __init__(self, in_ch, out_ch, pool: bool):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
        self.bn   = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2) if pool else nn.Identity()

    def forward(self, x):
        x = self.conv(x); x = self.bn(x); x = self.relu(x); x = self.pool(x)
        return x

- Or the already implemented **ResNet18**.

The model we chose is then implemented in the `FeatureExtractor` module:



In [None]:
class FeatureExtractor(nn.Module):
    """Feature extractor with selectable backbone: 'conv' (light), 'cifar_resnet' or 'resnet18'.

    Args:
        backbone: 'conv', 'resnet18' or 'cifar_resnet'.
        normalize_out: if True, L2-normalize the output features.
        resnet_pretrained: if True (only for 'resnet18'), load ImageNet pretrained weights.
        remove_last_relu: if True (only for 'resnet18'), remove the last post-add ReLU in the final BasicBlock
                          (useful with cosine classifiers).
    """
    def __init__(
        self,
        backbone: str = "conv",
        normalize_out: bool = True,
        resnet_pretrained: bool = True,
        remove_last_relu: bool = True,
    ):
        super().__init__()
        self.normalize_out = normalize_out

        if backbone.lower() == "conv":
            self.fe = nn.Sequential(
                ConvBlock(3,   64, pool=True),   # 32 -> 16
                ConvBlock(64,  64, pool=True),   # 16 -> 8
                ConvBlock(64,  64, pool=False),
                ConvBlock(64,  64, pool=False),
                nn.AdaptiveAvgPool2d(1),         # -> (B,64,1,1)
            )
            self._mode = "conv"
            self.out_dim = 64

        elif backbone.lower() == "cifar_resnet":
            self.fe = CIFARResNetSmall(remove_last_relu=remove_last_relu)
            self._mode = "cifar_resnet"; self.out_dim = 256

        elif backbone.lower() == "resnet18":
            weights = ResNet18_Weights.IMAGENET1K_V1 if resnet_pretrained else None
            m = resnet18(weights=weights)
            m.fc = nn.Identity()  # we want the 512-D penultimate features

            if remove_last_relu:
                # Patch only the final BasicBlock to skip the post-add ReLU
                last_block = m.layer4[-1]

                def forward_norelu(self_block, x):
                    identity = x
                    out = self_block.conv1(x); out = self_block.bn1(out); out = self_block.relu(out)
                    out = self_block.conv2(out); out = self_block.bn2(out)
                    if self_block.downsample is not None:
                        identity = self_block.downsample(x)
                    out = out + identity
                    return out  # no final ReLU

                last_block.forward = types.MethodType(forward_norelu, last_block)

            self.fe = m
            self._mode = "resnet18"
            self.out_dim = 512

        else:
            raise ValueError(f"Unknown backbone '{backbone}'. Use 'conv', 'resnet18' or 'cifar_resnet'.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self._mode == "conv":
            z = self.fe(x).squeeze(-1).squeeze(-1)    # (B, 64)
        else:  # resnet18
            z = self.fe(x)                            # (B, 512)
        return l2_normalize(z, dim=1) if self.normalize_out else z

**2.1.2.2 Cosine Classifier**

This is the component that transforms feature vectors into class scores.

Instead of using a standard dot-product, it computes the **cosine similarity** between the feature vector $x$ extracted from the images and the weights vectors $W$ of each class:


$$
\cos(\mathbf{x}, \mathbf{w}) = \frac{\mathbf{x}^{\top}\mathbf{w}}{\|\mathbf{x}\|_2 \, \|\mathbf{w}\|_2}
$$

The result is scaled by a **learnable parameter $τ$**, which controls how sharp or flat the final softmax distribution is:

$$
y = τ * \cos(\mathbf{x}, \mathbf{w})
$$

In fact, base class weights are learned gradually during Stage-1 training, while novel class weights are generated dynamically from just a few support examples.
If we relied on raw magnitudes (as in a dot-product classifier), the two sets of weights would not be directly comparable.  
**By using cosine similarity, the classifier ensures that both base and novel categories live in the same normalized space**, allowing fair competition between them when predicting the label of a query image.


In [None]:
class CosineClassifier(nn.Module):
    """
    [DFSLwF] Cosine classifier with learnable temperature $τ$ (tau).
    """
    def __init__(self, in_dim: int, n_classes: int, init_scale: float = TAU_INIT):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(n_classes, in_dim))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.tau = nn.Parameter(torch.tensor(float(init_scale)))

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        W = l2_normalize(self.weight, dim=1)
        feats = l2_normalize(feats, dim=1)
        logits = feats @ W.t()
        return self.tau * logits

**2.1.2.3 Few-Shot Weight Generator**


Given $N$ novel classes, each with $K$ support examples, the generator produces one
classification weight $ \mathbf{w}'_c \in \mathbb{R}^D $ per novel class $c$ by combining:
1) a **prototype** from support features and
2) an **attention-guided mixture** of base weights.

All outputs are L2-normalized so they are comparable to cosine-classifier weights.

**Inputs**
- Support features $ Z \in \mathbb{R}^{N\times K \times D} $ (already L2-normalized).
- Base weights $ W_b \in \mathbb{R}^{C_{\text{base}}\times D} $.
- Learnable **keys** $ K_b \in \mathbb{R}^{C_{\text{base}}\times D} $ (one per base class).
- Learnable parameters: $ \Phi_q \in \mathbb{R}^{D\times D} $ (query layer),
  $ \boldsymbol{\phi}_{\text{avg}}, \boldsymbol{\phi}_{\text{att}} \in \mathbb{R}^D $ (diagonal re-weighting),
  and **temperature** $ \gamma \in \mathbb{R}_{>0} $ for the attention softmax.

**Step 1 - Class prototype (feature averaging)**  
For each novel class $c$:
$$
\mathbf{w}^{\text{avg}}_c
=\operatorname{norm}_2\!\Big(\frac{1}{K}\sum_{i=1}^{K} \mathbf{z}_{c,i}\Big)
\qquad\in\mathbb{R}^D.
$$

**Step 2 - Attention over base classes**  
Compute **queries** from support features and attend to base **keys**:
$$
\tilde{\mathbf{z}}_{c,i}=\operatorname{norm}_2(\Phi_q\,\mathbf{z}_{c,i}),\quad
\tilde{K}_b=\operatorname{norm}_2(K_b),\quad
\tilde{W}_b=\operatorname{norm}_2(W_b).
$$
Cosine attention logits and normalized weights:
$$
a_{c,i,b}
=\operatorname{softmax}_b\!\big(\gamma\,\tilde{\mathbf{z}}_{c,i}^{\top}\tilde{\mathbf{k}}_b\big),
$$
(optionally masking some base classes before the softmax).
Aggregate the attended base weights and average across shots:
$$
\mathbf{w}^{\text{att}}_{c}
=\frac{1}{K}\sum_{i=1}^{K}\sum_{b=1}^{C_{\text{base}}} a_{c,i,b}\,\tilde{\mathbf{w}}_{b}
\quad\in\mathbb{R}^D.
$$

**Step 3 - Combine & normalize**  
Re-weight the two components **per-dimension** and combine (Hadamard product $ \odot $):
$$
\mathbf{w}'_{c}
=\operatorname{norm}_2\!\Big(\boldsymbol{\phi}_{\text{avg}}\odot \mathbf{w}^{\text{avg}}_{c}
+\boldsymbol{\phi}_{\text{att}}\odot \mathbf{w}^{\text{att}}_{c}\Big).
$$

**Notes**
- With $K=1$ the prototype reduces to the single support vector.
- The temperature $\gamma$ controls the **sharpness** of attention: large $\gamma$ to focus on few base classes; small $\gamma$ for smoother mixing.
- Dropout on $Z$ is applied **only during Stage-2 training** to regularize the generator.
- The final normalization ensures $ \mathbf{w}'_c $ lives in the same space as cosine-classifier weights.

In [None]:
class FewShotWeightGenerator(nn.Module):
    """
    [DFSLwF] Few-shot classification weight generator (Avg + Attention).
    """
    def __init__(self, dim: int, num_base: int, p_dropout: float = 0.5):
        super().__init__()
        self.dim = dim
        self.num_base = num_base
        self.phi_q = nn.Linear(dim, dim, bias=False)
        nn.init.kaiming_uniform_(self.phi_q.weight, a=math.sqrt(5))
        self.keys = nn.Parameter(l2_normalize(torch.randn(num_base, dim), dim=1))
        self.phi_avg = nn.Parameter(torch.ones(dim))
        self.phi_att = nn.Parameter(torch.ones(dim))
        self.gamma = nn.Parameter(torch.tensor(10.0))  # attention temperature (like τ)
        self.dropout = nn.Dropout(p=p_dropout)

    def forward(
        self,
        support_feats: torch.Tensor,          # (N*K, D), L2-normalized
        base_weights: torch.Tensor,           # (C_base, D), not necessarily normalized
        shots_per_class: int,
        exclude_mask: Optional[torch.Tensor] = None  # (C_base,) bool; True = keep; False = exclude
    ) -> torch.Tensor:
        D = support_feats.size(1)
        N = support_feats.size(0) // shots_per_class
        z = support_feats.view(N, shots_per_class, D)
        if self.training:
            z = self.dropout(z)
        # w_avg
        w_avg = l2_normalize(z.mean(dim=1), dim=1)  # (N, D)
        # attention over base weights
        Wb = l2_normalize(base_weights, dim=1)      # (C_base, D)
        Kb = l2_normalize(self.keys, dim=1)         # (C_base, D)
        # queries
        q = self.phi_q(z)                           # (N, K, D)
        q = l2_normalize(q, dim=2)                 # normalize across D
        # cosine(q, Kb) => (N, K, C_base)
        att_logits = torch.einsum("nkd,bd->nkb", q, Kb) * self.gamma
        if exclude_mask is not None:
            # set -inf on excluded classes before softmax
            mask = exclude_mask.view(1, 1, -1)  # broadcast
            att_logits = att_logits.masked_fill(~mask, float("-inf"))
        att = torch.softmax(att_logits, dim=2)      # (N, K, C_base)
        # weighted sum of base weights -> (N, K, D), then average over K (shots)
        w_att = torch.einsum("nkb,bd->nkd", att, Wb).mean(dim=1)  # (N, D)
        # combine
        w = self.phi_avg * w_avg + self.phi_att * w_att            # Hadamard
        w = l2_normalize(w, dim=1)  # (N, D)
        return w

**2.1.2.4 Final Network: DFSLwF module**

At inference time, the **Feature Extractor (FE)** maps an image to a L2-normalized embedding.

For few-shot episodes, the **Few-Shot Weight Generator** takes the support embeddings and the **base** class weights and produces **novel** class weights by combining:
1. Feature averaging (prototypes),
2. Attention over base classes.

The result is L2-normalized so it’s comparable to base weights.

The **Cosine Classifier** then scores a query embedding against (normalized) base weights, and, if provided, against the generated novel weights, scaling all logits by the learnable temperature. Final predictions use **Softmax** on these concatenated logits.


In [None]:
class DFSLwF(nn.Module):
    def __init__(self, fe: FeatureExtractor, clf_base: CosineClassifier, gen: FewShotWeightGenerator):
        super().__init__()
        self.fe = fe
        self.clf_base = clf_base
        self.gen = gen

    def forward_logits(
        self,
        x: torch.Tensor,
        novel_weights: torch.Tensor | None = None,
        base_keep_mask: torch.Tensor | None = None,   # <--- NEW (optional)
    ) -> torch.Tensor:
        feats = self.fe(x)                                        # (B, D)
        Wb = l2_normalize(self.clf_base.weight, dim=1)            # (C_base, D)
        if base_keep_mask is not None:
            Wb = Wb[base_keep_mask]                               # keep-only base
        logits = self.clf_base.tau * (feats @ Wb.t())             # base logits

        if novel_weights is not None and novel_weights.numel() > 0:
            Wn = l2_normalize(novel_weights, dim=1)               # (C_novel, D)
            logits_n = self.clf_base.tau * (feats @ Wn.t())
            logits = torch.cat([logits, logits_n], dim=1)
        return logits

    def build_novel_weights(
        self,
        support_imgs: torch.Tensor,
        k_shot: int,
        exclude_mask: torch.Tensor | None = None,                 # <--- pass through
    ) -> torch.Tensor:
        supp = self.fe(support_imgs)                              # (N*K, D), L2-norm
        Wb = self.clf_base.weight
        return self.gen(supp, Wb, k_shot, exclude_mask=exclude_mask)

Now we can build our network following the initial settings:

In [None]:
if BACKBONE == "conv":
    fe = FeatureExtractor(backbone=BACKBONE, normalize_out=True)
elif BACKBONE == "cifar_resnet":
    fe = FeatureExtractor(backbone=BACKBONE, normalize_out=True, remove_last_relu=True)
elif BACKBONE == "resnet18":
    fe = FeatureExtractor(backbone=BACKBONE, normalize_out=True, resnet_pretrained=True, remove_last_relu=True)
else:
    raise ValueError(f"Unknown backbone '{BACKBONE}'. Use 'conv', 'resnet18' or 'cifar_resnet'.")

clf = CosineClassifier(in_dim=fe.out_dim, n_classes=len(base_order))
gen = FewShotWeightGenerator(dim=fe.out_dim, num_base=len(base_order), p_dropout=0.5)

model = DFSLwF(fe=fe, clf_base=clf, gen=gen).to(device)

In [None]:
baseline_model = copy.deepcopy(model)

### **2.2 Training**

The training is split into two distinct stages:

1. **Base supervised training** which builds a discriminative embedding for the base classes;
2. **Episodic training** that teaches the model to dynamically generate classification weights for unseen classes, while still preserving performance on the base ones *without forgetting*.

#### **2.2.1 Stage 1: supervised base training**

The first phase consists of training the **feature extractor (FE)** and the **cosine base classifier** on the base classes with standard cross-entropy.  

The goal is to learn a strong embedding space where base weights are well aligned with their class features.

> *Note:* BatchNorm layers remain learnable (not frozen), as in a standard supervised classifier.

In [None]:
@torch.no_grad()
def evaluate_stage1_base_top1(model: DFSLwF, loader: DataLoader, device) -> float:
    model.fe.eval(); model.clf_base.eval()
    correct, total = 0, 0
    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        logits = model.clf_base(model.fe(xb))
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        total   += yb.numel()
    return correct / max(1, total)

In [None]:
def train_stage1(model: DFSLwF, loader: DataLoader, device: torch.device,
                 epochs: int = STAGE1_EPOCHS, lr: float = STAGE1_LR,
                 weight_decay: float = STAGE1_WEIGHT_DECAY,
                 val_loader: Optional[DataLoader] = None):

    # [DFSLwF] Train feature extractor + base classifier (cosine)
    model.fe.train(); model.clf_base.train(); model.gen.eval()

    # Keep BN learnable here (paper trains a standard classifier in Stage-1)
    params    = list(model.fe.parameters()) + list(model.clf_base.parameters())
    opt       = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    # Params for validation
    best_val = -1.0
    patience = 0
    best_state = None

    with tqdm(range(epochs), desc="[Stage1] Supervised base training") as epbar:
        for ep in epbar:
            model.fe.train(); model.clf_base.train()
            batch_losses = []

            for i, (xb, yb) in enumerate(loader):
                xb = xb.to(device); yb = yb.to(device)
                opt.zero_grad()
                logits = model.clf_base(model.fe(xb))
                loss   = criterion(logits, yb)
                loss.backward()
                opt.step()
                batch_losses.append(float(loss.item()))

                if (i + 1) % S1_LOG_EVERY == 0:
                    epbar.set_postfix(loss=f"{sliding_avg(batch_losses, S1_LOG_EVERY):.4f}")

            # ---- PERIODIC VALIDATION ----
            if val_loader is not None and ((ep + 1) % S1_VAL_EVERY == 0):
                val_acc = evaluate_stage1_base_top1(model, val_loader, device)
                epbar.write(f"\t[stage1/val] epoch {ep+1:03d}: acc_base={100*val_acc:.2f}%")

                if val_acc > best_val + 1e-6:
                    best_val   = val_acc
                    patience   = 0
                    best_state = {
                        "model": copy.deepcopy(model.state_dict()),
                        "opt":   copy.deepcopy(opt.state_dict()),
                        "epoch": ep + 1,
                        "val":   best_val,
                    }
                else:
                    patience += 1
                    if patience >= S1_PATIENCE:
                        epbar.write(f"[stage1] Early stopping (no improvement for {S1_PATIENCE} validations).")
                        break

    # Restore best result
    if best_state is not None:
        model.load_state_dict(best_state["model"])

In [None]:
model = copy.deepcopy(baseline_model)
train_stage1(model, train_loader_s1, device=device, val_loader=val_loader_s1)

[Stage1] Supervised base training:   2%|▏         | 2/120 [00:37<37:06, 18.86s/it, loss=3.1845]

	[stage1/val] epoch 002: acc_base=11.84%


[Stage1] Supervised base training:   3%|▎         | 4/120 [01:15<36:26, 18.85s/it, loss=2.5496]

	[stage1/val] epoch 004: acc_base=26.06%


[Stage1] Supervised base training:   5%|▌         | 6/120 [01:52<35:36, 18.75s/it, loss=2.1277]

	[stage1/val] epoch 006: acc_base=25.47%


[Stage1] Supervised base training:   7%|▋         | 8/120 [02:29<34:55, 18.71s/it, loss=1.8668]

	[stage1/val] epoch 008: acc_base=30.88%


[Stage1] Supervised base training:   8%|▊         | 10/120 [03:06<34:20, 18.73s/it, loss=1.6590]

	[stage1/val] epoch 010: acc_base=39.69%


[Stage1] Supervised base training:  10%|█         | 12/120 [03:43<33:39, 18.70s/it, loss=1.5021]

	[stage1/val] epoch 012: acc_base=38.91%


[Stage1] Supervised base training:  12%|█▏        | 14/120 [04:21<33:01, 18.69s/it, loss=1.3505]

	[stage1/val] epoch 014: acc_base=34.22%


[Stage1] Supervised base training:  13%|█▎        | 16/120 [04:58<32:25, 18.71s/it, loss=1.2418]

	[stage1/val] epoch 016: acc_base=50.56%


[Stage1] Supervised base training:  15%|█▌        | 18/120 [05:35<31:49, 18.72s/it, loss=1.1456]

	[stage1/val] epoch 018: acc_base=49.12%


[Stage1] Supervised base training:  17%|█▋        | 20/120 [06:12<31:14, 18.74s/it, loss=1.0685]

	[stage1/val] epoch 020: acc_base=53.81%


[Stage1] Supervised base training:  18%|█▊        | 22/120 [06:50<30:41, 18.79s/it, loss=0.9970]

	[stage1/val] epoch 022: acc_base=55.84%


[Stage1] Supervised base training:  20%|██        | 24/120 [07:27<30:03, 18.79s/it, loss=0.9345]

	[stage1/val] epoch 024: acc_base=48.62%


[Stage1] Supervised base training:  22%|██▏       | 26/120 [08:05<29:27, 18.80s/it, loss=0.8893]

	[stage1/val] epoch 026: acc_base=58.13%


[Stage1] Supervised base training:  23%|██▎       | 28/120 [08:42<28:55, 18.86s/it, loss=0.8559]

	[stage1/val] epoch 028: acc_base=59.00%


[Stage1] Supervised base training:  25%|██▌       | 30/120 [09:20<28:16, 18.85s/it, loss=0.8178]

	[stage1/val] epoch 030: acc_base=57.12%


[Stage1] Supervised base training:  27%|██▋       | 32/120 [09:57<27:37, 18.83s/it, loss=0.7743]

	[stage1/val] epoch 032: acc_base=62.44%


[Stage1] Supervised base training:  28%|██▊       | 34/120 [10:34<27:00, 18.85s/it, loss=0.7587]

	[stage1/val] epoch 034: acc_base=60.31%


[Stage1] Supervised base training:  30%|███       | 36/120 [11:12<26:21, 18.83s/it, loss=0.7165]

	[stage1/val] epoch 036: acc_base=58.91%


[Stage1] Supervised base training:  32%|███▏      | 38/120 [11:49<25:44, 18.83s/it, loss=0.7054]

	[stage1/val] epoch 038: acc_base=60.62%


[Stage1] Supervised base training:  33%|███▎      | 40/120 [12:27<25:10, 18.89s/it, loss=0.6642]

	[stage1/val] epoch 040: acc_base=59.94%


[Stage1] Supervised base training:  34%|███▍      | 41/120 [13:04<25:12, 19.14s/it, loss=0.6691]

	[stage1/val] epoch 042: acc_base=61.75%
[stage1] Early stopping (no improvement for 5 validations).





In [None]:
stage1_model = copy.deepcopy(model)

#### **2.2.2 Stage 2: episodic training**

In the seconda phase, we freeze the **feature extractor** (to avoid forgetting base knowledge), and then we train the **few-shot weight generator** together with the base classifier weights and the learnable temperature $τ$.

In each episode, some base classes are treated as *pseudo-novel*:  
  - Their weights are excluded from the base branch through a mask.  
  - The generator produces new weights for them using few support examples and attention over the remaining base memory.
  
In practice, DFSLwF does not reserve a separate novel set for Stage-2 training.  
Splitting classes further would reduce the already limited data and break the standard evaluation protocol. Instead, pseudo-novel classes are sampled from the base set during Stage-2, while the true novel classes remain untouched until the final test.

The query images are then classified against both the kept base classes and the generated novel weights.  


In [None]:
@torch.no_grad()
def evaluate_gfsl(model: DFSLwF, test_loader: DataLoader, cifar_targets_all: List[int],
                  base_classes: List[int], device: torch.device) -> Tuple[float, float, float]:
    model.fe.eval(); model.clf_base.eval(); model.gen.eval()
    acc_per_episode_base, acc_per_episode_novel = [], []
    base_order = sorted(base_classes)
    b2local = {cid: i for i, cid in enumerate(base_order)}
    Cb = model.clf_base.weight.size(0)

    for (support_novel, query_images, true_novel_ids, gt_novel, gt_base) in test_loader:
        support_novel = support_novel.to(device)
        query_images = query_images.to(device)
        gt_novel = gt_novel.to(device)
        gt_base = gt_base.to(device)

        novel_weights = model.build_novel_weights(support_novel, K_SHOT)  # (N, D)
        logits = model.forward_logits(query_images, novel_weights)
        preds = logits.argmax(dim=1)

        N = novel_weights.size(0)
        pred_novel = preds[:N * Q_NOVEL] - Cb
        pred_base = preds[N * Q_NOVEL:]

        id2local = {cid: i for i, cid in enumerate(true_novel_ids)}
        gt_novel_local = torch.tensor([id2local[int(y.item())] for y in gt_novel], device=device)
        gt_base_local = torch.tensor([b2local[int(y.item())] for y in gt_base], device=device)

        acc_b = (pred_base == gt_base_local).float().mean().item()
        acc_n = (pred_novel == gt_novel_local).float().mean().item()
        acc_per_episode_base.append(acc_b); acc_per_episode_novel.append(acc_n)

    return len(acc_per_episode_base), gfsl_stats(acc_per_episode_base, acc_per_episode_novel)

In [None]:
def forward_logits(self, x, novel_weights=None, base_keep_mask=None):
    feats = self.fe(x)                                  # (B, D)
    Wb = l2_normalize(self.clf_base.weight, dim=1)      # (C_base, D)
    if base_keep_mask is not None:
        Wb = Wb[base_keep_mask]                         # use only base “kept”
    logits = self.clf_base.tau * (feats @ Wb.t())
    if novel_weights is not None and novel_weights.numel() > 0:
        Wn = l2_normalize(novel_weights, dim=1)
        logits = torch.cat([logits, self.clf_base.tau * (feats @ Wn.t())], dim=1)
    return logits


def train_stage2(model: DFSLwF, meta_loader: DataLoader, device: torch.device,
                 base_order: List[int], n_tasks: int = STAGE2_TASKS, val_every: int = STAGE2_VAL_EVERY,
                 lr: float = STAGE2_LR, gc: float = STAGE2_GRAD_CLIP,
                 val_loader: Optional[DataLoader] = None):
    """
    [DFSLwF] Freeze the feature extractor, train the generator and continue training W_base (and τ).
    Exclude pseudo-novel classes both from the generator attention memory
    and from the base branch of the classifier (using the same mask).
    Use the true base labels for base queries; novel queries are indexed starting from Cb_kept.
    """

    # Freeze feature extractor; keep BN statistics frozen
    model.fe.eval(); model.fe.apply(set_bn_eval)
    for p in model.fe.parameters():
        p.requires_grad = False

    # Train generator, base weights, and tau
    for p in model.clf_base.parameters():
        p.requires_grad = True
    params = list(model.gen.parameters()) + [model.clf_base.tau, model.clf_base.weight]
    opt = torch.optim.Adam(params, lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Mapping CIFAR id -> local base index [0..Cb-1]
    b2local = {cid: i for i, cid in enumerate(sorted(base_order))}
    Cb = model.clf_base.weight.size(0)

    # Validation parameters
    best_score = -1.0
    patience   = 0
    best_state = None

    model.gen.train(); model.clf_base.train()
    with tqdm(enumerate(meta_loader), total=len(meta_loader), desc="[Stage2] Episodic Training") as pbar:
        for step, (support_novel, query_images, true_novel_ids, base_q_labels_cifar) in pbar:
            support_novel       = support_novel.to(device)
            query_images        = query_images.to(device)
            base_q_labels_cifar = base_q_labels_cifar.to(device)

            # 1) Mask: exclude pseudo-novel classes from base (False = excluded)
            exclude_mask = torch.ones(Cb, dtype=torch.bool, device=device)
            for cid in true_novel_ids:
                if cid in b2local:
                    exclude_mask[b2local[cid]] = False
            Cb_kept = int(exclude_mask.sum().item())

            # 2) Novel weights via model helper (FE frozen inside, GEN with grad)
            novel_weights = model.build_novel_weights(
                support_imgs=support_novel,
                k_shot=K_SHOT,
                exclude_mask=exclude_mask
            )  # (N, D)
            N = int(novel_weights.size(0))

            # 3) Logits: apply the same mask to the classifier
            logits = model.forward_logits(
                query_images, novel_weights, base_keep_mask=exclude_mask
            )  # [B, Cb_kept + N]

            # 4) Targets: novel after the "kept" base; remap base labels
            #    Assume the first N*Q_NOVEL queries are novel, the rest are base.
            y_novel = torch.arange(N, device=device).repeat_interleave(Q_NOVEL)  # (N*Q_NOVEL,)

            # Map CIFAR base id -> local base index [0..Cb-1]
            base_local_full = torch.tensor(
                [b2local[int(y.item())] for y in base_q_labels_cifar],
                device=device
            )

            # Map from [0..Cb-1] -> [0..Cb_kept-1] (=-1 if excluded)
            keep_idx = torch.nonzero(exclude_mask, as_tuple=True)[0]
            map_to_kept = torch.full((Cb,), -1, dtype=torch.long, device=device)
            map_to_kept[keep_idx] = torch.arange(Cb_kept, device=device)
            base_local_kept = map_to_kept[base_local_full]

            # If the sampler produced base queries belonging to excluded (pseudo-novel) classes, fail explicitly
            if (base_local_kept < 0).any():
                raise ValueError(
                    "Stage-2: base queries contain pseudo-novel classes. "
                    "Fix the sampler or filter/remap them before CE."
                )

            # Build targets
            B_total = logits.size(0)
            targets = torch.empty(B_total, dtype=torch.long, device=device)
            # Novel occupy [Cb_kept .. Cb_kept+N-1]
            targets[:N * Q_NOVEL] = Cb_kept + y_novel
            # Base occupy [0 .. Cb_kept-1]
            targets[N * Q_NOVEL:] = base_local_kept

            # 5) Backpropagation
            opt.zero_grad()
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, gc)
            opt.step()
            pbar.set_postfix(loss=f"{float(loss.item()):.4f}")

            # ---- PERIODIC VALIDATION ----
            if val_loader is not None and ((step + 1) % val_every == 0):
                Tval, vstats = evaluate_gfsl(model, val_loader, ds_train.targets, base_order, device)
                v_base  = vstats["base"]["mean"];  v_base_ci  = vstats["base"]["conf"]
                v_novel = vstats["novel"]["mean"]; v_novel_ci = vstats["novel"]["conf"]
                v_h     = vstats["hmean"]["mean"]; v_h_ci     = vstats["hmean"]["conf"]

                pbar.write(f"\t[stage2/val] step {step+1:05d}: "
                           f"base={100*v_base:.2f}%±{100*v_base_ci:.2f}  "
                           f"novel={100*v_novel:.2f}%±{100*v_novel_ci:.2f}  "
                           f"h-mean={100*v_h:.2f}%±{100*v_h_ci:.2f}  (T={Tval})")

                sel = {"base": v_base, "novel": v_novel, "hmean": v_h}[S2_SELECT_METRIC]
                if sel > best_score + 1e-6:
                    best_score = sel
                    patience   = 0
                    best_state = {
                        "model": copy.deepcopy(model.state_dict()),
                        "opt":   copy.deepcopy(opt.state_dict()),
                        "step":  step + 1,
                        "score": best_score,
                    }
                else:
                    patience += 1
                    if patience >= S2_PATIENCE:
                        pbar.write(f"[stage2] Early stopping (no improvement for {S2_PATIENCE} validations).")
                        break

                model.gen.train(); model.clf_base.train()

    # Restore the best result
    if best_state is not None:
        model.load_state_dict(best_state["model"])

In [None]:
model = copy.deepcopy(stage1_model)
train_stage2(model, train_loader_s2, device=device, base_order=base_order, val_loader=val_loader_s2)

[Stage2] Episodic Training:   3%|▎         | 502/20000 [01:27<31:28:07,  5.81s/it, loss=0.9047]

	[stage2/val] step 00500: base=56.23%±0.36  novel=63.45%±0.54  h-mean=59.15%±0.31  (T=1000)


[Stage2] Episodic Training:   5%|▌         | 1002/20000 [02:54<33:52:32,  6.42s/it, loss=0.6350]

	[stage2/val] step 01000: base=55.45%±0.37  novel=66.51%±0.54  h-mean=60.02%±0.31  (T=1000)


[Stage2] Episodic Training:   8%|▊         | 1503/20000 [04:22<25:46:53,  5.02s/it, loss=0.6000]

	[stage2/val] step 01500: base=55.60%±0.37  novel=66.45%±0.52  h-mean=60.09%±0.31  (T=1000)


[Stage2] Episodic Training:  10%|█         | 2003/20000 [05:48<24:38:15,  4.93s/it, loss=0.6037]

	[stage2/val] step 02000: base=55.53%±0.36  novel=67.17%±0.53  h-mean=60.34%±0.30  (T=1000)


[Stage2] Episodic Training:  13%|█▎        | 2503/20000 [07:14<24:09:16,  4.97s/it, loss=0.6208]

	[stage2/val] step 02500: base=55.06%±0.37  novel=67.08%±0.54  h-mean=60.00%±0.30  (T=1000)


[Stage2] Episodic Training:  15%|█▌        | 3002/20000 [08:41<27:14:00,  5.77s/it, loss=0.5253]

	[stage2/val] step 03000: base=55.40%±0.35  novel=67.49%±0.53  h-mean=60.42%±0.30  (T=1000)


[Stage2] Episodic Training:  18%|█▊        | 3502/20000 [10:06<25:55:03,  5.66s/it, loss=0.8207]

	[stage2/val] step 03500: base=55.36%±0.37  novel=67.44%±0.55  h-mean=60.33%±0.31  (T=1000)


[Stage2] Episodic Training:  20%|██        | 4002/20000 [11:33<28:13:38,  6.35s/it, loss=0.8080]

	[stage2/val] step 04000: base=55.83%±0.37  novel=67.69%±0.53  h-mean=60.72%±0.30  (T=1000)


[Stage2] Episodic Training:  23%|██▎       | 4502/20000 [13:00<24:34:41,  5.71s/it, loss=0.7757]

	[stage2/val] step 04500: base=55.47%±0.37  novel=67.62%±0.53  h-mean=60.49%±0.30  (T=1000)


[Stage2] Episodic Training:  25%|██▌       | 5002/20000 [14:26<24:09:38,  5.80s/it, loss=0.5833]

	[stage2/val] step 05000: base=55.66%±0.36  novel=67.37%±0.54  h-mean=60.48%±0.30  (T=1000)


[Stage2] Episodic Training:  28%|██▊       | 5502/20000 [15:53<21:35:37,  5.36s/it, loss=0.6606]

	[stage2/val] step 05500: base=55.47%±0.37  novel=68.15%±0.54  h-mean=60.69%±0.31  (T=1000)


[Stage2] Episodic Training:  30%|███       | 6002/20000 [17:19<22:35:51,  5.81s/it, loss=0.5648]

	[stage2/val] step 06000: base=55.80%±0.38  novel=67.57%±0.52  h-mean=60.67%±0.31  (T=1000)


[Stage2] Episodic Training:  33%|███▎      | 6502/20000 [18:47<24:09:11,  6.44s/it, loss=0.5850]

	[stage2/val] step 06500: base=55.86%±0.37  novel=68.20%±0.52  h-mean=60.96%±0.30  (T=1000)


[Stage2] Episodic Training:  35%|███▌      | 7003/20000 [20:14<19:42:09,  5.46s/it, loss=0.6108]

	[stage2/val] step 07000: base=56.08%±0.36  novel=68.33%±0.54  h-mean=61.15%±0.30  (T=1000)


[Stage2] Episodic Training:  38%|███▊      | 7502/20000 [21:39<19:38:00,  5.66s/it, loss=0.6294]

	[stage2/val] step 07500: base=55.70%±0.37  novel=67.94%±0.53  h-mean=60.76%±0.30  (T=1000)


[Stage2] Episodic Training:  40%|████      | 8002/20000 [23:06<21:13:34,  6.37s/it, loss=0.6494]

	[stage2/val] step 08000: base=55.93%±0.36  novel=67.63%±0.53  h-mean=60.78%±0.30  (T=1000)


[Stage2] Episodic Training:  43%|████▎     | 8502/20000 [24:31<20:12:01,  6.32s/it, loss=0.5712]

	[stage2/val] step 08500: base=56.27%±0.37  novel=67.44%±0.53  h-mean=60.93%±0.31  (T=1000)


[Stage2] Episodic Training:  45%|████▌     | 9002/20000 [25:57<19:17:06,  6.31s/it, loss=0.5397]

	[stage2/val] step 09000: base=56.12%±0.37  novel=68.67%±0.52  h-mean=61.34%±0.31  (T=1000)


[Stage2] Episodic Training:  48%|████▊     | 9502/20000 [27:23<16:32:27,  5.67s/it, loss=0.6812]

	[stage2/val] step 09500: base=56.41%±0.36  novel=68.77%±0.52  h-mean=61.58%±0.31  (T=1000)


[Stage2] Episodic Training:  50%|█████     | 10003/20000 [28:50<14:05:47,  5.08s/it, loss=0.4331]

	[stage2/val] step 10000: base=56.37%±0.37  novel=68.36%±0.52  h-mean=61.34%±0.30  (T=1000)


[Stage2] Episodic Training:  53%|█████▎    | 10502/20000 [30:16<16:47:34,  6.37s/it, loss=0.5228]

	[stage2/val] step 10500: base=55.78%±0.37  novel=68.79%±0.51  h-mean=61.18%±0.31  (T=1000)


[Stage2] Episodic Training:  55%|█████▌    | 11002/20000 [31:42<15:50:29,  6.34s/it, loss=0.5635]

	[stage2/val] step 11000: base=56.33%±0.37  novel=67.81%±0.52  h-mean=61.10%±0.31  (T=1000)


[Stage2] Episodic Training:  58%|█████▊    | 11502/20000 [33:09<13:29:46,  5.72s/it, loss=0.6125]

	[stage2/val] step 11500: base=56.57%±0.36  novel=68.39%±0.51  h-mean=61.51%±0.29  (T=1000)


[Stage2] Episodic Training:  60%|█████▉    | 11999/20000 [34:34<23:03,  5.78it/s, loss=0.7056]

	[stage2/val] step 12000: base=56.21%±0.36  novel=68.14%±0.52  h-mean=61.17%±0.30  (T=1000)
[stage2] Early stopping (no improvement for 5 validations).





### **2.3 Evaluation**


After the training, we evaluate again on the **test split** with episodic testing, reporting *base accuracy*, *novel accuracy* and their **harmonic mean**, including the *confidence interval at 95%* over the test episodes.

We report performance both after Stage-1 and Stage-2:
- Stage-1 evaluation reflects the quality of the backbone and the cosine classifier trained only on base classes.

In [None]:
T, stats = evaluate_gfsl(stage1_model, test_loader, ds_test.targets, base, device=device)
print_stats(T, stats, model="After stage 1 training")

[test] - After stage 1 training (95% CI on 1000 tasks)
 - [Base]   acc=57.83% ± 0.36%
 - [Novel]  acc=49.35% ± 0.54%
 - [H-mean] acc=52.64% ± 0.35%


- Stage-2 evaluation measures the full DFSLwF model, where the weight generator is trained to handle pseudo-novel episodes while preserving base performance.  

In [None]:
T, stats = evaluate_gfsl(model, test_loader, ds_test.targets, base, device=device)
print_stats(T, stats, model="After stage 2 training")

[test] - After stage 2 training (95% CI on 1000 tasks)
 - [Base]   acc=57.07% ± 0.37%
 - [Novel]  acc=67.33% ± 0.56%
 - [H-mean] acc=61.30% ± 0.32%


Comparing the two provides evidence of how much the generator improves novel recognition without forgetting base knowledge. In fact, we can observe that:

- **After Stage-1**, base accuracy is solid but novel accuracy is limited, since novel weights are only crude prototypes.  
- **After Stage-2**, **base accuracy remains stable**, while **novel accuracy increases significantly** thanks to the generator.  

As a result, the **harmonic mean rises from $≈53\%$ to $≈61\%$**, showing that DFSLwF successfully balances performance across base and novel classes.

