<a href="https://colab.research.google.com/github/avkornaev/ICML-2025/blob/main/UQ_CIFAR-10N_ProCoSphere.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Uncertainty Quantification with CIFAR-10N and Ensembling

By *First name* *Second name*.

*Month, Day, 2025.*

## Problem Statement

Re-annotated versions of the CIFAR-10 and CIFAR-100 data which contains real-world human annotation errors. The noise patterns deviate from the classically assumed ones and what the new challenges are. The website of CIFAR-N is available at [cifar-10-100n
](https://github.com/UCSC-REAL/cifar-10-100n/tree/main) project.

# Preparation of simulation models

## Import and Install Libraries

In [1]:
# %pip install --upgrade torch torchvision

In [2]:
# !pip install pytorch-lightning clearml

In [3]:
#Pytorch modules
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision.datasets import CIFAR10
from torchvision import datasets, transforms, models
from torchvision.transforms import RandAugment
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from torch.utils.data import Subset  # <-- Fix missing import
#scipy
from scipy.stats import mode
#sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
# from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torchmetrics import CalibrationError
#Numpy
import numpy as np
from numpy.core.multiarray import _reconstruct
#Pandas
import pandas as pd
#Lightning & logging
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
#Data observation
from PIL import Image
import random
import os
# import sys
# import pickle
import requests
from pathlib import Path
#Plotting
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
#Logging
from clearml import Task

In [4]:
# Essential safety settings
torch.serialization.add_safe_globals([_reconstruct, np.ndarray])

# Disable ClearML's torch.load patching temporarily
from clearml.binding.frameworks import _patched_call
original_torch_load = torch.load
_patched_call._original_torch_load = original_torch_load

## Set the Models

### Simulation Settings

Check the current directory

In [5]:
os.getcwd() #returns the current working directory

'/project/ICML-2025/ICML-2025'

In [6]:
# Go one level up (outside the current directory)
parent_dir = os.path.join(os.getcwd(), os.pardir)

# Set the checkpoint path to a folder in the parent directory
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", os.path.join(parent_dir, "saved_models/"))
print(f'CHECKPOINT_PATH: {CHECKPOINT_PATH}')

os.makedirs(CHECKPOINT_PATH, exist_ok=True)

CHECKPOINT_PATH: /project/ICML-2025/ICML-2025/../saved_models/


Set the reproducibility options

In [7]:
# Function for setting the seed to implement parallel tests
SEEDS =  [42]#[42, 0, 17, 9, 3, 16, 2]
SEED = 42 # random seed by default
# pl.seed_everything(SEED)

# Determine the device (GPU if available, otherwise CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    pl.seed_everything(seed)
    g = torch.Generator()
    g.manual_seed(seed)
    return g


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

### Logging

To configure ClearML in your Colab environment, follow these steps:

---

*Step 1: Create a ClearML Account*
1. Go to the [ClearML website](https://clear.ml/).
2. Sign up for a free account if you don’t already have one.
3. Once registered, log in to your ClearML account.

---

*Step 2: Get Your ClearML Credentials*
1. After logging in, navigate to the **Settings** page (click on your profile icon in the top-right corner and select **Settings**).
2. Under the **Workspace** section, find your **+ Create new credentials**.
3. Copy these credentials for a Jupiter notebook into the code cell below.

---

*Step 3: Accessing the ClearML Dashboard*
1. Go to your ClearML dashboard (https://app.clear.ml).
2. Navigate to the **Projects** section to see your experiments.
3. Click on the experiment (e.g., `Lab_1`) to view detailed metrics, logs, and artifacts.

---

In [8]:
#Enter your code here to implement Step 2 of the logging instruction as it is shown below
%env CLEARML_WEB_HOST=https://app.clear.ml/
%env CLEARML_API_HOST=https://api.clear.ml
%env CLEARML_FILES_HOST=https://files.clear.ml
%env CLEARML_API_ACCESS_KEY=ZP02U03C6V5ER4K9VWRNZT7EWA5ZTV
%env CLEARML_API_SECRET_KEY=BtA5GXZufr6QGpaqhX1GSKPTvaCt56OLqaNqUGLNoxx2Ye8Ctwbui0Ln5OXVnzUgH4I

env: CLEARML_WEB_HOST=https://app.clear.ml/
env: CLEARML_API_HOST=https://api.clear.ml
env: CLEARML_FILES_HOST=https://files.clear.ml
env: CLEARML_API_ACCESS_KEY=ZP02U03C6V5ER4K9VWRNZT7EWA5ZTV
env: CLEARML_API_SECRET_KEY=BtA5GXZufr6QGpaqhX1GSKPTvaCt56OLqaNqUGLNoxx2Ye8Ctwbui0Ln5OXVnzUgH4I


### Dataset

Summary

In [9]:
DATASET = 'CIFAR10N' # dataset with the real-world noise
# Can be 'clean_label', 'worse_label', 'aggre_label', 'random_label1', 'random_label2', 'random_label3'
NOISE_TYPE = 'worse_label'
LS = 0.0

SIZE = 32 #image size 32 is original size, but a ViT needs 224
NUM_CLASSES = 10
CLASS_NAMES = ['plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

### Collect parameters

In [10]:
#Model parameters
LOSS_FUN = 'ProCoSphere' # 'CE','CELoss'(custom), 'N', 'B', etc.
ARCHITECTURE = 'ProCoSphereEncoder' # 'CNN, 'ResNet50', 'ViT', etc.
LAMBDA = 0.5
#Collect the parameters (hyperparams and others)
# im_size = SIZE if ARCHITECTURE == 'CNN' else 224
hparams = {
    "seed": SEED,
    "lr": 1E-3, #0.001,
    'weight_decay': 5.0E-4,
    "dropout": {'dropout_rate': 0.1, 'mc_samples': 10},
    "embed_dim": 10, #NUM_CLASSES, #128,   #Embedding size D
    "lambda_con": LAMBDA,
    "temperature": 0.1,  # For contrastive loss if needed
    "bs": 500, #256,#32,
    "num_workers": 10,
    "num_epochs": 10,
    "warmup_epochs": 0,
    "selective_sampling": True,
    "selection_threshold": 0.1,
    "criterion": LOSS_FUN,
    "architecture": ARCHITECTURE,
    "freeze": False,
    "train_ratio": 0.9,
    "im_size": SIZE,
    "mean": [0.4914, 0.4822, 0.4465],
    "std": [0.2470, 0.2435, 0.2616],
    'randResCrop': {'size': (SIZE, SIZE), 'scale': (0.8, 1.0), 'ratio': (0.9, 1.1)},
    'label_smoothing': LS,
    "n_classes": NUM_CLASSES,
    "noise_path": './data/CIFAR-10_human.pt',
    "noise_type": NOISE_TYPE,  # Can be 'clean_label', 'worse_label', 'aggre_label', etc.
    "resume_checkpoint": None,#'/project/ICML-2025/saved_models/arch_ProCoSphereEncoder_loss_ProCoSphere_lambda_0.5_seed_9_noise_clean_label-v1.ckpt',
}

#Visualization
vis_params = {
    'fig_size': 5,
    'num_samples': 5,
    'num_bins': 50,
}

## Functions

### Lightning

Data module

In [11]:
def download_file(url, save_path):
    """Download a file from a URL and save it to the specified path."""
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)  # Ensure directory exists
        with open(save_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"File downloaded and saved to {save_path}")
    else:
        raise Exception(f"Failed to download file from {url}. Status code: {response.status_code}")

In [12]:
class CIFAR10(datasets.CIFAR10):
    """CIFAR10 dataset with noisy labels and dual views"""
    def __init__(self, root, train=True, transform_clf=None, transform_agn=None,
                 target_transform=None, download=False, noise_type=None,
                 noise_path=None, is_human=True):
        super().__init__(root, train=train, transform=None,  # Disable default transform
                         target_transform=target_transform, download=download)
        self.noise_type = noise_type
        self.noise_path = noise_path
        self.is_human = is_human
        self.transform_clf = transform_clf
        self.transform_agn = transform_agn

        if self.train and self.noise_type is not None:
            self.load_noisy_labels()

    def load_noisy_labels(self):
        from numpy.core.multiarray import _reconstruct
        import torch.serialization
        torch.serialization.add_safe_globals([_reconstruct])

        noise_file = torch.load(self.noise_path, map_location='cpu', weights_only=False)
        if isinstance(noise_file, dict):
            if "clean_label" in noise_file.keys():
                clean_label = torch.tensor(noise_file['clean_label'])
                assert torch.sum(torch.tensor(self.targets) - clean_label) == 0
                print(f'Loaded {self.noise_type} from {self.noise_path}.')
                print(f'Noise rate: {1 - np.mean(clean_label.numpy() == noise_file[self.noise_type])}')
            self.noisy_labels = noise_file[self.noise_type].reshape(-1)
        else:
            raise Exception('Invalid noise file format')

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        # Apply dual transforms
        if self.transform_clf:
            clf_view = self.transform_clf(img)
        if self.transform_agn:
            agn_view = self.transform_agn(img)

        # Apply noisy labels if training
        if self.train and self.noise_type is not None:
            target = self.noisy_labels[index]

        return {
            'clf_view': clf_view,
            'agn_view': agn_view,
            'target': target,
            'index': index
        }

    def __len__(self):
        return len(self.data)

In [13]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, params):
        super().__init__()
        self.save_hyperparameters(params)
        self.seed = params['seed']
        self.batch_size = params['bs']
        self.num_workers = min(params['num_workers'], 4)
        self.mean = params['mean']
        self.std = params['std']
        self.train_ratio = params['train_ratio']
        self.rand_res_crop = params['randResCrop']
        self.noise_path = params.get('noise_path', './data/CIFAR-10_human.pt')
        self.noise_type = params.get('noise_type', 'worse_label')
        self.im_size = params.get('im_size', 32)
        self.embed_dim = params.get('embed_dim', 128)
        self.full_train = None
        self.original_train_indices = None
        self.original_val_indices = None
        self.train_mask = None
        self.val_mask = None
        self.clean_labels = None
        self.noisy_labels = None

        # Transforms for classifier (train/val)
        self.clf_transform = transforms.Compose([
            transforms.RandomResizedCrop(
                size=self.rand_res_crop['size'],
                scale=self.rand_res_crop['scale'],
                ratio=self.rand_res_crop['ratio']
            ),
            transforms.RandomHorizontalFlip(),
            RandAugment(num_ops=2, magnitude=9),
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])

        # Transforms for agnostic views (contrastive learning)
        self.agn_transform = transforms.Compose([
            transforms.RandomResizedCrop(
                size=self.rand_res_crop['size'],
                scale=(0.3, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(3),
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])

        # Simplified test transform: Only RandomResizedCrop + ToTensor
        self.test_transform = transforms.Compose([
            transforms.RandomResizedCrop(
                size=self.rand_res_crop['size'],
                scale=(0.9, 1.0),
                ratio=self.rand_res_crop['ratio']
            ),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(self.mean, self.std)
        ])

        os.makedirs(os.path.dirname(self.noise_path), exist_ok=True)
        self._download_data()

    def _download_data(self):
        if not os.path.exists(self.noise_path):
            download_file(
                "https://github.com/UCSC-REAL/cifar-10-100n/raw/main/data/CIFAR-10_human.pt",
                self.noise_path
            )

    def prepare_data(self):
        datasets.CIFAR10(root='./data', train=True, download=True)
        datasets.CIFAR10(root='./data', train=False, download=True)

    def setup(self, stage=None):
        # Train + val
        full_train = CIFAR10(
            root='./data',
            train=True,
            transform_clf=self.clf_transform,
            transform_agn=self.agn_transform,
            noise_type=self.noise_type,
            noise_path=self.noise_path
        )
        self.full_train = full_train
        self.clean_labels = torch.tensor(full_train.targets)
        self.noisy_labels = full_train.noisy_labels

        # Create original splits
        full_size = len(self.full_train)
        g = seed_everything(self.seed)
        indices = torch.randperm(full_size, generator=g)
        train_size = int(full_size * self.train_ratio)
        self.original_train_indices = indices[:train_size]
        self.original_val_indices = indices[train_size:]
    
        # Initialize masks (all True initially)
        self.train_mask = torch.ones_like(self.original_train_indices, dtype=torch.bool)
        self.val_mask = torch.ones_like(self.original_val_indices, dtype=torch.bool)
        self._create_splits()

        # Test uses simplified transform (no noise or normalization)
        self.cifar10_test = CIFAR10(
            root='./data',
            train=False,
            transform_clf=self.test_transform,  # Use simplified transform
            transform_agn=self.test_transform,  # Optional: Can be different
            noise_type=None,
            noise_path=self.noise_path
        )

    def _create_splits(self):
        """Create subsets using current masks"""
        active_train = Subset(self.full_train, 
                            self.original_train_indices[self.train_mask])
        active_val = Subset(self.full_train, 
                          self.original_val_indices[self.val_mask])
        
        self.cifar10_train = active_train
        self.cifar10_val = active_val
        print(f"\nActive training samples: {len(self.cifar10_train)}/{len(self.original_train_indices)}")
        print(f"Active validation samples: {len(self.cifar10_val)}/{len(self.original_val_indices)}")

    def update_masks(self, new_train_mask, new_val_mask):
        self.train_mask = new_train_mask.clone()
        self.val_mask = new_val_mask.clone()
        self._create_splits()

    def train_dataloader(self):
        return self._create_dataloader(self.cifar10_train, shuffle=True)

    def val_dataloader(self):
        return self._create_dataloader(self.cifar10_val)

    def test_dataloader(self):
        return self._create_dataloader(self.cifar10_test)

    def _create_dataloader(self, dataset, shuffle=False):
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            worker_init_fn=seed_worker,
            shuffle=shuffle,
            persistent_workers=self.num_workers > 0,
            pin_memory=True,
            drop_last=shuffle
        )

In [14]:
def robust_load_pl_model(checkpoint_path, model_class, device, **init_kwargs):
    """Universal PyTorch Lightning model loader that handles:
    - PyTorch 2.6+ security restrictions
    - ClearML interference
    - Numpy array compatibility
    """
    # Bypass all patching and load raw checkpoint
    checkpoint = original_torch_load(
        checkpoint_path,
        map_location=device,
        weights_only=False
    )
    
    # Reconstruct model manually
    model = model_class(**init_kwargs)
    
    # Handle both regular and PyTorch Lightning checkpoint formats
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    
    # Load state dict with strict=False for compatibility
    model.load_state_dict(state_dict, strict=False)
    return model.to(device)

Training module

In [15]:
class train_model(pl.LightningModule):
    def __init__(self, model=None, loss=None, hparams=hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.model = model
        self.loss_fn = loss
        self.automatic_optimization = False
        self.training_losses = []
        self.training_indices = []
        self.validation_losses = []
        self.validation_indices = []
        self.current_phase = 'warmup'

     # Add this critical method PROPERLY INDENTED under the class
    def configure_optimizers(self):
        base_lr = self.hparams['lr']
        warmup_lr = self.hparams['lr']

        # 3 parameter groups
        params = [
            {'params': self.model.agnostic.parameters(), 'lr': warmup_lr},
            {'params': self.model.classifier.parameters(), 
            'lr': base_lr,
            'weight_decay': self.hparams['weight_decay']},
            {'params': self.loss_fn.parameters(), 'lr': base_lr}
        ]

        optimizer = torch.optim.AdamW(params)

        # # Define separate lambdas for each parameter group
        # def agnostic_lambda(epoch):
        #     return min(1.0, 0.01 + (0.99 * epoch / (self.hparams['warmup_epochs'] + 1.0e-6)))
    
        # classifier_lambda = lambda _: 1.0  # No warmup
        # loss_lambda = lambda _: 1.0  # No warmup

        # warmup = torch.optim.lr_scheduler.LambdaLR(
        #     optimizer,
        #     lr_lambda=[agnostic_lambda, classifier_lambda, loss_lambda]
        # )

        # cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        #     optimizer,
        #     T_max=self.hparams['num_epochs'] - self.hparams['warmup_epochs'],
        #     eta_min=self.hparams['lr']/100
        # )

        # scheduler = torch.optim.lr_scheduler.SequentialLR(
        #     optimizer,
        #     schedulers=[warmup, cosine],
        #     milestones=[self.hparams['warmup_epochs']]
        # )

        return [optimizer] #, [scheduler]
    
    def forward(self, x_clf, x_agn):
        return self.model.classifier(x_clf), self.model.agnostic(x_agn)

    def on_train_epoch_start(self):
        """Handle phase transitions and parameter freezing"""
        # Reset mask at beginning of training
        if self.current_epoch == 0:
            self.trainer.datamodule.train_mask = torch.ones_like(
                self.trainer.datamodule.train_mask, dtype=torch.bool
            )
            self.trainer.datamodule.val_mask = torch.ones_like(
                self.trainer.datamodule.val_mask, dtype=torch.bool
            )
            self.trainer.datamodule._create_splits()
        
        self.training_losses.clear()
        self.training_indices.clear()
        self.validation_losses.clear()
        self.validation_indices.clear()
        # Add initialization of sample indices list
        self.sample_indices = []  # <-- FIX HERE
        if self.current_epoch < self.hparams['warmup_epochs']:
            self.current_phase = 'warmup'
            # Freeze classifier, train agnostic encoder
            for param in self.model.classifier.parameters():
                param.requires_grad = False
            for param in self.model.agnostic.parameters():
                param.requires_grad = True
        else:
            self.current_phase = 'joint'
            # Unfreeze classifier for joint training
            for param in self.model.classifier.parameters():
                param.requires_grad = True
            for param in self.model.agnostic.parameters():
                param.requires_grad = True
        self.loss_fn.set_phase(self.current_phase)

    def _shared_step(self, batch, stage):
        """Unified processing for validation/test stages"""
        x_clf = batch['clf_view']
        x_agn = batch['agn_view']
        y = batch['target']

        # Always use full model for validation/test
        clf_out, agn_out = self(x_clf, x_agn)
        loss = self.loss_fn(clf_out, agn_out, y)

        with torch.no_grad():
            # Metrics calculation
            var_clf = clf_out['log_var'].exp().mean()
            var_agn = agn_out['log_var'].exp().mean()
            similarity = F.cosine_similarity(clf_out['embed'], agn_out['embed']).mean()

            logs = {
                f'{stage}_loss': loss,
                f'{stage}_var_clf': var_clf,
                f'{stage}_var_agn': var_agn,
                f'{stage}_sim': similarity
            }

            if stage == 'test':
                preds = clf_out['logits'].argmax(1)
                logs[f'{stage}_acc'] = (preds == y).float().mean()

        self.log_dict(logs, prog_bar=(stage == 'train'), on_epoch=True)
        return loss

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        opt.zero_grad()

        if self.current_phase == 'warmup':
            # Warmup phase: contrastive learning between dual views
            view1 = batch['agn_view']
            view2 = batch['clf_view']
            
            # Both views through agnostic encoder
            agn_out1 = self.model.agnostic(view1)
            agn_out2 = self.model.agnostic(view2)
            
            # Compute contrastive loss
            # total_loss = self.loss_fn(agn_out1, agn_out2, batch['target'])
            loss_con = self.loss_fn._contrastive_loss(agn_out1['embed'], 
                                        agn_out2['embed'],
                                        agn_out1['log_var'].exp())
            total_loss = loss_con
        else:
            # Joint training phase
            clf_out, agn_out = self(batch['clf_view'], batch['agn_view'])
            total_loss = self.loss_fn(clf_out, agn_out, batch['target'])
            
            # Compute per-sample classification loss for selective sampling
            with torch.no_grad():
                per_sample_cls_loss = self.loss_fn.classification_loss_per_sample(
                    clf_out['logits'], batch['target']
                )
            
            self.training_losses.append(per_sample_cls_loss.detach().cpu())
            self.sample_indices.append(batch['index'].cpu())

        self.manual_backward(total_loss.mean())
        opt.step()      
        return total_loss.mean()
    
    def on_train_epoch_end(self):
        if self.current_phase == 'joint' and self.hparams['selective_sampling']:
            if not self.training_losses:
                return
        
            # Process TRAIN mask first
            losses = torch.cat(self.training_losses)
            all_indices = torch.cat(self.sample_indices)
            gathered_losses = self.all_gather(losses).flatten()
            gathered_indices = self.all_gather(all_indices).flatten()

            new_mask = None
            new_val_mask = self.trainer.datamodule.val_mask.clone()

            if self.trainer.is_global_zero:
                dm = self.trainer.datamodule
                current_train_mask = dm.train_mask
                original_train_indices = dm.original_train_indices

                # Create mapping from original index to its position in original_train_indices
                index_to_pos = {idx.item(): pos for pos, idx in enumerate(original_train_indices)}

                # Convert subset indices (original indices) to positions in original_train_indices
                subset_indices_np = gathered_indices.cpu().numpy()
                subset_positions = [index_to_pos.get(idx, -1) for idx in subset_indices_np]

                # Filter valid positions (indices present in original_train_indices)
                valid_mask = np.array([pos != -1 for pos in subset_positions])
                valid_positions = np.array(subset_positions)[valid_mask]
                valid_losses = gathered_losses[valid_mask]

                # Initialize loss tensor with infinity (inactive samples have high loss)
                per_sample_loss = torch.full(
                    (len(original_train_indices),),  # Correct tuple format
                    float('inf'), 
                    device=gathered_losses.device
                )
                per_sample_loss[valid_positions] = valid_losses

                # Compute threshold based on active samples' losses
                threshold = torch.quantile(valid_losses, self.hparams['selection_threshold'])
            
                # Update train mask: keep samples where loss < threshold
                new_mask = (per_sample_loss < threshold).cpu()

                # Process VALIDATION mask
                if self.validation_losses:
                    val_losses = torch.cat(self.validation_losses)
                    val_indices = torch.cat(self.validation_indices)
                    gathered_val_losses = self.all_gather(val_losses).flatten()
                    gathered_val_indices = self.all_gather(val_indices).flatten()

                    val_subset_indices_np = gathered_val_indices.cpu().numpy()
                    val_subset_positions = [index_to_pos.get(idx, -1) for idx in val_subset_indices_np]
            
                    val_valid_mask = np.array([pos != -1 for pos in val_subset_positions])
                    val_valid_positions = np.array(val_subset_positions)[val_valid_mask]
                    val_valid_losses = gathered_val_losses[val_valid_mask]

                    # Key fix: Add empty tensor check for validation losses
                    if len(val_valid_losses) > 0:
                        per_val_loss = torch.full(
                            (len(dm.original_train_indices),),
                            float('inf'),
                            device=gathered_val_losses.device
                        )
                        per_val_loss[val_valid_positions] = val_valid_losses
                        val_thresh = torch.quantile(val_valid_losses, self.hparams['selection_threshold'])
                        new_val_mask = (per_val_loss < val_thresh).cpu()
                    else:
                        # Handle empty validation losses case
                        new_val_mask = self.val_mask.clone()
                        val_thresh = torch.tensor(float('inf'))
                else:
                    # No validation losses collected
                    new_val_mask = self.val_mask.clone()
                    val_thresh = torch.tensor(float('inf'))

                # Apply new masks
                dm.update_masks(new_mask, new_val_mask)
                self.log('train/selection_ratio', new_mask.float().mean())
                self.log('val/selection_ratio', new_val_mask.float().mean())

            # Reset collections
            self.training_losses.clear()
            self.training_indices.clear()
            self.validation_losses.clear()
            self.validation_indices.clear()

    def validation_step(self, batch, batch_idx):
        # Perform standard validation step
        loss = self._shared_step(batch, 'val')

        if self.current_phase == 'joint' and self.hparams['selective_sampling']:
            # Compute per-sample losses for validation set selection
            with torch.no_grad():
                x_clf = batch['clf_view']
                x_agn = batch['agn_view']
                y = batch['target']
        
                # Get model outputs
                clf_out, agn_out = self(x_clf, x_agn)
        
                # Calculate per-sample classification loss (same as training)
                per_sample_loss = self.loss_fn.classification_loss_per_sample(
                    clf_out['logits'], y
                )
                per_val_loss = torch.full(
                    (len(self.trainer.datamodule.original_train_indices),),  # Corrected line
                    float('inf'), 
                    device=per_sample_loss.device
                )
        
                # Store losses and indices for mask update
                self.validation_losses.append(per_sample_loss.detach().cpu())
                self.validation_indices.append(batch['index'].cpu())
    
        return loss
    # def validation_step(self, batch, batch_idx):
    #     return self._shared_step(batch, 'val')

    def test_step(self, batch, batch_idx):
        return self._shared_step(batch, 'test')

In [16]:
hparams

{'seed': 42,
 'lr': 0.001,
 'weight_decay': 0.0005,
 'dropout': {'dropout_rate': 0.1, 'mc_samples': 10},
 'embed_dim': 10,
 'lambda_con': 0.5,
 'temperature': 0.1,
 'bs': 500,
 'num_workers': 10,
 'num_epochs': 10,
 'warmup_epochs': 0,
 'selective_sampling': True,
 'selection_threshold': 0.1,
 'criterion': 'ProCoSphere',
 'architecture': 'ProCoSphereEncoder',
 'freeze': False,
 'train_ratio': 0.9,
 'im_size': 32,
 'mean': [0.4914, 0.4822, 0.4465],
 'std': [0.247, 0.2435, 0.2616],
 'randResCrop': {'size': (32, 32), 'scale': (0.8, 1.0), 'ratio': (0.9, 1.1)},
 'label_smoothing': 0.0,
 'n_classes': 10,
 'noise_path': './data/CIFAR-10_human.pt',
 'noise_type': 'worse_label',
 'resume_checkpoint': None}

### Models

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def call_bn(bn, x):
    return bn(x)

class CNNFeatureExtractor(nn.Module):
    """
    Feature extractor based on your CNN, modified to output embeddings for ProCoSphere.
    """
    def __init__(self, input_channel=3, dropout_rate=0.25):
        super().__init__()
        self.dropout_rate = dropout_rate

        self.c1 = nn.Conv2d(input_channel, 128, kernel_size=3, stride=1, padding=1)
        self.c2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.c3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        self.c4 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.c5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.c6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.c7 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=0)
        self.c8 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=0)
        self.c9 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=0)

        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)
        self.bn5 = nn.BatchNorm2d(256)
        self.bn6 = nn.BatchNorm2d(256)
        self.bn7 = nn.BatchNorm2d(512)
        self.bn8 = nn.BatchNorm2d(256)
        self.bn9 = nn.BatchNorm2d(128)

    def forward(self, x):
        h = F.leaky_relu(call_bn(self.bn1, self.c1(x)), negative_slope=0.01)
        h = F.leaky_relu(call_bn(self.bn2, self.c2(h)), negative_slope=0.01)
        h = F.leaky_relu(call_bn(self.bn3, self.c3(h)), negative_slope=0.01)
        h = F.max_pool2d(h, kernel_size=2, stride=2)
        h = F.dropout2d(h, p=self.dropout_rate)

        h = F.leaky_relu(call_bn(self.bn4, self.c4(h)), negative_slope=0.01)
        h = F.leaky_relu(call_bn(self.bn5, self.c5(h)), negative_slope=0.01)
        h = F.leaky_relu(call_bn(self.bn6, self.c6(h)), negative_slope=0.01)
        h = F.max_pool2d(h, kernel_size=2, stride=2)
        h = F.dropout2d(h, p=self.dropout_rate)

        h = F.leaky_relu(call_bn(self.bn7, self.c7(h)), negative_slope=0.01)
        h = F.leaky_relu(call_bn(self.bn8, self.c8(h)), negative_slope=0.01)
        h = F.leaky_relu(call_bn(self.bn9, self.c9(h)), negative_slope=0.01)

        h = F.avg_pool2d(h, kernel_size=h.data.shape[2])
        h = h.view(h.size(0), h.size(1))  # Flatten to (B, C)

        return h  # feature dimension is 128


