# Interpreting Neural Networks

## What?

Much of machine learning you have studied so far in the course revolves around loss function and evaluation metrics, i.e., objective and well-defined measures of performance goals.

Interpretability, however, is not easy to define objectively. [Miller (2017)](https://arxiv.org/abs/1706.07269) define interpretability to be the "degree to which a human can understand the cause of a decision (made by an ML model)". 

There are both merits and demerits to this definition, but this is a good starting point. The overall goal does seem to be, in some desired way, to make sense of how and why a model yields some output. But why do we want to do that? Turns out there are a lot of good reasons.

## Why?

Here are some reasons why we'd like to interpret AI:

- Practically, AI is being deployed in a lot of high-stake areas such as healthcare, law, defense, etc., and interpretability is important to improve the trust on a model, as well as to increase some responsibility. 
- 🧠: If a model incorrectly identifies a disease and suggests incorrect medication, who is at fault?
- To come up with better models. If we know why a model gives an incorrect answer, it is likely that we can "fill those gaps" more easily.
- To prevent long-term existential risk from advanced AI.
- 🧠: Read up on [Open Philanthropy](https://www.openphilanthropy.org/) and [Center for Long-Term Risk](https://longtermrisk.org/) and their work in this direction.
- As a somewhat long-term, meta goal, to understand intelligence and learning as a phenomenon in nature.

## How?

Neural networks are much harder to interpret because of the large number of parameters and the dense connections between them. What are some ways we can consider these huge million and billion and trillion parameter models and try to make sense of how they work? In all honesty, we haven't been able to do so in the most general sense. But a number of very insightful and innovative ways have been proposed over the last few years, some of which we'll try to explore today.

🧠: Let's say someone says this: "It's easy to understand how a network works. It learns some weights through backpropagation (which is well-defined) and then during inference it multiplies the learned weights with the input and gets the output. QED." What would you say? Have they interpreted the network? Hint: Does this trivial argument help us answer the questions we were asking about why we want to do interpretability? Why or why not?

🧠: Note that unlike a model for which we have an objective performance measure (accuracy), no such method exists for interpretability. That is, we don't know how to figure out if a particular interpretation is better than another. Naturally, it will also depend on the type of interpretation. We will look at some cases. Based on "why" we want to do interpretability, can you think of some properties that are desired of any "interpretation evaluator"?

Let's get started! We'll load a simple ResNet model trained on ImageNet-1K: (🧠: read up on these if needed, they are both very standard):

In [1]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import ResNet50_Weights
import plotly.graph_objs as go
import plotly.express as px
import plotly.subplots as sp
model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
categories = ResNet50_Weights.DEFAULT.meta['categories']
print(f'Loaded ResNet50 with {round(sum(p.numel() for p in model.parameters()) / 1000000, 1)}M parameters! Happy Interpreting!!')

  from .autonotebook import tqdm as notebook_tqdm


Loaded ResNet50 with 25.6M parameters! Happy Interpreting!!


## Saliency Maps!

One aspect of interpreting a model is to look at the input and ask the following: 

"What part(s) of this input X caused the model to predict Y?"

While it is arguably not answering how the entire computation happened, it still is pretty informative in some cases. For instance, interpreting a CNN model that counts the number of spots on a medical scan can tell us what locations it was looking at, which can tell us about the false positives or negatives directly.

"Saliency" is a fancy word that means "conspicuous", "apparent", or "important". In this context, a saliency map is a map from input pixels (in the case of images) to their saliency or importance toward a particular prediction by a trained model. We will look at a neat implementation of the basic idea in [this paper](https://arxiv.org/abs/1312.6034) (🧠).

TODO: add some simple theory for saliency.

In [2]:
# some pre-processing (note that this converts an image to a pytorch tensor)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
inv_normalize = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],std=[1/0.229, 1/0.224, 1/0.255])
transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),normalize])

In [3]:
def get_saliency(input, model):
    input.requires_grad = False
    model.eval()
    if type(input) != torch.Tensor: 
        input = transform(input)
        input.unsqueeze_(0)
    input.requires_grad = True
    preds = model(input)
    score, indices = torch.max(preds, 1)
    # get gradients w.r.t. image
    score.backward()
    #get max along channel axis
    saliency_map, _ = torch.max(torch.abs(input.grad[0]), dim=0)
    #normalize
    saliency_map = (saliency_map - saliency_map.min())/(saliency_map.max()-saliency_map.min())
    return saliency_map

