In [1]:
import torch
import torchvision.transforms as transforms
import torchvision
from ResNet import ResNet
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
import plotly.graph_objects as go
import plotly.express as px
from tracin_pipeline import TracInPipeline
from selection import select_clean_subset

colors = ['#1f77b4', '#ff7f0e', '#d62728', '#2ca02c', '#9467bd']

device = "mps" if torch.backends.mps.is_built() \
    else "cuda" if torch.cuda.is_available() else "cpu"

print(device)

torch.manual_seed(3)


mps


<torch._C.Generator at 0x1172879f0>

# Example: Simple self-influence pipeline usage with CIFAR-10 from torchvision

In [None]:
# Example: Simple pipeline usage with CIFAR-10 from torchvision

# Preprocessing/transform for CIFAR-10
cifar10_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load CIFAR-10 training set from torchvision
cifar10_train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar10_transform)

# Example model and dummy checkpoints
# (You can save your checkpoints or use model.state_dict() as needed)
model = ResNet(num_classes=10, n=9)
checkpoints = [f'checkpoints_CIFAR10/resnet_epoch_{i}.pth' for i in range(30, 271, 30)]

# Setup pipeline
pipeline = TracInPipeline(
    model=model, 
    checkpoints=checkpoints[-2:], 
    batch_size=128, 
    device=device
)

total_influence, results, metrics = pipeline.run(cifar10_train)

print(results.head())


# Example: Self-Influence pipeline for CIFAR-100 with mislabeled data (Hirundo dataset) from HuggingFace

In [2]:
resnet_56=ResNet(num_classes=100,n=9).to(device)

# define checkpoints (every 30 epochs)
checkpoints = [f'checkpoints/resnet_epoch_{i}.pth' for i in range(30, 301, 30)]

In [3]:
# standard data transforms for CIFAR100

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])

In [4]:
# CIFAR100 label mapping
CIFAR100_CLASSES = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
    'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree',
    'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy',
    'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
    'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
    'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
    'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
    'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
    'woman', 'worm'
]
CLASS_TO_IDX = {cls: i for i, cls in enumerate(CIFAR100_CLASSES)}
IDX_TO_CLASS = {i: cls for i, cls in enumerate(CIFAR100_CLASSES)}

In [5]:
# Data Preprocessor
class CIFAR100DatasetPreprocessor:
    def __init__(self, meta_csv):
        self.meta_dict = self._build_meta_dict(meta_csv)

    def _build_meta_dict(self, meta_csv):
        df = pd.read_csv(meta_csv)
        df['class_name'] = df['class_name'].str.strip().str.lower().str.replace(" ", "_")
        df['orig_class_name'] = df['orig_class_name'].str.strip().str.lower().str.replace(" ", "_")
        df['label'] = df['class_name'].map(CLASS_TO_IDX)
        df['true_label'] = df['orig_class_name'].map(CLASS_TO_IDX)
        # df['key'] = df['image_path'].apply(lambda x: x.replace(".png", "").replace("/", "."))
        # return df.set_index("key").to_dict(orient="index")
        return df.set_index('image_path').to_dict(orient='index')

    def add_meta(self, example):
        info = self.meta_dict.get(".".join([example['__key__'], "png"]))
        if info is None:
            raise ValueError(f"Missing metadata for {example['__key__']}")
        example['label'] = info['label']
        example['true_label'] = info['true_label']
        example['mislabeled'] = info['mislabeled']
        return example

    def process(self, dataset):
        return dataset.map(self.add_meta)




In [6]:
# Collate Function (handles transform)
def make_collate_fn(transform):
    def collate_fn(batch):
        pixel_values = torch.stack([
            transform(x["png"].convert("RGB")) if transform else x["png"]
            for x in batch
        ])
        labels = torch.tensor([x["label"] for x in batch])
        return pixel_values, labels
    return collate_fn

In [None]:
#  Example pipeline
pipeline = TracInPipeline(model = resnet_56, checkpoints=checkpoints, preprocessor=CIFAR100DatasetPreprocessor('cifar-100-noisy.csv'), collate_fn=make_collate_fn(transform = train_transform), device=device)
df = load_dataset("hirundo-io/Noisy-CIFAR-100", split='train')
total_influence, influence_results, metrics = pipeline.run(dataset=df)

