# **Generalized Few-Shot Learning (GFSL)**

Intro

---

## **1. Background**

### **1.1 Generalized Few-Shot Leaning (GFSL)**

???

### **1.2 Dynamic Few-Shot Learning without Forgetting**

???

### **1.3 CIFAR100 dataset**

???

---

## **2. Practice**

In [1]:
from __future__ import annotations
import math
import random
import types
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import copy
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torchvision.models import resnet18, ResNet18_Weights
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Seed:

In [2]:
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

train_rng = np.random.default_rng(SEED + 1)
test_rng = np.random.default_rng(SEED + 2)

Settings:

In [3]:
# Few-shot / episodic config (Stage 2 + Test)
N_WAY = 5
K_SHOT = 5
Q_NOVEL = 15
Q_BASE_TOTAL = 75

# CIFAR-100 split sizes
N_BASE = 64
N_VALNOVEL = 16
N_TESTNOVEL = 20

# Initial network params
TAU_INIT = 10.0  # temperature init

# Stage 1
STAGE1_EPOCHS = 120
STAGE1_LR = 3e-3
STAGE1_BS = 512
STAGE1_WEIGHT_DECAY = 5e-4
STAGE1_LABEL_SMOOTH = 0.1
S1_VAL_FRAC   = 0.10   # frazione per classe da tenere da parte per la validazione
S1_VAL_EVERY  = 2      # valida ogni N epoche
S1_PATIENCE   = 5      # early stop dopo N validazioni senza migliorare
S1_LOG_EVERY  = 50     # log loss ogni N batch

# Stage 2
STAGE2_TASKS = 20_000
STAGE2_VAL_EVERY = 500
STAGE2_LR = 5e-4
STAGE2_GRAD_CLIP = 1.0
S2_VAL_TASKS      = 1_000    # #episodi usati ad ogni validazione Stage-2
S2_PATIENCE       = 5        # early-stopping (GFSL) dopo N validazioni senza miglioramenti
S2_SELECT_METRIC  = "hmean"  # metrica per scegliere il best (hmean|base|novel)

# Test
TEST_TASKS = 1_000

Utility functions:

In [4]:
def split_cifar100_classes(seed: int, n_base=64, n_val=16, n_test=20):
    """Split CIFAR-100 class IDs into base/val-novel/test-novel sets.

    Shuffles the 100 class IDs with a reproducible RNG and returns three
    disjoint lists for base classes (used for supervised training), validation
    novel classes (optional episodic validation), and test novel classes
    (used in GFSL evaluation).

    Args:
        seed: Random seed for the class shuffling.
        n_base: Number of base classes.
        n_val: Number of validation novel classes.
        n_test: Number of test novel classes.

    Returns:
        A tuple (base, valn, testn) where each element is a list of class IDs.
    """
    rng = np.random.default_rng(seed)
    classes = np.arange(100); rng.shuffle(classes)
    return classes[:n_base].tolist(), classes[n_base:n_base+n_val].tolist(), classes[n_base+n_val:n_base+n_val+n_test].tolist()

def subset_by_classes(ds, keep):
    """Return a Subset containing only samples whose label is in `keep`.

    Uses vectorized filtering over `ds.targets` to select the indices that
    belong to the provided set of class IDs.

    Args:
        ds: A torchvision-style dataset exposing `targets` (list/array of ints).
        keep: Iterable of class IDs to retain.

    Returns:
        torch.utils.data.Subset wrapping `ds` with filtered indices.
    """
    t = np.array(ds.targets)
    idx = np.nonzero(np.isin(t, keep))[0]
    return Subset(ds, idx)

def class_to_local_indices(subset):
    """Build a mapping class_id -> list of *local* indices within `subset`.

    Iterates over the subset indices and groups them by their original class
    ID (read from `subset.dataset.targets`). Useful for fast episodic sampling
    (e.g., drawing K support and Q query images per class).

    Args:
        subset: A torch Subset whose `dataset` exposes `targets`
            and whose `indices` reference the original dataset.

    Returns:
        Dict[int, List[int]] mapping each class ID to a list of local indices
        (0..len(subset)-1) within the subset.

    Notes:
        The returned indices are local to `subset` (not the original dataset).
    """
    t = np.array(subset.dataset.targets)
    out = {}
    for j, i in enumerate(subset.indices):
        y = int(t[i])
        (out.setdefault(y, [])).append(j)
    return out

def stratified_split_subset(subset: Subset, val_frac: float, seed: int):
    """
    Divide un Subset (stesse classi) in (train_part, val_part),
    mantenendo proporzioni per ciascuna classe CIFAR originale.
    """
    rng = np.random.default_rng(seed)
    ds_targets = np.array(subset.dataset.targets)
    # mappa: class_id -> lista di indici LOCALi nel subset
    cls2locals = {}
    for j, i in enumerate(subset.indices):
        y = int(ds_targets[i])
        (cls2locals.setdefault(y, [])).append(j)

    train_locals, val_locals = [], []
    for y, locals_ in cls2locals.items():
        locals_ = np.array(locals_, dtype=int)
        n_val = max(1, int(round(len(locals_) * val_frac)))
        rng.shuffle(locals_)
        val_locals.append(locals_[:n_val])
        train_locals.append(locals_[n_val:])

    train_locals = np.concatenate(train_locals).tolist()
    val_locals   = np.concatenate(val_locals).tolist()

    train_indices = [subset.indices[i] for i in train_locals]
    val_indices   = [subset.indices[i] for i in val_locals]

    return Subset(subset.dataset, train_indices), Subset(subset.dataset, val_indices)

In [5]:
def l2_normalize(x: torch.Tensor, dim: int = 1, eps: float = 1e-6) -> torch.Tensor:
  """L2-normalize a tensor along a given dimension.

  Each vector along `dim` is divided by its L2 norm, producing unit-length
  vectors. A small epsilon is used to avoid division by zero.

  Args:
      x: Input tensor.
      dim: Dimension along which to compute the L2 norm (default: 1).
      eps: Minimum norm value used for numerical stability (default: 1e-6).

  Returns:
      A tensor with the same shape as `x`, L2-normalized along `dim`.
  """
  return x / (x.norm(p=2, dim=dim, keepdim=True).clamp_min(eps))

In [6]:
def set_bn_eval(m: nn.Module):
  """Put BatchNorm2d layers in eval mode and freeze their parameters.

  When applied (e.g., `model.apply(set_bn_eval)`), this sets each
  `nn.BatchNorm2d` module to evaluation mode so it uses stored running
  statistics and stops updating them, and it disables gradient updates
  for its affine parameters (gamma/beta).

  Args:
      m: A module that may be an instance of `nn.BatchNorm2d`.

  Returns:
      None. The module is modified in place if it is BatchNorm2d.
  """
  if isinstance(m, nn.BatchNorm2d):
      m.eval()
      for p in m.parameters():
          p.requires_grad = False

