In [None]:
%load_ext autoreload
%autoreload 2

You need the saved tensorflow weights from: https://drive.google.com/open?id=1HdSyCFzJlzJ2mFo_ZClSVCA1nJJpwAmg
```
$ sha256sum vgg_16_weights.npz                  
```
Should give
```
ff50e3f93d9cf158f31d1cc4275cfd477e37dcc4fdcdc8c9266decdcc561b049  vgg_16_weights.npz
```

In [None]:
import matplotlib.pyplot as plt

from tqdm.auto import tqdm as tqdmbar

import numpy as np

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose, Resize
from torchvision.models import vgg16



from load_tensorflow_vgg_weights import TensorflowVGGWeights, TensorflowTransform

from IBA.pytorch import IBA, tensor_to_np_img, get_imagenet_folder
from IBA.utils import plot_saliency_map, to_unit_interval


from IBA.utils import load_monkeys
import torch.nn.functional as F
import PIL

import pickle
import os
import glob

from assert_cache import assert_cache, get_assert_file

In [None]:
assert_dir = "asserts"

def to_nhwc(x):
    return x.transpose(0, 2, 3, 1)

def to_np(x):
    return x.detach().cpu().numpy()

def assert_(key, obj, assertion_fn, message_fn=None):
    if type(obj) == torch.Tensor:
        obj = obj.detach().cpu().numpy()
    if type(obj) == np.ndarray and len(obj.shape) == 4:
        obj = to_nhwc(obj)
    assert_cache("pytorch", key, obj, assertion_fn, message_fn)
    

In [None]:
fname = get_assert_file("pytorch", "1")

In [None]:
assert_("1", 2, 
        lambda a, b: a == b, 
        lambda a, b: "seriously? {} != {}".format(a, b))

In [None]:
# Initialize some pre-trained model to analyze
dev = 'cuda:0' if  torch.cuda.is_available() else 'cpu'

# Load model
tf_weights = TensorflowVGGWeights('cuda:0')
model = tf_weights.get_model()

#model = vgg16(pretrained=True)
#model.to(dev)

# setup data loader
val_set = get_imagenet_folder('/srv/public/leonsixt/data/imagenet/validation')
val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=4)

pattern_val_set = get_imagenet_folder(
    '/srv/public/leonsixt/data/imagenet/validation',  
    transform=TensorflowTransform()
)

In [None]:
pattern_val_loader = DataLoader(pattern_val_set, batch_size=50, 
                                shuffle=False, num_workers=4)
pattern_val_loader_shuffle = DataLoader(pattern_val_set, batch_size=50, 
                                shuffle=True, num_workers=4)
imgs, logits = next(iter(pattern_val_loader))

In [None]:
plt.imshow(to_unit_interval(to_nhwc(to_np(imgs))[0]))
plt.show()


In [None]:
type(imgs)

## Check Model and Data

In [None]:
%pdb off
assert_("first_image_batch_equal", imgs, lambda a, b: np.abs(a - b).mean() < 1e-4)

In [None]:
_ = model.eval()

In [None]:
with torch.no_grad():
    outputs = model(imgs.to(dev))
outputs.shape

In [None]:
outputs.min(), outputs.max()

In [None]:
assert_("first_batch_outputs_equal", outputs, 
        lambda a, b: np.abs((a - b)).mean() < 1e-4,
        lambda a, b: np.abs((a - b)).mean())

In [None]:
correct = []
logits = []
with torch.no_grad():
    progbar = tqdmbar(pattern_val_loader)
    for img, target in progbar:
        logit = model(img.to(dev))

        correct.append(torch.argmax(logit, 1).cpu() == target)
        logits.append(logit.cpu().numpy())
        progbar.set_postfix(acc=torch.cat(correct).float().mean().item())
        if len(logits) == 100:
            break
            
logits = np.concatenate(logits)
correct = np.concatenate(correct)

In [None]:
assert_("corrects_equal", correct, 
        lambda a, b: (a == b).mean(),
        lambda a, b: (a, b))

In [None]:
imgs, _ = next(iter(pattern_val_loader))
imgs.shape

## Check IBA

In [None]:
explained_layer = model.features[18]

In [None]:
if 'iba' in globals():
    iba.detach()

In [None]:
# Add a Per-Sample Bottleneck at layer conv4_1
iba = IBA(explained_layer)

In [None]:
pattern_val_loader.batch_size

In [None]:
# Estimate the mean and variance of the feature map at this layer.
iba.estimate(model, pattern_val_loader, n_samples=5, progbar=False)

