# Train and save model checkpoints

In [1]:
import os
from pathlib import Path
import wget
from tqdm import tqdm
import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
import torchvision

In [2]:
# Resnet9
class Mul(torch.nn.Module):
    def __init__(self, weight):
        super(Mul, self).__init__()
        self.weight = weight
    def forward(self, x): return x * self.weight


class Flatten(torch.nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)


class Residual(torch.nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module
    def forward(self, x): return x + self.module(x)


def construct_rn9(num_classes=10):
    def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
        return torch.nn.Sequential(
                torch.nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size,
                            stride=stride, padding=padding, groups=groups, bias=False),
                torch.nn.BatchNorm2d(channels_out),
                torch.nn.ReLU(inplace=True)
        )
    model = torch.nn.Sequential(
        conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
        conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
        Residual(torch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
        conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
        torch.nn.MaxPool2d(2),
        Residual(torch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
        conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
        torch.nn.AdaptiveMaxPool2d((1, 1)),
        Flatten(),
        torch.nn.Linear(128, num_classes, bias=False),
        Mul(0.2)
    )
    return model

In [3]:
def get_dataloader(batch_size=256, num_workers=8, split='train'):
    
    transforms = torchvision.transforms.Compose(
                    [torchvision.transforms.RandomHorizontalFlip(),
                     torchvision.transforms.RandomAffine(0),
                     torchvision.transforms.ToTensor(),
                     torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201))])

    is_train = (split == 'train')
    dataset = torchvision.datasets.CIFAR10(root='/tmp/cifar/', download=True, train=is_train, transform=transforms)
    loader = torch.utils.data.DataLoader(dataset=dataset, shuffle=False, batch_size=batch_size, num_workers=num_workers)
    
    return loader

In [4]:
def train(model, loader, lr=0.4, epochs=24, momentum=0.9, weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0):
    opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    iters_per_epoch = len(loader)
    # Cyclic LR with single triangle
    lr_schedule = np.interp(np.arange((epochs+1) * iters_per_epoch),
                            [0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch],
                            [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
    scaler = GradScaler()
    loss_fn = CrossEntropyLoss(label_smoothing=label_smoothing)

    for ep in range(epochs):
        model_count = 0
        for it, (ims, labs) in enumerate(loader):
            ims = ims.float().cuda()
            labs = labs.cuda()
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(ims)
                loss = loss_fn(out, labs)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            scheduler.step()

In [5]:
os.makedirs('./checkpoints', exist_ok=True)

for i in tqdm(range(3), desc='Training models..'):
    model = construct_rn9().to(memory_format=torch.channels_last).cuda()
    loader_train = get_dataloader(batch_size=512, split='train')
    train(model, loader_train)
    
    torch.save(model.state_dict(), f'./checkpoints/sd_{i}.pt')

Training models..:   0%|          | 0/3 [00:00<?, ?it/s]

Files already downloaded and verified


Training models..:  33%|███▎      | 1/3 [01:27<02:55, 87.84s/it]

Files already downloaded and verified


Training models..:  67%|██████▋   | 2/3 [02:53<01:26, 86.50s/it]

Files already downloaded and verified


Training models..: 100%|██████████| 3/3 [04:19<00:00, 86.65s/it]


In [5]:
ckpt_files = list(Path('./checkpoints').rglob('*.pt'))
ckpts = [torch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]

# Set up the TRAKer class

In [8]:
model = construct_rn9().to(memory_format=torch.channels_last).cuda()
model = model.eval()

In [19]:
from trak import TRAKer

traker = TRAKer(model=model,
                task='image_classification',
                proj_dim=1024,
                save_dir='/tmp/trak_store',
                train_set_size=len(loader_train.dataset))

Existing IDs in /tmp/trak_store: []


# Compute TRAK features for train data

In [20]:
batch_size = 128
loader_train = get_dataloader(batch_size=batch_size, split='train')

Files already downloaded and verified


In [21]:
from tqdm import tqdm

for model_id, ckpt in enumerate(tqdm(ckpts)):
    traker.load_checkpoint(ckpt, model_id=model_id)
    for batch in loader_train:
        batch = [x.cuda() for x in batch]
        traker.featurize(batch=batch, num_samples=batch[0].shape[0])

traker.finalize_features()

100%|██████████| 3/3 [01:11<00:00, 23.81s/it]
Finalizing features for all model IDs..: 100%|██████████| 3/3 [00:01<00:00,  2.58it/s]


# Compute TRAK scores for targets

In [24]:
loader_targets = get_dataloader(batch_size=batch_size, split='val')

Files already downloaded and verified


In [26]:
for model_id, ckpt in enumerate(tqdm(ckpts)):
    traker.start_scoring_checkpoint(ckpt,
                                    model_id=model_id,
                                    num_targets=len(loader_targets.dataset))
    for batch in loader_targets:
        batch = [x.cuda() for x in batch]
        traker.score(batch=batch, num_samples=batch[0].shape[0])

scores = traker.finalize_scores()

100%|██████████| 3/3 [00:15<00:00,  5.21s/it]
Finalizing scores for all model IDs..: 100%|██████████| 3/3 [00:00<00:00, 10.10it/s]


Saving scores in /tmp/trak_store/scores/scores_150a5fe1-7655-46c3-9bd0-ebee4feeea8d.npy


# Bonus: evaluate counterfactuals

We exactly follow the steps in https://docs.ffcv.io/ffcv_examples/cifar10.html. Additionally, we train on *subsets* of CIFAR-10, parametrized by the `masks` arrays below. We collect the model outputs for each retraining on a different subset (mask) in a separate array `margins`.

We train a total of 10,000 models. Note that this is not necessary to get TRAK scores. This step is only necessary to get (very high quality) LDS correlation estimates.

In [53]:
from scipy.stats import spearmanr

In [54]:
def eval_correlations(scores, tmp_path):
    masks_url = '/url/to/masks'  # dropbox link coming soon
    margins_url = '/url/to/margins'

    masks_path = Path(tmp_path).joinpath('mask.npy')
    wget.download(masks_url, out=str(masks_path), bar=None)
    # num masks, num train samples
    masks = torch.as_tensor(np.load(masks_path, mmap_mode='r')).float()

    margins_path = Path(tmp_path).joinpath('val_margins.npy')
    wget.download(margins_url, out=str(margins_path), bar=None)
    # num , num val samples
    margins = torch.as_tensor(np.load(margins_path, mmap_mode='r'))

    val_inds = np.arange(10_000)
    preds = masks @ scores
    rs = []
    ps = []
    for ind, j in tqdm(enumerate(val_inds)):
        r, p = spearmanr(preds[:, ind], margins[:, j])
        rs.append(r)
        ps.append(p)
    rs, ps = np.array(rs), np.array(ps)
    print(f'Correlation: {rs.mean()} (avg p value {ps.mean()})')
    return rs.mean()

In [None]:
eval_correlations(scores.cpu(), '.')