In [7]:
def sliding_avg(xs: List[float], k: int = 20) -> float:
  if not xs:
      return 0.0
  if len(xs) < k:
      return float(sum(xs) / len(xs))
  return float(sum(xs[-k:]) / k)

In [8]:
def mean_ci95(xs: np.ndarray) -> Tuple[float, float]:
    xs = np.asarray(xs, dtype=float)
    n  = xs.size
    if n < 2:
        return float(xs.mean()), 0.0
    std = xs.std(ddof=1)
    stderr = std / np.sqrt(n)
    z = 1.96
    return float(xs.mean()), float(z * stderr)

def gfsl_stats(
    acc_per_ep_base: List[float],
    acc_per_ep_novel: List[float],
) -> Dict[str, Dict[str, float]]:

    if len(acc_per_ep_base) != len(acc_per_ep_novel):
        raise ValueError("base and novel must be of the same length")
    T = len(acc_per_ep_base)

    base = np.asarray(acc_per_ep_base, dtype=float)
    novel = np.asarray(acc_per_ep_novel, dtype=float)

    denom = base + novel
    h_per_ep = np.where(denom > 0, 2.0 * base * novel / denom, 0.0)

    base_mean, base_ci  = mean_ci95(base)
    novel_mean, novel_ci = mean_ci95(novel)
    h_mean, h_ci = mean_ci95(h_per_ep)

    return {
        "base":  {"mean": base_mean,  "conf": base_ci},
        "novel": {"mean": novel_mean, "conf": novel_ci},
        "hmean": {"mean": h_mean,     "conf": h_ci},
    }

def print_stats(T: int, stats: Dict[str, Dict[str, float]], model: str = ""):
  print(f"[test] - {model} (95% CI on {T} tasks)")
  print(f" - [Base]   acc={100*stats['base']['mean']:.2f}% ± {100*stats['base']['conf']:.2f}%")
  print(f" - [Novel]  acc={100*stats['novel']['mean']:.2f}% ± {100*stats['novel']['conf']:.2f}%")
  print(f" - [H-mean] acc={100*stats['hmean']['mean']:.2f}% ± {100*stats['hmean']['conf']:.2f}%")


### **2.1 Environment**

#### **2.1.1 CIFAR100 dataset**

???

In [9]:
# Transforms
IMAGE_SIZE = 32

train_tf = transforms.Compose([
    transforms.RandomCrop(IMAGE_SIZE, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])
eval_tf = transforms.Compose([transforms.ToTensor(),])

In [10]:
ds_train = CIFAR100(root="./data", train=True,  transform=train_tf, download=True)
ds_test  = CIFAR100(root="./data", train=False, transform=eval_tf,  download=True)

base, valn, testn = split_cifar100_classes(SEED)

In [11]:
# Stage 1: Train + validation (base) / Stage 2: Train (base)
train_base_full = subset_by_classes(ds_train, base)
train_base_tr, train_base_val = stratified_split_subset(train_base_full, S1_VAL_FRAC, SEED+7)

# Stage 2: Validation (Novel)
train_valnovel = subset_by_classes(ds_train, valn)

# Test (base + novel)
test_base  = subset_by_classes(ds_test,  base)
test_novel = subset_by_classes(ds_test,  testn)

In [12]:
# Per Stage-1 (supervised)
cti_train_base_tr  = class_to_local_indices(train_base_tr)
cti_train_base_val = class_to_local_indices(train_base_val)

# Per Stage-2 VALIDAZIONE (episodica GFSL) — tutto dal TRAIN
cti_val_base   = cti_train_base_val
cti_val_novel  = class_to_local_indices(train_valnovel)

# Per TEST (episodico GFSL) — come prima
cti_test_base  = class_to_local_indices(test_base)
cti_test_novel = class_to_local_indices(test_novel)

In [13]:
len(test_novel)

2000

**2.1.1.1 DataLoader: Stage 1**

In [14]:
class Stage1TrainDS(torch.utils.data.Dataset):
    def __init__(self, subset: Subset, orig_targets: List[int], order_map: Dict[int, int]):
        """Subset over base classes con label locali precompute [0..Cb-1]."""
        self.subset = subset
        self.local_labels = [order_map[int(orig_targets[i])] for i in subset.indices]

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        x, _ = self.subset[idx]
        return x, self.local_labels[idx]

In [15]:
base_order = sorted(base)
order_map  = {cid: i for i, cid in enumerate(base_order)}

stage1_train_ds = Stage1TrainDS(train_base_tr,  ds_train.targets, order_map)  # train
stage1_val_ds   = Stage1TrainDS(train_base_val, ds_train.targets, order_map)  # val (disjoined)

train_loader_s1 = DataLoader(stage1_train_ds, batch_size=STAGE1_BS, shuffle=True, num_workers=2, pin_memory=True)
val_loader_s1   = DataLoader(stage1_val_ds,   batch_size=STAGE1_BS*2, shuffle=False, num_workers=2, pin_memory=True)

**2.1.1.2 DataLoader: Stage 2**

In [16]:
class GFSLTrainEpisodicBatchSampler:
    """
    Stage-2: N_WAY pseudo-novel (from BASE train) with K+Qn each + Qb base queries from BASE (any class).
    Returns indices over the Subset(train_base).
    """
    def __init__(self, class_to_indices: Dict[int, List[int]], n_tasks: int, rng: np.random.Generator,
                 n_way: int = N_WAY, k_shot: int = K_SHOT, q_novel: int = Q_NOVEL, q_base_total: int = Q_BASE_TOTAL):
        self.cti = class_to_indices
        self.all_classes = list(class_to_indices.keys())
        self.n_tasks = n_tasks
        self.rng = rng
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_novel = q_novel
        self.q_base_total = q_base_total

    def __len__(self): return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):

            novel = self.rng.choice(self.all_classes, size=self.n_way, replace=False)

            batch = []
            for c in novel:
                pool = self.cti[c]
                need = self.k_shot + self.q_novel
                if len(pool) < need:
                    raise ValueError(f"Class {c} has {len(pool)} < {need}")
                idx = self.rng.choice(pool, size=need, replace=False)
                batch.append(idx)

            novel_classes = set(novel.tolist())
            base_pool_classes = [c for c in self.all_classes if c not in novel_classes]

            used = set(np.concatenate(batch).tolist())

            base_q = []
            while len(base_q) < self.q_base_total:
                c = int(self.rng.choice(base_pool_classes))
                cand = int(self.rng.choice(self.cti[c]))
                if cand not in used:
                    base_q.append(cand)
                    used.add(cand)

            batch.append(np.array(base_q, dtype=int))
            yield np.concatenate(batch)

