# Feature Visualization Learning Journey

After reading [Zoom In: An Introduction to Circuits](https://distill.pub/2020/circuits/zoom-in/) I was quite interested in how some of the visuals were generated! It seemed like the authors were "asking" the model what it was looking for, and I wanted to know how I could have a similar "conversation" with a model! With this motivation, I started trying to generate visualizations similar to those found in the paper, however, I don't have much experience with PyTorch or interacting with neural networks, so I encountered some issues while doing so, and I figured a brief description of the "journey" I took could help others starting from a similar place as I get started on conversing with models too!

Note: I won't cover neural network basics, so it may be best to learn the basics before following the steps outlined here!

## Boring Stuff

Just some boring imports to get started!

In [None]:
import random

import torch
import torchvision
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

And here are some basic parameters used throughout the notebook. Also pretty boring.

In [None]:
img_size = 224
num_channels = 3
num_itrs = 512
device = "cpu"

The feature visualizations all start out as random images which we will generated like so:

In [None]:
img = torch.randn(1, num_channels, img_size, img_size).to(device)

Now, to view images we'll use the below utility function:

In [None]:
def show_img(img):
    img = img.detach().to(device).squeeze(dim=0).permute(1, 2, 0)
    img = torch.clamp(img, 0, 1)
    plt.imshow(img)
    plt.axis('off')
    plt.show()

Now let's visualize the image we generated above!

In [None]:
show_img(img)

Okay sweet! Now that we can generate and view images, let's get to generating feature visualizations.

## Load Model

Here's the first interesting bit! Here we're loading the model which we'll use throughout the notebook. The model we'll use is [ResNet18](https://pytorch.org/hub/pytorch_vision_resnet/) ([original paper](https://arxiv.org/abs/1512.03385)) which is pretrained on ImageNet. The techniques should apply to any pretrained image model, but we will use ResNet18 because it's a small model, and thus runs reasonably quickly.

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
for param in model.parameters():
    param.requires_grad_(False)

## Basic Feature Visualization

To visualize features we'll iteratively "tweek" (adjust based on the gradients) the image so it activates a part of the network more. To do this we'll perform the following steps for each iteration:
1. Perform a forward pass of the model with the current image
2. Compute the mean activation of the channel of interest (this acts as our loss function)
3. Calculate the gradients
4. Update the image using the gradients of the loss function

The first thing we need to perform the above steps is a way to get the output of the layer we wish to visualize. We will store the outputs in a dictionary:

In [None]:
activations = {}

And we will use a hook to capture the outputs (we'll [register the hook](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook) to the forward pass of the layer of interest later):

In [None]:
def get_layer_hook(layer_name):
    def hook(module, input, output):
        activations[layer_name] = output
    return hook

Now that we have a way to capture the outputs, we need a way to compute the mean activation:

In [None]:
def get_loss_fcn(layer, channel_idx):
    def loss_fcn():
        if (activations[layer].dim() == 2):
            layer_activations = -activations[layer][:, channel_idx]
        else:
            layer_activations = -activations[layer][:, channel_idx, :, :]
        return layer_activations.mean()
    return loss_fcn

With the loss function above, we can write a function for a single optimization step:

In [None]:
def opt_step(model, img, optimizer, loss_fcn):
    optimizer.zero_grad()

    # 1. Perform a forward pass of the model with the current image
    model(img)

    # 2. Compute the mean activation of the channel of interest (this acts as our loss function)
    loss = loss_fcn()
    # 3. Calculate the gradients
    loss.backward()

    # 4. Update the image using the gradients of the loss function
    optimizer.step()

Let's put it all together!

In [None]:
def opt_img(model, img, layer_name, channel_idx):
    # Register the hook for capturing the layer output
    layer = getattr(model, layer_name)
    handle = layer.register_forward_hook(get_layer_hook(layer_name))

    model.eval()
    # Run a forward pass of the model to populate the activations dictionary.
    # This is required so we can enable the gradients on the layer output.
    pred = model(img)
    
    # I initially didn't include this, and it took me a little while to figure out why
    # my images were not updating XD.
    activations[layer_name] = activations[layer_name].requires_grad_().to(device)

    # We'll use the Adam optimizer
    optimizer = torch.optim.Adam([img], lr=0.05)

    # Get loss function for the layer and channel of interest
    loss_fcn = get_loss_fcn(layer_name, channel_idx)

    # Iteratively tweek the image
    for i in range(num_itrs):
        opt_step(model, img, optimizer, loss_fcn)

    handle.remove()
    return img

Now that we have our basic feature visualization generator completed we can create our first visualization! We'll start by visualizing the output of one of the neurons in the last layer of the model because these neurons are often the easiest to interpret (each neuron corresponds to one of the output classes). First we need to find the name of the last layer!

In [None]:
print(list(map(lambda x: x[0], model.named_children())))

Okay, it's "fc". Let's generate an image!

In [None]:
img = torch.randn(1, num_channels, img_size, img_size, requires_grad=True).to(device)
out_img = opt_img(model, img, "fc", 9)
show_img(out_img)

This is probably not what you were hoping for. I found myself somewhat stuck at this point! The visualizations seemed to have some order to them, but they still seemed like mostly noise. After a little while of trying to figure out why the visualizations weren't nearly as interesting as visualizations I had seen others generate, I found a particularly useful section in a [feature visualization paper](https://distill.pub/2017/feature-visualization/): [The Enemy of Feature Visualization](https://distill.pub/2017/feature-visualization/#enemy-of-feature-vis). The technique from the paper I tried was slightly transforming the image before each optimization step, so let's see how that impacts the visualizations.

## Transformation Robustness

In order to transform the image before the optimization step we need to modify our `opt_img` function a bit! So let's do that:

In [None]:
def opt_img(model, img, layer_name, channel_idx, transforms=None):
    layer = getattr(model, layer_name)
    handle = layer.register_forward_hook(get_layer_hook(layer_name))

    model.eval()
    pred = model(img)
    activations[layer_name] = activations[layer_name].requires_grad_().to(device)

    optimizer = torch.optim.Adam([img], lr=0.05)

    loss_fcn = get_loss_fcn(layer_name, channel_idx)

    for i in range(num_itrs):
        # Here's the new bit!
        if transforms is not None:
            tform_img = transforms(img)
        else:
            tform_img = img
        opt_step(model, tform_img, optimizer, loss_fcn)

    handle.remove()
    return img

The paper includes three different transformations: jittering, rotation, and scaling. Let's use all three!

In [None]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(img_size, padding=random.randint(0, 8)), # jitter
    torchvision.transforms.RandomRotation((-45, 45)), # rotate
    torchvision.transforms.RandomResizedCrop(img_size, scale=(0.9, 1.2), ratio=(1.0, 1.0)) # scale
])

In [None]:
img = torch.randn(1, num_channels, img_size, img_size, requires_grad=True).to(device)
out_img = opt_img(model, img, "fc", 9, transforms)
show_img(out_img)

This seems much more interesting! Can you figure out what it is?

It's an ostrich! It doesn't look *exactly* like an ostrich, but you can at least see that it has some ostrich features: you can kind of make out the long legs; the dark, round body; and the long neck!

This visualization is much more interesting than the one before! Now that we can generate visualizations like this, let's explore some of the hidden layers.

## Hidden Layer Visualization

Let's start out by generating visualizations without transformations, so we can see the impact of transformations on hidden layers!

In [None]:
img = torch.randn(1, num_channels, img_size, img_size, requires_grad=True).to(device)
out_img = opt_img(model, img, "layer2", 71)
show_img(out_img)

There's definitely some order to this image, but it still seems like there's a lot of noise, so let's now try with transformations.

In [None]:
img = torch.randn(1, num_channels, img_size, img_size, requires_grad=True).to(device)
out_img = opt_img(model, img, "layer2", 71, transforms)
show_img(out_img)

Again, it's much better!

## Where can this go?

This is great! We were able to generate some feature visualizations! With the above code we can generate visualizations for different parts of the model, and begin to "ask" the model what it's looking for so we can interpret what the model knows, and how it makes it's decisions. To further improve the quality of the images there are a number of other techniques in the [Feature Visualizations](https://distill.pub/2017/feature-visualization) paper that can be applied. In addition to improving the quality of the visualizations, they can be used together with other techniques to improve our understanding of neural networks. Examples of this can be seen in [this](https://distill.pub/2018/building-blocks/) paper, and in other papers. The feature visualizations we generated are just the start of interpreting image models! While generating the feature visualizations is interesting, what can be done with the visualizations is even more so!