In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.func import functional_call, vmap, grad
from ResNet import ResNet
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gc
import os
import glob
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm
import plotly.graph_objects as go

device = "mps" if torch.backends.mps.is_built() \
    else "cuda" if torch.cuda.is_available() else "cpu"

print(device)

torch.manual_seed(3)

mps


<torch._C.Generator at 0x11d0cfe10>

In [2]:
model=ResNet(num_classes=100,n=9).to(device)

In [3]:
# Define standard data transforms for CIFAR100
# CIFAR100 mean and std:
# mean = [0.5071, 0.4867, 0.4408], std = [0.2675, 0.2565, 0.2761]

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])


In [4]:
# Load datasets
train_dataset=torchvision.datasets.CIFAR100(root='./data',train=True,download=True,transform=train_transform)
test_dataset=torchvision.datasets.CIFAR100(root='./data',train=False,download=True,transform=test_transform)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
checkpoint_dir='checkpoints'
checkpoint_epoch = 30
checkpoint_path = os.path.join(checkpoint_dir, f'resnet_epoch_{checkpoint_epoch}.pth')

In [6]:
# Create one training example (x, y)
# Get a batch from the train_loader and pick the first example
x_batch, y_batch = next(iter(train_loader))
x = x_batch[0]
y = y_batch[0].item()

# move to device
x = x.unsqueeze(0).to(device)  # add batch dimension
y = torch.tensor([y], device=device)



In [7]:

# --- assume you already have these ---
# model: trained ResNet (e.g., ResNet-56 or ResNet-50)
# checkpoint_path: path to a saved checkpoint (e.g. "ckpt_epoch_60.pt")
# x, y: a single training example (tensor and its label)

# load weights from a checkpoint
model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
model.eval()  # important: use evaluation mode


# we want gradients w.r.t. last layer only
params = list(model.parameters())[-2:]  # usually weight, bias

# compute loss with gradient tracking enabled
with torch.set_grad_enabled(True):
    outputs = model(x)
    loss = F.cross_entropy(outputs, y, reduction='sum')

# gradient of loss w.r.t. last-layer parameters
grads = torch.autograd.grad(loss, params, create_graph=False, retain_graph=False)

# flatten and concatenate all gradients
flat_grads = torch.cat([g.reshape(-1) for g in grads if g is not None])

# compute squared L2 norm — the self-influence at this checkpoint
self_influence = (flat_grads ** 2).sum().item()

print(f"Self-influence for this example at checkpoint: {self_influence:.6f}")


Self-influence for this example at checkpoint: 51.145996


