# Influence functions for Computer vision

This notebook explores the use of influences functions for convolutional neural networks. In the first part we will fine-tune Resnet18 on the [tiny-imagenet dataset](https://huggingface.co/datasets/Maysee/tiny-imagenet). This dataset was first created for the [Stanford Deep Learning for Computer Vision](http://cs231n.stanford.edu/) course, and it contains a subset of the [famous ImageNet dataset](https://image-net.org/challenges/LSVRC/2012/index), (200 classes vs 1000, and images are downsampled to a lower-resolution, from 64x64 pixels to 256x256). 

After training the last layers of the network, we will use pyDVL to find the most and least influential points on the evaluation images. This can be used e.g. to explain errors in the inference of new images or to direct efforts for collecting new data. In the last part of the notebook we will also see that influence functions are an effective tool for finding anomalous or corrupted data points.

If you want to know more about the mathematical foundations of influence functions for neural networks, you can find a primer in the appendix to this notebook.

Let's now proceed with the code!

## Imports

In [None]:
%load_ext autoreload

In [None]:
%autoreload
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from copy import deepcopy
from .notebook_support import (
    load_model,
    save_model,
    load_results,
    save_results,
    plot_sample_images,
    plot_top_bottom_if_images,
    plot_train_val_loss,
    get_corrupted_imagenet,
    plot_influence_distribution,
)

import os
import torch
from torch.optim import Adam
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from pydvl.utils.dataset import load_preprocess_imagenet
from pydvl.influence.model_wrappers import TorchModel
from pydvl.influence.general import compute_influences
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score

In [None]:
plt.rcParams["font.size"] = 12
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 10

## Constants

In [None]:
random_state = 42
is_CI = os.environ.get("CI")

run_model_trainings = True
calculate_influences = True

In [None]:
np.random.seed(random_state)

## Loading and preprocessing the dataset

The dataset is loaded through the load_preprocess_imagenet method. We will load the images related to the classes 90 and 100. This is an arbitrary choice and any other class would have worked. You can try selecting any other set of numbers, even more than just two (you could even select all 200 classes, though this will require longer training times).

In [None]:
labels_to_keep = [90, 100]
train_ds, val_ds, test_ds = load_preprocess_imagenet(
    train_size=0.8,
    test_size=0.1,
    keep_labels=labels_to_keep,
    is_CI=is_CI,
)

Now that we have loaded the data, let's take a look at a sample of the images

In [None]:
plot_sample_images(train_ds, labels_to_keep, n_images_per_class=3)

The first class is related to dining tables, the second to boats and to Venice! Let's now further pre-process the data and prepare for model training

In [None]:
ds_label_to_model_label = {ds_label: idx for idx, ds_label in enumerate(labels_to_keep)}
model_label_to_ds_label = {idx: ds_label for idx, ds_label in enumerate(labels_to_keep)}


def process_io(x, y):
    x_nn = torch.stack(x.tolist())
    y_nn = [ds_label_to_model_label[yi] for yi in y]
    return x_nn, y_nn


train_x, train_y = process_io(train_ds["normalized_images"], train_ds["labels"])
val_x, val_y = process_io(val_ds["normalized_images"], val_ds["labels"])
test_x, test_y = process_io(test_ds["normalized_images"], test_ds["labels"])

## Model definition

In this part we will proceed with the initialization of the model and of some helper methods for training and evaluation. 

The model is defined by loading resnet18 and then switching the last few layers so that we can do binary classification on our selected classes.

In [None]:
def initialize_model(output_size):
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False

    # Finetune final few layers
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, output_size)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model


ce_loss = nn.CrossEntropyLoss()
model_ft = initialize_model(output_size=len(labels_to_keep))

Training is done through some pytorch convenience wrappers (TorchModel) which are part of pyDVL.