In [11]:
from PIL import Image
img = Image.open('strawberry.jpeg').convert('RGB')
img = img.resize((224, 224))
img_array = np.array(img)
fig = go.Figure(go.Image(z=img_array))
fig.update_layout(title_text="Original Image", width=400, height=400)
fig.show()

In [12]:
vals, idxs = model(transform(img).unsqueeze_(0))[0].topk(5)
for i in range(5):  print(f'Prediction: {categories[idxs[i]]:<15}\t\tConfidence: {round(float(vals[i]), 2)}')

Prediction: strawberry     		Confidence: 6.68
Prediction: chocolate sauce		Confidence: 1.06
Prediction: strainer       		Confidence: 0.94
Prediction: scale          		Confidence: 0.92
Prediction: banana         		Confidence: 0.87


In [13]:
sal = get_saliency(transform(img).unsqueeze_(0), model)
fig2 = px.imshow(sal ** 0.6)
fig2.update_layout(title_text="Saliency Map for Haathi", width=400, height=400)

🧠: Plot Haathi and it's ResNet50 Saliency Map side-by-side (Plotly practice!). Try to find out all high-saliency areas and why they might be important for the elephant prediction.

🧠: Try it out on other images! You can take inspiration from a thousand categories that ImageNet1K has.

🧠: Mix two different images (just take the average of the RGB numbers). How does it look? How do the predictions look like? What does the saliency map look like? Is everything as expected?

🧠: Play around with the saliency function. Can you try to find out parts of this Elephant that make the model think that it is a dog? Or a spider?

## Visualizing CNNs!

Let us look at some parts of our learned model and try to figure out what they have learned.

In [14]:
model.conv1

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [15]:
model.conv1.weight.data.shape

torch.Size([64, 3, 7, 7])

In [16]:
conv1_visual = model.conv1.weight.data.permute(0, 2, 3, 1).reshape(64, 7, 7, 3) * 1000
fig = sp.make_subplots(rows=8, cols=8, subplot_titles=[f"Filter {i}" for i in range(64)])
for i in range(64):
    row, col = i // 8 + 1, i % 8 + 1
    fig.add_trace(go.Image(z=conv1_visual[i].numpy()), row=row, col=col)
fig.update_layout(title_text="CNN Visualization", showlegend=False, height=1000, width=1000)
fig.show()

In [17]:
x = transform(img).unsqueeze_(0)
x = model.conv1(x)
# x = model.bn1(x)
# x = model.relu(x)
# x = model.maxpool(x)
# x = model.layer1(x)
# x = model.layer2(x)
# x = model.layer3(x)
# x = model.layer4(x)
# x = model.avgpool(x)
# x = torch.flatten(x, 1)
# x = model.fc(x)

print(x.shape)

torch.Size([1, 64, 112, 112])


In [18]:
px.imshow(x[0][23].detach())

In [20]:
px.imshow(x[0][58].detach())

## What "concepts" do these filters learn to represent?

Let's try to consider one of them: Filter 58.



🧠: Look at some other interesting filters and try to figure out what they represent (get activated by), and then look at the output of the image when passed through those filters. Some pointers to interesting filters: 57, 58, 29.

🧠: Can you try to extend this practice to further layers? Can you try to visualize filters in layer 1 of the model and look at their activation output on Haathi? 

🧠: What sort of challenges would arise if one were to try and do this for the whole model?

🧠: Play around with [CNN Explainer](https://poloclub.github.io/cnn-explainer/).

## Circuits!

🧠: Read [Zoom In: An Introduction to Circuits](https://distill.pub/2020/circuits/zoom-in/) by Chris Olah and others.

We'll just go there and read a bit about circuits, and I'll leave the rest for you to explore.

## Mech. Interp.

This method of reverse-engineering neural networks by understanding what human-interpretable concepts their learned weights and corresponding circuits represent is called "mechanistic interpretability".

🧠^🧠: (Optional) After you've studied attention and transformers, feel free to read up and try our as much as your curiosity pushes to about Transformer Circuits. One awesome resource to get your hands dirty in interpreting language models is Arena Tutorials.

## Thank you!

- Hope you enjoyed this lecture and learned something new.

- Please go through all the TODOs marked with a "🧠". You will likely learn much more through them.

- For any further discussion or doubt or feedback, please feel free to contact me (from [7vik.io](https://7vik.io)). 

- Please do not distribute this notebook publicly though; this was heavily finetuned for the goals of this lecture.

- I'll be staying around a bit after the lecture in case anyone wants to talk! I'd love to hear some live feedback too.