In [1]:
import torch
import torchvision.transforms as transforms
from ResNet import ResNet
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_dataset
import plotly.graph_objects as go
import plotly.express as px
from tracin_pipeline import TracInPipeline
from selection import select_clean_subset

colors = ['#1f77b4', '#ff7f0e', '#d62728', '#2ca02c', '#9467bd']

device = "mps" if torch.backends.mps.is_built() \
    else "cuda" if torch.cuda.is_available() else "cpu"

print(device)

torch.manual_seed(3)


mps


<torch._C.Generator at 0x10ed879f0>

In [2]:
import selection; from selection import select_clean_subset; 
import importlib
importlib.reload(selection)

<module 'selection' from '/Users/ronibendom/Documents/hirundo_assigmnment/selection.py'>

In [3]:
resnet_56=ResNet(num_classes=100,n=9).to(device)

# define checkpoints (every 30 epochs)
checkpoints = [f'checkpoints/resnet_epoch_{i}.pth' for i in range(30, 301, 30)]

In [4]:
# Define standard data transforms for CIFAR100
# CIFAR100 mean and std:
# mean = [0.5071, 0.4867, 0.4408], std = [0.2675, 0.2565, 0.2761]

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])

In [5]:
# CIFAR100 label mapping
CIFAR100_CLASSES = [
    'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
    'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
    'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
    'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
    'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
    'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
    'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
    'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree',
    'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy',
    'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
    'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail',
    'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper',
    'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train',
    'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf',
    'woman', 'worm'
]
CLASS_TO_IDX = {cls: i for i, cls in enumerate(CIFAR100_CLASSES)}
IDX_TO_CLASS = {i: cls for i, cls in enumerate(CIFAR100_CLASSES)}

In [6]:
class CIFAR100DatasetPreprocessor:
    def __init__(self, meta_csv):
        self.meta_dict = self._build_meta_dict(meta_csv)

    def _build_meta_dict(self, meta_csv):
        df = pd.read_csv(meta_csv)
        df['class_name'] = df['class_name'].str.strip().str.lower().str.replace(" ", "_")
        df['orig_class_name'] = df['orig_class_name'].str.strip().str.lower().str.replace(" ", "_")
        df['label'] = df['class_name'].map(CLASS_TO_IDX)
        df['true_label'] = df['orig_class_name'].map(CLASS_TO_IDX)
        # df['key'] = df['image_path'].apply(lambda x: x.replace(".png", "").replace("/", "."))
        # return df.set_index("key").to_dict(orient="index")
        return df.set_index('image_path').to_dict(orient='index')

    def add_meta(self, example):
        info = self.meta_dict.get(".".join([example['__key__'], "png"]))
        if info is None:
            raise ValueError(f"Missing metadata for {example['__key__']}")
        example['label'] = info['label']
        example['true_label'] = info['true_label']
        example['mislabeled'] = info['mislabeled']
        return example

    def process(self, dataset):
        return dataset.map(self.add_meta)




In [7]:
def make_collate_fn(transform):
    def collate_fn(batch):
        pixel_values = torch.stack([
            transform(x["png"].convert("RGB")) if transform else x["png"]
            for x in batch
        ])
        labels = torch.tensor([x["label"] for x in batch])
        return pixel_values, labels
    return collate_fn

In [8]:
pipeline = TracInPipeline(model = resnet_56, checkpoints=checkpoints, preprocessor=CIFAR100DatasetPreprocessor('cifar-100-noisy.csv'), collate_fn=make_collate_fn(transform = train_transform), device=device)
df = load_dataset("hirundo-io/Noisy-CIFAR-100", split='train')
total_influence, influence_results, metrics = pipeline.run(train_dataset=df)

Computing self-influence on 50000 samples


Checkpoint checkpoints/resnet_epoch_30.pth: 100%|██████████| 391/391 [00:49<00:00,  7.85it/s]
Checkpoint checkpoints/resnet_epoch_60.pth: 100%|██████████| 391/391 [00:46<00:00,  8.37it/s]
Checkpoint checkpoints/resnet_epoch_90.pth: 100%|██████████| 391/391 [00:56<00:00,  6.88it/s]
Checkpoint checkpoints/resnet_epoch_120.pth: 100%|██████████| 391/391 [01:01<00:00,  6.37it/s]
Checkpoint checkpoints/resnet_epoch_150.pth: 100%|██████████| 391/391 [01:02<00:00,  6.23it/s]
Checkpoint checkpoints/resnet_epoch_180.pth: 100%|██████████| 391/391 [01:02<00:00,  6.27it/s]
Checkpoint checkpoints/resnet_epoch_210.pth: 100%|██████████| 391/391 [01:02<00:00,  6.28it/s]
Checkpoint checkpoints/resnet_epoch_240.pth: 100%|██████████| 391/391 [01:02<00:00,  6.24it/s]
Checkpoint checkpoints/resnet_epoch_270.pth: 100%|██████████| 391/391 [01:04<00:00,  6.03it/s]
Checkpoint checkpoints/resnet_epoch_300.pth: 100%|██████████| 391/391 [01:05<00:00,  5.94it/s]


Self-Influence computation complete.


## __Compare using only end of training checkpoints vs all__