stage2_train_batch_sampler = GFSLTrainEpisodicBatchSampler(
    cti_train_base_tr, n_tasks=STAGE2_TASKS, rng=train_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL
)

In [17]:
def stage2_collate(batch, n_way: int, k_shot: int, q_novel: int, q_base_total: int):
    """
    Output:
      - support_novel: (N*K, C, H, W)
      - query_images : (N*Qn + Qb, C, H, W)
      - true_novel_ids (original CIFAR ids, one per novel class)
      - base_query_labels_cifar: (Qb,) original CIFAR ids for base queries (for true targets)
    """
    imgs, labs = list(zip(*batch))
    images = torch.stack(imgs)
    labels = torch.tensor([int(y) for y in labs])

    per_novel = k_shot + q_novel
    total_novel = n_way * per_novel

    novel_block = images[:total_novel].view(n_way, per_novel, *images.shape[1:])
    novel_labels_block = labels[:total_novel].view(n_way, per_novel)

    support_novel = novel_block[:, :k_shot].reshape(-1, *images.shape[1:])
    query_novel = novel_block[:, k_shot:].reshape(-1, *images.shape[1:])
    query_base = images[total_novel:]

    query_images = torch.cat([query_novel, query_base], dim=0)

    true_novel_ids = [int(novel_labels_block[i, 0].item()) for i in range(n_way)]
    base_query_labels_cifar = labels[total_novel:]  # original CIFAR ids
    return support_novel, query_images, true_novel_ids, base_query_labels_cifar

stage2_collate_fn = partial(
    stage2_collate,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL,
)

In [18]:
class Stage2TrainDS(torch.utils.data.Dataset):
    def __init__(self, subset: Subset):
        self.subset = subset
    def __len__(self):
      return len(self.subset)
    def __getitem__(self, idx):
      return self.subset[idx]

stage2_train_ds = Stage2TrainDS(train_base_tr)

In [19]:
train_loader_s2 = DataLoader(
    stage2_train_ds, batch_sampler=stage2_train_batch_sampler,
    collate_fn=stage2_collate_fn, num_workers=2, pin_memory=True
)

**2.1.1.3 DataLoader: Stage 2 Validation and Test**

In [20]:
class GFSLEvalEpisodeSampler:
    """Batch sampler for GFSL test episodes over a single ConcatDataset.

    Each yielded batch is a 1D numpy array of indices into `test_concat`
    laid out as:
        [ N_WAY*(K_SHOT+Q_NOVEL) indices from novel part | Q_BASE_TOTAL indices from base part ]

    This mirrors Stage-2 structure, but draws from the *test* splits:
      - support + query for true novel classes from the novel test subset
      - base queries from the base test subset
    """

    def __init__(self, base_cti: Dict[int, List[int]], novel_cti: Dict[int, List[int]], offset_base: int, n_tasks: int, rng: np.random.Generator,
                 n_way: int = N_WAY, k_shot: int = K_SHOT, q_novel: int = Q_NOVEL, q_base_total: int = Q_BASE_TOTAL):
        # class-id -> list of local indices (within each Subset) for base/novel test splits
        self.base_cti = base_cti
        self.novel_cti = novel_cti
        self.offset_base = offset_base

        # explicit class-id pools
        self.base_classes  = list(base_cti.keys())
        self.novel_classes = list(novel_cti.keys())

        # episode config
        self.n_tasks = n_tasks
        self.rng = rng
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_novel = q_novel
        self.q_base_total = q_base_total

    def __len__(self):
        return self.n_tasks

    def __iter__(self):
        for _ in range(self.n_tasks):
            # ---- NOVEL BLOCK: pick N_WAY true novel classes and sample K+Qn per class (no replacement) ----
            chosen_novel = self.rng.choice(self.novel_classes, size=self.n_way, replace=False)

            novel_chunks = []
            per_novel = self.k_shot + self.q_novel

            for c in chosen_novel:
                pool = self.novel_cti[c]                 # local indices within test_novel
                if len(pool) < per_novel:
                    raise ValueError(f"Novel class {c} has {len(pool)} < {per_novel}")
                # sample K+Qn *without* replacement to avoid reusing the same image as support/query
                idx = self.rng.choice(pool, size=per_novel, replace=False)
                novel_chunks.append(idx)

            # Flatten the novel part; indices are still in the "novel namespace" (no offset)
            novel_block = np.concatenate(novel_chunks).astype(int)

            # ---- BASE BLOCK: sample Qb indices from base part (optionally allow replacement) ----
            base_q = []
            while len(base_q) < self.q_base_total:
                c = int(self.rng.choice(self.base_classes))
                # sample a local index within test_base
                cand_local = int(self.rng.choice(self.base_cti[c]))
                # shift to address the second component of ConcatDataset
                base_q.append(cand_local + self.offset_base)

            base_block = np.array(base_q, dtype=int)

            # ---- FINAL EPISODE ----
            # Concatenate [novel | base] to match the downstream collate expectations
            full_episode = np.concatenate([novel_block, base_block])
            yield full_episode

In [21]:
def eval_collate(batch, n_way: int, k_shot: int, q_novel: int, q_base_total: int):
    """Reconstruct support/query tensors for a GFSL test episode.

    Input `batch` is a list of (image, label) from `test_concat` where the first
    N_WAY*(K_SHOT+Q_NOVEL) items belong to the novel subset and the remaining Q_BASE_TOTAL
    items belong to the base subset.
    """
    imgs, labs = list(zip(*batch))
    images = torch.stack(imgs)
    labels = torch.tensor([int(y) for y in labs])

    per_novel = k_shot + q_novel
    total_novel = n_way * per_novel

    # reshape novel block into (N, K+Qn, C, H, W) and (N, K+Qn) for labels
    novel_block = images[:total_novel].view(n_way, per_novel, *images.shape[1:])
    novel_labels_block = labels[:total_novel].view(n_way, per_novel)

    # split into support (first K) and query (last Qn)
    support_novel = novel_block[:, :k_shot].reshape(-1, *images.shape[1:])
    query_novel   = novel_block[:, k_shot:].reshape(-1, *images.shape[1:])
    query_base    = images[total_novel:]  # remaining Qb from base subset

    # concatenate all queries [novel | base]
    query_images = torch.cat([query_novel, query_base], dim=0)

    # collect true novel CIFAR IDs (one per class) and per-query GT labels
    true_novel_ids = [int(novel_labels_block[i, 0].item()) for i in range(n_way)]
    gt_novel = novel_labels_block[:, k_shot:].reshape(-1)
    gt_base  = labels[total_novel:]

    return support_novel, query_images, true_novel_ids, gt_novel, gt_base