In [None]:
assert_('estimated_mean_1', iba.estimator.mean().permute(1, 2, 0),
        lambda a, b: np.abs(a - b).mean() < 1e-4,
        lambda a, b: (a.mean(), b.mean()))

In [None]:
for i in range(1, 10):
    iba.estimate(model, pattern_val_loader, 
                 n_samples=i*50 - 5,  progbar=False, reset=True)
    assert_('estimated_mean_' + str(50*i), 
            iba.estimator.mean().permute(1, 2, 0),
            lambda a, b: np.abs(a - b).mean() < 1e-4,
            lambda a, b: (a.mean(), b.mean()))

In [None]:
iba.estimate(model, pattern_val_loader_shuffle, 
             n_samples=5000 - 1,  progbar=False, reset=True)

In [None]:
# Closure that returns the loss for one batch
model_loss_closure = lambda x: -torch.log_softmax(model(x), dim=1)[:, target].mean()

In [None]:
# Explain class target for the given image

img, target = pattern_val_set[0]
target = 2

def model_loss(x):
    logits = model(x)
    target_torch = torch.LongTensor([target] * len(logits)).to(dev)
    return F.cross_entropy(logits, target_torch)

saliency_map = iba.analyze(img.unsqueeze(0).to(dev), 
                                   model_loss, beta=10)

# display result
np_img = to_unit_interval(tensor_to_np_img(img))
plot_saliency_map(saliency_map, np_img)

In [None]:
_ = model.eval()

In [None]:
def model_loss(x):
    logits = model(x)
    target_torch = torch.LongTensor([target] * len(logits)).to(dev)
    return F.cross_entropy(logits, target_torch)

monkeys, target = load_monkeys(pil=True)
monkeys_trans = TensorflowTransform()(monkeys)

saliency_map = iba.analyze(monkeys_trans.unsqueeze(0).to(dev), 
                           model_loss, beta=10, 
                           lr=1,
                           min_std=0,
                           optimization_steps=10)

# display result
np_img = to_unit_interval(tensor_to_np_img(monkeys_trans))
plot_saliency_map(saliency_map, np_img)

In [None]:
iba._alpha_grads[0].shape

In [None]:
plt.hist(iba._alpha_grads[0].flatten(), bins=20, log=True)

In [None]:
assert_('grad_alpha_0',
        iba._alpha_grads[0].transpose(1, 2, 0), 
        lambda s, o: np.abs(s-o).mean() < 1e-6,
        lambda s, o: (s.mean(), s.std(), o.mean(), o.std())
       )

In [None]:
capacity = iba._buffer_capacity.cpu().detach().numpy()

In [None]:
plt.imshow(capacity[0].sum(0))
plt.colorbar()

In [None]:
mean = iba.estimator.mean().cpu().numpy()

In [None]:
plt.imshow(mean.sum(0))
plt.colorbar()

In [None]:

active = iba._active_neurons.cpu().numpy()
print(active.shape)
plt.imshow((1- active).sum(0))
plt.colorbar()

In [None]:
iba._model_loss

In [None]:
#plt.plot(pattern_iba._loss, label='loss')
plt.plot(iba._information_loss, label='info')
plt.plot(iba._model_loss, label='model')
plt.legend()

In [None]:
monkeys

In [None]:
alpha = pattern_iba.alpha.detach().cpu().numpy()
plt.hist(alpha.flatten(), log=True)

In [None]:
capacity = pattern_iba.capacity().cpu().detach().numpy()

In [None]:
capacity.mean()

In [None]:
plt.hist(capacity.flatten(), log=True)

In [None]:
plt.hist(capacity.sum(0).flatten())

In [None]:
plt.hist(pattern_iba.estimator.mean().cpu().numpy().flatten())

In [None]:
plt.hist(pattern_iba.estimator.std().cpu().numpy().flatten())

In [None]:
# PyTorch

In [None]:
# Add a Per-Sample Bottleneck at layer conv4_1
iba = IBA(model.features[17])

# Estimate the mean and variance of the feature map at this layer.
iba.estimate(model, val_loader, n_samples=5000, progbar=True)


# Closure that returns the loss for one batch
model_loss_closure = lambda x: -torch.log_softmax(model(x), dim=1)[:, target].mean()

# Explain class target for the given image
img, target = val_set[0]
saliency_map = iba.analyze(img.unsqueeze(0).to(dev), model_loss_closure, beta=10)

# display result
np_img = to_unit_interval(tensor_to_np_img(img))
plot_saliency_map(saliency_map, np_img)

In [None]:
idx, target