In [3]:
import torch
import torchvision
from modeldiff import ModelDiff

### Helper functions

In [5]:
def get_dataloader(batch_size=256, num_workers=8, split='train', shuffle=False, augment=True):
    if augment:
        transforms = torchvision.transforms.Compose(
                        [torchvision.transforms.RandomHorizontalFlip(),
                         torchvision.transforms.RandomAffine(0),
                         torchvision.transforms.ToTensor(),
                         torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                          (0.2023, 0.1994, 0.201))])
    else:
        transforms = torchvision.transforms.Compose([
                         torchvision.transforms.ToTensor(),
                         torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                          (0.2023, 0.1994, 0.201))])

    is_train = (split == 'train')
    dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/',
                                           download=True,
                                           train=is_train,
                                           transform=transforms)

    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         shuffle=shuffle,
                                         batch_size=batch_size,
                                         num_workers=num_workers)

    return loader

### Initialize the models we want to compare.
For simplicity, we will use the same architecture for both models (though this is not necessary).

In [9]:
modelA = torchvision.models.resnet18()
modelB = torchvision.models.resnet18()

### Initialize the training set
Has to be the same for both models.

In [6]:
train_loader = get_dataloader(split='train')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 65413585.20it/s]


Extracting /tmp/cifar/cifar-10-python.tar.gz to /tmp/cifar/


### Load checkpoints
Load multiple model checkpoints in order to compute TRAK attribution scores. For more details, check out TRAK's [repo](https://github.com/MadryLab/trak) (and TRAK's [quickstart](https://trak.readthedocs.io/en/latest/quickstart.html)).

The expected format is a list of `state_dict`s

In [7]:
ckptsA = [...] # @ Harshay TODO: let's have a `download_cifar_ckpts.sh`
ckptsB = [...]

### Initialize the `ModelDiff` instance

In [11]:
md = ModelDiff(modelA, modelB, ckptsA, ckptsB, train_loader=train_loader)

                             Report any issues at https://github.com/MadryLab/trak/issues


INFO:STORE:No existing model IDs in /mnt/xfs/home/krisgrg/projects/modeldiff/modeldiff_scores/modelA.
INFO:STORE:No existing TRAK scores in /mnt/xfs/home/krisgrg/projects/modeldiff/modeldiff_scores/modelA.
                             Report any issues at https://github.com/MadryLab/trak/issues
INFO:STORE:No existing model IDs in /mnt/xfs/home/krisgrg/projects/modeldiff/modeldiff_scores/modelB.
INFO:STORE:No existing TRAK scores in /mnt/xfs/home/krisgrg/projects/modeldiff/modeldiff_scores/modelB.


### Take any `diff` of choice!

Now you can compute `A-B` and `B-A` with just a single line of code!

In [None]:
val_loader1 = get_dataloader(split='val')
diff1 = md.get_A_minus_B(val_loader=val_loader1, num_pca_comps=2)

Once we've initialized the `md` instance, we can use it to compute `diff`s wrt many different target datasets (e.g., like with `val_loader1` above).

In [None]:
val_loader2 = get_dataloader(split='val')
diff2 = md.get_B_minus_A(val_loader=val_loader2, num_pca_comps=4)

### Bring Your Own Scores

If you already have some attribution scores computed, you can still use the same API!

In [13]:
md_from_scores = ModelDiff()

In [16]:
# run scripts/download_living17_checkpoints.sh first
from pathlib import Path

scores_dir = Path('./datamodels/')
scoresA = torch.load(scores_dir.joinpath('living17_data-aug.pt'))['weight']
scoresB = torch.load(scores_dir.joinpath('living17_without-data-aug.pt'))['weight']

In [17]:
diff = md_from_scores.get_A_minus_B_from_scores(scoresA, scoresB, num_pca_comps=2)

In [18]:
diff

{'directions': array([[-0.00066547,  0.00221916,  0.00012629, ...,  0.00020184,
         -0.00036254, -0.00049859],
        [-0.00524965, -0.00065124, -0.00038203, ...,  0.00063359,
         -0.0018108 , -0.00165741]], dtype=float32),
 'projections': array([[ 5.1245297e-04, -2.4516261e-04],
        [-1.0042154e-04,  5.1465724e-04],
        [-3.7063823e-05,  8.0200400e-05],
        ...,
        [-1.7703998e-05,  2.1589889e-04],
        [-8.5801890e-05, -1.4266069e-04],
        [-5.9207014e-05,  5.5971013e-05]], dtype=float32),
 'variances': {'A': array([0.00033447, 0.00109441], dtype=float32),
  'B': array([0.00012626, 0.00037166], dtype=float32)}}

That's it!