In [18]:

class ProCoSphereEncoder(nn.Module):
    """
    Updated encoder using the CNNFeatureExtractor.
    """
    def __init__(self, input_channel=3, n_outputs=10, embed_dim=128, dropout_rate=0.25, hparams=None):
        super().__init__()
        self.hparams = hparams or {}
        self.backbone = CNNFeatureExtractor(input_channel=input_channel, dropout_rate=dropout_rate)

        self.feat_dim = 128  # after CNN feature extractor

        # Heads
        self.class_head = nn.Linear(embed_dim, n_outputs) if n_outputs > 0 else None
        self.var_head = nn.Linear(self.feat_dim, 1)
        self.embed_head = nn.Sequential(
            nn.Linear(self.feat_dim, embed_dim),
            nn.ReLU(),
            nn.LayerNorm(embed_dim)
        )

    def forward(self, x):
        features = self.backbone(x)

        outputs = {
            'log_var': self.var_head(features),
            'embed': F.normalize(self.embed_head(features), p=2, dim=-1)
        }
        if self.class_head is not None:
            outputs['logits'] = self.class_head(outputs['embed'])
        return outputs


class ProCoSphereDual(nn.Module):
    """
    Dual-encoder using ProCoSphereEncoder with CNN backbone.
    """
    def __init__(self, hparams):
        super().__init__()
        n_classes = hparams.get('n_classes', 10)
        embed_dim = hparams.get('embed_dim', 128)
        dropout_rate = hparams.get('dropout_rate', 0.25)

        self.classifier = ProCoSphereEncoder(
            input_channel=hparams.get('input_channel', 3),
            n_outputs=n_classes,
            embed_dim=embed_dim,
            dropout_rate=dropout_rate,
            hparams=hparams
        )
        self.agnostic = ProCoSphereEncoder(
            input_channel=hparams.get('input_channel', 3),
            n_outputs=0,
            embed_dim=embed_dim,
            dropout_rate=dropout_rate,
            hparams=hparams
        )

    def forward(self, x_j, x_i):
        out_J = self.agnostic(x_j)
        out_I = self.classifier(x_i)
        return out_J, out_I

