In [None]:
import os
import h5py
import torch
from tqdm import tqdm
from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from data_utils.dataset_to_beton import get_dataset
_ = torch.manual_seed(0)

Load MLP

In [None]:
dataset = 'cifar10'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'
data_resolution = 32                # Resolution of data as it is stored
crop_resolution = 64                # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
data_path = './beton/'
eval_batch_size = 100
checkpoint = 'in21k_cifar10'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the model and specify the pre-trained weights
mlp = get_model(architecture=architecture, resolution=crop_resolution, num_classes=CLASS_DICT[dataset],
                  checkpoint='in21k_cifar10')

Attach hook to "harvest" last-layer activations

In [None]:
acts_mlp = {}

def hook_mlp(module, input, output):
    assert input[0].shape[1] == 1024
    acts_mlp['act'] = (input[0].clone().detach().numpy())

hook_mlp = mlp.linear_out.register_forward_hook(hook_mlp)

Get the train loader

In [None]:
loader = get_loader(
    dataset,
    bs=eval_batch_size,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path=data_path,
    data_resolution=data_resolution,
    crop_resolution=crop_resolution,
)

"Harvest" the activations

In [None]:
all_acts_mlp = []

@torch.no_grad()
def test(model_mlp, loader):
    model_mlp.eval()
    total_acc_mlp, total_top5_mlp = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims_flat = torch.reshape(ims, (ims.shape[0], -1))
        preds_mlp = model_mlp(ims_flat)
        all_acts_mlp.append(acts_mlp['act'])

        if dataset != 'imagenet_real':
            acc_mlp, top5_mlp = topk_acc(preds_mlp, targs, k=5, avg=True)
        else:
            acc_mlp = real_acc(preds_mlp, targs, k=5, avg=True)
            top5_mlp = 0

        total_acc_mlp.update(acc_mlp, ims_flat.shape[0])
        total_top5_mlp.update(top5_mlp, ims_flat.shape[0])

    return (
        total_acc_mlp.get_avg(percentage=True),
        total_top5_mlp.get_avg(percentage=True),
    )

In [None]:
test_acc_mlp, test_top5_mlp = test(mlp, loader)
hook_mlp.remove()

# Print all the stats
print("Test Accuracy MLP:      ", "{:.4f}".format(train_acc_mlp))
print("Top 5 Train Accuracy MLP:", "{d:.4f}".format(train_top5_mlp))

Save the collected activations to disk

In [None]:
del mlp
acts_mlp_np = np.concatenate(all_acts_mlp, axis=0)

with h5py.File('acts_' + architecture + '_' + dataset + '_test_postskip.h5', 'w') as hf:
    hf.create_dataset('activations', data=acts_mlp_np)