# Influence functions in Computer vision

## Imports

In [None]:
%load_ext autoreload

In [None]:
%autoreload
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

import torch
from torch.optim import lr_scheduler, SGD
import torch.nn as nn
from torchvision.models import resnet18
from pydvl.utils.dataset import load_preprocess_imagenet
from pydvl.influence.model_wrappers import TorchModel
from pydvl.influence.general import compute_influences
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score

In [None]:
plt.rcParams["figure.figsize"] = (16, 8)
plt.rcParams["font.size"] = 12
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 10

In [None]:
from pathlib import Path
from cloudpickle import pickle as pkl

imgnet_model_data_path = Path().resolve().parent / "data/imgnet_model"


def save_model(model, train_loss, val_loss):
    torch.save(model.state_dict(), imgnet_model_data_path / "model_weights.pth")
    with open(imgnet_model_data_path / "train_val_loss.pkl", "wb") as file:
        pkl.dump([train_loss, val_loss], file)


def load_model(model):
    model.load_state_dict(torch.load(imgnet_model_data_path / "model_weights.pth"))
    with open(imgnet_model_data_path / "train_val_loss.pkl", "rb") as file:
        train_loss, val_loss = pkl.load(file)
    return train_loss, val_loss


def save_influences(influences):
    with open(imgnet_model_data_path / "influences.pkl", "wb") as file:
        pkl.dump(influences, file)


def load_influences():
    with open(imgnet_model_data_path / "influences.pkl", "rb") as file:
        influences = pkl.load(file)
    return influences

In [None]:
labels_to_keep = list(range(50, 110, 10))
train_ds, val_ds, test_ds = load_preprocess_imagenet(
    train_size=0.8, test_size=0.1, keep_labels=labels_to_keep
)

In [None]:
n_images_per_class = 3
fig, axes = plt.subplots(nrows=n_images_per_class, ncols=len(labels_to_keep))
fig.suptitle("Examples of training images")
for class_idx, class_label in enumerate(labels_to_keep):
    for img_idx, (_, img_data) in enumerate(
        train_ds[train_ds["labels"] == class_label].iterrows()
    ):
        axes[img_idx, class_idx].imshow(img_data["images"])
        axes[img_idx, class_idx].axis("off")
        axes[img_idx, class_idx].set_title(f"img label: {class_label}")
        if img_idx + 1 >= n_images_per_class:
            break
plt.show()

In [None]:
model_ft = resnet18(weights=True)

for param in model_ft.parameters():
    param.requires_grad = False

# Finetune Final few layers to adjust for tiny imagenet input
model_ft.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(labels_to_keep))

In [None]:
ce_loss = nn.CrossEntropyLoss()
optimizer = SGD(model_ft.parameters(), lr=0.01, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

ds_label_to_model_label = {ds_label: idx for idx, ds_label in enumerate(labels_to_keep)}
model_label_to_ds_label = {idx: ds_label for idx, ds_label in enumerate(labels_to_keep)}


def get_model_io(x, y):
    x_nn = torch.stack(x.tolist())
    y_nn = [ds_label_to_model_label[yi] for yi in y]
    return x_nn, y_nn


train_x, train_y = get_model_io(train_ds["normalized_images"], train_ds["labels"])
val_x, val_y = get_model_io(val_ds["normalized_images"], val_ds["labels"])
test_x, test_y = get_model_io(test_ds["normalized_images"], test_ds["labels"])

In [None]:
train_model = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_ft.to(device)

if train_model:
    num_epochs = 15
    train_loss, val_loss = TorchModel(model=model_ft).fit(
        x_train=train_x,
        y_train=train_y,
        x_val=val_x,
        y_val=val_y,
        loss=ce_loss,
        optimizer=optimizer,
        num_epochs=num_epochs,
        batch_size=1000,
    )
    save_model(model_ft, train_loss, val_loss)
else:
    train_loss, val_loss = load_model(model_ft)

In [None]:
_, ax = plt.subplots()
ax.plot(train_loss, label="Train")
ax.plot(val_loss, label="Val")
ax.legend()
plt.show()

In [None]:
pred_y_test = np.argmax(model_ft(test_x).detach(), axis=1)

cm = confusion_matrix(test_y, pred_y_test)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels_to_keep)
disp.plot();