In [9]:
pipeline = TracInPipeline(model = resnet_56, checkpoints=checkpoints[-2:], preprocessor=CIFAR100DatasetPreprocessor('cifar-100-noisy.csv'), collate_fn=make_collate_fn(transform = train_transform), device=device)
df = load_dataset("hirundo-io/Noisy-CIFAR-100", split='train')
total_influence_end_training, influence_results_end_training, metrics_end_training = pipeline.run(train_dataset=df)

Computing self-influence on 50000 samples


Checkpoint checkpoints/resnet_epoch_270.pth: 100%|██████████| 391/391 [01:07<00:00,  5.84it/s]
Checkpoint checkpoints/resnet_epoch_300.pth: 100%|██████████| 391/391 [01:08<00:00,  5.75it/s]


Self-Influence computation complete.


## __Compare using only begining of training checkpoints vs all__

In [10]:
pipeline = TracInPipeline(model = resnet_56, checkpoints=checkpoints[:2], preprocessor=CIFAR100DatasetPreprocessor('cifar-100-noisy.csv'), collate_fn=make_collate_fn(transform = train_transform), device=device)
df = load_dataset("hirundo-io/Noisy-CIFAR-100", split='train')
total_influence_beginingtraining, influence_results_begining_training, metrics_begining_training = pipeline.run(train_dataset=df)

Computing self-influence on 50000 samples


Checkpoint checkpoints/resnet_epoch_30.pth: 100%|██████████| 391/391 [01:04<00:00,  6.03it/s]
Checkpoint checkpoints/resnet_epoch_60.pth: 100%|██████████| 391/391 [01:02<00:00,  6.29it/s]


Self-Influence computation complete.


## __Compare between different subset selection methods__

In [11]:
def evaluate_clean_selection_methods(select_fn, df):
    """
    Compare different self-influence-based methods for selecting 'clean' subsets.
    Requires that df contains 'influence' and 'mislabeled' columns.
    Args:
        select_fn: function like selection.select_clean_subset
        df: dataframe with columns ['influence', 'mislabeled']
    Returns:
        pd.DataFrame of statistics for each method.
    """
    assert "mislabeled" in df.columns, "mislabeled column required for evaluation"
    methods = ["percentile", "knee", "gmm"]

    results = []
    total_mislabeled = df["mislabeled"].sum()
    total_samples = len(df)

    for method in methods:
        clean_df, cutoff_info = select_fn(df, method=method)

        removed_mask = ~df.index.isin(clean_df.index)
        removed_df = df.loc[removed_mask]
        n_removed = len(removed_df)

        # metrics
        precision = removed_df["mislabeled"].mean() if n_removed > 0 else 0.0
        recall = removed_df["mislabeled"].sum() / total_mislabeled if total_mislabeled > 0 else 0.0
        remaining_mislabeled = clean_df["mislabeled"].mean() if len(clean_df) > 0 else 0.0

        results.append({
            "method": method,
            "removed_fraction": n_removed / total_samples,
            "precision_removed": precision,
            "recall_removed": recall,
            "remaining_mislabeled_rate": remaining_mislabeled,
            "cutoff_info": cutoff_info
        })

    return pd.DataFrame(results)


In [12]:
clean_df_results = evaluate_clean_selection_methods(select_clean_subset, influence_results)

In [13]:
def plot_clean_selection_comparison(results_df):
    figs = {}

    # Bar plot: precision and recall (horizontal)
    fig = go.Figure()
    fig.add_trace(go.Bar(
        y=results_df["method"], x=results_df["precision_removed"],
        name="Precision (mislabeled samples removed fraction)",
        orientation='h',
        marker=dict(color=colors[0])
    ))
    fig.add_trace(go.Bar(
        y=results_df["method"], x=results_df["recall_removed"],
        name="Recall (fraction of mislabeled samples removed)",
        orientation='h',
        marker=dict(color=colors[1])
    ))
    fig.update_layout(
        barmode='group',
        title="Precision and Recall of Clean Subset Selection Methods",
        yaxis_title="Method",
        xaxis_title="Metric Value",
        template="plotly_white",
        width=900,
        height=500
    )
    figs["precision_recall"] = fig

    # Remaining mislabeled rate
    figs["remaining_mislabeled"] = px.bar(
        results_df, x="method", y="remaining_mislabeled_rate",
        title="Remaining Mislabeled Rate in Clean Subset",
        color_discrete_sequence=[colors[0]]
    )
    figs["remaining_mislabeled"].update_layout(
        template="plotly_white",
        width=900,
        height=500
    )

    # Fraction removed
    figs["fraction_removed"] = px.bar(
        results_df, x="method", y="removed_fraction",
        title="Fraction of Dataset Removed per Method",
        color_discrete_sequence=[colors[0]]
    )
    figs["fraction_removed"].update_layout(
        template="plotly_white",
        width=900,
        height=500
    )

    return figs


In [14]:
figs = plot_clean_selection_comparison(clean_df_results)

In [15]:
for name, fig in figs.items():
    fig.show()

### Further possible tests:
- What are the results when the NN is trained on a similar but not the same dataset
- What happens when the NN is trained on the data with mislabeled examples (real world scenario)
- Correlation with model confidence
- Affect of different final ACC of the model on same or similar data
- Clean subset impact on training