In [None]:
def train_model(model, num_epochs, training_data, lr=0.001):
    optimizer = Adam(model.parameters(), lr=lr)

    process_io(training_data["normalized_images"], training_data["labels"])
    train_x, train_y = process_io(
        training_data["normalized_images"], training_data["labels"]
    )

    train_loss, val_loss = TorchModel(model=model).fit(
        x_train=train_x,
        y_train=train_y,
        x_val=val_x,
        y_val=val_y,
        loss=ce_loss,
        optimizer=optimizer,
        num_epochs=num_epochs,
        batch_size=1000,
    )
    return train_loss, val_loss

## Model training and influence computation

We will train the model for 50 epochs and save the results. Then we will plot the train and validation loss curves.

In [None]:
if run_model_trainings:
    num_epochs = 50
    train_loss, val_loss = train_model(
        model_ft, num_epochs=num_epochs, training_data=train_ds
    )
    save_model(model_ft, train_loss, val_loss, model_name="model_ft")
else:
    train_loss, val_loss = load_model(model_ft, model_name="model_ft")

In [None]:
plot_train_val_loss(train_loss, val_loss)

Confusion matrix and f1 score are good, especially considering the low resolution of the images and their complexity (large diversity of objects)

In [None]:
pred_y_test = np.argmax(model_ft(test_x).detach(), axis=1)

cm = confusion_matrix(test_y, pred_y_test)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels_to_keep)
disp.plot();

In [None]:
f1_score(test_y, pred_y_test, average="weighted")

And now let's calculate influences. The method compute_influences will take the trained model, a loss (which typically is the training loss, but not necessarily), some input dataset with labels (which typically are the training data, or a subset of it) and some test data (which in this case will be the validation set). 