eval_collate_fn = partial(
    eval_collate,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL,
)

Validation:

In [22]:
sampler_val_s2 = GFSLEvalEpisodeSampler(
    cti_val_base, cti_val_novel, len(train_valnovel),
    n_tasks=S2_VAL_TASKS, rng=test_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL
)

In [23]:
# Build a single test dataset by concatenating the two subsets.
# Order matters: novel part first, base part second.
val_concat_s2 = ConcatDataset([train_valnovel, train_base_val])

In [24]:
val_loader_s2 = DataLoader(
    val_concat_s2,
    batch_sampler=sampler_val_s2,
    collate_fn=eval_collate_fn,
    num_workers=2,
    pin_memory=True,
)

Test:

In [25]:
sampler_test = GFSLEvalEpisodeSampler(
    cti_test_base, cti_test_novel, len(test_novel), n_tasks=TEST_TASKS, rng=test_rng,
    n_way=N_WAY, k_shot=K_SHOT, q_novel=Q_NOVEL, q_base_total=Q_BASE_TOTAL,
)

In [26]:
# Build a single test dataset by concatenating the two subsets.
# Order matters: novel part first, base part second.
test_concat = ConcatDataset([test_novel, test_base])

In [27]:
test_loader = DataLoader(
    test_concat,
    batch_sampler=sampler_test,   # emits indices into the single concatenated dataset
    collate_fn=eval_collate_fn,   # reconstructs (support/query etc.) from that batch
    num_workers=2,
    pin_memory=True,
)

#### **2.1.2 DFSLwF module**

???

In [28]:
class ConvBlock(nn.Module):
    """Conv3x3 -> BN -> ReLU -> (optional MaxPool2d)."""
    def __init__(self, in_ch, out_ch, pool: bool):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
        self.bn   = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(2) if pool else nn.Identity()

    def forward(self, x):
        x = self.conv(x); x = self.bn(x); x = self.relu(x); x = self.pool(x)
        return x


class FeatureExtractor(nn.Module):
    """Feature extractor with selectable backbone: 'conv' (light) or 'resnet18'.

    Args:
        backbone: 'conv' or 'resnet18'.
        normalize_out: if True, L2-normalize the output features.
        resnet_pretrained: if True (only for 'resnet18'), load ImageNet pretrained weights.
        remove_last_relu: if True (only for 'resnet18'), remove the last post-add ReLU in the final BasicBlock
                          (useful with cosine classifiers, per DFSLwF ablations).
    """
    def __init__(
        self,
        backbone: str = "conv",
        normalize_out: bool = True,
        resnet_pretrained: bool = True,
        remove_last_relu: bool = True,
    ):
        super().__init__()
        self.normalize_out = normalize_out

        if backbone.lower() == "conv":
            # Lightweight CIFAR-style CNN trained from scratch: 64-D output
            self.fe = nn.Sequential(
                ConvBlock(3,   64, pool=True),   # 32 -> 16
                ConvBlock(64,  64, pool=True),   # 16 -> 8
                ConvBlock(64,  64, pool=False),
                ConvBlock(64,  64, pool=False),
                nn.AdaptiveAvgPool2d(1),         # -> (B,64,1,1)
            )
            self._mode = "conv"
            self.out_dim = 64

        elif backbone.lower() == "resnet18":
            weights = ResNet18_Weights.IMAGENET1K_V1 if resnet_pretrained else None
            m = resnet18(weights=weights)
            m.fc = nn.Identity()  # we want the 512-D penultimate features

            if remove_last_relu:
                # Patch only the final BasicBlock to skip the post-add ReLU
                last_block = m.layer4[-1]

                def forward_norelu(self_block, x):
                    identity = x
                    out = self_block.conv1(x); out = self_block.bn1(out); out = self_block.relu(out)
                    out = self_block.conv2(out); out = self_block.bn2(out)
                    if self_block.downsample is not None:
                        identity = self_block.downsample(x)
                    out = out + identity
                    return out  # no final ReLU

                last_block.forward = types.MethodType(forward_norelu, last_block)

            self.fe = m
            self._mode = "resnet18"
            self.out_dim = 512

        else:
            raise ValueError(f"Unknown backbone '{backbone}'. Use 'conv' or 'resnet18'.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self._mode == "conv":
            z = self.fe(x).squeeze(-1).squeeze(-1)    # (B, 64)
        else:  # resnet18
            z = self.fe(x)                            # (B, 512)
        return l2_normalize(z, dim=1) if self.normalize_out else z

In [29]:
class CosineClassifier(nn.Module):
    """Cosine classifier with learnable temperature τ (tau)."""
    def __init__(self, in_dim: int, n_classes: int, init_scale: float = TAU_INIT):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(n_classes, in_dim))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.tau = nn.Parameter(torch.tensor(float(init_scale)))

    def forward(self, feats: torch.Tensor) -> torch.Tensor:
        W = l2_normalize(self.weight, dim=1)
        feats = l2_normalize(feats, dim=1)
        logits = feats @ W.t()
        return self.tau * logits