### Loss functions

In [19]:
class ProCoSphereLoss(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.num_classes = hparams['n_classes']
        self.embed_dim = hparams['embed_dim']
        self.lambda_con = hparams['lambda_con']
        self.training_phase = 'warmup'  # Initialize with warmup phase
        
        # Temperature parameters with constrained initialization
        self.temp_gain = nn.Parameter(torch.tensor(0.0))  # Start with neutral gain
        self.temp_bias = nn.Parameter(torch.tensor(0.1))  # Start with low base temperature
        
        # Variance constraints
        self.log_var_min = -7
        self.log_var_max = 4
        
        # Label smoothing parameters
        self.smoothing = hparams.get('label_smoothing', 0.0)
        self.contrast_smoothing = 0.01
        
        # Loss components weights
        self.warmup_weight = 1.0
        self.cls_weight = 1.0
        self.con_weight = hparams['lambda_con']

    def set_phase(self, phase):
        """Dynamically adjust loss components based on training phase"""
        self.training_phase = phase
        if phase == 'warmup':
            self.warmup_weight = 1.0
            self.cls_weight = 0.0
            self.con_weight = 1.0
        else:
            self.warmup_weight = 0.0
            self.cls_weight = 1.0
            self.con_weight = hparams['lambda_con']

    def _contrastive_loss(self, emb1, emb2, variances):
        """General contrastive loss implementation"""
        # Normalize embeddings
        emb1 = F.normalize(emb1, p=2, dim=-1)
        emb2 = F.normalize(emb2, p=2, dim=-1)
        
        # Compute similarity matrix
        sim_matrix = emb1 @ emb2.t()
        
        # Adaptive temperature calculation
        temp_per_sample = (self.temp_gain.exp() * variances + self.temp_bias.abs())
        temp_per_sample = temp_per_sample.squeeze().clamp(0.01, 100.0)
        temp_matrix = torch.sqrt(temp_per_sample.unsqueeze(0) * temp_per_sample.unsqueeze(1))
        
        # Stabilized similarity
        scaled_sim = sim_matrix / temp_matrix.clamp(min=1e-6)
        logits_max, _ = torch.max(scaled_sim, dim=1, keepdim=True)
        logits_stable = scaled_sim - logits_max.detach()
        # logits_stable = sim_matrix
        
        # Symmetric contrastive loss
        targets = torch.arange(emb1.size(0), device=emb1.device)
        loss = (F.cross_entropy(logits_stable, targets, label_smoothing=self.contrast_smoothing) +
               F.cross_entropy(logits_stable.t(), targets, label_smoothing=self.contrast_smoothing)) / 2
               
        return loss #torch.clamp(loss, 0.0, 10.0)

    def classification_loss_per_sample(self, logits, y):  # Renamed from _classification_loss
        """Compute per-sample classification loss (without contrastive component)."""
        yoh = torch.zeros_like(logits).scatter(1, y.unsqueeze(1), 1)
        yoh = yoh * (1 - self.smoothing) + self.smoothing / self.num_classes
        
        # Forward CE per sample
        log_probs = F.log_softmax(logits, dim=1)
        ce_forward = -torch.sum(yoh * log_probs, dim=1)  # [batch_size]
        
        # Reverse CE per sample
        probs = F.softmax(logits, dim=1)
        log_yoh = torch.log(yoh.clamp(min=1e-8))
        reverse_ce = -torch.sum(probs * log_yoh, dim=1)  # [batch_size]
        
        return ce_forward + reverse_ce  # [batch_size]

    def forward(self, clf_out, agn_out, y, reduction='mean'):
        if self.training_phase == 'warmup':
            # Warmup phase: contrastive loss between dual agnostic views
            var_agn = torch.exp(torch.clamp(agn_out['log_var'], 
                                 self.log_var_min, self.log_var_max))
            loss_con = self._contrastive_loss(
                clf_out['embed'],  # agn_view1 embeddings
                agn_out['embed'],   # agn_view2 embeddings
                var_agn
            )
            total_loss = self.warmup_weight * loss_con
        else:
            # Joint training phase
            # Classification loss
            loss_cls = self.classification_loss_per_sample(clf_out['logits'], y)

            # Contrastive loss
            var_agn = torch.exp(torch.clamp(agn_out['log_var'],
                               self.log_var_min, self.log_var_max))
            loss_con = self._contrastive_loss(
                clf_out['embed'],  # classifier embeddings
                agn_out['embed'],  # agnostic embeddings
                var_agn
            )
            
            total_loss = (self.cls_weight * loss_cls + 
                         self.con_weight * loss_con)
        if reduction == 'none':
            return total_loss  # return per-sample losses
        else:
            return total_loss.mean()  # default behavior

### Models zoo

Architectures and loss functions

In [20]:
def get_arch_and_loss(hparams):
    if hparams['criterion'] == 'ProCoSphere':
        return ProCoSphereDual(hparams), ProCoSphereLoss(hparams)

### Metrics

In [21]:
def calc_metrics(dataloader, model, hparams):
    """Compute comprehensive metrics including uncertainty-aware scores"""
    model.eval()
    device = next(model.parameters()).device

    # Initialize collectors
    preds, labels, probs = [], [], []
    var_clfs, var_agns, similarities = [], [], []

    with torch.no_grad():
        for batch in dataloader:
            x_clf = batch['clf_view'].to(device)
            x_agn = batch['agn_view'].to(device)
            y = batch['target'].to(device)

            # Forward pass through both encoders
            clf_out, agn_out = model(x_clf, x_agn)

            # Get predictions and probabilities
            logits = clf_out['logits']
            prob = torch.softmax(logits, dim=1)
            pred = torch.argmax(logits, dim=1)

            # Collect basic metrics
            preds.append(pred.cpu())
            labels.append(y.cpu())
            probs.append(prob.cpu())

            # Collect uncertainty metrics
            var_clfs.append(clf_out['log_var'].exp().cpu())
            var_agns.append(agn_out['log_var'].exp().cpu())

            # Compute embedding similarity
            sim = F.cosine_similarity(clf_out['embed'], agn_out['embed'])
            similarities.append(sim.cpu())

    # Concatenate all results
    preds = torch.cat(preds).numpy()
    labels = torch.cat(labels).numpy()
    probs = torch.cat(probs).numpy()
    var_clfs = torch.cat(var_clfs).numpy()
    var_agns = torch.cat(var_agns).numpy()
    similarities = torch.cat(similarities).numpy()

    # Calculate metrics
    results = {
        # Standard classification metrics
        'accuracy': accuracy_score(labels, preds),
        'precision': precision_score(labels, preds, average='macro'),
        'recall': recall_score(labels, preds, average='macro'),
        'f1': f1_score(labels, preds, average='macro'),

        # Uncertainty metrics
        'var_clf_mean': var_clfs.mean(),
        'var_clf_std': var_clfs.std(),
        'var_agn_mean': var_agns.mean(),
        'var_agn_std': var_agns.std(),
        'similarity_mean': similarities.mean(),
        'similarity_std': similarities.std(),

        # Calibration metrics
        'ece': CalibrationError(task='multiclass', 
                                num_classes=hparams['n_classes'])(torch.tensor(probs), torch.tensor(labels)).item()
    }

    # Variance correlation analysis
    results['var_correlation'] = np.corrcoef(var_clfs.flatten(), var_agns.flatten())[0,1]

    return results, preds, similarities

def compute_confidence(logits, var=None):
    """Enhanced confidence computation with uncertainty awareness"""
    probs = torch.softmax(logits, dim=1)
    conf, _ = torch.max(probs, dim=1)

    if var is not None:  # Adjust confidence by uncertainty
        conf *= torch.exp(-var)  # Higher variance reduces confidence
    return conf

def compute_certainty(clf_out, agn_out, model_type, num_classes):
    """Uncertainty quantification combining both encoders"""
    if model_type == 'B':
        return torch.sigmoid(clf_out[:, num_classes:]).squeeze()
    elif model_type in ['N', 'ProCoSphere']:
        # Combine uncertainties from both encoders
        var_clf = clf_out['log_var'].exp()
        var_agn = agn_out['log_var'].exp()
        return 1 / (1 + var_clf + var_agn)  # Combined certainty
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

In [22]:
def compute_embeddings_and_predictions(model, dataloader, device):
    model.eval()
    all_agn_embeds = []
    all_clf_embeds = []
    all_clf_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            x_clf = batch['clf_view'].to(device)
            x_agn = batch['agn_view'].to(device)
            y = batch['target'].to(device)
            
            # Get model outputs
            clf_out, agn_out = model(x_clf, x_agn)
            
            # Store embeddings and predictions
            all_agn_embeds.append(agn_out['embed'].cpu().numpy())
            all_clf_embeds.append(clf_out['embed'].cpu().numpy())
            all_clf_preds.append(clf_out['logits'].argmax(1).cpu().numpy())
            all_labels.append(y.cpu().numpy())
    
    return (
        np.concatenate(all_agn_embeds),
        np.concatenate(all_clf_embeds),
        np.concatenate(all_clf_preds),
        np.concatenate(all_labels)
    )

### Visualization
Note: needs collection of the loss values for the each sample

In [23]:
# Plot image samples with top loss values
def top_losses_vis(vis_params, images, preds, labels, losses):
    num_imgs = vis_params['num_samples']
    top_loss_indices = np.argsort(losses)[-num_imgs:]

    plt.figure(figsize=(num_imgs*2, 2))
    for i, idx in enumerate(top_loss_indices):
        plt.subplot(1, num_imgs, i + 1)
        plt.imshow(images[idx].squeeze(), cmap='gray')
        plt.title(f'True: {labels[idx]}\nPred: {preds[idx]}\nLoss: {losses[idx]:.2f}')
        plt.axis('off')
    plt.show()

# Plot confusion matrix
def conf_mat(figsize,class_names=None):
    plt.figure(figsize)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

# Ensembling
This approach is expected to give a robust ensemble model that leverages the diversity introduced by different seeds, potentially improving the overall accuracy on the test set.

## Create Dataset and Data Loaders

Initialization of the dataset, the dataloader, and the training module

In [24]:
data_module = CIFAR10DataModule(hparams)

## Train the Ensemble

Loop over different seeds

In [25]:
# List to store predictions from each model
all_predictions = []
all_confidences = []
all_certainties = []
all_best_model_paths = []

In [26]:
for seed in SEEDS:
    # Set seed for reproducibility
    seed_everything(seed)

    # Reinitialize model
    arch, loss_fn = get_arch_and_loss(hparams)
    fname = f'arch_{ARCHITECTURE}_loss_{LOSS_FUN}_lambda_{LAMBDA}_seed_{seed}_noise_{NOISE_TYPE}'

    checkpoint_callback_img = ModelCheckpoint(
        monitor='val_loss',
        dirpath=CHECKPOINT_PATH,
        filename=fname,
        save_top_k=1,
        mode='min'
    )

    task = Task.init(project_name="ICML-2025",
                     task_name=fname)

    model = train_model(model=arch, loss=loss_fn)
    task.connect(model.hparams)

    trainer = Trainer(max_epochs=hparams['num_epochs'],
                      callbacks=[checkpoint_callback_img],
                      accelerator="auto", 
                      devices=1 #"auto"
                      )
    trainer.fit(model,
                data_module,
                ckpt_path=hparams.get('resume_checkpoint', None)
                )

    # Save path for later testing
    best_model_path = checkpoint_callback_img.best_model_path
    all_best_model_paths.append(best_model_path)

    if seed != SEEDS[-1]:
        task.close()
        del[model]

Seed set to 42


ClearML Task: created new task id=6ecbb02fcfac4435aef25779fb7312ab
2025-05-27 12:55:01,380 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/ccaa059e6de442b6abe578eab9e214c8/experiments/6ecbb02fcfac4435aef25779fb7312ab/output/log


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


2025-05-27 12:55:09,590 - clearml.model - INFO - Selected model id: 8515be3004ca474eae4a86e50e244266


Seed set to 42


Loaded worse_label from ./data/CIFAR-10_human.pt.
Noise rate: 0.40208

Active training samples: 45000/45000
Active validation samples: 5000/5000



Checkpoint directory /project/ICML-2025/saved_models exists and is not empty.

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type            | Params | Mode 
----------------------------------------------------
0 | model   | ProCoSphereDual | 8.9 M  | train
1 | loss_fn | ProCoSphereLoss | 2      | train
----------------------------------------------------
8.9 M     Trainable params
0         Non-trainable params
8.9 M     Total params
35.478    Total estimated model params size (MB)


Epoch 0:   0%|          | 0/90 [00:00<?, ?it/s]                            
Active training samples: 45000/45000
Active validation samples: 5000/5000
Epoch 0: 100%|██████████| 90/90 [00:21<00:00,  4.10it/s, v_num=399]

AttributeError: 'train_model' object has no attribute 'val_mask'

## Test the models and the ensemble of the models

In [None]:
all_best_model_paths

In [None]:
all_predictions = []
all_confidences = []
all_certainties_clf = []
all_certainties_agn = []
all_similarities = []

for best_model_path in all_best_model_paths:
    # Rebuild model arch and loss
    arch, loss_fn = get_arch_and_loss(hparams)

    # Load best checkpoint
    best_model = robust_load_pl_model(best_model_path,
                                      train_model,
                                      model=arch,
                                      loss=loss_fn,
                                      device=device,
                                      )
    best_model = best_model.to(device)

    # Enable MC Dropout if needed
    mc_samples = hparams["dropout"]['mc_samples']
    if mc_samples > 0:
        best_model.model.train()  # Enable dropout
    else:
        best_model.model.eval()

    predictions = []
    confidences = []
    similarities = []
    certainties_clf = []
    certainties_agn = []
    

    with torch.no_grad():
        for batch in data_module.test_dataloader():
            # Get both views from batch dictionary
            x_clf = batch['clf_view'].to(device)
            x_agn = batch['agn_view'].to(device)
            y = batch['target'].to(device)

            if mc_samples > 0:
                logits_samples_clf = []
                embed_samples_clf = []
                embed_samples_agn = []
                var_samples_clf = []
                var_samples_agn = []
                
                for _ in range(mc_samples):
                    # Forward pass through both encoders
                    clf_out, agn_out = best_model(x_clf, x_agn)
                    logits_samples_clf.append(clf_out['logits'])
                    embed_samples_clf.append(clf_out['embed'])
                    embed_samples_agn.append(agn_out['embed']) 
                    var_samples_clf.append(clf_out['log_var'].exp())  
                    var_samples_agn.append(agn_out['log_var'].exp())

                # Average predictions
                logits = torch.stack(logits_samples_clf).mean(0)
                probs = torch.nn.functional.softmax(logits, dim=1)
                embed_clf = torch.stack(embed_samples_clf).mean(0)
                embed_agn = torch.stack(embed_samples_agn).mean(0)
                var_clf = torch.stack(var_samples_clf).mean(0)
                var_agn = torch.stack(var_samples_agn).mean(0)
                
            else:
                # Single forward pass
                clf_out, agn_out = best_model(x_clf, x_agn)
                logits = clf_out['logits']
                probs = torch.nn.functional.softmax(logits, dim=1)
                embed_clf = clf_out['embed']
                embed_agn = agn_out['embed']
                var_clf = clf_out['log_var'].exp()
                var_agn = agn_out['log_var'].exp()

            # Collect predictions and uncertainties
            preds = torch.argmax(logits, dim=1)
            confidence = probs.max(1)[0]
            similarity = F.cosine_similarity(embed_clf, embed_agn)

            predictions.append(preds.cpu().numpy())
            confidences.append(confidence.cpu().numpy())
            similarities.append(similarity.cpu().numpy())

            # Calculate variance-based certainty if using N/B loss
            if hparams['criterion'] in ['ProCoSphere']:
                certainties_clf.append(1.0 / (1.0 + var_clf.cpu().numpy().squeeze(-1)))
                certainties_agn.append(1.0 / (1.0 + var_agn.cpu().numpy().squeeze(-1)))

    # Post-process results
    predictions = np.concatenate(predictions)
    confidences = np.concatenate(confidences)
    similarities= np.concatenate(similarities)
    all_predictions.append(predictions)
    all_confidences.append(confidences)
    all_similarities.append(similarities)


    if hparams['criterion'] in ['ProCoSphere']:
        certainties_clf = np.concatenate(certainties_clf)
        certainties_agn = np.concatenate(certainties_agn)
        all_certainties_clf.append(certainties_clf)
        all_certainties_agn.append(certainties_agn)

In [None]:
from sklearn.neighbors import NearestNeighbors

def knn_majority_vote(agn_embeddings, clf_preds, k=5):
    nn = NearestNeighbors(n_neighbors=k+1, metric='cosine').fit(agn_embeddings)
    distances, indices = nn.kneighbors(agn_embeddings)
    
    # Exclude self-neighbor
    indices = indices[:, 1:]
    
    majority_preds = []
    for i in range(len(indices)):
        neighbor_preds = clf_preds[indices[i]]
        counts = np.bincount(neighbor_preds)
        majority = np.argmax(counts)
        majority_preds.append(majority)
        
    return np.array(majority_preds)

In [None]:
def plot_rejection_curve(preds, confidences, labels, method):
    """Plots accuracy vs. uncertainty threshold (rejection curve) for model predictions.
    
    Args:
        preds: Array of model predictions (class indices)
        confidences: Array of confidence scores (0-1)
        labels: Ground truth labels
        method: Name of the method (for plot legend)
    
    Returns:
        Matplotlib figure showing:
        - Accuracy at different confidence thresholds
        - Percentage of retained samples at each threshold
    """
    thresholds = np.linspace(0, 1, 20)
    accuracies = []
    keep_ratios = []
    
    # Calculate accuracy and sample retention at each threshold
    for thresh in thresholds:
        mask = confidences >= thresh
        if mask.sum() > 0:  # Avoid division by zero
            acc = accuracy_score(labels[mask], preds[mask])
        else:
            acc = 0
        accuracies.append(acc)
        keep_ratios.append(mask.mean())
    
    # Plot configuration
    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, accuracies, label=f'Accuracy ({method})', linewidth=2)
    plt.plot(thresholds, keep_ratios, label=f'Retention ({method})', linestyle='--')
    
    plt.xlabel('Confidence Threshold', fontsize=12)
    plt.ylabel('Metric Value', fontsize=12)
    plt.title(f'Rejection Curve: {method}', fontsize=14, pad=20)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    return plt.gcf()

