# Train and save model checkpoints

In [None]:
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze

import os
from pathlib import Path
import wget
from tqdm import tqdm
import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss, Conv2d, BatchNorm2d
from torch.optim import SGD, lr_scheduler
import torchvision

In [None]:
BETONS = {
        'train': "https://www.dropbox.com/s/zn7jsp2rl09e0fh/train.beton?dl=1",
        'val': "https://www.dropbox.com/s/4p73milxxafv4cm/val.beton?dl=1",
}

STATS = {
        'mean': [125.307, 122.961, 113.8575],
        'std': [51.5865, 50.847, 51.255]
}

def get_dataloader(batch_size=256,
                num_workers=8,
                split='train',  # split \in [train, val]
                aug_seed=0,
                order='sequential',
                subsample=False,
                should_augment=True,
                indices=None):
        label_pipeline: List[Operation] = [IntDecoder(),
                                        ToTensor(),
                                        ToDevice(torch.device('cuda:0')),
                                        Squeeze()]
        image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]

        if should_augment:
                image_pipeline.extend([
                        RandomHorizontalFlip(),
                        RandomTranslate(padding=2, fill=tuple(map(int, STATS['mean']))),
                        Cutout(4, tuple(map(int, STATS['mean']))),
                ])

        image_pipeline.extend([
            ToTensor(),
            ToDevice(torch.device('cuda:0'), non_blocking=True),
            ToTorchImage(),
            Convert(torch.float32),
            torchvision.transforms.Normalize(STATS['mean'], STATS['std']),
        ])

        beton_url = BETONS[split]
        beton_path = f'./{split}.beton'
        wget.download(beton_url, out=str(beton_path), bar=None)
        
        if subsample and split == 'train':
            indices = np.random.choice(np.arange(10_000), replace=False, size=5_000)

        if order == 'sequential':
            order = OrderOption.SEQUENTIAL
        else:
            order = OrderOption.RANDOM
        
        return Loader(beton_path,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    order=order,
                    drop_last=False,
                    seed=aug_seed,
                    indices=indices,
                    pipelines={'image': image_pipeline, 'label': label_pipeline})


# Resnet9
class Mul(torch.nn.Module):
    def __init__(self, weight):
        super(Mul, self).__init__()
        self.weight = weight
    def forward(self, x): return x * self.weight


class Flatten(torch.nn.Module):
    def forward(self, x): return x.view(x.size(0), -1)


class Residual(torch.nn.Module):
    def __init__(self, module):
        super(Residual, self).__init__()
        self.module = module
    def forward(self, x): return x + self.module(x)


def construct_rn9(num_classes=2):
    def conv_bn(channels_in, channels_out, kernel_size=3, stride=1, padding=1, groups=1):
        return torch.nn.Sequential(
                torch.nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size,
                            stride=stride, padding=padding, groups=groups, bias=False),
                torch.nn.BatchNorm2d(channels_out),
                torch.nn.ReLU(inplace=True)
        )
    model = torch.nn.Sequential(
        conv_bn(3, 64, kernel_size=3, stride=1, padding=1),
        conv_bn(64, 128, kernel_size=5, stride=2, padding=2),
        Residual(torch.nn.Sequential(conv_bn(128, 128), conv_bn(128, 128))),
        conv_bn(128, 256, kernel_size=3, stride=1, padding=1),
        torch.nn.MaxPool2d(2),
        Residual(torch.nn.Sequential(conv_bn(256, 256), conv_bn(256, 256))),
        conv_bn(256, 128, kernel_size=3, stride=1, padding=0),
        torch.nn.AdaptiveMaxPool2d((1, 1)),
        Flatten(),
        torch.nn.Linear(128, num_classes, bias=False),
        Mul(0.2)
    )
    return model

