In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from functorch import vmap
from ResNet import ResNet
import matplotlib.pyplot as plt
import numpy as np
import gc
import os
import glob
from datasets import load_dataset
from datasets import Dataset
from tqdm import tqdm

device = "mps" if torch.backends.mps.is_built() \
    else "cuda" if torch.cuda.is_available() else "cpu"

print(device)

torch.manual_seed(3)

mps


<torch._C.Generator at 0x11d0cfe50>

In [2]:
model=ResNet(num_classes=100,n=9).to(device)

In [3]:
# Define standard data transforms for CIFAR100
# CIFAR100 mean and std:
# mean = [0.5071, 0.4867, 0.4408], std = [0.2675, 0.2565, 0.2761]

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5071, 0.4867, 0.4408],
        std=[0.2675, 0.2565, 0.2761]
    ),
])


In [4]:
# Load datasets
train_dataset=torchvision.datasets.CIFAR100(root='./data',train=True,download=True,transform=train_transform)
test_dataset=torchvision.datasets.CIFAR100(root='./data',train=False,download=True,transform=test_transform)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [11]:
checkpoint_dir='checkpoints'
checkpoint_epoch = 30
checkpoint_path = os.path.join(checkpoint_dir, f'resnet_epoch_{checkpoint_epoch}.pth')

In [19]:
# Create one training example (x, y)
# Get a batch from the train_loader and pick the first example
x_batch, y_batch = next(iter(train_loader))
x = x_batch[0]
y = y_batch[0].item()

# move to device
x = x.unsqueeze(0).to(device)  # add batch dimension
y = torch.tensor([y], device=device)



In [None]:

# --- assume you already have these ---
# model: trained ResNet (e.g., ResNet-56 or ResNet-50)
# checkpoint_path: path to a saved checkpoint (e.g. "ckpt_epoch_60.pt")
# x, y: a single training example (tensor and its label)

# load weights from a checkpoint
model.load_state_dict(torch.load(checkpoint_path)['model_state_dict'])
model.eval()  # important: use evaluation mode


# we want gradients w.r.t. last layer only
params = list(model.parameters())[-2:]  # usually weight, bias

# compute loss with gradient tracking enabled
with torch.set_grad_enabled(True):
    outputs = model(x)
    loss = F.cross_entropy(outputs, y, reduction='sum')

# gradient of loss w.r.t. last-layer parameters
grads = torch.autograd.grad(loss, params, create_graph=False, retain_graph=False)

# flatten and concatenate all gradients
flat_grads = torch.cat([g.reshape(-1) for g in grads if g is not None])

# compute squared L2 norm â€” the self-influence at this checkpoint
self_influence = (flat_grads ** 2).sum().item()

print(f"Self-influence for this example at checkpoint: {self_influence:.6f}")


torch.Size([6500])
Self-influence for this example at checkpoint: 52.753872


In [None]:
class SelfInfluence:
    """
    Computes TracInCP self-influence scores across multiple checkpoints.
    """
    def __init__(self, model, device='cpu', last_layer_only=True):
        self.model = model.to(device)
        self.device = device
        self.last_layer_only = last_layer_only

        # identify the target parameters (usually last layer)
        params = list(model.parameters())
        self.target_params = params[-2:] if last_layer_only else params

    def load_checkpoint(self, checkpoint_path):
        """Load model weights from a saved checkpoint."""
        state_dict = torch.load(checkpoint_path, map_location=self.device)['model_state_dict']
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def _per_sample_grad(self, x, y):
        """Compute flattened gradient vector for a single sample."""
        out = self.model(x.unsqueeze(0))
        loss = F.cross_entropy(out, y.unsqueeze(0), reduction='sum')
        grads = torch.autograd.grad(loss, self.target_params,
                                    retain_graph=False, create_graph=False)
        flat = torch.cat([g.reshape(-1) for g in grads if g is not None])
        return flat

    def compute_batch_influence(self, inputs, labels):
        """
        Compute self-influence for each example in a batch.
        Returns: tensor of shape [B] with self-influence scores.
        """
        self.model.eval()
        inputs, labels = inputs.to(self.device), labels.to(self.device)

        # vectorized over batch
        grads = vmap(self._per_sample_grad)(inputs, labels)
        influences = (grads ** 2).sum(dim=1)
        return influences.detach().cpu()

    def compute_tracin_self_influence(self, dataloader, checkpoint_paths, eta_list=None):
        """
        Aggregate self-influence across checkpoints (TracInCP).

        Args:
            dataloader: DataLoader over training data.
            checkpoint_paths: list of checkpoint file paths.
            eta_list: optional weighting factors (default = equal).
        Returns:
            tensor [N] of total self-influence scores for training set.
        """
        if eta_list is None:
            eta_list = [1.0 for _ in checkpoint_paths]

        # initialize empty vector for total influence
        num_samples = len(dataloader.dataset)
        total_influence = torch.zeros(num_samples)

        for eta_i, ckpt in zip(eta_list, checkpoint_paths):
            self.load_checkpoint(ckpt)

            offset = 0
            for inputs, labels in tqdm(dataloader, desc=f'Checkpoint {ckpt}'):
                batch_inf = self.compute_batch_influence(inputs, labels)
                total_influence[offset : offset + len(inputs)] += eta_i * batch_inf
                offset += len(inputs)

        return total_influence