In [8]:
class SelfInfluence:
    """
    Computes TracInCP self-influence scores across multiple checkpoints.
    """
    def __init__(self, model, device='cpu', last_layer_only=True):
        self.model = model.to(device)
        self.device = device
        self.last_layer_only = last_layer_only

        # identify the target parameters (usually last layer)
        # params = list(model.parameters())
        # self.target_params = params[-2:] if last_layer_only else params

    def load_checkpoint(self, checkpoint_path):
        """Load model weights from a saved checkpoint."""
        state_dict = torch.load(checkpoint_path, map_location=self.device)['model_state_dict']
        self.model.load_state_dict(state_dict)
        self.model.eval()
        # refresh params from current model weights
        # self.func_model, self.params, self.buffers = make_functional_with_buffers(self.model)

        self.params = dict(self.model.named_parameters())
        self.buffers = dict(self.model.named_buffers())
        if self.last_layer_only:
            self.target_params = {
                k: v for k, v in list(self.params.items())[-2:]
            }
        else:
            self.target_params = self.params


    def _per_sample_grad(self, x, y):
        """Compute flattened gradient vector for a single sample."""
        out = self.model(x.unsqueeze(0))
        loss = F.cross_entropy(out, y.unsqueeze(0), reduction='sum')
        grads = torch.autograd.grad(loss, self.target_params,
                                    retain_graph=False, create_graph=False)
        flat = torch.cat([g.reshape(-1) for g in grads if g is not None])
        return flat

    def compute_batch_influence(self, inputs, labels):
        """
        Compute self-influence for each example in a batch.
        Returns: tensor of shape [B] with self-influence scores.
        """
        self.model.eval()
        inputs, labels = inputs.to(self.device), labels.to(self.device)

        def loss_fn(params, buffers, x, y):
            out = functional_call(self.model, (params, buffers), (x.unsqueeze(0),))
            loss = F.cross_entropy(out, y.unsqueeze(0), reduction='sum')
            return loss

        grad_fn = grad(loss_fn)
        grads = vmap(grad_fn, in_dims=(None, None, 0, 0))(
            self.target_params, self.buffers, inputs, labels
        )

        # flatten gradients and compute squared sum
        if isinstance(grads, dict):
            flat_grads = torch.cat(
                [g.reshape(g.shape[0], -1) for g in grads.values()], dim=1
            )
        else:
            flat_grads = torch.cat([g.reshape(g.shape[0], -1) for g in grads], dim=1)

        influences = (flat_grads ** 2).sum(dim=1)
        return influences.detach().cpu()

        # def loss_fn(params, buffers, x, y):
        #     out = self.func_model(params, buffers, x.unsqueeze(0))
        #     loss = F.cross_entropy(out, y.unsqueeze(0), reduction='sum')
        #     return loss

        # # compute per-sample gradients of loss w.r.t params
        # grad_fn = grad(loss_fn)
        # grads = vmap(grad_fn, in_dims=(None, None, 0, 0))(self.params, self.buffers, inputs, labels)

        # # flatten last-layer grads only
        # if self.last_layer_only:
        #     grads = grads[-2:]  # pick last two
        # flat_grads = torch.cat([g.reshape(g.shape[0], -1) for g in grads], dim=1)
        # influences = (flat_grads ** 2).sum(dim=1)
        # return influences.detach().cpu()

        # # vectorized over batch
        # batch_grads = [self._per_sample_grad(x, y) for x, y in zip(inputs, labels)]
        # grads = torch.stack(batch_grads)
        # influences = (grads ** 2).sum(dim=1)
        # return influences.detach().cpu()

    def compute_tracin_self_influence(self, dataloader, checkpoint_paths, eta_list=None):
        """
        Aggregate self-influence across checkpoints (TracInCP).

        Args:
            dataloader: DataLoader over training data.
            checkpoint_paths: list of checkpoint file paths.
            eta_list: optional weighting factors (default = equal).
        Returns:
            tensor [N] of total self-influence scores for training set.
        """
        if eta_list is None:
            eta_list = [1.0 for _ in checkpoint_paths]

        # initialize empty vector for total influence
        num_samples = len(dataloader.dataset)
        total_influence = torch.zeros(num_samples)

        for eta_i, ckpt in zip(eta_list, checkpoint_paths):
            self.load_checkpoint(ckpt)

            offset = 0
            for inputs, labels in tqdm(dataloader, desc=f'Checkpoint {ckpt}'):
                batch_inf = self.compute_batch_influence(inputs, labels)
                total_influence[offset : offset + len(inputs)] += eta_i * batch_inf
                offset += len(inputs)

        return total_influence


In [50]:
# create SelfInfluenceComputer as before
sic = SelfInfluence(model=model, last_layer_only=True)

# define checkpoints (every 30 epochs)
checkpoints = [f'checkpoints/resnet_epoch_{i}.pth' for i in range(30, 301, 30)]

# compute total TracInCP self-influence (order-safe!)
total_influence = sic.compute_tracin_self_influence(train_loader, checkpoints)

Checkpoint checkpoints/resnet_epoch_30.pth: 100%|██████████| 196/196 [01:40<00:00,  1.95it/s]
Checkpoint checkpoints/resnet_epoch_60.pth: 100%|██████████| 196/196 [01:37<00:00,  2.01it/s]
Checkpoint checkpoints/resnet_epoch_90.pth: 100%|██████████| 196/196 [01:40<00:00,  1.96it/s]
Checkpoint checkpoints/resnet_epoch_120.pth: 100%|██████████| 196/196 [01:37<00:00,  2.01it/s]
Checkpoint checkpoints/resnet_epoch_150.pth: 100%|██████████| 196/196 [01:41<00:00,  1.93it/s]
Checkpoint checkpoints/resnet_epoch_180.pth: 100%|██████████| 196/196 [01:41<00:00,  1.94it/s]
Checkpoint checkpoints/resnet_epoch_210.pth: 100%|██████████| 196/196 [01:37<00:00,  2.00it/s]
Checkpoint checkpoints/resnet_epoch_240.pth: 100%|██████████| 196/196 [01:37<00:00,  2.00it/s]
Checkpoint checkpoints/resnet_epoch_270.pth: 100%|██████████| 196/196 [01:39<00:00,  1.97it/s]
Checkpoint checkpoints/resnet_epoch_300.pth: 100%|██████████| 196/196 [01:37<00:00,  2.00it/s]


