In [None]:
import os
import h5py
import torch
from tqdm import tqdm

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from data_utils.dataset_to_beton import get_dataset
from PyTorch_CIFAR10.cifar10_models.vgg import vgg13_bn

Dataset specifications

In [None]:
dataset = 'cifar10'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
data_resolution = 32                # Resolution of data as it is stored
crop_resolution = 64                # Resolution of fine-tuned model (64 for all models we provide)
data_path = './beton/'
eval_batch_size = 100

Load VGG

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vgg = vgg13_bn(pretrained=True)
state_dict = os.path.join(
                "PyTorch_CIFAR10", "cifar10_models", "state_dicts", "vgg13_bn" + ".pt"
            )
vgg.load_state_dict(torch.load(state_dict))

Attach hook to "harvest" last-layer activations

In [None]:
acts_vgg = {}

def hook_vgg(module, input, output):
    assert output.shape[1] == 4096
    acts_vgg['act'] = (output.clone().detach().numpy())

hook_mlp = mlp.blocks[-1].register_forward_hook(hook_mlp)
hook_vgg = vgg._modules['classifier'][4].register_forward_hook(hook_vgg)

Get the data loader

In [None]:
loader = get_loader(
    dataset,
    bs=eval_batch_size,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path=data_path,
    data_resolution=data_resolution,
    crop_resolution=crop_resolution,
)

"Harvest" the activations

In [None]:
# Define a test function that evaluates test accuracy
all_acts_vgg = []

@torch.no_grad()
def test(model_vgg, loader):
    model_vgg.eval()
    total_acc_vgg, total_top5_vgg = AverageMeter(), AverageMeter()
    downsample = torch.nn.MaxPool2d(2)

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims_small = downsample(ims).detach()
        preds_vgg = model_vgg(ims_small)

        images.append(ims_small.clone().detach())
        targets.append(targs.clone().detach())
        all_acts_vgg.append(acts_vgg['act'])

        acc_vgg, top5_vgg = topk_acc(preds_vgg, targs, k=5, avg=True)
        top5_vgg = 0

        total_acc_vgg.update(acc_vgg, ims.shape[0])
        total_top5_vgg.update(top5_vgg, ims.shape[0])

    return (
        total_acc_vgg.get_avg(percentage=True),
        total_top5_vgg.get_avg(percentage=True)
    )

In [None]:
test_acc_vgg, test_top5_vgg = test(vgg, loader)
hook_vgg.remove()

# Print all the stats
print("Test Accuracy VGG:      ", "{:.4f}".format(test_acc_vgg))
print("Top 5 Test Accuracy VGG:      ", "{:.4f}".format(test_top5_vgg))

Save the collected activations to disk

In [None]:
acts_vgg_np = np.concatenate(all_acts_vgg, axis=0)

with h5py.File('acts_VGG13_bn_' + dataset + '_test.h5', 'w') as hf:
    hf.create_dataset('activations', data=acts_vgg_np)