In [None]:
# Convert to numpy arrays with proper shapes
all_predictions = np.array(all_predictions)        # (num_models, num_samples)
all_confidences = np.array(all_confidences)        # (num_models, num_samples)
all_similarities = np.array(all_similarities)      # (num_models, num_samples)
all_certainties_clf = np.array(all_certainties_clf) # (num_models, num_samples)
all_certainties_agn = np.array(all_certainties_agn) # (num_models, num_samples)

# Get true labels
test_labels = np.array(data_module.cifar10_test.targets)

# Initialize metrics storage
metrics = {
    'majority': {'preds': None, 'confidence': None},
    'confidence': {'preds': None, 'confidence': None},
    'similarity': {'preds': None, 'confidence': None},
    'certainty_clf': {'preds': None, 'confidence': None},
    'certainty_agn': {'preds': None, 'confidence': None}
}

# 1. Majority Voting
metrics['majority']['preds'] = mode(all_predictions, axis=0)[0].squeeze()
metrics['majority']['confidence'] = all_confidences.max(axis=0)

# 2. Highest Confidence Selection
model_indices = np.argmax(all_confidences, axis=0)
metrics['confidence']['preds'] = all_predictions[model_indices, np.arange(all_predictions.shape[1])]
metrics['confidence']['confidence'] = all_confidences.max(axis=0)

