# Collecting Samples for Activation Atlases with captum.optim

This notebook demonstrates how to collect the activation and corresponding attribution samples required for [Activation Atlases](https://distill.pub/2019/activation-atlas/) for the InceptionV1 model imported from Caffe.

In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
import torchvision

from captum.optim.models import googlenet

import captum.optim as opt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Dataset Download & Setup 

To begin, we'll need to download and setup the image dataset that our model was trained on. You can download ImageNet's ILSVRC2012 dataset from the [ImageNet website](http://www.image-net.org/challenges/LSVRC/2012/) or via BitTorrent from [Academic Torrents](https://academictorrents.com/details/a306397ccf9c2ead27155983c254227c0fd938e2).

In [None]:
collect_attributions = True  # Set to False for no attributions

# Setup basic transforms
# The model has the normalization step in its internal transform_input
# function, so we don't need to normalize our inputs here.
transform_list = [
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
]
transform_list = torchvision.transforms.Compose(transform_list)

To make it easier to load the ImageNet dataset, we can use [Torchvision](https://pytorch.org/vision/stable/datasets.html#imagenet)'s `torchvision.datasets.ImageNet` instead of the default `ImageFolder`.

In [None]:
# Load the dataset
image_dataset = torchvision.datasets.ImageNet(
    root="path/to/dataset", split="train", transform=transform_list
)

Now we wrap our dataset in a `torch.utils.data.DataLoader` instance, and set the desired batch size.

In [None]:
# Set desired batch size & load dataset with torch.utils.DataLoader
image_loader = torch.utils.data.DataLoader(
    image_dataset,
    batch_size=32,
    shuffle=True,
)

We load our model, then set the desired model target layers and corresponding file names.

In [None]:
# Model to collect samples from, what layers of the model to collect samples from,
# and the desired names to use for the target layers.
sample_model = (
    googlenet(
        pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True
    )
    .eval()
    .to(device)
)
sample_targets = [sample_model.mixed4c_relu]
sample_target_names = ["mixed4c_relu_samples"]

By default the activation samples will not have the right class attributions, so we remedy this by loading a second instance of our model. We then replace all `nn.MaxPool2d` layers in the second model instance with Captum's `MaxPool2dRelaxed` layer. The relaxed max pooling layer lets us estimate the sample class attributions by determining the rate at which increasing the neuron affects the output classes.

In [None]:
# Optionally collect attributions from a copy of the first model that's
# been setup with relaxed pooling layers.
if collect_attributions:
    sample_model_attr = (
        googlenet(
            pretrained=True, replace_relus_with_redirectedrelu=False, bgr_transform=True
        )
        .eval()
        .to(device)
    )
    opt.models.replace_layers(
        sample_model_attr,
        torch.nn.MaxPool2d,
        opt.models.MaxPool2dRelaxed,
        transfer_vars=True,
    )
    sample_attr_targets = [sample_model_attr.mixed4c_relu]
    sample_logit_target = sample_model_attr.fc
else:
    sample_model_attr = None
    sample_attr_targets = None
    sample_logit_target = None

With our dataset loaded and models ready to go, we can now start collecting our samples. To make sample collection easier, we can use Captum's `capture_activation_samples` function to randomly sample an x and y position for every image for all specified target layers.

In [None]:
# Directory to save sample files to
sample_dir = "inceptionv1_samples"
try:
    os.mkdir(sample_dir)
except:
    pass

# Collect samples & optionally attributions as well
opt.dataset.capture_activation_samples(
    loader=image_loader,
    model=sample_model,
    targets=sample_targets,
    target_names=sample_target_names,
    attr_model=sample_model_attr,
    attr_targets=sample_attr_targets,
    input_device=device,
    sample_dir=sample_dir,
    show_progress=True,
    collect_attributions=collect_attributions,
    logit_target=sample_logit_target,
)

Now that we've collected our samples, we need to combine them into a single tensor. Below we use the `consolidate_samples` function to load each list of tensor samples, and then concatinate them into a single tensor.

In [None]:
# Combine our newly collected samples into single tensors.
# We load the sample tensors from sample_dir and then
# concatenate them.

for name in sample_target_names:
    activation_samples = opt.dataset.consolidate_samples(
        sample_dir=sample_dir,
        sample_basename=name + "_activations",
        dim=1,
        show_progress=True,
    )
    if collect_attributions:
        sample_attributions = opt.dataset.consolidate_samples(
            sample_dir=sample_dir,
            sample_basename=name + "_attributions",
            dim=0,
            show_progress=True,
        )

    # Save the results
    torch.save(activation_samples, name + "activation_samples.pt")
    if collect_attributions:
        torch.save(sample_attributions, name + "attribution_samples.pt")