Dataset ready. Computing influence on 50000 samples...


Checkpoint checkpoints/resnet_epoch_30.pth: 100%|██████████| 391/391 [00:38<00:00, 10.05it/s]
Checkpoint checkpoints/resnet_epoch_60.pth: 100%|██████████| 391/391 [00:38<00:00, 10.04it/s]
Checkpoint checkpoints/resnet_epoch_90.pth: 100%|██████████| 391/391 [00:37<00:00, 10.53it/s]
Checkpoint checkpoints/resnet_epoch_120.pth: 100%|██████████| 391/391 [00:37<00:00, 10.56it/s]
Checkpoint checkpoints/resnet_epoch_150.pth: 100%|██████████| 391/391 [00:37<00:00, 10.45it/s]
Checkpoint checkpoints/resnet_epoch_180.pth: 100%|██████████| 391/391 [00:37<00:00, 10.56it/s]
Checkpoint checkpoints/resnet_epoch_210.pth: 100%|██████████| 391/391 [00:37<00:00, 10.53it/s]
Checkpoint checkpoints/resnet_epoch_240.pth: 100%|██████████| 391/391 [00:36<00:00, 10.58it/s]
Checkpoint checkpoints/resnet_epoch_270.pth: 100%|██████████| 391/391 [00:36<00:00, 10.58it/s]
Checkpoint checkpoints/resnet_epoch_300.pth: 100%|██████████| 391/391 [00:37<00:00, 10.54it/s]


Influence computation complete.


# Example: Function for clean subset selection - separating data based on lowest 90% of self-influence

In [None]:
clean_df, cutoff_info = select_clean_subset(influence_results, method='percentile', frac=0.9)

In [None]:
clean_df.sample(5)

Unnamed: 0,influence,mislabeled,label,true_label,__key__,rank,norm_rank
29685,139.989838,False,52,52,train/oak_tree/20869,37546.0,0.75092
7437,360.524139,False,9,9,train/bottle/27197,28578.0,0.57156
14627,270.335876,False,48,48,train/motorcycle/11708,32031.0,0.64062
31269,14.801138,False,57,57,train/pear/42738,44996.0,0.89992
30209,218.699265,False,98,98,train/woman/45360,34128.0,0.68256


In [13]:
# Example: Cross Influence Calculation for CIFAR-10 (torchvision)

import torchvision

# Custom collate function to apply torchvision transforms and stack tensors
def make_cifar10_collate_fn(transform=None):
    def collate_fn(batch):
        imgs, labels = zip(*batch)  # Each item is (PIL_image, int_label)
        if transform is not None:
            imgs = [transform(img) for img in imgs]
        else:
            imgs = [torch.tensor(np.array(img)).permute(2, 0, 1) for img in imgs]
        imgs = torch.stack(imgs)
        labels = torch.tensor(labels)
        return imgs, labels
    return collate_fn


# Download and load CIFAR-10 train and test sets
cifar10_train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True)
cifar10_test = torchvision.datasets.CIFAR10(root="./data", train=False, download=True)

checkpoints = [f'checkpoints_CIFAR10/resnet_epoch_{i}.pth' for i in range(30, 271, 30)]

# Example setup for cross influence (train -> test)
cross_pipeline = TracInPipeline(
    model = model,
    checkpoints = checkpoints,
    preprocessor = None,  # No metadata/extra features for vanilla CIFAR10
    collate_fn = make_cifar10_collate_fn(transform=cifar10_transform),
    device = device,
    mode = "cross"
)

# Run cross influence; this returns the full influence matrix:
influence_matrix = cross_pipeline.run(
    train_dataset = cifar10_train, 
    test_dataset = cifar10_test,
    plot_results = False
)

# Inspect shape and a sample of the influence matrix
print("Influence matrix shape:", influence_matrix.shape)
print("Influence matrix sample:", influence_matrix[:5, :5])


NameError: name 'cifar10_transform' is not defined