In [30]:
class FewShotWeightGenerator(nn.Module):
    """
    [DFSLwF] Few-shot classification weight generator = Avg + Attention:
      - w_avg  = mean(z_i)                     → φ_avg ⊙ w_avg
      - w_att  = avg_i softmax(γ cos(φ_q z_i, k_b)) · w_b (over base classes b) → φ_att ⊙ w_att
      - w'     = φ_avg ⊙ w_avg + φ_att ⊙ w_att, then L2-normalize
    Includes:
      - learnable keys k_b (one per base class), size (C_base, D)
      - learnable φ_q (Linear D→D, no bias), φ_avg, φ_att (vectors), γ (scalar)
      - optional dropout on features during training (Stage-2)
    """
    def __init__(self, dim: int, num_base: int, p_dropout: float = 0.5):
        super().__init__()
        self.dim = dim
        self.num_base = num_base
        self.phi_q = nn.Linear(dim, dim, bias=False)
        nn.init.kaiming_uniform_(self.phi_q.weight, a=math.sqrt(5))
        self.keys = nn.Parameter(l2_normalize(torch.randn(num_base, dim), dim=1))
        self.phi_avg = nn.Parameter(torch.ones(dim))
        self.phi_att = nn.Parameter(torch.ones(dim))
        self.gamma = nn.Parameter(torch.tensor(10.0))  # attention temperature (like τ)
        self.dropout = nn.Dropout(p=p_dropout)

    def forward(
        self,
        support_feats: torch.Tensor,          # (N*K, D), L2-normalized
        base_weights: torch.Tensor,           # (C_base, D), not necessarily normalized
        shots_per_class: int,
        exclude_mask: Optional[torch.Tensor] = None  # (C_base,) bool; True = keep; False = exclude
    ) -> torch.Tensor:
        D = support_feats.size(1)
        N = support_feats.size(0) // shots_per_class
        z = support_feats.view(N, shots_per_class, D)
        if self.training:
            z = self.dropout(z)
        # w_avg
        w_avg = l2_normalize(z.mean(dim=1), dim=1)  # (N, D)
        # attention over base weights
        Wb = l2_normalize(base_weights, dim=1)      # (C_base, D)
        Kb = l2_normalize(self.keys, dim=1)         # (C_base, D)
        # queries
        q = self.phi_q(z)                           # (N, K, D)
        q = l2_normalize(q, dim=2)                 # normalize across D
        # cosine(q, Kb) => (N, K, C_base)
        att_logits = torch.einsum("nkd,bd->nkb", q, Kb) * self.gamma
        if exclude_mask is not None:
            # set -inf on excluded classes before softmax
            mask = exclude_mask.view(1, 1, -1)  # broadcast
            att_logits = att_logits.masked_fill(~mask, float("-inf"))
        att = torch.softmax(att_logits, dim=2)      # (N, K, C_base)
        # weighted sum of base weights -> (N, K, D), then average over K (shots)
        w_att = torch.einsum("nkb,bd->nkd", att, Wb).mean(dim=1)  # (N, D)
        # combine
        w = self.phi_avg * w_avg + self.phi_att * w_att            # Hadamard
        w = l2_normalize(w, dim=1)  # (N, D)
        return w

In [31]:
class DFSLwF(nn.Module):
    def __init__(self, fe: FeatureExtractor, clf_base: CosineClassifier, gen: FewShotWeightGenerator):
        super().__init__()
        self.fe = fe
        self.clf_base = clf_base
        self.gen = gen

    def forward_logits(self, x: torch.Tensor, novel_weights: torch.Tensor | None = None) -> torch.Tensor:
        feats = self.fe(x)
        Wb = l2_normalize(self.clf_base.weight, dim=1)
        logits = self.clf_base.tau * (feats @ Wb.t())
        if novel_weights is not None and novel_weights.numel() > 0:
            Wn = l2_normalize(novel_weights, dim=1)
            logits_n = self.clf_base.tau * (feats @ Wn.t())
            logits = torch.cat([logits, logits_n], dim=1)
        return logits

    @torch.no_grad()
    def build_novel_weights(self, support_imgs: torch.Tensor, k_shot: int) -> torch.Tensor:
        supp = self.fe(support_imgs)  # (N*K, D), already normalized
        Wb = self.clf_base.weight
        return self.gen(supp, Wb, k_shot, exclude_mask=None)

In [32]:
# fe  = FeatureExtractor(backbone="resnet18", normalize_out=True, resnet_pretrained=True, remove_last_relu=True)
fe  = FeatureExtractor(backbone="conv", normalize_out=True)

clf = CosineClassifier(in_dim=fe.out_dim, n_classes=len(base_order))
gen = FewShotWeightGenerator(dim=fe.out_dim, num_base=len(base_order), p_dropout=0.5)

model = DFSLwF(fe=fe, clf_base=clf, gen=gen).to(device)

In [33]:
baseline_model = copy.deepcopy(model)

### **2.2 Training**

???

#### **2.2.1 Stage 1: supervised base training**

???

In [34]:
@torch.no_grad()
def evaluate_stage1_base_top1(model: DFSLwF, loader: DataLoader, device) -> float:
    model.fe.eval(); model.clf_base.eval()
    correct, total = 0, 0
    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        logits = model.clf_base(model.fe(xb))
        pred = logits.argmax(dim=1)
        correct += (pred == yb).sum().item()
        total   += yb.numel()
    return correct / max(1, total)

In [35]:
def train_stage1(model: DFSLwF, loader: DataLoader, device: torch.device,
                 epochs: int = STAGE1_EPOCHS, lr: float = STAGE1_LR,
                 weight_decay: float = STAGE1_WEIGHT_DECAY,
                 label_smoothing: float = STAGE1_LABEL_SMOOTH,
                 val_loader: Optional[DataLoader] = None):

    # [DFSLwF] Train feature extractor + base classifier (cosine)
    model.fe.train(); model.clf_base.train(); model.gen.eval()

    # Keep BN learnable here (paper trains a standard classifier in Stage-1)
    params    = list(model.fe.parameters()) + list(model.clf_base.parameters())
    opt       = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    # Params for validation
    best_val = -1.0
    patience = 0
    best_state = None

    with tqdm(range(epochs), desc="[Stage1] Supervised base training") as epbar:
        for ep in epbar:
            model.fe.train(); model.clf_base.train()
            batch_losses = []

            for i, (xb, yb) in enumerate(loader):
                xb = xb.to(device); yb = yb.to(device)
                opt.zero_grad()
                logits = model.clf_base(model.fe(xb))
                loss   = criterion(logits, yb)
                loss.backward()
                opt.step()
                batch_losses.append(float(loss.item()))

                if (i + 1) % S1_LOG_EVERY == 0:
                    epbar.set_postfix(loss=f"{sliding_avg(batch_losses, S1_LOG_EVERY):.4f}")

            # ---- PERIODIC VALIDATION ----
            if val_loader is not None and ((ep + 1) % S1_VAL_EVERY == 0):
                val_acc = evaluate_stage1_base_top1(model, val_loader, device)
                epbar.write(f"\t[stage1/val] epoch {ep+1:03d}: acc_base={100*val_acc:.2f}%")

                if val_acc > best_val + 1e-6:
                    best_val   = val_acc
                    patience   = 0
                    best_state = {
                        "model": copy.deepcopy(model.state_dict()),
                        "opt":   copy.deepcopy(opt.state_dict()),
                        "epoch": ep + 1,
                        "val":   best_val,
                    }
                else:
                    patience += 1
                    if patience >= S1_PATIENCE:
                        epbar.write(f"\t[stage1] Early stopping (no improvement for {S1_PATIENCE} validations).")
                        break

    # Restore best result
    if best_state is not None:
        model.load_state_dict(best_state["model"])

In [36]:
train_stage1(model, train_loader_s1, device=device, val_loader=val_loader_s1)

[Stage1] Supervised base training:   2%|▏         | 2/120 [00:18<18:10,  9.24s/it, loss=3.1534]

	[stage1/val] epoch 002: acc_base=13.34%