def train(model, loader, lr=0.4, epochs=100, momentum=0.9, weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0):
    opt = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    iters_per_epoch = len(loader)
    # Cyclic LR with single triangle
    lr_schedule = np.interp(np.arange((epochs+1) * iters_per_epoch),
                            [0, lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch],
                            [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)
    scaler = GradScaler()
    loss_fn = CrossEntropyLoss(label_smoothing=label_smoothing)

    for ep in range(epochs):
        for it, (ims, labs) in enumerate(loader):
            opt.zero_grad(set_to_none=True)
            with autocast():
                out = model(ims)
                loss = loss_fn(out, labs)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            scheduler.step()

def evaluate(model, loader_val):
    model.eval()
    with torch.no_grad():
        total_correct, total_num = 0., 0.
        for ims, labs in tqdm(loader):
            with autocast():
                out = model(ims)
                total_correct += out.argmax(1).eq(labs).sum().cpu().item()
                total_num += ims.shape[0]
        print(f'Test accuracy: {total_correct / total_num * 100:.1f}%')

In [None]:
os.makedirs('./checkpoints_cifar2', exist_ok=True)

for i in tqdm(range(5), desc='Training models..'):
    model = construct_rn9().to(memory_format=torch.channels_last).cuda()
    loader_train = get_dataloader(batch_size=512, split='train', order='random', subsample=True)
    train(model, loader_train)
    
    torch.save(model.state_dict(), f'./checkpoints_cifar2/sd_{i}.pt')

In [None]:
ckpt_files = list(Path('./models_cifar2').rglob('model_sd_99.pt'))
ckpts = [torch.load(ckpt, map_location='cpu') for ckpt in ckpt_files]

ckpt_files_old = list(Path('/mnt/cfs/projects/better_tracin/checkpoints/resnet9_cifar2/50pct/debug').rglob('*.pt'))
ckpts_old = [torch.load(ckpt, map_location='cpu') for ckpt in ckpt_files_old]

In [None]:
loader = get_dataloader(split='val')
model = construct_rn9().to(memory_format=torch.channels_last).cuda()
model.load_state_dict(ckpts[1])

evaluate(model, loader)

# Set up the TRAKer class

In [None]:
model = construct_rn9().to(memory_format=torch.channels_last).cuda()
model = model.eval()

In [None]:
from trak.projectors import BasicProjector
from trak import TRAKer

In [None]:
traker = TRAKer(model=model,
                task='image_classification',
                proj_dim=2048,
                save_dir='./trak_results_cifar_2_debug_2k',
                train_set_size=10_000)

# Compute TRAK features for train data

In [None]:
batch_size = 128
loader_train = get_dataloader(batch_size=batch_size, split='train')

In [None]:
for model_id, ckpt in enumerate(tqdm(ckpts)):
    traker.load_checkpoint(ckpt, model_id=model_id)
    for batch in loader_train:
        traker.featurize(batch=batch, num_samples=batch[0].shape[0])

traker.finalize_features()

# Compute TRAK scores for targets

In [None]:
loader_targets = get_dataloader(batch_size=batch_size, split='val', should_augment=False)

In [None]:
for model_id, ckpt in enumerate(tqdm(ckpts)):
    traker.start_scoring_checkpoint(ckpt,
                                    model_id=model_id,
                                    num_targets=len(loader_targets.indices))
    for batch in loader_targets:
        traker.score(batch=batch, num_samples=batch[0].shape[0])

scores = traker.finalize_scores()

# Visualize the attributions

In [None]:
from matplotlib import pyplot as plt

In [None]:
targets = [85, 100]  # let's look at two validation images
loader_targets = get_dataloader(batch_size=2, split='val', indices=targets, should_augment=False)

In [None]:
for batch in loader_targets:
    ims, _ = batch
    ims = (ims - ims.min()) / (ims.max() - ims.min())
    for image in ims:
        plt.figure(figsize=(1.5,1.5))
        plt.imshow(image.cpu().permute([1, 2, 0]).numpy()); plt.axis('off'); plt.show()

And the highest scoring examples in the train set according to TRAK

In [None]:
scores_prev = np.load('/mnt/cfs/projects/better_tracin/estimators/CIFAR2/ablation_jl_dim/100models_1epochs_jl1000/estimates.npy')
S = scores_prev
# S = scores.cpu()

for target in targets:
    print(f'Top scorers for target {target}')
    loader_top_scorer = get_dataloader(batch_size=3, split='train', indices=S[:, target].argsort()[-3:],
                                       should_augment=False)
    for batch in loader_top_scorer:
        ims, _ = batch
        ims = (ims - ims.min()) / (ims.max() - ims.min())
        for image in ims:
            plt.figure(figsize=(1.5, 1.5))
            plt.imshow(image.cpu().permute([1, 2, 0]).numpy()); plt.axis('off'); plt.show()

In [None]:
# scores_prev = np.load('/mnt/cfs/projects/better_tracin/estimators/CIFAR2/ablation_jl_dim/100models_1epochs_jl1000/estimates.npy')
# S = scores_prev
S = scores.cpu()

for target in targets:
    print(f'Top scorers for target {target}')
    loader_top_scorer = get_dataloader(batch_size=3, split='train', indices=S[:, target].argsort()[-3:],
                                       should_augment=False)
    for batch in loader_top_scorer:
        ims, _ = batch
        ims = (ims - ims.min()) / (ims.max() - ims.min())
        for image in ims:
            plt.figure(figsize=(1.5, 1.5))
            plt.imshow(image.cpu().permute([1, 2, 0]).numpy()); plt.axis('off'); plt.show()

# Extra: evaluate counterfactuals

We exactly follow the steps in https://docs.ffcv.io/ffcv_examples/cifar10.html, except for the fact that we replace the CIFAR-10 dataloader with the CIFAR-2 dataloader above. Additionally, we train on *subsets* of CIFAR-2, parametrized by the `masks` arrays below. We collect the model outputs for each retraining on a different subset (mask) in a separate array `margins`.

We train a total of 10,000 models. Note that this is not necessary to get TRAK scores. This step is only necessary to get (very high quality) LDS correlation estimates.

In [None]:
EVAL_DIR = Path('/mnt/cfs/home/spark/store/kernel/cifar2/50pct_new_augs_10x_per_mask')
indices = np.where(np.load(EVAL_DIR / '_completed.npy'))[0]

comp_indices = []

for i in tqdm(range(0, 99480, 10)):
    if all(j in indices for j in range(i,i+10)):
        comp_indices.extend(list(range(i,i+10)))

masks = np.load(EVAL_DIR / 'mask.npy')[comp_indices[::10]]
margins = np.load(EVAL_DIR / 'val_margins.npy')[comp_indices]
margins = margins.reshape(len(margins) // 10,10,2000).mean(1)

In [None]:
from scipy.stats import spearmanr

In [None]:
# SS = scores_prev
SS = scores.cpu().numpy()

tmp_path = '.'
# masks_url = 'https://www.dropbox.com/s/2nmcjaftdavyg0m/mask.npy?dl=1'
# margins_url = 'https://www.dropbox.com/s/tc3r3c3kgna2h27/val_margins.npy?dl=1'

# masks_path = Path(tmp_path).joinpath('mask.npy')
# wget.download(masks_url, out=str(masks_path), bar=None)
# # num masks, num train samples
# masks = torch.as_tensor(np.load(masks_path, mmap_mode='r')).float()

# margins_path = Path(tmp_path).joinpath('val_margins.npy')
# wget.download(margins_url, out=str(margins_path), bar=None)
# # num , num val samples
# margins = torch.as_tensor(np.load(margins_path, mmap_mode='r'))

val_inds = np.arange(2000)
preds = masks @ SS
rs = []
ps = []
for ind, j in tqdm(enumerate(val_inds)):
    r, p = spearmanr(preds[:, ind], margins[:, j])
    rs.append(r)
    ps.append(p)
rs, ps = np.array(rs), np.array(ps)
print(f'Correlation: {rs.mean()} (avg p value {ps.mean()})')