In [47]:
def plot_self_influence_histogram(total_influence):
    fig = go.Figure()

    # Add histogram
    fig.add_trace(go.Histogram(
        x=total_influence, 
        name='Self-Influence Histogram',
        nbinsx=50
    ))

    # Add vertical line at 90th percentile
    percentile_90 = np.percentile(total_influence, 90)
    fig.add_vline(
        x=percentile_90, 
        line_width=3, 
        line_dash="dash", 
        line_color="green",
        annotation_text=f"90th Percentile: {percentile_90:.4f}",
        annotation_position="top"
    )

    fig.update_layout(
        title="Distribution of Total TracInCP Self-Influence",
        xaxis_title="Self-Influence Score",
        yaxis_title="Count",
        showlegend=True
    )

    fig.show()


In [51]:
plot_self_influence_histogram(total_influence)

In [10]:
# Create mapping for CIFAR100 from class name to numeric label
cifar100_classes = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
    'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree',
    'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy',
    'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
    'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
    'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
    'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
    'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
    'woman', 'worm'
]

class_to_idx = {cls: i for i, cls in enumerate(cifar100_classes)}
idx_to_class = {i: cls for i, cls in enumerate(cifar100_classes)}


In [11]:
meta = pd.read_csv("cifar-100-noisy.csv")
meta.head()

Unnamed: 0,image_path,class_name,orig_class_name,mislabeled
0,train/cattle/00000.png,cattle,cattle,False
1,train/dinosaur/00001.png,dinosaur,dinosaur,False
2,train/apple/00002.png,apple,apple,False
3,train/boy/00003.png,boy,boy,False
4,train/aquarium_fish/00004.png,aquarium_fish,aquarium_fish,False


In [16]:
# Normalize names once (lowercase, underscores)
meta['class_name'] = meta['class_name'].str.strip().str.lower().str.replace(" ", "_")
meta['orig_class_name'] = meta['orig_class_name'].str.strip().str.lower().str.replace(" ", "_")

# Verify CSV labels match official class names - check if there are any names that don’t appear in the official list
missing_labels = set(meta['class_name']) - set(cifar100_classes)
if missing_labels:
    print("Missing/invalid labels in CSV:", missing_labels)
else:
    print("All labels match CIFAR-100 fine names.")

All labels match CIFAR-100 fine names.


In [17]:
meta['label'] = meta['class_name'].map(class_to_idx)
meta['true_label'] = meta['orig_class_name'].map(class_to_idx)

In [18]:
meta_dict = meta.set_index('image_path').to_dict(orient='index')

def add_meta(example):
    info = meta_dict.get(".".join([example['__key__'], "png"]))
    if info is None:
        raise ValueError(f"Missing metadata for {example['__key__']}")
    example['label'] = info['label']
    example['true_label'] = info['true_label']
    example['mislabeled'] = info['mislabeled']
    return example



In [42]:
dataset = load_dataset("hirundo-io/Noisy-CIFAR-100", split='train')
dataset = dataset.map(add_meta)

In [43]:
dataset

Dataset({
    features: ['png', '__key__', '__url__', 'label', 'true_label', 'mislabeled'],
    num_rows: 50000
})

In [44]:
dataset[0]

{'png': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>,
 '__key__': 'train/hamster/04330',
 '__url__': '/Users/ronibendom/.cache/huggingface/hub/datasets--hirundo-io--Noisy-CIFAR-100/snapshots/f23a0a3545777df387b949bbddd43a7ca919b00d/train.tar.gz',
 'label': 36,
 'true_label': 36,
 'mislabeled': False}

In [45]:
print("Num samples:", len(dataset))
print("Fraction mislabeled:", np.mean(np.array(dataset['mislabeled'])))
print("Example:", dataset[0])

