In [1]:
import os
import h5py
import torch
from tqdm import tqdm

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from data_utils.dataset_to_beton import get_dataset
from PyTorch_CIFAR10.cifar10_models.vgg import vgg13_bn

In [2]:
dataset = 'cifar10'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'
data_resolution = 32                # Resolution of data as it is stored
crop_resolution = 64                # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
data_path = './beton/'
eval_batch_size = 100
checkpoint = 'in21k_cifar10'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10

In [3]:
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the model and specify the pre-trained weights
mlp = get_model(architecture=architecture, resolution=crop_resolution, num_classes=CLASS_DICT[dataset],
                  checkpoint='in21k_cifar10')
#model.cuda()
vgg = vgg13_bn(pretrained=True)
state_dict = os.path.join(
                "PyTorch_CIFAR10", "cifar10_models", "state_dicts", "vgg13_bn" + ".pt"
            )
vgg.load_state_dict(torch.load(state_dict))

Weights already downloaded
Load_state output <All keys matched successfully>


<All keys matched successfully>

In [4]:
acts_mlp = {}
acts_vgg = {}

def hook_mlp(module, input, output):
    assert output.shape[1] == 1024
    acts_mlp['act'] = (output.clone().detach().numpy())

def hook_vgg(module, input, output):
    assert output.shape[1] == 4096
    acts_vgg['act'] = (output.clone().detach().numpy())

hook_mlp = mlp.blocks[-1].register_forward_hook(hook_mlp)
hook_vgg = vgg._modules['classifier'][4].register_forward_hook(hook_vgg)

In [5]:
# Get the train loader
loader = get_loader(
    dataset,
    bs=eval_batch_size,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path=data_path,
    data_resolution=data_resolution,
    crop_resolution=crop_resolution,
)

Loading ./beton/cifar10/train/train_32.beton


In [6]:
# Define a test function that evaluates test accuracy
images = []
targets = []
all_acts_mlp = []
all_acts_vgg = []

@torch.no_grad()
def test(model_mlp, model_vgg, loader):
    model_mlp.eval()
    model_vgg.eval()
    total_acc_mlp, total_top5_mlp = AverageMeter(), AverageMeter()
    total_acc_vgg, total_top5_vgg = AverageMeter(), AverageMeter()
    downsample = torch.nn.MaxPool2d(2)

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims_flat = torch.reshape(ims, (ims.shape[0], -1))
        ims_small = downsample(ims).detach()
        preds_mlp = model_mlp(ims_flat)
        preds_vgg = model_vgg(ims_small)

        images.append(ims_small.clone().detach())
        targets.append(targs.clone().detach())
        all_acts_mlp.append(acts_mlp['act'])
        all_acts_vgg.append(acts_vgg['act'])

        # if dataset != 'imagenet_real':
        acc_mlp, top5_mlp = topk_acc(preds_mlp, targs, k=5, avg=True)
        acc_vgg, top5_vgg = topk_acc(preds_vgg, targs, k=5, avg=True)
        # else:
        acc_mlp = real_acc(preds_mlp, targs, k=5, avg=True)
        acc_vgg = real_acc(preds_vgg, targs, k=5, avg=True)
        top5_mlp = 0
        top5_vgg = 0

        total_acc_mlp.update(acc_mlp, ims_flat.shape[0])
        total_top5_mlp.update(top5_mlp, ims_flat.shape[0])
        total_acc_vgg.update(acc_vgg, ims.shape[0])
        total_top5_vgg.update(top5_vgg, ims.shape[0])

    return (
        total_acc_mlp.get_avg(percentage=True),
        total_top5_mlp.get_avg(percentage=True),
        total_acc_vgg.get_avg(percentage=True),
        total_top5_vgg.get_avg(percentage=True)
    )

In [7]:
test_acc_mlp, test_top5_mlp, test_acc_vgg, test_top5_vgg = test(mlp, vgg, loader)
hook_mlp.remove()
hook_vgg.remove()

# Print all the stats
print("Test Accuracy MLP:      ", "{:.4f}".format(test_acc_mlp), ",      VGG:      ", "{:.4f}".format(test_acc_vgg))
print("Top 5 Test Accuracy MLP:", "{:.4f}".format(test_top5_mlp), ",      VGG:      ", "{:.4f}".format(test_top5_vgg))

Evaluation:   0%|          | 0/500 [00:00<?, ?it/s]

Evaluation: 100%|██████████| 500/500 [09:25<00:00,  1.13s/it]

Train Accuracy MLP:       100.0000 ,      VGG:       100.0000
Top 5 Train Accuracy MLP: 0.0000 ,      VGG:       0.0000





In [9]:
acts_mlp_np = np.concatenate(all_acts_mlp, axis=0)
acts_vgg_np = np.concatenate(all_acts_vgg, axis=0)
images_np = np.concatenate(images, axis=0)
targets_np = np.concatenate(targets, axis=0)
with h5py.File('acts_' + architecture + '_' + dataset + '_test.h5', 'w') as hf:
    hf.create_dataset('activations', data=acts_mlp_np)

with h5py.File('acts_VGG13_bn_' + dataset + '_test.h5', 'w') as hf:
    hf.create_dataset('activations', data=acts_vgg_np)

with h5py.File('ims_' + dataset + '_test.h5', 'w') as hf:
    hf.create_dataset('images', data=images_np)

with h5py.File('targs_' + dataset + '_test.h5', 'w') as hf:
    hf.create_dataset('targets', data=targets_np)

: 