# PyTorch: IBA (Per-Sample Bottleneck)

This notebook shows how to apply the Per-Sample Bottleneck to pretrained ImageNet models. 

Ensure that `./imagenet` points to your copy of the ImageNet dataset. 

You might want to create a symlink:

In [None]:
# ! ln -s /path/to/your/imagenet/folder/ imagenet 

In [None]:
%load_ext autoreload
%autoreload 2
%env CUDA_VISIBLE_DEVICES=1

In [None]:
from IBA.tensorflow_v1 import IBACopyInnvestigate, model_wo_softmax, get_imagenet_generator
from IBA.utils import load_monkeys, plot_saliency_map
from keras.applications.vgg16 import VGG16, preprocess_input

from assert_cache import assert_cache, get_asserted_values
import numpy as np
from IBA.utils import to_unit_interval

import keras.backend as K
from tqdm.auto import tqdm as tqdmbar

import matplotlib.pyplot as plt

In [None]:
def assert_(key, obj, assertion_fn, message_fn=None):
    assert_cache("tensorflow", key, obj, assertion_fn, message_fn)
    
assert_("1", 2, 
        lambda a, b: a == b, 
        lambda a, b: "seriously? {} != {}".format(a, b))

In [None]:
# load model
model_softmax = VGG16(weights='imagenet')

# remove the final softmax layer
model = model_wo_softmax(model_softmax)

# select layer after which the bottleneck will be inserted
feat_layer = model.get_layer(name='block4_conv1')



In [None]:

from torchvision import transforms 

class PatternTransform(object):
    # only work for VGG16
    def __init__(self):
        self.scale = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
        ])
        self.offset = np.array([103.939, 116.779, 123.68])[:, np.newaxis, np.newaxis]

    def __call__(self, raw_img):
        scaled_img = self.scale(raw_img)
        ret = np.array(scaled_img, dtype=np.float)
        # Channels first
        ret = ret.transpose((2, 0, 1))
        # Remove pixel-wise mean.
        # To BGR.
        ret = ret[::-1, :, :]
        ret -= self.offset
        return np.ascontiguousarray(ret.transpose(1, 2, 0))
    
def np_collate(batch):
    imgs = [b[0] for b in batch]
    targets = [b[1] for b in batch]
    return np.stack(imgs), np.stack(targets)

In [None]:
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose, Resize, Lambda