[Stage1] Supervised base training:   3%|▎         | 4/120 [00:40<20:15, 10.48s/it, loss=2.7863]

	[stage1/val] epoch 004: acc_base=30.06%


[Stage1] Supervised base training:   5%|▌         | 6/120 [00:56<17:09,  9.03s/it, loss=2.5861]

	[stage1/val] epoch 006: acc_base=33.59%


[Stage1] Supervised base training:   7%|▋         | 8/120 [01:12<15:53,  8.52s/it, loss=2.4521]

	[stage1/val] epoch 008: acc_base=37.88%


[Stage1] Supervised base training:   8%|▊         | 10/120 [01:28<15:27,  8.44s/it, loss=2.3539]

	[stage1/val] epoch 010: acc_base=29.69%


[Stage1] Supervised base training:  10%|█         | 12/120 [01:45<15:11,  8.44s/it, loss=2.2590]

	[stage1/val] epoch 012: acc_base=42.03%


[Stage1] Supervised base training:  12%|█▏        | 14/120 [02:02<14:41,  8.32s/it, loss=2.1819]

	[stage1/val] epoch 014: acc_base=41.56%


[Stage1] Supervised base training:  13%|█▎        | 16/120 [02:18<14:27,  8.34s/it, loss=2.1294]

	[stage1/val] epoch 016: acc_base=43.84%


[Stage1] Supervised base training:  15%|█▌        | 18/120 [02:34<13:56,  8.20s/it, loss=2.0740]

	[stage1/val] epoch 018: acc_base=45.66%


[Stage1] Supervised base training:  17%|█▋        | 20/120 [02:50<13:35,  8.16s/it, loss=2.0301]

	[stage1/val] epoch 020: acc_base=48.62%


[Stage1] Supervised base training:  18%|█▊        | 22/120 [03:05<13:01,  7.97s/it, loss=1.9891]

	[stage1/val] epoch 022: acc_base=47.16%


[Stage1] Supervised base training:  20%|██        | 24/120 [03:22<13:06,  8.20s/it, loss=1.9553]

	[stage1/val] epoch 024: acc_base=42.72%


[Stage1] Supervised base training:  22%|██▏       | 26/120 [03:38<12:38,  8.06s/it, loss=1.9199]

	[stage1/val] epoch 026: acc_base=41.62%


[Stage1] Supervised base training:  23%|██▎       | 28/120 [03:54<12:26,  8.11s/it, loss=1.8964]

	[stage1/val] epoch 028: acc_base=50.47%


[Stage1] Supervised base training:  25%|██▌       | 30/120 [04:10<12:03,  8.04s/it, loss=1.8753]

	[stage1/val] epoch 030: acc_base=48.56%


[Stage1] Supervised base training:  27%|██▋       | 32/120 [04:27<12:11,  8.32s/it, loss=1.8473]

	[stage1/val] epoch 032: acc_base=47.25%


[Stage1] Supervised base training:  28%|██▊       | 34/120 [04:42<11:43,  8.18s/it, loss=1.8328]

	[stage1/val] epoch 034: acc_base=47.97%


[Stage1] Supervised base training:  30%|███       | 36/120 [04:58<11:20,  8.10s/it, loss=1.8084]

	[stage1/val] epoch 036: acc_base=50.38%


[Stage1] Supervised base training:  32%|███▏      | 38/120 [05:14<10:54,  7.98s/it, loss=1.7903]

	[stage1/val] epoch 038: acc_base=52.91%


[Stage1] Supervised base training:  33%|███▎      | 40/120 [05:31<10:57,  8.22s/it, loss=1.7788]

	[stage1/val] epoch 040: acc_base=52.25%


[Stage1] Supervised base training:  35%|███▌      | 42/120 [05:46<10:33,  8.13s/it, loss=1.7662]

	[stage1/val] epoch 042: acc_base=54.31%


[Stage1] Supervised base training:  37%|███▋      | 44/120 [06:02<10:15,  8.10s/it, loss=1.7675]

	[stage1/val] epoch 044: acc_base=51.94%


[Stage1] Supervised base training:  38%|███▊      | 46/120 [06:18<09:52,  8.00s/it, loss=1.7527]

	[stage1/val] epoch 046: acc_base=42.19%


[Stage1] Supervised base training:  40%|████      | 48/120 [06:35<09:48,  8.18s/it, loss=1.7376]

	[stage1/val] epoch 048: acc_base=52.84%


[Stage1] Supervised base training:  42%|████▏     | 50/120 [06:50<09:26,  8.09s/it, loss=1.7297]

	[stage1/val] epoch 050: acc_base=51.47%


[Stage1] Supervised base training:  42%|████▎     | 51/120 [07:06<09:36,  8.36s/it, loss=1.7148]

	[stage1/val] epoch 052: acc_base=51.91%
	[stage1] Early stopping (no improvement for 5 validations).





In [37]:
stage1_model = copy.deepcopy(model)

#### **2.2.2 Stage 2: episodic training**

???

In [38]:
@torch.no_grad()
def evaluate_gfsl(model: DFSLwF, test_loader: DataLoader, cifar_targets_all: List[int],
                  base_classes: List[int], device: torch.device) -> Tuple[float, float, float]:
    model.fe.eval(); model.clf_base.eval(); model.gen.eval()
    acc_per_episode_base, acc_per_episode_novel = [], []
    base_order = sorted(base_classes)
    b2local = {cid: i for i, cid in enumerate(base_order)}
    Cb = model.clf_base.weight.size(0)

    for (support_novel, query_images, true_novel_ids, gt_novel, gt_base) in test_loader:
        support_novel = support_novel.to(device)
        query_images = query_images.to(device)
        gt_novel = gt_novel.to(device)
        gt_base = gt_base.to(device)

        novel_weights = model.build_novel_weights(support_novel, K_SHOT)  # (N, D)
        logits = model.forward_logits(query_images, novel_weights)
        preds = logits.argmax(dim=1)

        N = novel_weights.size(0)
        pred_novel = preds[:N * Q_NOVEL] - Cb
        pred_base = preds[N * Q_NOVEL:]

        id2local = {cid: i for i, cid in enumerate(true_novel_ids)}
        gt_novel_local = torch.tensor([id2local[int(y.item())] for y in gt_novel], device=device)
        gt_base_local = torch.tensor([b2local[int(y.item())] for y in gt_base], device=device)

        acc_b = (pred_base == gt_base_local).float().mean().item()
        acc_n = (pred_novel == gt_novel_local).float().mean().item()
        acc_per_episode_base.append(acc_b); acc_per_episode_novel.append(acc_n)

    return len(acc_per_episode_base), gfsl_stats(acc_per_episode_base, acc_per_episode_novel)

