<!-- Markdown Cell 0: Notebook Title / Overview -->
# **Memory-Robust Few-Shot Test-Time Adaptation on CIFAR-10-C (Q-MemBN+)**

This notebook implements my end-to-end pipeline for **robust test-time adaptation (TTA)** on small vision models.  
I pretrain a ResNet-18-style backbone on **CIFAR-10**, then adapt it online to **CIFAR-10-C** using my **Q-MemBN+** 'recipe':
- **Quantile + Memory BatchNorm** for robust, drift-aware normalization  
- **Few-shot support fine-tuning** to align to a new corruption  
- **Online entropy + prototype-guided updates** to adapt continuously  

**Datasets**
- CIFAR-10 (clean source)  
- CIFAR-10-C (15 corruptions × 5 severities)

A **Pipeline Flowchart** is shown below (double click this cell and click on the link to view in more detail).

<img src="https://drive.google.com/uc?export=view&id=1BnJLXmTVz0GxjURdcOgZOB_b4DXiwrHP" width="1200">

## **Section 1: Setup & Data Acquisition**
### **Downloading CIFAR-10-C**

CIFAR-10 is pulled automatically by torchvision, but **CIFAR-10-C must be fetched manually**.  
To do that, I create a local data directory, download the official Zenodo tarball, and then unpack it.  
By keeping the raw corruption files intact, I skip re-downloading the dataset when I slice it by **corruption type** and **severity** later.


In [None]:
# CIFAR-10-C Dataset Download From Zenodo
# CIFAR-10 (without the C for Corruption) is downloaded by PyTorch automatically, so that's good.
!mkdir -p /content/data

!wget -O /content/data/CIFAR-10-C.tar \
  https://zenodo.org/api/records/2535967/files/CIFAR-10-C.tar/content

!tar -xf /content/data/CIFAR-10-C.tar -C /content/data/

!ls /content/data/CIFAR-10-C

--2025-11-23 21:46:05--  https://zenodo.org/api/records/2535967/files/CIFAR-10-C.tar/content
Resolving zenodo.org (zenodo.org)... 137.138.52.235, 188.185.48.75, 188.185.43.153, ...
Connecting to zenodo.org (zenodo.org)|137.138.52.235|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2918471680 (2.7G) [application/octet-stream]
Saving to: ‘/content/data/CIFAR-10-C.tar’


2025-11-23 21:48:38 (18.3 MB/s) - ‘/content/data/CIFAR-10-C.tar’ saved [2918471680/2918471680]

brightness.npy	       gaussian_noise.npy    saturate.npy
contrast.npy	       glass_blur.npy	     shot_noise.npy
defocus_blur.npy       impulse_noise.npy     snow.npy
elastic_transform.npy  jpeg_compression.npy  spatter.npy
fog.npy		       labels.npy	     speckle_noise.npy
frost.npy	       motion_blur.npy	     zoom_blur.npy
gaussian_blur.npy      pixelate.npy


<!-- Markdown Cell 2: For Code Cell 1 -->
### **Extracting and Verifying the Dataset**

I extract the tar archive into my working `data/` folder and list its contents to confirm that they match the standard layout.  
This prevents me from encountering silent path issues during an expensive training and adaptation process.


In [None]:
# CIFAR-10-C Dataset Extraction
!mkdir -p /content/data

!tar -xf /content/data/CIFAR-10-C.tar -C /content/data/

!ls /content/data/CIFAR-10-C

brightness.npy	       gaussian_noise.npy    saturate.npy
contrast.npy	       glass_blur.npy	     shot_noise.npy
defocus_blur.npy       impulse_noise.npy     snow.npy
elastic_transform.npy  jpeg_compression.npy  spatter.npy
fog.npy		       labels.npy	     speckle_noise.npy
frost.npy	       motion_blur.npy	     zoom_blur.npy
gaussian_blur.npy      pixelate.npy


## **Section 2: Imports & Environment**
### **Libraries I Rely On**

This cell pulls in:
- **PyTorch + torchvision** for models, transforms, and dataloaders  
- **NumPy / Python stdlib** for stats, randomness, and utilities  
- Small helper types for clear result displayment and de-bugging.

I have centralized the imports, and you can read the notebook like the clean pipeline it is.


In [None]:
# Imports.
import os
import math
import random
from typing import Dict, List, Tuple, Optional

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, Subset, random_split

### **Reproducibility Utilities**

To compare my adaptation methods fairly, I set **all random seeds** (Python, NumPy, PyTorch CPU/GPU).  
This keeps their training curves and TTA outcomes stable across reruns.


In [None]:
# Utilities.
def set_global_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

## **Section 3: Q-MemBN Core**
### **Quantile + Memory BatchNorm (Q-MemBN)**

This is the heart of my method. I replace the standard BN with a layer that:
- Uses **median + IQR** instead of mean/variance → robust to outliers & heavy corruption.
- Maintains a **FIFO memory** of recent stats → smooths online updates.
- Can switch between **source mode** (frozen) and **adapt mode** (streaming).  

My helper functions let me:
- Toggle adaptation and memory usage globally.
- Capture “source” robust stats after Stage-I alignment.  
- Reset stats if a drift event is detected.  

> **Design Purpose:** make normalization robust, fast, and safe under non-stationary corruptions.

In [None]:
# Q-MemBN (Quantile + Memory BatchNorm)