# 3. Highest Similarity Selection
model_indices = np.argmax(all_similarities, axis=0)
metrics['similarity']['preds'] = all_predictions[model_indices, np.arange(all_predictions.shape[1])]
metrics['similarity']['confidence'] = all_similarities.max(axis=0)

# 4. Highest Classifier Certainty
model_indices = np.argmax(all_certainties_clf, axis=0)
metrics['certainty_clf']['preds'] = all_predictions[model_indices, np.arange(all_predictions.shape[1])]
metrics['certainty_clf']['confidence'] = all_certainties_clf.max(axis=0)

# 5. Highest Agnostic Certainty
model_indices = np.argmax(all_certainties_agn, axis=0)
metrics['certainty_agn']['preds'] = all_predictions[model_indices, np.arange(all_predictions.shape[1])]
metrics['certainty_agn']['confidence'] = all_certainties_agn.max(axis=0)

# Calculate metrics
results = {}
ece = CalibrationError(task='multiclass', num_classes=len(CLASS_NAMES))

for method in metrics:
    preds = metrics[method]['preds']
    confs = metrics[method]['confidence']
    
    # Convert to tensors for ECE calculation
    probs = np.zeros((len(preds), len(CLASS_NAMES)))
    probs[np.arange(len(preds)), preds] = confs
    
    results[method] = {
        'accuracy': accuracy_score(test_labels, preds),
        'ece': ece(torch.tensor(probs), torch.tensor(test_labels)).item()
    }
    # plot_rejection_curve(preds, confs, test_labels, method)