In [39]:
def train_stage2(model: DFSLwF, meta_loader: DataLoader, device: torch.device,
                 base_order: List[int], n_tasks: int = STAGE2_TASKS, val_every: int = STAGE2_VAL_EVERY,
                 lr: float = STAGE2_LR, val_loader: Optional[DataLoader] = None):
    """
    [DFSLwF] Freeze F, train generator + continue training W_base (and τ). Exclude pseudo-novel from attention memory.
    Use *true* base labels for base queries; novel queries target indices are (Cb .. Cb+N-1).
    """
    # Freeze feature extractor; freeze BN stats
    model.fe.eval(); model.fe.apply(set_bn_eval)
    for p in model.fe.parameters(): p.requires_grad = False

    # Train generator, base weights, and tau
    for p in model.clf_base.parameters(): p.requires_grad = True
    params = list(model.gen.parameters()) + [model.clf_base.tau, model.clf_base.weight]
    opt = torch.optim.Adam(params, lr=lr)
    criterion = nn.CrossEntropyLoss()

    # mapping CIFAR id -> local base index
    b2local = {cid: i for i, cid in enumerate(sorted(base_order))}
    Cb = model.clf_base.weight.size(0)

    # Validation params
    best_score = -1.0
    patience   = 0
    best_state = None

    model.gen.train(); model.clf_base.train()
    with tqdm(enumerate(meta_loader), total=len(meta_loader), desc="[Stage2] Episodic Training") as pbar:
        for step, (support_novel, query_images, true_novel_ids, base_q_labels_cifar) in pbar:
            support_novel       = support_novel.to(device)
            query_images        = query_images.to(device)
            base_q_labels_cifar = base_q_labels_cifar.to(device)

            # Mask: escludi pseudo-novel dall'attenzione
            exclude_mask = torch.ones(Cb, dtype=torch.bool, device=device)
            for cid in true_novel_ids:
                if cid in b2local:
                    exclude_mask[b2local[cid]] = False

            # Novel weights (support -> gen), FE è congelato
            with torch.no_grad():
                supp_feats = model.fe(support_novel)
            novel_weights = model.gen(supp_feats, model.clf_base.weight, K_SHOT, exclude_mask=exclude_mask)

            # Logits e target
            logits = model.forward_logits(query_images, novel_weights)  # [Cb | N]
            N = novel_weights.size(0)

            y_novel = torch.arange(N, device=device).repeat_interleave(Q_NOVEL)
            targets = torch.empty(logits.size(0), dtype=torch.long, device=device)
            targets[:N * Q_NOVEL] = Cb + y_novel
            base_local = torch.tensor([b2local[int(y.item())] for y in base_q_labels_cifar], device=device)
            targets[N * Q_NOVEL:] = base_local

            opt.zero_grad()
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(params, STAGE2_GRAD_CLIP)
            opt.step()
            pbar.set_postfix(loss=f"{float(loss.item()):.4f}")

            # ---- PERIODIC VALIDATION ----
            if val_loader is not None and ((step + 1) % val_every == 0):
                Tval, vstats = evaluate_gfsl(model, val_loader, ds_train.targets, base_order, device)
                v_base  = vstats["base"]["mean"];  v_base_ci  = vstats["base"]["conf"]
                v_novel = vstats["novel"]["mean"]; v_novel_ci = vstats["novel"]["conf"]
                v_h     = vstats["hmean"]["mean"]; v_h_ci     = vstats["hmean"]["conf"]

                pbar.write(f"\t[stage2/val] step {step+1:05d}: "
                          f"base={100*v_base:.2f}%±{100*v_base_ci:.2f}  "
                          f"novel={100*v_novel:.2f}%±{100*v_novel_ci:.2f}  "
                          f"h-mean={100*v_h:.2f}%±{100*v_h_ci:.2f}  (T={Tval})")

                sel = {"base": v_base, "novel": v_novel, "hmean": v_h}[S2_SELECT_METRIC]
                if sel > best_score + 1e-6:
                    best_score = sel
                    patience   = 0
                    best_state = {
                        "model": copy.deepcopy(model.state_dict()),
                        "opt":   copy.deepcopy(opt.state_dict()),
                        "step":  step + 1,
                        "score": best_score,
                    }
                else:
                    patience += 1
                    if patience >= S2_PATIENCE:
                        pbar.write(f"\t[stage2] Early stopping (no improvement for {S2_PATIENCE} validations).")
                        break

                model.gen.train(); model.clf_base.train()

    # Restore the best result
    if best_state is not None:
        model.load_state_dict(best_state["model"])

In [40]:
train_stage2(model, train_loader_s2, device=device, base_order=base_order, val_loader=val_loader_s2)

[Stage2] Episodic Training:   3%|▎         | 504/20000 [01:20<20:42:54,  3.83s/it, loss=1.2390]

	[stage2/val] step 00500: base=31.99%±0.36  novel=70.36%±0.54  h-mean=43.47%±0.35  (T=1000)


[Stage2] Episodic Training:   5%|▌         | 1003/20000 [02:41<26:11:14,  4.96s/it, loss=1.4171]

	[stage2/val] step 01000: base=32.05%±0.36  novel=70.92%±0.54  h-mean=43.67%±0.35  (T=1000)


[Stage2] Episodic Training:   8%|▊         | 1504/20000 [04:02<23:13:00,  4.52s/it, loss=1.0668]

	[stage2/val] step 01500: base=32.23%±0.37  novel=71.06%±0.53  h-mean=43.85%±0.36  (T=1000)


[Stage2] Episodic Training:  10%|█         | 2004/20000 [05:23<20:50:13,  4.17s/it, loss=1.3270]

	[stage2/val] step 02000: base=32.31%±0.36  novel=71.46%±0.54  h-mean=43.99%±0.35  (T=1000)


[Stage2] Episodic Training:  13%|█▎        | 2504/20000 [06:44<20:18:13,  4.18s/it, loss=1.4309]

	[stage2/val] step 02500: base=32.27%±0.37  novel=71.57%±0.52  h-mean=44.00%±0.36  (T=1000)


[Stage2] Episodic Training:  15%|█▌        | 3004/20000 [08:04<19:16:08,  4.08s/it, loss=1.3204]

	[stage2/val] step 03000: base=33.12%±0.37  novel=71.52%±0.54  h-mean=44.79%±0.36  (T=1000)


[Stage2] Episodic Training:  18%|█▊        | 3504/20000 [09:24<20:16:52,  4.43s/it, loss=1.4077]

	[stage2/val] step 03500: base=33.05%±0.35  novel=71.94%±0.54  h-mean=44.85%±0.34  (T=1000)


