# Readout Bottleneck 

This notebook shows how to train the Readout Bottleneck and apply it to a pretrained ImageNet model. 

Ensure that `./data/imagenet` points to your copy of the ImageNet dataset. 

## Loading Data and Model

In [None]:
import torch
import torchvision.models 
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, Normalize
import json
import os

imagenet_dir = './data/imagenet'

dev = torch.device('cuda:0')

model = torchvision.models.vgg16(pretrained=True).to(dev)

valset = ImageFolder(
    os.path.join(imagenet_dir, 'validation'),
    transform=Compose([
        CenterCrop(256), Resize(224), ToTensor(), 
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))

trainset = ImageFolder(
    os.path.join(imagenet_dir, 'train'),
    transform=Compose([
        CenterCrop(256), Resize(224), ToTensor(), 
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))

with open('imagenet_class_index.json') as f:
    idx2class = {int(k): v[1] for k, v in json.load(f).items()}

trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)
img, target = valset[0]

## Insert the bottleneck into the model

You can experiment with layers to read and the location of the bottleneck.

In [None]:
from IBA.pytorch_readout import IBA_Readout
from IBA.pytorch import tensor_to_np_img

# Select a set of layers to read. The first layer is also the position of the injected bottleneck. 
layers = [
    model.features[10],
    model.features[14],
    model.features[18],
    model.features[28],
    model.classifier,
]
# Initialize the Readout Bottleneck
btln = IBA_Readout(layers, model)

# Inject it at the first layer
btln.attach(layers[0])

## Estimate Mean and Variance

Here, we estimate the mean and variances of the feature map. It is important for measuring the amount of information transmitted.

In [None]:
btln.estimate(model, trainloader, device=dev, n_samples=100, progbar=True)

## Train the Readout Network
We train the mapping from readout feature maps to alphas on the train dataset.

In [None]:
# Prepare training: We only train the parameters of the
# Readout Bottleneck - the model remains frozen
optimizer = torch.optim.Adam(lr=1, params=btln.parameters())
beta = 10

# Train for 10 epochs, this may take some time. 
# You may interrupt earlier to inspect intermediate results.
with btln.supress_information():
    for epoch in range(10):
        print(f"Training epoch {epoch}...")
        for x, target in trainloader:
            x, target = x.to(dev), target.to(dev)
            optimizer.zero_grad()
            model_loss = -torch.log_softmax(model(x), 1)[:, target].mean()
            information_loss = btln.capacity().mean()
            loss = model_loss + beta * information_loss
            loss.backward()
            optimizer.step()


## Display Heatmaps for some random samples



In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from IBA.utils import plot_heatmap

fig, axes = plt.subplots(2, 5, figsize=(20, 6))
np.random.seed(0)
for ax, sample_idx in zip(axes.flatten(), np.random.choice(50000, 10)):
    img, target = valset[sample_idx]
    img = img[None].to(dev)
    
    # Execute the model on a given sample and return the target NLL
    model_loss_closure = lambda x: -torch.log_softmax(model(x), 1)[:, target].mean()
    
    # Generate the heatmap
    heatmap = btln.heatmap(img, model_loss_closure)
    
    # Reverse the data pre-processing for plotting the original image
    np_img = tensor_to_np_img(img[0])
    
    # Show the heatmap
    plot_heatmap(heatmap, np_img,  ax=ax)
    ax.set_title(idx2class[target])
    
fig.suptitle("model: {}".format(type(model).__name__))
plt.show()

## Monkey image

In [None]:
from PIL import Image

img = np.array(Image.open("./monkeys.jpg"))
img = (img.transpose(2, 0, 1) / 255)
target = 382  # 382: squirrel monkey

# preprocess image
img  = Compose([
    Resize(224), ToTensor(),  
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])(Image.open("./monkeys.jpg"))

model_loss_closure = lambda x: -torch.log_softmax(model(x), 1)[:, target].mean()
heatmap = btln.heatmap(img[None].to(dev), model_loss_closure) 
ax = plot_heatmap(heatmap, tensor_to_np_img(img))
_ = ax.set_title(idx2class[target])
plt.show()