# Print results
print("Ensemble Evaluation Results:")
for method, vals in results.items():
    print(f"\n{method.replace('_', ' ').title()}:")
    print(f"Accuracy: {vals['accuracy']:.4f}")
    print(f"ECE: {vals['ece']:.4f}")

In [None]:
# Add to results analysis
for method in ['certainty_clf', 'certainty_agn']:
    corr = np.corrcoef(
        metrics[method]['confidence'],
        (metrics[method]['preds'] == test_labels).astype(float)
    )[0,1]
    results[method]['accuracy_certainty_corr'] = corr

print('Uncertainty correlation = ', corr)

In [None]:
# Calculate ensemble diversity using pairwise disagreement
disagreement = np.zeros((len(test_labels),))
for i in range(len(test_labels)):
    unique, counts = np.unique(all_predictions[:, i], return_counts=True)
    disagreement[i] = 1 - counts.max()/len(all_best_model_paths)
    
print(f"Average Disagreement: {disagreement.mean():.4f}")

In [None]:
# After loading models in test loop
agn_embeds, clf_embeds, clf_preds, labels = compute_embeddings_and_predictions(best_model, data_module.test_dataloader(), device)

# Calculate original classifier accuracy
original_acc = accuracy_score(labels, clf_preds)

# Calculate k-NN majority vote accuracy
knn_preds = knn_majority_vote(agn_embeds, clf_preds, k=100)
knn_acc = accuracy_score(labels, knn_preds)

