In [1]:
import torch
import torchvision.transforms as transforms
import torchvision
from ResNet import ResNet
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
import plotly.graph_objects as go
import plotly.express as px
from tracin_pipeline import TracInPipeline
from selection import select_clean_subset

colors = ['#1f77b4', '#ff7f0e', '#d62728', '#2ca02c', '#9467bd']

device = "mps" if torch.backends.mps.is_built() \
    else "cuda" if torch.cuda.is_available() else "cpu"

print(device)

torch.manual_seed(3)


mps


<torch._C.Generator at 0x117e879f0>

# Example: Simple self-influence pipeline usage with CIFAR-10 from torchvision

In [2]:
# Example: Simple pipeline usage with CIFAR-10 from torchvision

# Preprocessing/transform for CIFAR-10
cifar10_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load CIFAR-10 training set from torchvision
cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar10_transform)

# Example model and dummy checkpoints
# (You can save your checkpoints or use model.state_dict() as needed)
model = ResNet(num_classes=10, n=9)
checkpoints = [f'checkpoints_CIFAR10/resnet_epoch_{i}.pth' for i in range(30, 271, 30)]

# Setup pipeline
pipeline = TracInPipeline(
    model=model, 
    checkpoints=checkpoints[-2:], 
    batch_size=128, 
    device=device
)

total_influence, results, metrics = pipeline.run(cifar10_train)

print(results.head())


Computing self-influence on 50000 samples


Checkpoint checkpoints_CIFAR10/resnet_epoch_240.pth: 100%|██████████| 391/391 [00:46<00:00,  8.41it/s]
Checkpoint checkpoints_CIFAR10/resnet_epoch_270.pth: 100%|██████████| 391/391 [00:43<00:00,  9.08it/s]


Self-Influence computation complete.


      influence  label     rank  norm_rank
0  4.809395e-14      6  36201.0    0.72402
1  1.877739e-18      9  42505.0    0.85010
2  2.155781e-02      9  15177.0    0.30354
3  2.247105e-23      4  47050.0    0.94100
4  4.407151e-19      1  43277.0    0.86554


# Example: Self-Influence pipeline for CIFAR-100 with mislabeled data (Hirundo dataset) from HuggingFace

In [3]:
resnet_56=ResNet(num_classes=100,n=9).to(device)

# define checkpoints (every 30 epochs)
checkpoints = [f'checkpoints/resnet_epoch_{i}.pth' for i in range(30, 301, 30)]

In [4]:
# standard data transforms for CIFAR100

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])

In [5]:
# CIFAR100 label mapping
CIFAR100_CLASSES = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
    'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree',
    'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy',
    'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
    'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
    'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
    'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
    'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
    'woman', 'worm'
]
CLASS_TO_IDX = {cls: i for i, cls in enumerate(CIFAR100_CLASSES)}
IDX_TO_CLASS = {i: cls for i, cls in enumerate(CIFAR100_CLASSES)}

In [None]:
# Data Preprocessor
class CIFAR100DatasetPreprocessor:
    def __init__(self, meta_csv):
        self.meta_dict = self._build_meta_dict(meta_csv)

    def _build_meta_dict(self, meta_csv):
        df = pd.read_csv(meta_csv)
        df['class_name'] = df['class_name'].str.strip().str.lower().str.replace(" ", "_")
        df['orig_class_name'] = df['orig_class_name'].str.strip().str.lower().str.replace(" ", "_")
        df['label'] = df['class_name'].map(CLASS_TO_IDX)
        df['true_label'] = df['orig_class_name'].map(CLASS_TO_IDX)
        # df['key'] = df['image_path'].apply(lambda x: x.replace(".png", "").replace("/", "."))
        # return df.set_index("key").to_dict(orient="index")
        return df.set_index('image_path').to_dict(orient='index')

    def add_meta(self, example):
        info = self.meta_dict.get(".".join([example['__key__'], "png"]))
        if info is None:
            raise ValueError(f"Missing metadata for {example['__key__']}")
        example['label'] = info['label']
        example['true_label'] = info['true_label']
        example['mislabeled'] = info['mislabeled']
        return example

    def process(self, dataset):
        return dataset.map(self.add_meta)

In [7]:
# Collate Function (handles transform)
def make_collate_fn(transform):
    def collate_fn(batch):
        pixel_values = torch.stack([
            transform(x["png"].convert("RGB")) if transform else x["png"]
            for x in batch
        ])
        labels = torch.tensor([x["label"] for x in batch])
        return pixel_values, labels
    return collate_fn