Other important parameters are the hessian regularization term, which should be chosen as small as possible for the computation to converge. Details on why this is important can be found in the pyDVL documentation or in the [original paper](https://arxiv.org/pdf/1703.04730.pdf). 

Since the Resnet18 is quite big, the inversion methods that should be preferred is conjugate gradient ("cg"). The direct method would require a lot of memory. Finally, the influence type will be "up" (the other option, "perturbation", is beyond the scope of this notebook, but more info can be found in the influence_wine notebook or on the pyDVL documentation).

The output of calculate_influences is a matrix of size (validation_set_length, training_set_length). Each row represents a validation data-point, and each column a training data-point. Each entry (i,j) represents the influence of training point j on the validation point i.

In [None]:
if calculate_influences:
    influences = compute_influences(
        model=model_ft,
        loss=ce_loss,
        x=train_x,
        y=train_y,
        x_test=val_x,
        y_test=val_y,
        hessian_regularization=1e-3,
        inversion_method="cg",
        influence_type="up",
    )
    save_results(influences, file_name="influences.pkl")
else:
    influences = load_results(file_name="influences.pkl")

## Analysing the influence on validation images

Let's take an image in the validation set. Among the images in the training set, we will take those that have the same label and visualize those that have the highest and lowest influence. 

In [None]:
val_image_idx = 118
plt.rcParams["figure.figsize"] = (5, 5)
plt.imshow(val_ds["images"][val_image_idx])

In [None]:
print(
    "Predicted label:",
    model_label_to_ds_label[
        np.argmax(model_ft(val_x[val_image_idx].unsqueeze(0)).detach(), axis=1).item()
    ],
)
print("Real label:", val_ds["labels"][val_image_idx])

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
for label in labels_to_keep:
    plt.hist(
        influences[val_image_idx][train_ds["labels"] == label], label=label, alpha=0.7
    )
plt.legend()
plt.show()

In [None]:
images_with_same_label = train_ds["labels"] == val_ds["labels"][val_image_idx]
if_same_label = influences[val_image_idx][images_with_same_label]
imges_same_label = train_ds["images"][images_with_same_label].values
plot_top_bottom_if_images(if_same_label, subset_images=imges_same_label, num_to_plot=3)

In [None]:
avg_influences = np.mean(influences, axis=0)

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
for label in labels_to_keep:
    plt.hist(avg_influences[train_ds["labels"] == label], label=label, alpha=0.7)
plt.legend()
plt.show()

In [None]:
label = 90
img_with_selected_label = train_ds["labels"] == label
if_selected_label = avg_influences[img_with_selected_label]
imges_same_label = train_ds["images"][img_with_selected_label].values
plot_top_bottom_if_images(if_selected_label, imges_same_label, num_to_plot=3)

## Calculating the influence of corrupted training data

In [None]:
model_corrupted = initialize_model(output_size=len(labels_to_keep))
corrupted_dataset, corrupted_indices = get_corrupted_imagenet(
    dataset=train_ds,
    labels_to_keep=labels_to_keep,
    fraction_to_corrupt=0.1,
    avg_influences=avg_influences,
)

if run_model_trainings:
    num_epochs = 50
    train_loss, val_loss = train_model(
        model_corrupted,
        num_epochs=num_epochs,
        training_data=corrupted_dataset,
    )
    save_model(model_corrupted, train_loss, val_loss, model_name="model_corrupted")
else:
    train_loss, val_loss = load_model(model_corrupted, model_name="model_corrupted")

In [None]:
plot_train_val_loss(train_loss, val_loss)

In [None]:
pred_y_test = np.argmax(model_corrupted(test_x).detach(), axis=1)
model_score = f1_score(test_y, pred_y_test, average="weighted")
print(model_score)

In [None]:
if calculate_influences:
    corrupted_train_x, corrupted_train_y = process_io(
        corrupted_dataset["normalized_images"],
        corrupted_dataset["labels"],
    )
    influences = compute_influences(
        model=model_corrupted,
        loss=ce_loss,
        x=corrupted_train_x,
        y=corrupted_train_y,
        x_test=val_x,
        y_test=val_y,
        hessian_regularization=1e-3,
        inversion_method="cg",
        influence_type="up",
    )
    save_results(influences, file_name="influences_corrupted.pkl")
else:
    influences = load_results(file_name="influences_corrupted.pkl")

In [None]:
avg_corrupted_influences = np.mean(influences, axis=0)

In [None]:
label = 100
img_with_selected_label = corrupted_dataset["labels"] == label
if_selected_label = avg_corrupted_influences[img_with_selected_label]
imges_same_label = corrupted_dataset["images"][img_with_selected_label].values
plot_top_bottom_if_images(if_selected_label, imges_same_label, num_to_plot=3)

In [None]:
avg_label_influence = plot_influence_distribution(
    corrupted_dataset, labels_to_keep, corrupted_indices, avg_corrupted_influences
)

In [None]:
avg_label_influence

## Appendix: Theory of Influence functions for neural networks

In this appendix we will briefly go through the essential formulas that lay the foundations of influnce functions for neural networks. A more in-depth and expanded analysis can be found on the original paper: ["Understanding Black-box Predictions via Influence Functions"](https://arxiv.org/pdf/1703.04730.pdf).

### Upweighting points

Let's start by considering some input space $\mathcal{X}$ to a model (e.g. images) and an output space $\mathcal{Y}$ (e.g. labels). Let's take $z_i = (x_i, y_i)$ to be the $i$-th training point, and $\theta$ to be the (potentially highly) multi-dimensional parameters of the neural network (i.e. $\theta$ is a big array with very many parameters). We will indicate with $L(z, \theta)$ the loss of the model for point $z$ and parameters $\theta$. When training the model, we want to minimize some sort of "empirical risk". The optimal parameters are calculated through minimization (i.e. typically through gradient descent) of the following formula:
$$
\hat{\theta} = \arg \min_\theta \frac{1}{n}\sum_{i=1}^n L(z_i, \theta)
$$
where $n$ is the total number of training data-points.

Let's thus define
 $$
\hat{\theta}_{-z} = \arg \min_\theta \frac{1}{n}\sum_{z_i \ne z} L(z_i, \theta) \ ,
$$
i.e. $\hat{\theta}_{-z}$ are the model parameters that minimize the total loss when $z$ is not in the training dataset. 

In order to check the impact of each training point on the model, we would need to calculate $\hat{\theta}_{-z}$ for each $z$ in the training dataset, thus re-training the model ~$n$ times. This is computationally very expensive, especially for big neural networks.

To circumvent this problem, we will calculate a first order approximation of $\hat{\theta}$, which can be computed without re-training the full model. 

Let's define
$$
\hat{\theta}_{\epsilon, z} = \arg \min_\theta \frac{1}{n}\sum_{i=1}^n L(z_i, \theta) + \epsilon L(z_i, \theta) \ ,
$$
which is the optimal $\hat{\theta}$ if we were to up-weigh $z$ by an amount $\epsilon$. 

From a classical result (details at *Cook, R. D. and Weisberg, S. [Residuals and influence in
regression](https://onlinelibrary.wiley.com/doi/abs/10.1002/bimj.4710270110). New York: Chapman and Hall, 1982*), we know that:
$$
\frac{d \ \hat{\theta}_{\epsilon, z}}{d \epsilon} = -H_{\hat{\theta}}^{-1} \nabla_\theta L(z, \hat{\theta})
$$
where $H_{\hat{\theta}} = \frac{1}{n} \sum_{i=1}^n \nabla_\theta^2 L(z_i, \hat{\theta})$ is the Hessian of $L$. Importantly, notice that this expression is only valid when $\hat{\theta}$ is a minimum of $L$, or otherwise $H_{\hat{\theta}}$ cannot be inverted!

### Approximating points' influence

We will define the influence of training point $z$ on test point $z_{\text{test}}$ as $\mathcal{I}(z, z_{\text{test}}) =  L(z_{\text{test}}, \hat{\theta}_{-z}) - L(z_{\text{test}}, \hat{\theta})$, which is higher for points which positively impact the model score (i.e. if they are excluded, the loss is higher). In practice, however, we will always use the infinitesimal approximation $\mathcal{I}_{up}(z, z_{\text{test}})$, defined as
$$
 \mathcal{I}_{up}(z, z_{\text{test}}) = - \frac{d L(z_{\text{test}}, \hat{\theta}_{\epsilon, z})}{d \epsilon} \Big|_{\epsilon=0}
$$

Using the chain rule and the results calculated above, we thus have:
$$
 \mathcal{I}_{up}(z, z_{\text{test}}) = - \nabla_\theta L(z_{\text{test}}, \hat{\theta})^\top \ \frac{d \hat{\theta}_{\epsilon, z}}{d \epsilon} \Big|_{\epsilon=0} = \nabla_\theta L(z_{\text{test}}, \hat{\theta})^\top \ H_{\hat{\theta}}^{-1} \ \nabla_\theta L(z, \hat{\theta})
$$

In order to calculate this expression we need the gradient and the Hessian of the loss wrt. the model parameters $\hat{\theta}$. This can be easily done through a single backpropagation pass.

### Regularizing the Hessian

One very important assumption that we make when approximating influence is that $\hat{\theta}$ is a (at least local) minimum of the loss. However, when dealing with neural networks' training, many factors, such as the noise in SGD or a non-small enough learning rate, may interfere with us reaching the actual minimum. In this scenario, calculating and inverting the Hessian may become infeasible (the computation diverges and returns random values).

To prevent this from happening, instead of inverting the true Hessian $H_{\hat{\theta}}$, in our computation with invert $H_{\hat{\theta}} + \lambda \mathbb{I}$, with $\mathbb{I}$ the identity matrix with same shape as $H$. Therefore, the regularization parameter $\lambda$ should be chosen to be as small as possible, but big enough so that the inversion of $H_{\hat{\theta}} + \lambda \mathbb{I}$ is stable. 