print(f"Original Classifier Accuracy: {original_acc:.4f}")
print(f"k-NN Majority Vote Accuracy: {knn_acc:.4f}")

In [None]:
# Calculate k-NN majority vote accuracy
knn_preds = knn_majority_vote(clf_embeds, clf_preds, k=100)
knn_acc = accuracy_score(labels, knn_preds)

print(f"Original Classifier Accuracy: {original_acc:.4f}")
print(f"k-NN Majority Vote Accuracy: {knn_acc:.4f}")

In [None]:
task.close()

Visualization with t-SNE

In [None]:
def plot_tsne(embeddings, labels, title, class_names=CLASS_NAMES):
    plt.figure(figsize=(15,10))
    tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    # Create color map
    cmap = plt.get_cmap('tab10', len(class_names))
    
    # Plot with true labels (despite noise)
    for i, class_name in enumerate(class_names):
        idx = labels == i
        plt.scatter(embeddings_2d[idx, 0], embeddings_2d[idx, 1], 
                    c=[cmap(i)], label=class_name, alpha=0.6)
    
    plt.title(f't-SNE: {title}')
    plt.xlabel('TSNE-1')
    plt.ylabel('TSNE-2')
    plt.legend(bbox_to_anchor=(1.04,1), loc="upper left")
    plt.show()

