# PyTorch: IBA Readout Bottleneck 

This notebook shows how to train the Readout Bottleneck and apply it to a pretrained ImageNet model. 

Ensure that `./imagenet` points to your copy of the ImageNet dataset. 

You might want to create a symlink:

In [None]:
# ! ln -s /path/to/your/imagenet/folder/ imagenet 

## Loading Data and Model

In [None]:
# to set you cuda device
# %env CUDA_VISIBLE_DEVICES=1

import torch
import torchvision.models 
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, Normalize
import json
import os
import sys
from tqdm.notebook import tqdm


try:
    import IBA
except ModuleNotFoundError:
    sys.path.insert(0, '..')
    import IBA
    
from IBA.pytorch_readout import IBAReadout
from IBA.pytorch import tensor_to_np_img

In [None]:
imagenet_dir = './imagenet'

dev = torch.device('cuda:0')

model = torchvision.models.vgg16(pretrained=True).to(dev)

valset = ImageFolder(
    os.path.join(imagenet_dir, 'validation'),
    transform=Compose([
        CenterCrop(256), Resize(224), ToTensor(), 
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))

trainset = ImageFolder(
    os.path.join(imagenet_dir, 'train'),
    transform=Compose([
        CenterCrop(256), Resize(224), ToTensor(), 
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]))

with open('imagenet_class_index.json') as f:
    idx2class = {int(k): v[1] for k, v in json.load(f).items()}

trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)
img, target = valset[0]

## Insert the bottleneck into the model

You can experiment with layers to read and the location of the bottleneck.

In [None]:
# Select a set of layers to read. 
readout_layers = [
    model.features[10],
    model.features[14],
    model.features[18],
    model.features[28],
    model.classifier,
]

# Initialize the Readout Bottleneck and inject it to the 10-th layer
iba = IBAReadout(model.features[10], readout_layers, model)

## Estimate Mean and Variance

Here, we estimate the mean and variances of the feature map. It is important for measuring the amount of information transmitted.

In [None]:
iba.estimate(model, trainloader, device=dev, n_samples=10000, progbar=True)

## Train the Readout Network

We train the mapping from readout feature maps to alphas on the train dataset.

In [None]:
# Prepare training: We only train the parameters of the
# Readout Bottleneck - the model remains frozen
optimizer = torch.optim.Adam(lr=1e-5, params=iba.parameters())
beta = 10

# Train for 10 epochs, this may take some time. 
# You may interrupt earlier to inspect intermediate results.
with iba.restrict_flow():
    for epoch in range(10):
        for x, target in tqdm(trainloader, desc=f"Training epoch {epoch}"):
            x, target = x.to(dev), target.to(dev)
            optimizer.zero_grad()
            model_loss = -torch.log_softmax(model(x), 1)[:, target].mean()
            information_loss = iba.capacity().mean()
            loss = model_loss + beta * information_loss
            loss.backward()
            optimizer.step()


## Display Heatmaps for some random samples


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from IBA.utils import plot_saliency_map

fig, axes = plt.subplots(2, 5, figsize=(20, 6))
np.random.seed(0)
for ax, sample_idx in zip(axes.flatten(), np.random.choice(50000, 10)):
    img, target = valset[sample_idx]
    img = img[None].to(dev)
    
    # Execute the model on a given sample and return the target NLL
    model_loss_closure = lambda x: -torch.log_softmax(model(x), 1)[:, target].mean()
    
    # Generate the heatmap
    heatmap = iba.analyze(img, model_loss_closure)
    
    # Reverse the data pre-processing for plotting the original image
    np_img = tensor_to_np_img(img[0])
    
    # Show the heatmap
    plot_saliency_map(heatmap, np_img,  ax=ax)
    ax.set_title(idx2class[target])
    
fig.suptitle("model: {}".format(type(model).__name__))
plt.show()

## Monkey image

In [None]:
from PIL import Image

img = np.array(Image.open("./monkeys.jpg"))
img = (img.transpose(2, 0, 1) / 255)
target = 382  # 382: squirrel monkey

# preprocess image
img  = Compose([
    Resize(224), ToTensor(),  
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])(Image.open("./monkeys.jpg"))

model_loss_closure = lambda x: -torch.log_softmax(model(x), 1)[:, target].mean()
heatmap = iba.analyze(img[None].to(dev), model_loss_closure) 
ax = plot_saliency_map(heatmap, tensor_to_np_img(img))
_ = ax.set_title(idx2class[target])
plt.show()