In [7]:
# Set your working directory
import os
os.chdir('/Users/charleslego/my_documents/ETH/Classes/Sem3/Deep_learning/Project/src')

import h5py
import torch
from tqdm import tqdm
from data_utils.data_stats import *
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from CNN_models.vgg import vgg13_bn
from torchvision import datasets, transforms

Load VGG

In [4]:
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vgg = vgg13_bn(pretrained=True)
state_dict = os.path.join(
                "CNN_models", "state_dicts", "vgg13_bn" + ".pt"
            )
vgg.load_state_dict(torch.load(state_dict))

<All keys matched successfully>

Attach hook to "harvest" last-layer activations

In [8]:
acts_vgg = {}

def hook_vgg(module, input, output):
    assert output.shape[1] == 4096
    acts_vgg['act'] = (output.clone().detach().numpy())

hook_vgg = vgg._modules['classifier'][4].register_forward_hook(hook_vgg)

Get the data loader

In [18]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(32),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
])

test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform = transform)
loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified


"Harvest" the activations

In [20]:
# Define a test function that evaluates test accuracy
all_acts_vgg = []

@torch.no_grad()
def test(model_vgg, loader):
    model_vgg.eval()
    total_acc_vgg, total_top5_vgg = AverageMeter(), AverageMeter()
    downsample = torch.nn.MaxPool2d(2)

    for ims, targs in tqdm(loader, desc="Evaluation"):
        preds_vgg = model_vgg(ims)
        all_acts_vgg.append(acts_vgg['act'])

        acc_vgg, top5_vgg = topk_acc(preds_vgg, targs, k=5, avg=True)

        total_acc_vgg.update(acc_vgg, ims.shape[0])
        total_top5_vgg.update(top5_vgg, ims.shape[0])

    return (
        total_acc_vgg.get_avg(percentage=True),
        total_top5_vgg.get_avg(percentage=True)
    )

In [21]:
test_acc_vgg, test_top5_vgg = test(vgg, loader)
hook_vgg.remove()

# Print all the stats
print("Test Accuracy VGG:      ", "{:.4f}".format(test_acc_vgg))
print("Top 5 Test Accuracy VGG:      ", "{:.4f}".format(test_top5_vgg))

Evaluation:   0%|          | 0/100 [00:00<?, ?it/s]

Evaluation: 100%|██████████| 100/100 [01:55<00:00,  1.16s/it]

Test Accuracy VGG:       94.2100
Top 5 Test Accuracy VGG:       99.7400





Save the collected activations to disk

In [24]:
acts_vgg_np = np.concatenate(all_acts_vgg, axis=0)
dataset = 'cifar10'
with h5py.File('acts_VGG13_bn_' + dataset + '_test.h5', 'w') as hf:
    hf.create_dataset('activations', data=acts_vgg_np)