# Per Sample Bottleneck 

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torchvision
import torch
from torch import nn

import torchvision.models 
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, Normalize
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt 
import os
from tqdm import tqdm_notebook
import json
import itertools

from per_sample_bottleneck.per_sample_bottleneck import PerSampleBottleneck, insert_into_sequential
from per_sample_bottleneck.utils import get_output_shapes, plot_heatmap

## Loading Data and Model

In [None]:
imagenet_dir = '/mnt/ssd/data/imagenet/imagenet-raw/'
valset = ImageFolder(
    os.path.join(imagenet_dir, 'validation'),
    transform=Compose([
        CenterCrop(256), Resize(224), ToTensor(), 
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))

trainset = ImageFolder(
    os.path.join(imagenet_dir, 'train'),
    transform=Compose([
        CenterCrop(256), Resize(224), ToTensor(), 
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))

trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
img, target = valset[0]
dev = torch.device('cuda:0')
model = torchvision.models.vgg19_bn(pretrained=True)

In [None]:
sizes = get_output_shapes(model, img[None], nn.Conv2d)

In [None]:
for name, size in sizes.items():
    print(name, size)

In [None]:
layer_idx = 17
size = sizes['features.17']

In [None]:
btln = PerSampleBottleneck(*size)

if not any([layer == btln for layer in model.features]):
    model.features = insert_into_sequential(model.features, btln, layer_idx+1)

In [None]:
model

In [None]:
model.to(dev)
loader200 = itertools.islice(trainloader, 200)
btln.estimate(model, loader200, device=dev, progbar=True)

In [None]:
model.eval()

In [None]:
with open('imagenet_class_index.json') as f:
    idx2class = {int(k): v[1] for k, v in json.load(f).items()}

In [None]:
img, target = valset[544]
# cross entropy
heatmap = btln.heatmap(img[None].to(dev), lambda x: model(x)[:, target].mean())
ax = plot_heatmap(img, heatmap)
_ = ax.set_title(idx2class[target])