In [None]:
f1_score(test_y, pred_y_test, average="weighted")

In [None]:
calculate_influences = True

if calculate_influences:
    influences = compute_influences(
        model_ft,
        ce_loss,
        x=train_x,
        y=train_y,
        x_test=val_x,
        y_test=val_y,
        hessian_regularization=0.1,
        inversion_method="cg",
        influence_type="up",
    )
    save_influences(influences)
else:
    influences = load_influences()

In [None]:
val_image_idx = 110
plt.rcParams["figure.figsize"] = (5, 5)
plt.imshow(val_ds["images"][val_image_idx])

In [None]:
print(
    "Predicted label:",
    model_label_to_ds_label[
        np.argmax(model_ft(val_x[val_image_idx].unsqueeze(0)).detach(), axis=1).item()
    ],
)
print("Real label:", val_ds["labels"][val_image_idx])

In [None]:
def plot_top_bottom_if_images(
    subset_influences, subset_images, num_to_plot, figsize=(8, 8)
):
    top_if_idxs = np.argsort(subset_influences)[-num_to_plot:]
    bottom_if_idxs = np.argsort(subset_influences)[:num_to_plot]

    fig, axes = plt.subplots(nrows=num_to_plot, ncols=2)
    plt.rcParams["figure.figsize"] = figsize
    fig.suptitle("Botton (left) and top (right) influences")

    for plt_idx, img_idx in enumerate(bottom_if_idxs):
        axes[plt_idx, 0].set_title(f"img influence: {subset_influences[img_idx]:0f}")
        axes[plt_idx, 0].imshow(subset_images[img_idx])
        axes[plt_idx, 0].axis("off")

    for plt_idx, img_idx in enumerate(top_if_idxs):
        axes[plt_idx, 1].set_title(f"img influence: {subset_influences[img_idx]:0f}")
        axes[plt_idx, 1].imshow(subset_images[img_idx])
        axes[plt_idx, 1].axis("off")

    plt.show()

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
for label in labels_to_keep:
    plt.hist(influences[val_image_idx][train_ds["labels"] == label], label=label)
plt.legend()

In [None]:
images_with_same_label = train_ds["labels"] == val_ds["labels"][val_image_idx]
if_same_label = influences[val_image_idx][images_with_same_label]
imges_same_label = train_ds["images"][images_with_same_label].values
plot_top_bottom_if_images(if_same_label, subset_images=imges_same_label, num_to_plot=3)

In [None]:
avg_influences = np.mean(influences, axis=0)

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
for label in labels_to_keep:
    plt.hist(avg_influences[train_ds["labels"] == label], label=label)
plt.legend()

In [None]:
plot_top_bottom_if_images(avg_influences, train_ds["images"], num_to_plot=3)

In [None]:
indices_to_exclude = []
fraction_to_exclude = 0.7

train_ds["avg_influences"] = avg_influences

for label in labels_to_keep:
    class_data = train_ds[train_ds["labels"] == label]
    num_exclude = int(fraction_to_exclude * len(class_data))
    indices_to_exclude.extend(
        class_data.nsmallest(num_exclude, "avg_influences").index.tolist()
    )

reduced_train_ds = train_ds.loc[~train_ds.index.isin(indices_to_exclude)]

In [None]:
num_epochs = 5
optimizer = SGD(model_ft.parameters(), lr=0.01, momentum=0.9)

red_train_x, red_train_y = get_model_io(
    reduced_train_ds["normalized_images"], reduced_train_ds["labels"]
)

train_loss, val_loss = TorchModel(model=model_ft).fit(
    x_train=red_train_x,
    y_train=red_train_y,
    x_val=val_x,
    y_val=val_y,
    loss=ce_loss,
    optimizer=optimizer,
    num_epochs=num_epochs,
    batch_size=1000,
)

In [None]:
_, ax = plt.subplots()
ax.plot(train_loss, label="Train")
ax.plot(val_loss, label="Val")
ax.legend()
plt.show()

In [None]:
pred_y_test = np.argmax(model_ft(test_x).detach(), axis=1)

f1_score(test_y, pred_y_test, average="weighted")