class   QMemBatchNorm2d ( nn.Module ):

    # Q-MemBN uses medians + IQR, and keeps a small memory of past stats.
    def __init__ ( self ,
                   num_features : int ,
                   eps : float = 1e-5 ,
                   memory_size : int = 16 ):

        super().__init__()

        self.num_features = num_features
        self.eps          = eps
        self.memory_size  = memory_size

        # The learnable affine params.
        self.weight = nn.Parameter( torch.ones ( num_features ) )
        self.bias   = nn.Parameter( torch.zeros( num_features ) )

        # Stored “source” stats that I later reset to.
        self.register_buffer( "source_median" , torch.zeros( num_features ) )
        self.register_buffer( "source_iqr"    , torch.ones ( num_features ) )

        # The FIFO memory for recent medians / IQRs.
        self.register_buffer( "memory_medians" , torch.zeros( memory_size , num_features ) )
        self.register_buffer( "memory_iqrs"    , torch.ones ( memory_size , num_features ) )

        self.memory_filled : int = 0
        self.memory_index  : int = 0

        # Controls whether I fuse batch stats with memory.
        self.use_memory : bool = False
        self.adapt_mode : bool = False


    def forward ( self , x : torch.Tensor ) -> torch.Tensor:

        if   x.dim() != 4 :
            raise ValueError( "NCHW format input only!" )

        N , C , H , W = x.shape

        # Per-channel flattening, to compute robust stats.
        x_flat = ( x.permute(1,0,2,3)
                     .contiguous()
                     .view( C , -1 ) )

        batch_median = x_flat.median( dim = 1 ).values
        q1           = x_flat.quantile( 0.25 , dim = 1 )
        q3           = x_flat.quantile( 0.75 , dim = 1 )
        batch_iqr    = ( q3 - q1 ).clamp_min( self.eps )

        # Fuse the memory stats only when I use them (efficient).
        if self.use_memory  and  self.memory_filled > 0 :
            mem_med = self.memory_medians[:self.memory_filled].mean( dim = 0 )
            mem_iqr = self.memory_iqrs   [:self.memory_filled].mean( dim = 0 )

            median  = 0.5 * batch_median + 0.5 * mem_med
            iqr     = 0.5 * batch_iqr    + 0.5 * mem_iqr
        else :
            median  = batch_median
            iqr     = batch_iqr

        # Normalizes to then apply gamma / beta.
        x_norm = ( x - median.view(1,C,1,1) ) / ( iqr.view(1,C,1,1) + self.eps )
        out    = self.weight.view(1,C,1,1) * x_norm   +   self.bias.view(1,C,1,1)

        # I record memory during training or adaptation.
        if self.training or self.adapt_mode :
            with torch.no_grad():
                if self.memory_size > 0 :
                    idx = self.memory_index
                    self.memory_medians[idx].copy_( batch_median )
                    self.memory_iqrs   [idx].copy_( batch_iqr    )
                    self.memory_index   = ( idx + 1 ) % self.memory_size
                    self.memory_filled  = min( self.memory_filled + 1 , self.memory_size )

        return out

    @torch.no_grad()
    def get_memory_stats(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return aggregated memory median and IQR (mean over the FIFO)."""
        if self.memory_filled == 0:
            # If empty, fall back to source stats (or defaults)
            return self.source_median, self.source_iqr
        med = self.memory_medians[:self.memory_filled].mean(dim=0)
        iqr = self.memory_iqrs[:self.memory_filled].mean(dim=0)
        return med, iqr

    @torch.no_grad()
    def reset_to_source(self) -> None:
        """Reset memory to the source stats (used by drift detection resets)."""
        self.memory_medians.copy_(self.source_median.unsqueeze(0).repeat(self.memory_size, 1))
        self.memory_iqrs.copy_(self.source_iqr.unsqueeze(0).repeat(self.memory_size, 1))
        self.memory_filled = self.memory_size
        self.memory_index = 0


def set_qmem_adapt_mode(model: nn.Module,
                        adapt: bool,
                        use_memory: bool = True) -> None:
    """
    Turn adaptation mode on/off for all QMemBatchNorm2d layers.
    """
    for m in model.modules():
        if isinstance(m, QMemBatchNorm2d):
            m.adapt_mode = adapt
            m.use_memory = use_memory


@torch.no_grad()
def capture_bn_source_stats(model: nn.Module) -> None:
    # Capture current Q-MemBN memory stats as "source" stats, to be used.
    # as a safe reset point during drift.
    for m in model.modules():
        if isinstance(m, QMemBatchNorm2d):
            med, iqr = m.get_memory_stats()
            m.source_median.copy_(med.detach())
            m.source_iqr.copy_(iqr.detach())
            m.reset_to_source()


@torch.no_grad()
def reset_bn_stats(model: nn.Module) -> None:

    # Reset all Q-MemBN layers to their stored source stats (used on drift).
    for m in model.modules():
        if isinstance(m, QMemBatchNorm2d):
            m.reset_to_source()

## **Section 4: Backbone Network**
### **ResNet-18 with Q-MemBN**

I implement a lightweight **ResNet-18-style** backbone where each BN is swapped for Q-MemBN.  
The stack layout stays faithful to ResNet so any improvements (fingers crossed) come from my adaptation, not some architecture trick.

Key choices:
- Small model to match “compute-poor deployment” settings.  
- Feature extractor exposed ( `forward_features` ) for prototype learning later (what I tried doing on my MSE's MCQs).  

In [None]:
# The ResNet-18 (style) backbone (w. Q-MemBN).

class BasicBlock ( nn.Module ) :

    expansion = 1

    def __init__ ( self ,
                   in_planes : int ,
                   planes    : int ,
                   stride    : int = 1 ,
                   norm_layer = QMemBatchNorm2d ) :

        super().__init__()

        self.conv1 = nn.Conv2d( in_planes , planes ,
                                kernel_size = 3 ,
                                stride      = stride ,
                                padding     = 1 ,
                                bias        = False )

        self.bn1   = norm_layer( planes )
        self.relu  = nn.ReLU( inplace = True )

        self.conv2 = nn.Conv2d( planes , planes ,
                                kernel_size = 3 ,
                                stride      = 1 ,
                                padding     = 1 ,
                                bias        = False )

        self.bn2   = norm_layer( planes )

        self.downsample = None
        if stride != 1  or  in_planes != planes :
            self.downsample = nn.Sequential(
                nn.Conv2d( in_planes , planes ,
                           kernel_size = 1 ,
                           stride      = stride ,
                           bias        = False ),
                norm_layer( planes )
            )


    def forward ( self , x : torch.Tensor ) -> torch.Tensor :

        identity = x

        out = self.conv1( x )
        out = self.bn1 ( out )
        out = self.relu( out )

        out = self.conv2( out )
        out = self.bn2 ( out )

        if self.downsample is not None :
            identity = self.downsample( x )

        out = out + identity
        out = self.relu( out )

        return out



# I use this as a small ResNet-18 variant for 32 x 32 images with Q-MemBN everywhere.
class QMemResNet18 ( nn.Module ) :

    def __init__ ( self ,
                   num_classes : int = 10 ,
                   norm_layer  = QMemBatchNorm2d ) :

        super().__init__()

        self.in_planes = 64

        self.conv1 = nn.Conv2d( 3 , 64 ,
                                kernel_size = 3 ,
                                stride      = 1 ,
                                padding     = 1 ,
                                bias        = False )

        self.bn1  = norm_layer( 64 )
        self.relu = nn.ReLU( inplace = True )

        self.layer1 = self._make_layer(  64 , blocks = 2 , stride = 1 , norm_layer = norm_layer )
        self.layer2 = self._make_layer( 128 , blocks = 2 , stride = 2 , norm_layer = norm_layer )
        self.layer3 = self._make_layer( 256 , blocks = 2 , stride = 2 , norm_layer = norm_layer )
        self.layer4 = self._make_layer( 512 , blocks = 2 , stride = 2 , norm_layer = norm_layer )

        self.avgpool = nn.AdaptiveAvgPool2d( ( 1 , 1 ) )
        self.fc      = nn.Linear( 512 , num_classes )


    def _make_layer ( self ,
                      planes    : int ,
                      blocks    : int ,
                      stride    : int ,
                      norm_layer ) -> nn.Sequential :

        layers = []
        layers.append(
            BasicBlock( self.in_planes , planes ,
                        stride     = stride ,
                        norm_layer = norm_layer )
        )
        self.in_planes = planes

        for _ in range( 1 , blocks ) :
            layers.append(
                BasicBlock( self.in_planes , planes ,
                            stride     = 1 ,
                            norm_layer = norm_layer )
            )

        return nn.Sequential( *layers )


    def forward_features ( self , x : torch.Tensor ) -> torch.Tensor :

        x = self.conv1( x )
        x = self.bn1 ( x )
        x = self.relu( x )

        x = self.layer1( x )
        x = self.layer2( x )
        x = self.layer3( x )
        x = self.layer4( x )

        x = self.avgpool( x )
        x = torch.flatten( x , 1 )

        return x


    def forward ( self , x : torch.Tensor ) -> Tuple[ torch.Tensor , torch.Tensor ] :

        feats  = self.forward_features( x )
        logits = self.fc( feats )

        return logits , feats


## **Section 5: Memory for Semantics**
### **Prototype Bank**

During adaptation, I need stable semantic anchors.  
This prototype bank stores **one feature prototype per class** and updates them with momentum.

It supports:
- **Initialization from support set** (true labels).  
- **Pseudo-label updates** from confident stream samples.  
- A **consistency mask**: only updates when a sample is near its predicted class prototype.  

> This gates learning such that noisy pseudo-labels don’t collapse the model.

In [None]:
# Prototype memory bank (few-shot + pseudo-label updates).

class PrototypeBank :

    def __init__ ( self ,
                   feature_dim  : int ,
                   num_classes  : int ,
                   device       : torch.device ,
                   momentum     : float = 0.1 ) :

        self.feature_dim = feature_dim
        self.num_classes = num_classes
        self.device      = device
        self.momentum    = momentum

        # I store one prototype per class only.
        self.prototypes  = torch.zeros( num_classes , feature_dim , device = device )
        self.counts      = torch.zeros( num_classes , dtype = torch.long , device = device )


    @torch.no_grad()
    def initialize_from_support ( self ,
                                  model  : nn.Module ,
                                  loader : DataLoader ) -> None :

        model.eval()

        for images , labels in loader :

            images = images.to( self.device )
            labels = labels.to( self.device )

            logits , feats = model( images )

            for c in range( self.num_classes ) :

                mask = ( labels == c )
                if not mask.any() :
                    continue

                mean_feat = feats[ mask ].mean( dim = 0 )

                if self.counts[ c ] == 0 :
                    self.prototypes[ c ] = mean_feat
                else :
                    self.prototypes[ c ] = ( 1.0 - self.momentum ) * self.prototypes[ c ]  \
                                           + self.momentum * mean_feat

                self.counts[ c ] += mask.sum()


    @torch.no_grad()
    def update ( self ,
                 features : torch.Tensor ,
                 labels   : torch.Tensor ) -> None :

        # I update the class prototypes using the newer, more confident samples here.
        for c in range( self.num_classes ) :

            mask = ( labels == c )
            if not mask.any() :
                continue

            mean_feat = features[ mask ].mean( dim = 0 )

            self.prototypes[ c ] = ( 1.0 - self.momentum ) * self.prototypes[ c ]  \
                                   + self.momentum * mean_feat

            self.counts[ c ] += mask.sum()


    @torch.no_grad()
    def consistency_mask ( self ,
                           features         : torch.Tensor ,
                           predicted_labels : torch.Tensor ,
                           distance_threshold : float = 2.0 ) -> torch.Tensor :

        # If the prototypes are empty, I reject everything (for safety).
        if ( self.counts == 0 ).all() :
            return torch.zeros( predicted_labels.shape[0] ,
                                dtype = torch.bool ,
                                device = self.device )

        dists = torch.cdist( features , self.prototypes ) # (N, C)

        nearest_dist  , nearest_class = torch.min( dists , dim = 1 )

        mask = ( nearest_class == predicted_labels )  \
               & ( nearest_dist  <= distance_threshold )

        return mask


<!-- Markdown Cell 8: For Code Cell 7 -->
### **Drift Detection**

Adaptation can go wrong if the stream shifts abruptly. So, I detect drift by comparing the **latest BN median** to the **memory median**, scaled by memory IQR. If the normalized shift crosses a threshold, I treat it as drift and later **reset BN stats**.  

This gives Q-MemBN+ a 'seatbelt' against catastrophic online updates; updates that can down a small-vison model (e.g., Drone).


In [None]:
# Drift Detection (BN memory + entropy).

class DriftDetector :

    # Combines the entropy + BN stuff to detect drift.
    def __init__ ( self ,
                   median_delta_multiplier : float = 3.0 ,
                   min_entropy             : float = 1.0 ) :

        self.median_delta_multiplier = median_delta_multiplier
        self.min_entropy             = min_entropy


    @torch.no_grad()
    def detect ( self ,
                 model             : nn.Module ,
                 batch_confidences : torch.Tensor ) -> bool :

        # Here we compute the mean prediction entropy for the batch.
        probs = batch_confidences
        H     = -( probs.clamp_min(1e-8) * probs.clamp_min(1e-8).log() ).sum( dim = 1 )
        mean_entropy = H.mean().item()

        if mean_entropy < self.min_entropy :
            return False

        median_shift_norms : List[ float ] = []

        for m in model.modules() :

            if not isinstance( m , QMemBatchNorm2d ) :
                continue
            if m.memory_filled <= 0 :
                continue

            mem_med , mem_iqr = m.get_memory_stats()

            # This approximates the batch median (most recent one) from the FIFO tail.
            last_idx   = ( m.memory_index - 1 ) % max( 1 , m.memory_size )
            batch_med  = m.memory_medians[ last_idx ]

            delta      = ( batch_med - mem_med ).abs()
            scaled     = delta / ( mem_iqr + 1e-5 )

            median_shift_norms.append( scaled.mean().item() )

        if not median_shift_norms :
            return False

        avg_scaled_shift = float( np.mean( median_shift_norms ) )
        return avg_scaled_shift > self.median_delta_multiplier


## **Section 6: Datasets & Loaders**
### **CIFAR-10-C Wrapper + Few-Shot Split**

Here, I define a clean wrapper for CIFAR-10-C that slices **(corruption, severity)** into a normal PyTorch dataset.

Then, I build:
- **Support subset**: `shots_per_class` , or labeled samples per class.
- **Stream subset**: the remaining unlabeled data for online TTA.  

The transforms use CIFAR-10 normalization so that the source and target stay compatible.


In [None]:
# My Datasets: CIFAR-10 and CIFAR-10-C.

CIFAR_MEAN = ( 0.4914 , 0.4822 , 0.4465 )
CIFAR_STD  = ( 0.2470 , 0.2435 , 0.2616 )


class CIFAR10C ( Dataset ) :

    # A wrapper that slices one corruption and one severity level.
    def __init__ ( self ,
                   root       : str ,
                   corruption : str ,
                   severity   : int = 5 ,
                   transform  = None ) :

        super().__init__()

        self.root       = root
        self.corruption = corruption
        self.severity   = severity
        self.transform  = transform

        data_path   = os.path.join( root , "CIFAR-10-C" , f"{corruption}.npy" )
        labels_path = os.path.join( root , "CIFAR-10-C" , "labels.npy" )

        self.data   = np.load( data_path ) # info: (50000, 32, 32, 3).
        self.labels = np.load( labels_path )

        assert 1 <= severity <= 5

        n = self.data.shape[0] // 5
        st = ( severity - 1 ) * n
        en = severity * n

        self.data   = self.data  [ st : en ]
        self.labels = self.labels[ st : en ]


    def __len__ ( self ) -> int :
        return len( self.data )


    def __getitem__ ( self , idx : int ) -> Tuple[ torch.Tensor , int ] :

        img = self.data[ idx ]
        img = Image.fromarray( img.astype( np.uint8 ) )

        if self.transform is not None :
            img = self.transform( img )

        label = int( self.labels[ idx ] )
        return img , label



def get_cifar10_dataloaders ( root        : str ,
                              batch_size  : int = 128 ,
                              num_workers : int = 2 ) -> Tuple[ Dataset , Dataset , DataLoader , DataLoader ] :

    transform_train = transforms.Compose([
        transforms.RandomCrop( 32 , padding = 4 ),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize( CIFAR_MEAN , CIFAR_STD ),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize( CIFAR_MEAN , CIFAR_STD ),
    ])

    train_ds = datasets.CIFAR10( root = root ,
                                 train = True ,
                                 download = True ,
                                 transform = transform_train )

    test_ds  = datasets.CIFAR10( root = root ,
                                 train = False ,
                                 download = True ,
                                 transform = transform_test )

    train_loader = DataLoader( train_ds ,
                               batch_size = batch_size ,
                               shuffle    = True ,
                               num_workers = num_workers )

    test_loader  = DataLoader( test_ds ,
                               batch_size = batch_size ,
                               shuffle    = False ,
                               num_workers = num_workers )

    return train_ds , test_ds , train_loader , test_loader



def get_cifar10c_loader ( root        : str ,
                          corruption  : str = "gaussian_noise" ,
                          severity    : int = 5 ,
                          batch_size  : int = 64 ,
                          num_workers : int = 2 ,
                          shuffle     : bool = False ) -> Tuple[ Dataset , DataLoader ] :

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize( CIFAR_MEAN , CIFAR_STD ),
    ])

    dataset = CIFAR10C( root       = root ,
                        corruption = corruption ,
                        severity   = severity ,
                        transform  = transform )

    loader  = DataLoader( dataset ,
                          batch_size  = batch_size ,
                          shuffle     = shuffle ,
                          num_workers = num_workers )

    return dataset , loader




def build_support_and_stream_subsets ( dataset        : Dataset ,
                                       num_classes    : int ,
                                       shots_per_class : int = 5 ) -> Tuple[ Subset , Subset ] :

    # I split the things into low-shot labeled support + an unlabeled adaptation stream.
    indices_per_class = { c : [] for c in range( num_classes ) }

    all_indices = list( range( len( dataset ) ) )
    random.shuffle( all_indices )

    for idx in all_indices :

        _ , label = dataset[ idx ]

        if len( indices_per_class[ label ] ) < shots_per_class :
            indices_per_class[ label ].append( idx )

        if all( len(v) >= shots_per_class for v in indices_per_class.values() ) :
            break

    support_indices : List[int] = []
    for c in range( num_classes ) :
        support_indices.extend( indices_per_class[ c ] )

    support_indices = sorted( support_indices )

    full_set    = set( range( len( dataset ) ) )
    support_set = set( support_indices )

    stream_indices = sorted( list( full_set - support_set ) )

    support_subset = Subset( dataset , support_indices )
    stream_subset  = Subset( dataset , stream_indices )

    return support_subset , stream_subset


<!-- Markdown Cell 10: For Code Cell 9 -->
## **Section 7: Metrics**
### **Interpretable Evaluation**

I compute:
- Accuracy  
- Macro Precision / Recall / F1 (class-balanced)  
- RMSE over probabilities (calibration feel)

I chose these metrics because they summarize both **correctness** and **confidence behavior** under corruption.


In [None]:
# My Metrics: accuracy, precision, recall, F1, RMSE.

def compute_classification_metrics ( logits : torch.Tensor ,
                                     targets : torch.Tensor ,
                                     num_classes : int ) -> Dict[ str , float ] :

    # Computes the CORE classification metrics and also the RMSE (over probabilities ofc.).
    with torch.no_grad() :

        probs = F.softmax( logits , dim = 1 )
        preds = probs.argmax( dim = 1 )

        total    = targets.numel()
        correct  = ( preds == targets ).sum().item()
        accuracy = correct / total

        TP = torch.zeros( num_classes )
        FP = torch.zeros( num_classes )
        FN = torch.zeros( num_classes )

        for c in range( num_classes ) :
            TP[ c ] = ( ( preds == c ) & ( targets == c ) ).sum()
            FP[ c ] = ( ( preds == c ) & ( targets != c ) ).sum()
            FN[ c ] = ( ( preds != c ) & ( targets == c ) ).sum()

        precision_per_class = TP / ( TP + FP + 1e-8 )
        recall_per_class    = TP / ( TP + FN + 1e-8 )
        f1_per_class        = 2 * precision_per_class * recall_per_class \
                              / ( precision_per_class + recall_per_class + 1e-8 )

        precision_macro = precision_per_class.mean().item()
        recall_macro    = recall_per_class.mean().item()
        f1_macro        = f1_per_class.mean().item()

        # Measures the RMSE between the predicted probabilities and then the one-hot labels.
        y_one_hot = F.one_hot( targets , num_classes = num_classes ).float()
        rmse      = torch.sqrt( ( ( probs - y_one_hot ) ** 2 ).mean() ).item()

        return {
            "accuracy"        : accuracy ,
            "precision_macro" : precision_macro ,
            "recall_macro"    : recall_macro ,
            "f1_macro"        : f1_macro ,
            "rmse"            : rmse ,
        }


### **Train/Eval Routines**

This is a standard supervised loop:
- `train_one_epoch` : means cross-entropy training on clean CIFAR-10.  
- `evaluate` : means identical logic in `no_grad` mode.  

I log full-dataset metrics each epoch so my progress remains interpretable later.

In [None]:
# A normal training and evaluation loop.

def train_one_epoch ( model      : nn.Module ,
                      loader     : DataLoader ,
                      optimizer  : torch.optim.Optimizer ,
                      device     : torch.device ,
                      num_classes : int ) -> Tuple[ float , Dict[ str , float ] ] :

    # I run one full training pass over the dataloader.
    model.train()

    total_loss  = 0.0
    all_logits  = []
    all_targets = []

    criterion = nn.CrossEntropyLoss()

    for images , labels in loader :

        images = images.to( device )
        labels = labels.to( device )

        optimizer.zero_grad()

        logits , feats = model( images )
        loss           = criterion( logits , labels )

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size( 0 )

        all_logits.append ( logits.detach().cpu()  )
        all_targets.append( labels.detach().cpu() )

    all_logits  = torch.cat( all_logits  , dim = 0 )
    all_targets = torch.cat( all_targets , dim = 0 )

    metrics  = compute_classification_metrics( all_logits , all_targets , num_classes )
    avg_loss = total_loss / len( loader.dataset )

    return avg_loss , metrics



@torch.no_grad()
def evaluate ( model       : nn.Module ,
               loader      : DataLoader ,
               device      : torch.device ,
               num_classes : int ) -> Tuple[ float , Dict[ str , float ] ] :

    # And I evaluate without gradient tracking (saves resources, not needed for prediction / performance stuff).
    model.eval()

    total_loss  = 0.0
    all_logits  = []
    all_targets = []

    criterion = nn.CrossEntropyLoss()

    for images , labels in loader :

        images = images.to( device )
        labels = labels.to( device )

        logits , feats = model( images )
        loss           = criterion( logits , labels )

        total_loss += loss.item() * labels.size( 0 )

        all_logits.append ( logits.cpu()  )
        all_targets.append( labels.cpu() )

    all_logits  = torch.cat( all_logits  , dim = 0 )
    all_targets = torch.cat( all_targets , dim = 0 )

    metrics  = compute_classification_metrics( all_logits , all_targets , num_classes )
    avg_loss = total_loss / len( loader.dataset )

    return avg_loss , metrics


## **Section 8: Source Training (Stage-0)**
### **Hold-Out Hyperparameter Search + Final Training**

Before any adaptation, I am going to need a solid **source model**.  
I do a simple 80/20 hold-out search over learning rates and weight decay, pick the best combination of their values, then train fully.

Why this way?
- Its cheap and practical. Hyperparameter search should never consume the majority of your compute budget.
- Prevents over-engineering CV (where l_r and w_d diverge instantly) but still avoids bad defaults.

In [None]:
# Simple hyperparameter "CV" (hold-out validation) for source training.

def hyperparam_search ( train_dataset  : Dataset ,
                        num_classes    : int ,
                        device         : torch.device ,
                        learning_rates : List[ float ] ,
                        weight_decays  : List[ float ] ,
                        batch_size     : int = 128 ,
                        max_epochs     : int = 5 ) -> Dict[ str , float ] :

    # I use a good train/val split (not the one I detailed in my M.S.E.!) to pick the best LR & weight decay.
    n_total  = len( train_dataset )
    n_val    = int( 0.2 * n_total )
    n_train  = n_total - n_val

    train_subset , val_subset = random_split( train_dataset , [ n_train , n_val ] )

    best_config = None
    best_acc    = 0.0

    for lr in learning_rates :
        for wd in weight_decays :

            model = QMemResNet18( num_classes = num_classes ).to( device )
            optimizer = torch.optim.SGD(
                model.parameters() ,
                lr       = lr ,
                momentum = 0.9 ,
                weight_decay = wd
            )

            train_loader = DataLoader( train_subset ,
                                       batch_size = batch_size ,
                                       shuffle    = True )

            val_loader = DataLoader( val_subset ,
                                     batch_size = batch_size ,
                                     shuffle    = False )

            for epoch in range( max_epochs ) :
                train_one_epoch( model , train_loader , optimizer ,
                                 device , num_classes )

            _ , val_metrics = evaluate( model , val_loader ,
                                        device , num_classes )

            val_acc = val_metrics[ "accuracy" ]

            if val_acc > best_acc :
                best_acc    = val_acc
                best_config = { "lr" : lr , "weight_decay" : wd }

    if best_config is None :
        best_config = { "lr" : learning_rates[ 0 ] ,
                        "weight_decay" : weight_decays[ 0 ] }

    return best_config




def train_source_model ( data_root : str ,
                         num_classes : int ,
                         device : torch.device ,
                         batch_size : int = 128 ,
                         epochs : int = 30 ) -> nn.Module :

    # Trains the base source model on CIFAR-10 using the now tuned hyperparameters.
    train_ds , test_ds , train_loader , test_loader = \
        get_cifar10_dataloaders( data_root , batch_size = batch_size )

    best = hyperparam_search(
        train_ds ,
        num_classes    = num_classes ,
        device         = device ,
        learning_rates = [ 0.1 , 0.05 , 0.01 ] ,
        weight_decays  = [ 5e-4 , 1e-4 ] ,
        batch_size     = batch_size ,
        max_epochs     = 3 ,
    )

    print( "Best hyperparameters for source training:" , best )

    model = QMemResNet18( num_classes = num_classes ).to( device )

    optimizer = torch.optim.SGD(
        model.parameters() ,
        lr          = best[ "lr" ] ,
        momentum    = 0.9 ,
        weight_decay = best[ "weight_decay" ] ,
    )

    for epoch in range( epochs ) :

        train_loss , train_metrics = train_one_epoch(
            model , train_loader , optimizer , device , num_classes
        )

        val_loss , val_metrics = evaluate(
            model , test_loader , device , num_classes
        )

        print(
            f"[Source] Epoch {epoch+1}/{epochs} "
            f"Train loss {train_loss:.4f} acc {train_metrics['accuracy']:.4f} "
            f"Val loss {val_loss:.4f} acc {val_metrics['accuracy']:.4f}"
        )

    return model


## **Section 9: Stage-I Few-Shot Alignment**
### **Support-Set Fine-Tuning with Feature Mixing**

Given a small labeled support set, I fine-tune:
- The **final FC layer**.
- Q-MemBN affine parameters (γ/β).

I apply **feature mixing** (α-blend) as a tiny regularizer so the model doesn’t overfit the 5-shot labels.  
Afterwards, I **capture the aligned robust BN stats** as the new “source” state for Stage-II.

In [None]:
# Stage I: Few-shot fine-tuning with feature mixing.

def stage1_finetune ( model         : nn.Module ,
                      support_loader : DataLoader ,
                      device        : torch.device ,
                      num_classes   : int ,
                      epochs        : int = 5 ,
                      lr            : float = 1e-4 ,
                      mix_alpha     : float = 0.3 ) -> None :

    # Kept the model in training mode but disable adaptation.
    model.train()
    set_qmem_adapt_mode( model , adapt = False , use_memory = False )

    # Updated BN affine params & the final FC layer.
    params : List[ nn.Parameter ] = []
    for m in model.modules() :
        if isinstance( m , QMemBatchNorm2d ) :
            params.append( m.weight )
            params.append( m.bias )
    params.extend( list( model.fc.parameters() ) )

    optimizer = torch.optim.AdamW( params , lr = lr )
    criterion = nn.CrossEntropyLoss()

    for epoch in range( epochs ) :

        epoch_loss = 0.0

        for images , labels in support_loader :

            images = images.to( device )
            labels = labels.to( device )

            optimizer.zero_grad()

            logits , feats = model( images )

            # I mixed features to increase the diversity (within the batch).
            if images.size( 0 ) > 1 :

                perm = torch.randperm( images.size( 0 ) , device = device )
                lam  = float( np.random.beta( mix_alpha , mix_alpha ) )

                feats_mix  = lam * feats + ( 1.0 - lam ) * feats[ perm ]
                logits_mix = model.fc( feats_mix )

                labels_a = labels
                labels_b = labels[ perm ]

                loss_main = criterion( logits     , labels_a )
                loss_mix  = lam * criterion( logits_mix , labels_a ) \
                            + ( 1.0 - lam ) * criterion( logits_mix , labels_b )

                loss = 0.5 * ( loss_main + loss_mix )

            else :
                loss = criterion( logits , labels )

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * labels.size( 0 )

        epoch_loss /= len( support_loader.dataset )
        print( f"[Stage I] Epoch {epoch+1}/{epochs}, loss {epoch_loss:.4f}" )

    # Captures the aligned BN stats as the new source state.
    capture_bn_source_stats( model )


## **Section 10: Stage-II Online TTA (Q-MemBN+)**
### **Entropy + Prototype-Guided Adaptation**

Now I adapt on the unlabeled stream, meaning that:
- Q-MemBN runs in **adapt + memory mode**.  
- Entropy is minimized on a fraction of the predictions, which stabilizes the decision boundaries.
- The prototype bank filters pseudo-labels via their distance consistency.  
- The drift detector can trigger BN resets if the stream shifts too hard.  

This is the “always-on” deployment behavior I’m targeting.


In [None]:
#  Stage II: Q-MemBN+ test-time adaptation (online).

def stage2_adapt_stream ( model          : nn.Module ,
                          prototype_bank : PrototypeBank ,
                          stream_loader  : DataLoader ,
                          device         : torch.device ,
                          num_classes    : int ,
                          lr             : float = 1e-3 ,
                          entropy_fraction   : float = 0.5 ,
                          distance_threshold : float = 2.0 ,
                          max_batches    : Optional[ int ] = None ) -> Dict[ str , float ] :

    # Now we switch the model into adaptation mode with memory enabled.
    model.train()
    set_qmem_adapt_mode( model , adapt = True , use_memory = True )

    # Then we update the BN affine params & the classifier head.
    params : List[ nn.Parameter ] = []
    for m in model.modules() :
        if isinstance( m , QMemBatchNorm2d ) :
            params.append( m.weight )
            params.append( m.bias )
    params.extend( list( model.fc.parameters() ) )

    optimizer = torch.optim.SGD( params , lr = lr , momentum = 0.9 )
    criterion = nn.CrossEntropyLoss()

    drift_detector = DriftDetector(
        median_delta_multiplier = 3.0 ,
        min_entropy             = math.log( num_classes ) * 0.75
    )

    all_logits  = []
    all_targets = []
    batch_counter = 0

    for images , labels in stream_loader :

        images = images.to( device )
        labels = labels.to( device )

        optimizer.zero_grad()

        logits , feats = model( images )
        probs          = F.softmax( logits , dim = 1 )

        entropy = -( probs.clamp_min(1e-8) * probs.clamp_min(1e-8).log() ).sum( dim = 1 )

        # I filter out high-entropy predictions (as they're no good for our purposes).
        entropy_threshold = entropy_fraction * math.log( num_classes )
        low_entropy_mask  = entropy < entropy_threshold

        pseudo_labels = probs.argmax( dim = 1 )

        # And I only keep the samples consistent with their nearest prototype.
        consistency_mask = prototype_bank.consistency_mask(
            feats , pseudo_labels , distance_threshold = distance_threshold
        )

        accepted_mask = low_entropy_mask & consistency_mask

        if accepted_mask.any() :

            x_conf = images[ accepted_mask ]
            y_conf = pseudo_labels[ accepted_mask ]

            logits_conf , feats_conf = model( x_conf )
            loss = criterion( logits_conf , y_conf )

            loss.backward()
            optimizer.step()

            # Updating the prototype memory with confident samples now.
            prototype_bank.update(
                feats_conf.detach() ,
                y_conf.detach()
            )
        else :
            optimizer.zero_grad()

        # The drift detection resets the BN stats when needed, and when needed only.
        if drift_detector.detect( model , probs ) :
            print( "[Stage II] Drift detected: resetting BN stats to source." )
            reset_bn_stats( model )

        all_logits.append ( logits.detach().cpu()  )
        all_targets.append( labels.detach().cpu() )

        batch_counter += 1
        if max_batches is not None and batch_counter >= max_batches :
            break

    all_logits  = torch.cat( all_logits  , dim = 0 )
    all_targets = torch.cat( all_targets , dim = 0 )

    metrics = compute_classification_metrics( all_logits , all_targets , num_classes )
    print( "[Stage II] Final metrics on adapted stream:" , metrics )

    return metrics


## **Section 11: Deployment Utilities**
### **Saving, Loading, and Single-Image Prediction**

Once adapted, I serialize the:
- Model weights  
- Number of classes

I also provide a small helper for **single-image inference**, using the same CIFAR normalization.  
This makes my notebook output directly deployable in, say, a dashboard or an API.


In [None]:
# Deployment: saving / loading and single-image prediction.

def save_deployed_model ( model : nn.Module ,
                          path  : str ) -> None :

    # Saves the adapted model for later (soon) deployment.
    os.makedirs( os.path.dirname( path ) , exist_ok = True )

    checkpoint = {
        "state_dict"  : model.state_dict() ,
        "num_classes" : model.fc.out_features ,
    }

    torch.save( checkpoint , path )
    print( f"[Deploy] Saved model to {path}" )



def load_deployed_model ( path   : str ,
                          device : torch.device ) -> nn.Module :

    # Loads a Q-MemResNet18 checkpoint and disables adaptation (bcs. deployment should be deterministic!).
    checkpoint   = torch.load( path , map_location = device )
    num_classes  = checkpoint.get( "num_classes" , 10 )

    model = QMemResNet18( num_classes = num_classes ).to( device )
    model.load_state_dict( checkpoint[ "state_dict" ] )
    model.eval()

    set_qmem_adapt_mode( model , adapt = False , use_memory = False )
    return model



@torch.no_grad()
def predict_single_image ( image_path : str ,
                           model_path : str ,
                           device_str : str = "cpu" ) -> Tuple[ int , float ] :

    # Now we load the model and run inference on 1 image.
    device = torch.device( device_str )
    model  = load_deployed_model( model_path , device = device )

    img = Image.open( image_path ).convert( "RGB" )

    transform = transforms.Compose([
        transforms.Resize( ( 32 , 32 ) ),
        transforms.ToTensor(),
        transforms.Normalize( CIFAR_MEAN , CIFAR_STD ),
    ])

    tensor = transform( img ).unsqueeze( 0 ).to( device )

    logits , feats = model( tensor )
    probs          = F.softmax( logits , dim = 1 )

    pred       = probs.argmax( dim = 1 ).item()
    confidence = probs[ 0 , pred ].item()

    return pred , confidence


## **Section 12: End-to-End Pipeline: Training and Evaluation**

Together, the three code cells below, each constructing their own pipeline to then call functions, etc., as necessary, do the following:
- **Train and test QMem-BN's custom backbone on the CIFAR-10 dataset**. (converged at 77% accuracy, is good) (Cell 1)
- **Test this custom ResNet-18 backbone on CIFAR-10-C without adaptation**. (65.91% accuracy; this is the baseline) (Cell 2)
- **Test this custom ResNet-18 backbone on CIFAR-10-C with clean adaptation (no entropy rules, prototype filtering, etc.)**. (66.78% accuracy; this is the compelte QMem-BN) (Cell 2)
- **Test this custom ResNet-18 backbone on CIFAR-10-C with naive (so none of my rules) adaptation**. (66.62% accuracy; so it is not only less safe but also slightly less accurate) (Cell 2)
- **Test robustness to 20% poisoned inputs on both QMem-BN with clean adaptation and naive adaptation**. (naive adaptation: **12.20 pp** accuracy drop; clean adaptation: **11.46 pp** accuracy drop; so we are better than naive adaptation at malicious-input resistancy) (Cell 3)
- **Records efficiency metrics for Objective 4**. (Cell 3)



In [None]:
# Cell 1: Source training and target loaders.



import os, time, copy
import torch
from torch.utils.data import DataLoader, Dataset

set_global_seed(42)

data_root   = "./data"
num_classes = 10
device      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("=== Training source model on CIFAR-10 ===")
source_model = train_source_model(
    data_root   = data_root,
    num_classes = num_classes,
    device      = device,
    batch_size  = 128,
    epochs      = 50,
)

base_state = copy.deepcopy(source_model.state_dict())
os.makedirs("./checkpoints", exist_ok=True)
save_deployed_model(source_model, "./checkpoints/qmembn_source_cifar10.pth")

print("=== Preparing CIFAR-10-C target domain ===")
target_dataset, _ = get_cifar10c_loader(
    data_root,
    corruption  = "gaussian_noise",
    severity    = 5,
    batch_size  = 64,
    num_workers = 2,
    shuffle     = False,
)

support_subset, stream_subset = build_support_and_stream_subsets(
    target_dataset,
    num_classes     = num_classes,
    shots_per_class = 5,
)

support_loader = DataLoader(support_subset, batch_size=32, shuffle=True)
stream_loader  = DataLoader(stream_subset,  batch_size=64, shuffle=False)
full_target_loader = DataLoader(target_dataset, batch_size=64, shuffle=True)
stream_loader_b1   = DataLoader(stream_subset, batch_size=1, shuffle=False)


=== Training source model on CIFAR-10 ===
Best hyperparameters for source training: {'lr': 0.1, 'weight_decay': 0.0001}
[Source] Epoch 1/50 Train loss 2.8810 acc 0.2038 Val loss 1.9816 acc 0.2705
[Source] Epoch 2/50 Train loss 1.9485 acc 0.2812 Val loss 1.8442 acc 0.3249
[Source] Epoch 3/50 Train loss 1.8121 acc 0.3329 Val loss 1.7125 acc 0.3692
[Source] Epoch 4/50 Train loss 1.7018 acc 0.3685 Val loss 1.6255 acc 0.4057
[Source] Epoch 5/50 Train loss 1.6402 acc 0.3968 Val loss 1.5713 acc 0.4328
[Source] Epoch 6/50 Train loss 1.5886 acc 0.4161 Val loss 1.5137 acc 0.4529
[Source] Epoch 7/50 Train loss 1.5432 acc 0.4342 Val loss 1.4606 acc 0.4658
[Source] Epoch 8/50 Train loss 1.5025 acc 0.4503 Val loss 1.4534 acc 0.4714
[Source] Epoch 9/50 Train loss 1.4550 acc 0.4703 Val loss 1.4042 acc 0.4894
[Source] Epoch 10/50 Train loss 1.4105 acc 0.4885 Val loss 1.3863 acc 0.4966
[Source] Epoch 11/50 Train loss 1.3802 acc 0.4971 Val loss 1.3505 acc 0.5125
[Source] Epoch 12/50 Train loss 1.3389 acc

In [None]:
# Cell 2: Static baseline, oracle, and clean-stream adaptation.



# The static source model (has no target supervision and no adaptation).
static_model = QMemResNet18(num_classes=num_classes).to(device)
static_model.load_state_dict(base_state)

set_qmem_adapt_mode(static_model, adapt=False, use_memory=False)
_, static_metrics = evaluate(static_model, stream_loader, device, num_classes)

print("Static (No Adaptation) metrics on the CIFAR-10-C stream:", static_metrics)


# Fine-tunes the oracle on the full labeled CIFAR-10-C target set.
oracle_model = QMemResNet18(num_classes=num_classes).to(device)
oracle_model.load_state_dict(base_state)

stage1_finetune(
    oracle_model,
    support_loader=full_target_loader,
    device=device,
    num_classes=num_classes,
    epochs=5,
    lr=1e-4,
)

set_qmem_adapt_mode(oracle_model, adapt=False, use_memory=False)
_, oracle_metrics = evaluate(oracle_model, stream_loader, device, num_classes)
print("Oracle (Target Supervised) Metrics on CIFAR-10-C Stream:", oracle_metrics)
save_deployed_model(oracle_model, "./checkpoints/qmembn_oracle_cifar10c.pth")


# Q-MemBN's adaptation process on a clean stream (the full method).
adapt_model_clean = QMemResNet18(num_classes=num_classes).to(device)
adapt_model_clean.load_state_dict(base_state)
stage1_finetune(
    adapt_model_clean,
    support_loader=support_loader,
    device=device,
    num_classes=num_classes,
    epochs=5,
    lr=1e-4,
)
prototype_bank_clean = PrototypeBank(
    feature_dim=512,
    num_classes=num_classes,
    device=device,
)
prototype_bank_clean.initialize_from_support(adapt_model_clean, support_loader)

metrics_qmembn_clean = stage2_adapt_stream(
    adapt_model_clean,
    prototype_bank   = prototype_bank_clean,
    stream_loader    = stream_loader,
    device           = device,
    num_classes      = num_classes,
    lr               = 1e-3,
    entropy_fraction = 0.5,
    distance_threshold = 2.0,
    max_batches      = None,
)
save_deployed_model(adapt_model_clean, "./checkpoints/qmembn_adapt_clean_cifar10c.pth")


# Naive adaptation: has the same backbone, but has no entropy/prootype filtering.
adapt_model_naive_clean = QMemResNet18(num_classes=num_classes).to(device)
adapt_model_naive_clean.load_state_dict(base_state)
stage1_finetune(
    adapt_model_naive_clean,
    support_loader=support_loader,
    device=device,
    num_classes=num_classes,
    epochs=5,
    lr=1e-4,
)

prototype_bank_naive_clean = PrototypeBank(
    feature_dim=512,
    num_classes=num_classes,
    device=device,
)

prototype_bank_naive_clean.initialize_from_support(adapt_model_naive_clean, support_loader)


metrics_naive_clean = stage2_adapt_stream(
    adapt_model_naive_clean,
    prototype_bank   = prototype_bank_naive_clean,
    stream_loader    = stream_loader,
    device           = device,
    num_classes      = num_classes,
    lr               = 1e-3,
    entropy_fraction = 1.0, # see: entropy filtering.
    distance_threshold = 1e9, # see: prototype filtering.
    max_batches      = None,
)

save_deployed_model(adapt_model_naive_clean, "./checkpoints/qmembn_naive_clean_cifar10c.pth")


Static (No Adaptation) metrics on the CIFAR-10-C stream: {'accuracy': 0.6590954773869346, 'precision_macro': 0.6696963906288147, 'recall_macro': 0.6590954661369324, 'f1_macro': 0.6576816439628601, 'rmse': 0.21556910872459412}
[Stage I] Epoch 1/5, loss 1.1290
[Stage I] Epoch 2/5, loss 1.0829
[Stage I] Epoch 3/5, loss 1.0673
[Stage I] Epoch 4/5, loss 1.0531
[Stage I] Epoch 5/5, loss 1.0447
Oracle (target-supervised) metrics on CIFAR-10-C stream: {'accuracy': 0.7, 'precision_macro': 0.6992928981781006, 'recall_macro': 0.699999988079071, 'f1_macro': 0.6992831230163574, 'rmse': 0.20176437497138977}
[Deploy] Saved model to ./checkpoints/qmembn_oracle_cifar10c.pth
[Stage I] Epoch 1/5, loss 1.5752
[Stage I] Epoch 2/5, loss 1.5222
[Stage I] Epoch 3/5, loss 1.6831
[Stage I] Epoch 4/5, loss 1.6728
[Stage I] Epoch 5/5, loss 1.6601
[Stage II] Final metrics on adapted stream: {'accuracy': 0.6678391959798995, 'precision_macro': 0.6743624806404114, 'recall_macro': 0.6678391695022583, 'f1_macro': 0.666

In [None]:
# Cell 3: Robustness to malicious inputs, latency, and objective summaries.



class PoisonedStreamDataset(Dataset):
    def __init__(self, base_subset, poison_fraction=0.2):
        self.base_subset = base_subset
        self.n = len(base_subset)
        num_poison = int(self.n * poison_fraction)
        perm = torch.randperm(self.n)
        self.poison_idx = set(perm[:num_poison].tolist())

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        img, label = self.base_subset[idx]
        if idx in self.poison_idx:
            noise = torch.randn_like(img) * 2.5
            img = img + noise
        return img, label

poisoned_stream_dataset = PoisonedStreamDataset(stream_subset, poison_fraction=0.2)
poisoned_stream_loader  = DataLoader(poisoned_stream_dataset, batch_size=64, shuffle=False)


# Q-MemBN+ on a poisoned stream. How does it do?
adapt_model_poison = QMemResNet18(num_classes=num_classes).to(device)
adapt_model_poison.load_state_dict(base_state)
stage1_finetune(
    adapt_model_poison,
    support_loader=support_loader,
    device=device,
    num_classes=num_classes,
    epochs=5,
    lr=1e-4,
)
prototype_bank_poison = PrototypeBank(
    feature_dim=512,
    num_classes=num_classes,
    device=device,
)
prototype_bank_poison.initialize_from_support(adapt_model_poison, support_loader)

metrics_qmembn_poison = stage2_adapt_stream(
    adapt_model_poison,
    prototype_bank   = prototype_bank_poison,
    stream_loader    = poisoned_stream_loader,
    device           = device,
    num_classes      = num_classes,
    lr               = 1e-3,
    entropy_fraction = 0.5,
    distance_threshold = 2.0,
    max_batches      = None,
)


# Naive adaptation on poisoned stream.
adapt_model_naive_poison = QMemResNet18(num_classes=num_classes).to(device)
adapt_model_naive_poison.load_state_dict(base_state)
stage1_finetune(
    adapt_model_naive_poison,
    support_loader=support_loader,
    device=device,
    num_classes=num_classes,
    epochs=5,
    lr=1e-4,
)
prototype_bank_naive_poison = PrototypeBank(
    feature_dim=512,
    num_classes=num_classes,
    device=device,
)
prototype_bank_naive_poison.initialize_from_support(adapt_model_naive_poison, support_loader)

metrics_naive_poison = stage2_adapt_stream(
    adapt_model_naive_poison,
    prototype_bank   = prototype_bank_naive_poison,
    stream_loader    = poisoned_stream_loader,
    device           = device,
    num_classes      = num_classes,
    lr               = 1e-3,
    entropy_fraction = 1.0,
    distance_threshold = 1e9,
    max_batches      = None,
)


# Model size (for my Objective 4).
param_model = QMemResNet18(num_classes=num_classes)
total_params = sum(p.numel() for p in param_model.parameters())
total_params_m = total_params / 1e6


# Latency for Stage-II adaptation at B=1 (for my Objective 4).
lat_model = QMemResNet18(num_classes=num_classes).to(device)
lat_model.load_state_dict(base_state)
stage1_finetune(
    lat_model,
    support_loader=support_loader,
    device=device,
    num_classes=num_classes,
    epochs=1,
    lr=1e-4,
)


prototype_bank_lat = PrototypeBank(
    feature_dim=512,
    num_classes=num_classes,
    device=device,
)
prototype_bank_lat.initialize_from_support(lat_model, support_loader)

if torch.cuda.is_available():
    torch.cuda.synchronize()
t0 = time.perf_counter()
_ = stage2_adapt_stream(
    lat_model,
    prototype_bank   = prototype_bank_lat,
    stream_loader    = stream_loader_b1,
    device           = device,
    num_classes      = num_classes,
    lr               = 1e-3,
    entropy_fraction = 0.5,
    distance_threshold = 2.0,
    max_batches      = 50,
)
if torch.cuda.is_available():
    torch.cuda.synchronize()
t1 = time.perf_counter()
avg_latency_ms = (t1 - t0) / 50.0 * 1000.0


# Summaries
acc_static = static_metrics["accuracy"]
acc_adapt  = metrics_qmembn_clean["accuracy"]
acc_oracle = oracle_metrics["accuracy"]

print("\n=== Objective 1: Static vs Q-MemBN+ on CIFAR-10-C ===")
print(f"Static accuracy:      {acc_static:.4f}")
print(f"Adapted accuracy:     {acc_adapt:.4f}")
print(f"Absolute gain:        {(acc_adapt - acc_static)*100:.2f} percentage points")

gap_static = acc_oracle - acc_static
gap_adapt  = acc_oracle - acc_adapt
if gap_static > 0:
    gap_reduction_pct = (1.0 - gap_adapt / gap_static) * 100.0
else:
    gap_reduction_pct = float("nan")

print("\n=== Objective 2: Gap to oracle (CIFAR-10-C) ===")
print(f"Oracle accuracy:      {acc_oracle*100:.2f} %")
print(f"Gap (oracle - static):  {gap_static*100:.2f} pp")
print(f"Gap (oracle - adapted): {gap_adapt*100:.2f} pp")
print(f"Gap reduction:          {gap_reduction_pct:.2f} %")

acc_naive_clean   = metrics_naive_clean["accuracy"]
acc_naive_poison  = metrics_naive_poison["accuracy"]
acc_qmembn_clean  = metrics_qmembn_clean["accuracy"]
acc_qmembn_poison = metrics_qmembn_poison["accuracy"]

drop_naive  = (acc_naive_clean  - acc_naive_poison)  * 100.0
drop_qmembn = (acc_qmembn_clean - acc_qmembn_poison) * 100.0

print("\n=== Objective 3: Robustness to 20% poisoned stream ===")
print(f"Naive adaptation drop:   {drop_naive:.2f} percentage points")
print(f"Q-MemBN+ drop:           {drop_qmembn:.2f} percentage points")

print("\n=== Objective 4: Efficiency ===")
print(f"Model size:              {total_params_m:.2f}M parameters")
print(f"Stage-II latency (B=1):  {avg_latency_ms:.2f} ms per batch")


[Stage I] Epoch 1/5, loss 1.5399
[Stage I] Epoch 2/5, loss 1.6663
[Stage I] Epoch 3/5, loss 1.4609
[Stage I] Epoch 4/5, loss 1.5733
[Stage I] Epoch 5/5, loss 1.3674
[Stage II] Final metrics on adapted stream: {'accuracy': 0.5516582914572864, 'precision_macro': 0.6218993663787842, 'recall_macro': 0.5516583323478699, 'f1_macro': 0.564264714717865, 'rmse': 0.26018691062927246}
[Stage I] Epoch 1/5, loss 1.6842
[Stage I] Epoch 2/5, loss 1.5306
[Stage I] Epoch 3/5, loss 1.6029
[Stage I] Epoch 4/5, loss 1.3970
[Stage I] Epoch 5/5, loss 1.6983
[Stage II] Final metrics on adapted stream: {'accuracy': 0.5458291457286432, 'precision_macro': 0.6163973808288574, 'recall_macro': 0.545829176902771, 'f1_macro': 0.5571845173835754, 'rmse': 0.25905847549438477}
[Stage I] Epoch 1/1, loss 1.8463
[Stage II] Final metrics on adapted stream: {'accuracy': 0.36, 'precision_macro': 0.40734848380088806, 'recall_macro': 0.36718255281448364, 'f1_macro': 0.3135073482990265, 'rmse': 0.2699425220489502}

=== Objectiv