In [8]:
#  Example pipeline
pipeline = TracInPipeline(model = resnet_56, checkpoints=checkpoints, preprocessor=CIFAR100DatasetPreprocessor('cifar-100-noisy.csv'), collate_fn=make_collate_fn(transform = train_transform), device=device)
df = load_dataset("hirundo-io/Noisy-CIFAR-100", split='train')
total_influence, influence_results, metrics = pipeline.run(train_dataset=df)

Computing self-influence on 50000 samples


Checkpoint checkpoints/resnet_epoch_30.pth: 100%|██████████| 391/391 [01:00<00:00,  6.50it/s]
Checkpoint checkpoints/resnet_epoch_60.pth: 100%|██████████| 391/391 [01:01<00:00,  6.37it/s]
Checkpoint checkpoints/resnet_epoch_90.pth: 100%|██████████| 391/391 [01:02<00:00,  6.23it/s]
Checkpoint checkpoints/resnet_epoch_120.pth: 100%|██████████| 391/391 [01:02<00:00,  6.26it/s]
Checkpoint checkpoints/resnet_epoch_150.pth: 100%|██████████| 391/391 [01:02<00:00,  6.29it/s]
Checkpoint checkpoints/resnet_epoch_180.pth: 100%|██████████| 391/391 [01:02<00:00,  6.25it/s]
Checkpoint checkpoints/resnet_epoch_210.pth: 100%|██████████| 391/391 [01:05<00:00,  6.00it/s]
Checkpoint checkpoints/resnet_epoch_240.pth: 100%|██████████| 391/391 [01:03<00:00,  6.21it/s]
Checkpoint checkpoints/resnet_epoch_270.pth: 100%|██████████| 391/391 [01:07<00:00,  5.83it/s]
Checkpoint checkpoints/resnet_epoch_300.pth: 100%|██████████| 391/391 [01:07<00:00,  5.77it/s]


Self-Influence computation complete.


# Example: Function for clean subset selection - separating data based on lowest 90% of self-influence

In [9]:
clean_df, cutoff_info = select_clean_subset(influence_results, method='percentile', frac=0.9)

In [10]:
clean_df.sample(5)

Unnamed: 0,influence,mislabeled,label,true_label,__key__,rank,norm_rank
44994,543.809753,False,10,10,train/bowl/17256,22226.0,0.44452
7550,1340.036377,False,40,40,train/lamp/07613,7498.0,0.14996
35415,787.023438,False,95,95,train/whale/18405,15466.0,0.30932
44513,232.311783,False,10,10,train/bowl/40354,33547.0,0.67094
41863,979.940918,False,62,62,train/poppy/14667,11668.0,0.23336


# Example: Cross Influence Calculation for CIFAR-10 from torchvision

In [11]:
# Download and load CIFAR-10 train and test sets
cifar10_train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=cifar10_transform)
cifar10_test = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=cifar10_transform)

checkpoints = [f'checkpoints_CIFAR10/resnet_epoch_{i}.pth' for i in range(30, 271, 30)]

# Example setup for cross influence (train -> test)
cross_pipeline = TracInPipeline(
    model = model,
    checkpoints = checkpoints[-2:],
    preprocessor = None, 
    collate_fn = None,
    device = device,
    mode = "cross"
)

# Run cross influence; this returns the full influence matrix:
influence_matrix = cross_pipeline.run(
    train_dataset = cifar10_train, 
    test_dataset = cifar10_test,
    plot_results = False
)


Computing cross-sample influence: 50000 train → 10000 test


Train grads: 100%|██████████| 391/391 [01:00<00:00,  6.51it/s]
Test grads: 100%|██████████| 79/79 [00:11<00:00,  6.87it/s]
Train grads: 100%|██████████| 391/391 [00:56<00:00,  6.88it/s]
Test grads: 100%|██████████| 79/79 [00:05<00:00, 13.40it/s]


cross-sample influence computation complete.


In [12]:
# Inspect shape and a sample of the influence matrix
print("Influence matrix shape:", influence_matrix.shape)
print("Influence matrix sample:", influence_matrix[:5, :5])

Influence matrix shape: (50000, 10000)
Influence matrix sample: [[ 5.63193675e-20 -6.24173624e-15  1.39804062e-11  4.74185424e-09
   4.56277082e-14]
 [ 8.29762870e-24 -4.76624382e-13 -3.32236381e-13  5.97926900e-13
   2.55786641e-18]
 [-9.94434203e-16 -5.16514010e-05 -5.14618569e-05  3.29737086e-05
  -1.41568165e-12]
 [ 2.14157326e-24 -7.51366957e-20  1.21105462e-16  2.33599212e-13
   1.37225634e-18]
 [ 1.08142840e-22 -9.12268342e-14  2.21721428e-14  3.26288035e-14
   1.19560003e-18]]