def get_imagenet_folder(path, image_size=224, transform='default'):
    """
    Returns a ``torchvision.datasets.ImageFolder`` with the default
    torchvision preprocessing.
    """
    from torch.utils.data import DataLoader
    from torchvision.datasets import ImageFolder
    from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, Normalize
    if transform == 'default':
        transform = Compose([
            CenterCrop(256), Resize(image_size), ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
    return ImageFolder(path, transform=transform)


val_set = get_imagenet_folder('/srv/public/leonsixt/data/imagenet/validation')

pattern_val_set = get_imagenet_folder(
    '/srv/public/leonsixt/data/imagenet/validation', 
    transform=PatternTransform(),
)

In [None]:
pattern_val_loader = DataLoader(pattern_val_set, batch_size=50, 
                                shuffle=False, num_workers=4,
                                collate_fn=np_collate)

pattern_val_loader_shuffle = DataLoader(pattern_val_set, batch_size=50, 
                                shuffle=True, num_workers=4,
                                collate_fn=np_collate)
imgs, logits = next(iter(pattern_val_loader))
imgs2, logits = next(iter(pattern_val_loader))

In [None]:
imgs.shape

In [None]:
(imgs == imgs2).all()

## Check models are equal

In [None]:
assert_("first_image_batch_equal", imgs, 
        lambda a, b: np.abs((a - b)).mean() < 1e-4,
        lambda a, b: np.abs((a - b)).mean())

In [None]:

transpose_idxs = [26, 28, 30]        

npz_weights = np.load("vgg_16_weights.npz") 

for i, weight in enumerate(model.weights):
    arr = npz_weights['arr_' + str(i)]
    arr = arr.T
    if i in transpose_idxs:
        arr = arr.T
    if len(arr.shape) == 4:
        arr = arr.transpose(1, 0, 2, 3)
    #print('l', arr.shape)
    #print('w', weight.shape)
    diff = np.abs(weight.eval(K.get_session()) - arr).mean()
    if diff != 0:
        plt.figure(figsize=(10, 5))
        plt.imshow(weight.eval(K.get_session()).T)
        plt.show()
        
        plt.figure(figsize=(10, 5))
        plt.imshow(arr.T)
        print(i, diff, arr.std())

In [None]:
for framework, val in get_asserted_values("first_image_batch_equal").items():
    print(framework, val.mean(), val.std())
    print(framework, val.min(), val.max())
    plt.imshow(to_unit_interval(val[0]))
    plt.show()

In [None]:
outputs = model.predict(imgs)

In [None]:
assert_("first_batch_outputs_equal", outputs, 
        lambda a, b: np.abs((a - b)).mean() < 1e-4,
        lambda a, b: np.abs((a - b)).mean())

In [None]:
for name, val in get_asserted_values("first_batch_outputs_equal").items():
    print(name, val.min(), val.max())
    plt.hist(val.flatten())
    plt.show()

In [None]:
correct = []
logits = []
progbar = tqdmbar(pattern_val_loader)
for img, target in progbar:
    logit = model.predict(img)
    correct.append(np.argmax(logit, 1) == target)
    logits.append(logit)
    progbar.set_postfix(acc=np.concatenate(correct).mean())
    #if len(logits) == 100:
    if len(logits) == 1:
        break
logits = np.concatenate(logits)
correct = np.concatenate(correct)

In [None]:
# assert_("corrects_equal", correct, 
#         lambda a, b: (a == b).mean(),
#         lambda a, b: (a, b))

## IBA

In [None]:
feat_layer.kernel.shape

In [None]:
# copies the model
iba = IBACopyInnvestigate(
    model,
    neuron_selection_mode='index',
    feature_name=feat_layer.output.name,
)

In [None]:
# estimate feature mean and std
n_samples = 50
iba.fit_generator(pattern_val_loader, 
                  steps_per_epoch=n_samples // pattern_val_loader.batch_size)


assert_('estimated_mean_1', iba._estimator.mean(),
        lambda a, b: np.abs(a - b).mean() < 1e-4,
        lambda a, b: (a.mean(), b.mean()))

In [None]:
for i in range(1, 10):
    iba._estimator.reset()
    iba.fit_generator(pattern_val_loader, 
                      steps_per_epoch=i, verbose=0)

    assert_(f'estimated_mean_{i*pattern_val_loader.batch_size}',
            iba._estimator.mean(),
            lambda a, b: np.abs(a - b).mean() < 1e-3,
            lambda a, b: (a.mean(), b.mean()))

In [None]:
estim_means = get_asserted_values('estimated_mean')
plt.hist((estim_means['tensorflow'] - estim_means['pytorch']).flatten())

In [None]:
plt.hist(estim_means['tensorflow'].flatten())

In [None]:
# estimate feature mean and std
n_samples = 5000
iba._estimator.reset()
iba.fit_generator(pattern_val_loader_shuffle, 
                  steps_per_epoch=n_samples // pattern_val_loader.batch_size)

In [None]:
monkeys, target = load_monkeys()

In [None]:
iba.set_default(beta=10, min_std=0, smooth_std=0, steps=10)
iba.collect_all()

In [None]:
monkeys_scaled =  preprocess_input(monkeys)

# get the saliency map and plot
saliency_map = iba.analyze(monkeys_scaled[None], neuron_selection=target)
plot_saliency_map(saliency_map, img=monkeys)

In [None]:
report = iba.get_report()

In [None]:
report['init']['grad_loss_wrt_alpha'].shape

In [None]:
plt.hist(report['init']['grad_loss_wrt_alpha'].flatten(), bins=20, log=True)

In [None]:
assert_('grad_alpha_0',
        report[0]['grad_loss_wrt_alpha'], 
        lambda s, o: np.abs(s-o).mean() < 1e-6,
        lambda s, o: (s.mean(), s.std(), o.mean(), o.std())
       )

In [None]:
np.isnan(report[0]['grad_loss_wrt_alpha']).sum()

In [None]:
[it['model_loss'] for it in report.values()]

In [None]:
report[9]['information_loss']

In [None]:
capacity = report['final']['capacity']
print(np.isnan(capacity).sum())
capacity.shape
plt.imshow(np.nansum(capacity[0], -1))
plt.colorbar()

In [None]:
capacity = report[0]['capacity_no_nans']
np.isnan(capacity).sum()
capacity.shape
plt.imshow(np.isnan(capacity[0]).sum(-1))

In [None]:
mean = iba._estimator.mean()

In [None]:
plt.imshow(mean.sum(-1))
plt.colorbar()

In [None]:
mean = report['init']['feature_mean'][0]

In [None]:
mean.shape

In [None]:
active = iba._active_neurons.eval(iba._session)
active.shape

In [None]:
plt.imshow((1-active[0]).sum(-1))
plt.colorbar()

In [None]:
import matplotlib.pyplot as plt

In [None]:
model.summary()

In [None]:
import numpy as np

In [None]:
restrict_mask = 1 - report['init']['pass_mask']
restrict_mask.sum() == np.prod(restrict_mask.shape)

In [None]:
report['init']['capacity'].mean()

In [None]:
report['final']['capacity'][0].mean()

In [None]:
plt.hist(report['final']['capacity'][0].flatten(), log=True)

In [None]:
plt.hist(iba._estimator.mean().flatten())

In [None]:
plt.hist(iba._estimator.std().flatten())

In [None]:
iters = list(report.keys())
plt.plot(iters, [vals['information_loss'] for it, vals in report.items()], label='info')
plt.plot(iters, [vals['model_loss'] for it, vals in report.items()], label='model')
plt.legend()

In [None]:
plt.hist(report['final']['alpha'].flatten(), log=True)

In [None]:
report['final']['capacity'][0].sum(-1)

In [None]:
import matplotlib.pyplot as plt
plt.hist(monkeys_scaled.flatten())

In [None]:
from skimage.color import rgb2gray

In [None]:
from IBA.utils import load_monkeys
import matplotlib.pyplot as plt

In [None]:
plt.imshow(load_monkeys()[0])