[Stage2] Episodic Training:  20%|██        | 4003/20000 [10:44<20:18:35,  4.57s/it, loss=1.5198]

	[stage2/val] step 04000: base=33.62%±0.36  novel=71.44%±0.53  h-mean=45.26%±0.35  (T=1000)


[Stage2] Episodic Training:  23%|██▎       | 4504/20000 [12:06<16:49:55,  3.91s/it, loss=1.1601]

	[stage2/val] step 04500: base=34.28%±0.37  novel=70.58%±0.53  h-mean=45.60%±0.35  (T=1000)


[Stage2] Episodic Training:  25%|██▌       | 5004/20000 [13:25<18:03:33,  4.34s/it, loss=1.5624]

	[stage2/val] step 05000: base=33.31%±0.36  novel=71.00%±0.55  h-mean=44.85%±0.36  (T=1000)


[Stage2] Episodic Training:  28%|██▊       | 5503/20000 [14:47<18:02:36,  4.48s/it, loss=1.3188]

	[stage2/val] step 05500: base=34.02%±0.36  novel=70.99%±0.54  h-mean=45.53%±0.35  (T=1000)


[Stage2] Episodic Training:  30%|███       | 6003/20000 [16:09<20:25:51,  5.25s/it, loss=1.2473]

	[stage2/val] step 06000: base=35.19%±0.37  novel=70.96%±0.53  h-mean=46.57%±0.35  (T=1000)


[Stage2] Episodic Training:  33%|███▎      | 6504/20000 [17:30<14:42:49,  3.92s/it, loss=1.2309]

	[stage2/val] step 06500: base=33.93%±0.37  novel=71.06%±0.55  h-mean=45.44%±0.36  (T=1000)


[Stage2] Episodic Training:  35%|███▌      | 7004/20000 [18:51<14:44:13,  4.08s/it, loss=1.0990]

	[stage2/val] step 07000: base=35.45%±0.38  novel=70.33%±0.55  h-mean=46.60%±0.35  (T=1000)


[Stage2] Episodic Training:  38%|███▊      | 7504/20000 [20:12<15:28:31,  4.46s/it, loss=1.3853]

	[stage2/val] step 07500: base=34.12%±0.36  novel=70.92%±0.52  h-mean=45.62%±0.34  (T=1000)


[Stage2] Episodic Training:  40%|████      | 8005/20000 [21:32<10:48:07,  3.24s/it, loss=1.2866]

	[stage2/val] step 08000: base=35.07%±0.39  novel=70.22%±0.53  h-mean=46.26%±0.36  (T=1000)


[Stage2] Episodic Training:  43%|████▎     | 8503/20000 [22:53<13:37:08,  4.26s/it, loss=1.2410]

	[stage2/val] step 08500: base=35.57%±0.38  novel=70.02%±0.52  h-mean=46.69%±0.35  (T=1000)


[Stage2] Episodic Training:  45%|████▌     | 9004/20000 [24:13<12:07:17,  3.97s/it, loss=1.0533]

	[stage2/val] step 09000: base=34.55%±0.38  novel=70.84%±0.53  h-mean=45.94%±0.36  (T=1000)


[Stage2] Episodic Training:  48%|████▊     | 9503/20000 [25:34<14:52:40,  5.10s/it, loss=1.4923]

	[stage2/val] step 09500: base=35.21%±0.38  novel=70.89%±0.55  h-mean=46.54%±0.35  (T=1000)


[Stage2] Episodic Training:  50%|█████     | 10004/20000 [26:54<9:51:53,  3.55s/it, loss=1.5270]

	[stage2/val] step 10000: base=34.77%±0.39  novel=71.15%±0.55  h-mean=46.17%±0.36  (T=1000)


[Stage2] Episodic Training:  53%|█████▎    | 10503/20000 [28:13<10:53:09,  4.13s/it, loss=1.3028]

	[stage2/val] step 10500: base=36.22%±0.37  novel=69.51%±0.54  h-mean=47.13%±0.34  (T=1000)


[Stage2] Episodic Training:  55%|█████▌    | 11004/20000 [29:33<9:09:27,  3.66s/it, loss=1.4387]

	[stage2/val] step 11000: base=35.88%±0.37  novel=69.99%±0.55  h-mean=46.94%±0.35  (T=1000)


[Stage2] Episodic Training:  58%|█████▊    | 11504/20000 [30:53<10:03:03,  4.26s/it, loss=1.2889]

	[stage2/val] step 11500: base=35.77%±0.39  novel=70.31%±0.55  h-mean=46.87%±0.36  (T=1000)


[Stage2] Episodic Training:  60%|██████    | 12005/20000 [32:13<8:36:59,  3.88s/it, loss=1.3924] 

	[stage2/val] step 12000: base=35.78%±0.38  novel=70.00%±0.53  h-mean=46.85%±0.34  (T=1000)


[Stage2] Episodic Training:  63%|██████▎   | 12504/20000 [33:33<8:19:46,  4.00s/it, loss=1.6867]

	[stage2/val] step 12500: base=35.80%±0.37  novel=70.73%±0.52  h-mean=47.06%±0.34  (T=1000)


[Stage2] Episodic Training:  65%|██████▍   | 12999/20000 [34:52<18:46,  6.21it/s, loss=1.2108]

	[stage2/val] step 13000: base=35.83%±0.37  novel=70.56%±0.52  h-mean=47.03%±0.35  (T=1000)
	[stage2] Early stopping (no improvement for 5 validations).





### **2.3 Evaluation**

???

In [41]:
T, stats = evaluate_gfsl(baseline_model, test_loader, ds_test.targets, base, device=device)
print_stats(T, stats, model="Baseline model")

[test] - Baseline model (95% CI on 1000 tasks)
 - [Base]   acc=0.00% ± 0.00%
 - [Novel]  acc=38.89% ± 0.53%
 - [H-mean] acc=0.00% ± 0.00%


In [42]:
T, stats = evaluate_gfsl(stage1_model, test_loader, ds_test.targets, base, device=device)
print_stats(T, stats, model="After stage 1 training")

[test] - After stage 1 training (95% CI on 1000 tasks)
 - [Base]   acc=14.85% ± 0.28%
 - [Novel]  acc=68.77% ± 0.57%
 - [H-mean] acc=23.97% ± 0.37%


In [44]:
T, stats = evaluate_gfsl(model, test_loader, ds_test.targets, base, device=device)
print_stats(T, stats, model="After stage 2 training")

[test] - After stage 2 training (95% CI on 1000 tasks)
 - [Base]   acc=33.13% ± 0.37%
 - [Novel]  acc=67.77% ± 0.56%
 - [H-mean] acc=43.96% ± 0.35%
