In [8]:
# Set your working directory
import os
os.chdir('/Location_of_src_folder')

import h5py
import torch
from tqdm import tqdm
from data_utils.data_stats import *
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from torchvision import datasets, transforms
_ = torch.manual_seed(0)

Load MLP

In [5]:
dataset = 'cifar10'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'
data_resolution = 32                # Resolution of data as it is stored
crop_resolution = 64                # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
data_path = './beton/'
eval_batch_size = 100
checkpoint = 'in21k_cifar10'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10

In [6]:
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the model and specify the pre-trained weights
mlp = get_model(architecture=architecture, resolution=crop_resolution, num_classes=CLASS_DICT[dataset],
                  checkpoint='in21k_cifar10')

Weights already downloaded
Load_state output <All keys matched successfully>


Attach hook to "harvest" last-layer activations

In [7]:
acts_mlp = {}

def hook_mlp(module, input, output):
    assert input[0].shape[1] == 1024
    acts_mlp['act'] = (input[0].clone().detach().numpy())

hook_mlp = mlp.linear_out.register_forward_hook(hook_mlp)

Get the train loader

In [13]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(64),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform = transform)
loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified


"Harvest" the activations

In [14]:
all_acts_mlp = []

@torch.no_grad()
def test(model_mlp, loader):
    model_mlp.eval()
    total_acc_mlp, total_top5_mlp = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims_flat = torch.reshape(ims, (ims.shape[0], -1))
        preds_mlp = model_mlp(ims_flat)
        all_acts_mlp.append(acts_mlp['act'])

        if dataset != 'imagenet_real':
            acc_mlp, top5_mlp = topk_acc(preds_mlp, targs, k=5, avg=True)
        else:
            acc_mlp = real_acc(preds_mlp, targs, k=5, avg=True)
            top5_mlp = 0

        total_acc_mlp.update(acc_mlp, ims_flat.shape[0])
        total_top5_mlp.update(top5_mlp, ims_flat.shape[0])

    return (
        total_acc_mlp.get_avg(percentage=True),
        total_top5_mlp.get_avg(percentage=True),
    )

In [17]:
test_acc_mlp, test_top5_mlp = test(mlp, loader)
hook_mlp.remove()

# Print all the stats
print("Test Accuracy MLP:      ", "{:.4f}".format(test_acc_mlp))
print("Top 5 Train Accuracy MLP:", "{:.4f}".format(test_top5_mlp))

Evaluation:   0%|          | 0/100 [00:00<?, ?it/s]

Evaluation: 100%|██████████| 100/100 [00:55<00:00,  1.79it/s]

Test Accuracy MLP:       94.0900
Top 5 Train Accuracy MLP: 99.6000





Save the collected activations to disk

In [18]:
del mlp
acts_mlp_np = np.concatenate(all_acts_mlp, axis=0)

with h5py.File('acts_' + architecture + '_' + dataset + '_test_postskip.h5', 'w') as hf:
    hf.create_dataset('activations', data=acts_mlp_np)