def get_embeddings(model, dataloader, device):
    clf_embeds = []
    agn_embeds = []
    true_labels = []
    
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            x_clf = batch['clf_view'].to(device)
            x_agn = batch['agn_view'].to(device)
            y = batch['target'].to(device)
            
            # Get both embeddings
            clf_out, agn_out = model(x_clf, x_agn)
            
            clf_embeds.append(clf_out['embed'].cpu().numpy())
            agn_embeds.append(agn_out['embed'].cpu().numpy())
            true_labels.append(y.cpu().numpy())
    
    return (np.concatenate(clf_embeds),
            np.concatenate(agn_embeds),
            np.concatenate(true_labels))

# Get embeddings for test set
clf_emb, agn_emb, labels = get_embeddings(best_model, data_module.test_dataloader(), device)

# Random subset for visualization (faster)
subset = np.random.choice(len(labels), 1000, replace=False)
clf_sub = clf_emb[subset]
agn_sub = agn_emb[subset]
lab_sub = labels[subset]

# Visualize both embedding spaces
plot_tsne(clf_sub, lab_sub, "Classifier Encoder Embeddings")
plot_tsne(agn_sub, lab_sub, "Label-Agnostic Encoder Embeddings")

In [None]:
# Add to imports
from sklearn.manifold import TSNE

# After test loop, before plotting rejection curves
def plot_tsne(features, labels, title, class_names=CLASS_NAMES):
    """Plot t-SNE visualization with class coloring"""
    tsne = TSNE(n_components=2, random_state=SEED)
    features_2d = tsne.fit_transform(features)
    
    plt.figure(figsize=(10,8))
    scatter = plt.scatter(features_2d[:,0], features_2d[:,1], 
                         c=labels, cmap='tab10', alpha=0.6)
    plt.title(title)
    plt.legend(handles=scatter.legend_elements()[0],
               labels=class_names,
               title="Classes")
    plt.show()

# Collect embeddings and logits from test set
clf_embeds = []
agn_embeds = [] 
logits = []
probs = []
true_labels = []

with torch.no_grad():
    for batch in data_module.test_dataloader():
        x_clf = batch['clf_view'].to(device)
        x_agn = batch['agn_view'].to(device)
        y = batch['target'].to(device)
        
        clf_out, agn_out = best_model(x_clf, x_agn)
        
        clf_embeds.append(clf_out['embed'].cpu())
        agn_embeds.append(agn_out['embed'].cpu())
        logits.append(clf_out['logits'].cpu())
        probs.append(torch.softmax(clf_out['logits'], dim=1).cpu())
        true_labels.append(y.cpu())

# Concatenate all batches
clf_embeds = torch.cat(clf_embeds).numpy()
agn_embeds = torch.cat(agn_embeds).numpy() 
logits = torch.cat(logits).numpy()
probs = torch.cat(probs).numpy()
true_labels = torch.cat(true_labels).numpy()

# Create visualization grid
plt.figure(figsize=(20,16))

# 1. Classifier Embeddings
plt.subplot(221)
plot_tsne(clf_embeds, true_labels, "Classifier Embeddings Space")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")

# 2. Agnostic Embeddings  
plt.subplot(222)
plot_tsne(agn_embeds, true_labels, "Agnostic Embeddings Space")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")

# 3. Logits Space
plt.subplot(223)
plot_tsne(logits, true_labels, "Classifier Logits Space")
plt.xlabel("t-SNE 1") 
plt.ylabel("t-SNE 2")

# 4. Probability Space
plt.subplot(224)
plot_tsne(probs, true_labels, "Classifier Probability Space")
plt.xlabel("t-SNE 1")
plt.ylabel("t-SNE 2")

plt.tight_layout()
plt.show()