# Interpreting CNNs

While deep learning, particularly CNNs, has achieved remarkable success across diverse applications like self-driving cars and facial recognition, a persistent lack of trust remains in their decision-making processes. Improving CNN interpretability is crucial for enhancing model performance and reliability.

This notebook focuses on introducing the Gradient-weighted Class Activation Mapping (Grad-CAM) technique, promising to enhance model interpretability and facilitate more transparent and trustworthy CNNs, leading to improved training strategies and model generalization across various applications.

The following cell runs setup code needed to get the Kaggle dataset for use in this notebook.

In [None]:
!pip -q install grad-cam
from google.colab import drive
drive.mount('/content/drive')

# Change this to the folder containing your Kaggle API key (kaggle.json)
%env KAGGLE_KEY_FOLDER=MDST/RvF
!mkdir data
!export KAGGLE_CONFIG_DIR=/content/drive/MyDrive/$KAGGLE_KEY_FOLDER && wget -O - "https://raw.githubusercontent.com/MichiganDataScienceTeam/W24-RvF/main/data/download.sh" | bash -s rvf10k

## Grad-CAM

Grad-CAM was developed in 2016 by researchers as a method to visualize _what_ a convolutional neural network is paying attention to when it makes a certain prediction. This can be useful - if the CNN is paying attention to the wrong things, then it would imply the model is not learning the correct patterns and trends, and that work needs to be done to combat the generalization error.

The `pytorch-grad-cam` package offers a neat implementation for producing Grad-CAM visualizations for any convolutional neural network implemented in PyTorch.

To get started, load a trained model into your notebook. For this example, we will use the `model_4.pt` file uploaded to the [media](https://github.com/MichiganDataScienceTeam/W24-RvF/blob/main/media) folder from GitHub, download it to this notebook environment. The following cell does this for us:

In [None]:
!wget https://github.com/MichiganDataScienceTeam/W24-RvF/raw/main/media/model_4.pt
!mkdir -p checkpoints/Net && mv model_4.pt checkpoints/Net
!wget https://raw.githubusercontent.com/MichiganDataScienceTeam/W24-RvF/main/starter_code/train.py

import torch
from train import load_model

class Net(torch.nn.Module):
    def __init__(self):
        """Constructor for the neural network."""
        super(Net, self).__init__()        # Call superclass constructor
        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv4 = torch.nn.Conv2d(in_channels=32, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = torch.nn.ReLU()
        self.flatten = torch.nn.Flatten()
        self.fc = torch.nn.Linear(2048, 2)

    def forward(self, x):
        z1 = self.conv1(x)
        h1 = self.relu(z1)
        p1 = self.pool(h1)

        z2 = self.conv2(p1)
        h2 = self.relu(z2)
        p2 = self.pool(h2)

        z3 = self.conv3(p2)
        h3 = self.relu(z3)
        p3 = self.pool(h3)

        z4 = self.conv4(p3)
        h4 = self.relu(z4)
        p4 = self.pool(h4)

        flat = self.flatten(p4)
        z = self.fc(flat)

        return z

model = Net()
load_model(model, "checkpoints", 4, map_location="cpu")

🚨 **CRITICAL** 🚨: For this tutorial, we downloaded the weights that were from training the architecture defined. For your model, you will have to load the model that

From here, running visualizations on the dataset isn't very difficult! The following cell will:
- Load 100 images of fake faces
- Use our model to evaluate what the model pays attention to in _each_ of those fake faces when training the image via the Grad-CAM method
- Visualize these regions on the input image to give us intuition

In [None]:
import cv2
import numpy as np
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from pathlib import Path
from tqdm import tqdm
import torchvision.transforms as tf

visualizations = []
cam = GradCAM(model=model, target_layers=[model.conv2, model.conv3])

# Set Class 0 = figure out what parts of the image make the model think face is fake
targets = [ClassifierOutputTarget(0)]

for image_path in tqdm(Path("./data/rvf10k/train/fake").iterdir()):
    image = cv2.imread(str(image_path))
    pipeline = tf.Compose([
        tf.ToTensor(),
        tf.ConvertImageDtype(torch.float32),
    ])
    image = np.float32(image) / 255
    input_tensor = pipeline(image).view(1,3,256, 256)

    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
    visualization = show_cam_on_image(image, grayscale_cam[0, :], use_rgb=True)
    visualizations.append(visualization)
    if len(visualizations) == 100:
        break

We can now take a look at the visualized images to see patterns the model is paying attention to predict a fake face.

In [None]:
import matplotlib.pyplot as plt

plt.imshow(visualizations[23])

The lighter the color, the more the model pays attention to that part of the input image. In the example above, the model seems to be paying attention mainly to the forehead and the nose - which seems reasonable for identifying a fake face.

However GradCAM can also show that sometimes the model doesn't always try and fit to the right trends!

In [None]:
plt.imshow(visualizations[0])

This visualization is more interesting - for some reason, our model has identified the background as something that may influence a fake face ... looks like this model is **overfitting**!

In [None]:
plt.imshow(visualizations[30])

Here to, the model seems to be looking more at the background and the edge of the image rather than the actual face - this also might indicate overfitting!

Explore some more! Can you find some more generic trends in how the model is overfitting (if at all?) And can you adapt this to your _own_ model to determine what your model is learning?