Num samples: 50000
Fraction mislabeled: 0.1
Example: {'png': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32 at 0x325598390>, '__key__': 'train/hamster/04330', '__url__': '/Users/ronibendom/.cache/huggingface/hub/datasets--hirundo-io--Noisy-CIFAR-100/snapshots/f23a0a3545777df387b949bbddd43a7ca919b00d/train.tar.gz', 'label': 36, 'true_label': 36, 'mislabeled': False}


In [None]:
def collate_fn(batch):
    # batch is a list of dicts (one per sample)
    imgs = [train_transform(x["png"].convert("RGB")) for x in batch]
    pixel_values = torch.stack(imgs)
    labels = torch.tensor([x["label"] for x in batch])
    # keys = [x["__key__"] for x in batch]  # optional
    # return keys, pixel_values, labels
    return pixel_values, labels


mislabeled_dataset_loader = DataLoader(dataset, batch_size=128, shuffle=False, collate_fn=collate_fn)
checkpoints = [f'checkpoints/resnet_epoch_{i}.pth' for i in range(30, 301, 30)]


sic = SelfInfluence(model=model, last_layer_only=True)

total_influence_mislabeled_dataset = sic.compute_tracin_self_influence(mislabeled_dataset_loader, checkpoints)

Checkpoint checkpoints/resnet_epoch_30.pth: 100%|██████████| 391/391 [01:42<00:00,  3.82it/s]
Checkpoint checkpoints/resnet_epoch_60.pth: 100%|██████████| 391/391 [01:42<00:00,  3.80it/s]
Checkpoint checkpoints/resnet_epoch_90.pth: 100%|██████████| 391/391 [01:43<00:00,  3.79it/s]
Checkpoint checkpoints/resnet_epoch_120.pth: 100%|██████████| 391/391 [01:42<00:00,  3.80it/s]
Checkpoint checkpoints/resnet_epoch_150.pth: 100%|██████████| 391/391 [01:42<00:00,  3.80it/s]
Checkpoint checkpoints/resnet_epoch_180.pth: 100%|██████████| 391/391 [01:43<00:00,  3.77it/s]
Checkpoint checkpoints/resnet_epoch_210.pth: 100%|██████████| 391/391 [01:43<00:00,  3.78it/s]
Checkpoint checkpoints/resnet_epoch_240.pth: 100%|██████████| 391/391 [01:44<00:00,  3.76it/s]
Checkpoint checkpoints/resnet_epoch_270.pth: 100%|██████████| 391/391 [01:43<00:00,  3.78it/s]
Checkpoint checkpoints/resnet_epoch_300.pth: 100%|██████████| 391/391 [01:47<00:00,  3.64it/s]


In [None]:
plot_self_influence_histogram(total_influence_mislabeled_dataset)

In [None]:
def plot_recovery_curve(tracin_scores, is_mislabeled):
    # Sort by descending TracInCP score
    sorted_idx = np.argsort(-tracin_scores)
    sorted_labels = np.asarray(is_mislabeled)[sorted_idx].astype(int)

    # Cumulative recall of mislabels and fraction examined
    total_mislabels = max(1, sorted_labels.sum())
    cum_recall = np.cumsum(sorted_labels) / total_mislabels
    frac_examined = np.arange(1, len(sorted_labels) + 1) / len(sorted_labels)

    # Downsample to one point per 5% of the data
    step = max(1, len(frac_examined) // 20)
    x = frac_examined[::step]
    y = cum_recall[::step]

    # Ensure last point is included
    if x[-1] != frac_examined[-1]:
        x = np.append(x, frac_examined[-1])
        y = np.append(y, cum_recall[-1])

    # Build figure
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=x,
        y=y,
        mode='lines+markers',
        name='Recovery Curve'
    ))

    # Optional: random baseline (y = x)
    fig.add_trace(go.Scatter(
        x=x,
        y=x,
        mode='lines',
        name='Random Baseline',
        line=dict(dash='dash', color='gray')
    ))

    fig.update_layout(
        title="TracInCP Recovery Curve",
        xaxis=dict(title="Fraction of dataset examined", range=[0, 1]),
        yaxis=dict(title="Fraction of mislabels recovered", range=[0, 1]),
        template="plotly_white"
    )

    fig.show()

SyntaxError: invalid syntax. Perhaps you forgot a comma? (3093791630.py, line 14)

In [59]:
is_mislabeled = np.array(dataset['mislabeled'])

plot_recovery_curve(total_influence_mislabeled_dataset, is_mislabeled)