# Influence functions in Computer vision

## Imports

In [None]:
%load_ext autoreload

In [None]:
%autoreload
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from copy import deepcopy
from PIL import Image

import torch
from torch.optim import Adam
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from pydvl.utils.dataset import load_preprocess_imagenet
from pydvl.influence.model_wrappers import TorchModel
from pydvl.influence.general import compute_influences
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score

In [None]:
plt.rcParams["figure.figsize"] = (16, 8)
plt.rcParams["font.size"] = 12
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 10

In [None]:
from pathlib import Path
from cloudpickle import pickle as pkl

imgnet_model_data_path = Path().resolve().parent / "data/imgnet_model"


def save_model(model, train_loss, val_loss, model_name):
    torch.save(model.state_dict(), imgnet_model_data_path / f"{model_name}_weights.pth")
    with open(
        imgnet_model_data_path / f"{model_name}_train_val_loss.pkl", "wb"
    ) as file:
        pkl.dump([train_loss, val_loss], file)


def load_model(model, model_name):
    model.load_state_dict(
        torch.load(imgnet_model_data_path / f"{model_name}_weights.pth")
    )
    with open(
        imgnet_model_data_path / f"{model_name}_train_val_loss.pkl", "rb"
    ) as file:
        train_loss, val_loss = pkl.load(file)
    return train_loss, val_loss


def save_results(results, file_name):
    with open(imgnet_model_data_path / f"{file_name}", "wb") as file:
        pkl.dump(results, file)


def load_results(file_name):
    with open(imgnet_model_data_path / f"{file_name}", "rb") as file:
        results = pkl.load(file)
    return results

In [None]:
labels_to_keep = np.random.choice(list(range(200)), 2)
# labels_to_keep = [90, 100, 110]
train_ds, val_ds, test_ds = load_preprocess_imagenet(
    train_size=0.8, test_size=0.1, keep_labels=labels_to_keep
)

In [None]:
n_images_per_class = 4
fig, axes = plt.subplots(nrows=n_images_per_class, ncols=len(labels_to_keep))
fig.suptitle("Examples of training images")
for class_idx, class_label in enumerate(labels_to_keep):
    for img_idx, (_, img_data) in enumerate(
        train_ds[train_ds["labels"] == class_label].iterrows()
    ):
        axes[img_idx, class_idx].imshow(img_data["images"])
        axes[img_idx, class_idx].axis("off")
        axes[img_idx, class_idx].set_title(f"img label: {class_label}")
        if img_idx + 1 >= n_images_per_class:
            break
plt.show()

In [None]:
def initialise_model(output_size):
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False

    # Finetune Final few layers to adjust for tiny imagenet input
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, output_size)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model


model_ft = initialise_model(output_size=len(labels_to_keep))

In [None]:
ds_label_to_model_label = {ds_label: idx for idx, ds_label in enumerate(labels_to_keep)}
model_label_to_ds_label = {idx: ds_label for idx, ds_label in enumerate(labels_to_keep)}


def get_model_io(x, y):
    x_nn = torch.stack(x.tolist())
    y_nn = [ds_label_to_model_label[yi] for yi in y]
    return x_nn, y_nn


ce_loss = nn.CrossEntropyLoss()

train_x, train_y = get_model_io(train_ds["normalized_images"], train_ds["labels"])
val_x, val_y = get_model_io(val_ds["normalized_images"], val_ds["labels"])
test_x, test_y = get_model_io(test_ds["normalized_images"], test_ds["labels"])


def get_f1_score_on_test_set(model):
    pred_y_test = np.argmax(model(test_x).detach(), axis=1)
    return f1_score(test_y, pred_y_test, average="weighted")


def plot_train_val_loss(train_loss, val_loss):
    plt.rcParams["figure.figsize"] = (10, 8)
    plt.plot(train_loss, label="Train")
    plt.plot(val_loss, label="Val")
    plt.legend()
    plt.show()


def train_model(model, num_epochs, training_data, lr=0.001):
    optimizer = Adam(model.parameters(), lr=lr)

    get_model_io(training_data["normalized_images"], training_data["labels"])
    train_x, train_y = get_model_io(
        training_data["normalized_images"], training_data["labels"]
    )

    train_loss, val_loss = TorchModel(model=model).fit(
        x_train=train_x,
        y_train=train_y,
        x_val=val_x,
        y_val=val_y,
        loss=ce_loss,
        optimizer=optimizer,
        num_epochs=num_epochs,
        batch_size=1000,
    )
    return train_loss, val_loss

In [None]:
run_model_training = True

if run_model_training:
    num_epochs = 30
    train_loss, val_loss = train_model(
        model_ft, num_epochs=num_epochs, training_data=train_ds
    )
    save_model(model_ft, train_loss, val_loss, model_name="model_ft")
else:
    train_loss, val_loss = load_model(model_ft, model_name="model_ft")

In [None]:
plot_train_val_loss(train_loss, val_loss)

In [None]:
pred_y_test = np.argmax(model_ft(test_x).detach(), axis=1)

cm = confusion_matrix(test_y, pred_y_test)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels_to_keep)
disp.plot();

In [None]:
f1_score(test_y, pred_y_test, average="weighted")

In [None]:
calculate_influences = True

if calculate_influences:
    influences = compute_influences(
        model=model_ft,
        loss=ce_loss,
        x=train_x,
        y=train_y,
        x_test=val_x,
        y_test=val_y,
        hessian_regularization=1e-3,
        inversion_method="cg",
        influence_type="up",
    )
    save_results(influences, file_name="influences.pkl")
else:
    influences = load_results(file_name="influences.pkl")

In [None]:
val_image_idx = 2
plt.rcParams["figure.figsize"] = (5, 5)
plt.imshow(val_ds["images"][val_image_idx])

In [None]:
print(
    "Predicted label:",
    model_label_to_ds_label[
        np.argmax(model_ft(val_x[val_image_idx].unsqueeze(0)).detach(), axis=1).item()
    ],
)
print("Real label:", val_ds["labels"][val_image_idx])

In [None]:
def plot_top_bottom_if_images(
    subset_influences, subset_images, num_to_plot, figsize=(8, 8)
):
    top_if_idxs = np.argsort(subset_influences)[-num_to_plot:]
    bottom_if_idxs = np.argsort(subset_influences)[:num_to_plot]

    fig, axes = plt.subplots(nrows=num_to_plot, ncols=2)
    plt.rcParams["figure.figsize"] = figsize
    fig.suptitle("Botton (left) and top (right) influences")

    for plt_idx, img_idx in enumerate(bottom_if_idxs):
        axes[plt_idx, 0].set_title(f"img influence: {subset_influences[img_idx]:0f}")
        axes[plt_idx, 0].imshow(subset_images[img_idx])
        axes[plt_idx, 0].axis("off")

    for plt_idx, img_idx in enumerate(top_if_idxs):
        axes[plt_idx, 1].set_title(f"img influence: {subset_influences[img_idx]:0f}")
        axes[plt_idx, 1].imshow(subset_images[img_idx])
        axes[plt_idx, 1].axis("off")

    plt.show()

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
for label in labels_to_keep:
    plt.hist(influences[val_image_idx][train_ds["labels"] == label], label=label)
plt.legend()

In [None]:
images_with_same_label = train_ds["labels"] == val_ds["labels"][val_image_idx]
if_same_label = influences[val_image_idx][images_with_same_label]
imges_same_label = train_ds["images"][images_with_same_label].values
plot_top_bottom_if_images(if_same_label, subset_images=imges_same_label, num_to_plot=3)

In [None]:
avg_influences = np.mean(influences, axis=0)

In [None]:
plt.rcParams["figure.figsize"] = (8, 8)
for label in labels_to_keep:
    plt.hist(avg_influences[train_ds["labels"] == label], label=label)
plt.legend()

In [None]:
label = 90
img_with_selected_label = train_ds["labels"] == label
if_selected_label = avg_influences[img_with_selected_label]
imges_same_label = train_ds["images"][img_with_selected_label].values
plot_top_bottom_if_images(if_selected_label, imges_same_label, num_to_plot=3)

In [None]:
def array_to_PIL(arr):
    return Image.fromarray(np.uint8(arr))


def get_corrupted_dataset(dataset, fraction_to_corrupt, avg_influences):
    indices_to_corrupt = []
    corrupted_dataset = deepcopy(dataset)
    corrupted_indices = {l: [] for l in labels_to_keep}

    avg_influences_series = pd.DataFrame()
    avg_influences_series["avg_influences"] = avg_influences
    avg_influences_series["labels"] = dataset["labels"]

    for label in labels_to_keep:
        class_data = avg_influences_series[avg_influences_series["labels"] == label]
        num_corrupt = int(fraction_to_corrupt * len(class_data))
        indices_to_corrupt = class_data.nlargest(
            num_corrupt, "avg_influences"
        ).index.tolist()
        wrong_labels = [l for l in labels_to_keep if l != label]
        for img_idx in indices_to_corrupt:
            sample_label = np.random.choice(wrong_labels)
            corrupted_dataset.at[img_idx, "labels"] = sample_label
            corrupted_indices[sample_label].append(img_idx)
    return corrupted_dataset, corrupted_indices

In [None]:
run_model_training = True
model_corrupted = initialise_model(output_size=len(labels_to_keep))
corrupted_dataset, corrupted_indices = get_corrupted_dataset(
    train_ds, 0.1, avg_influences
)

if run_model_training:
    num_epochs = 30
    train_loss, val_loss = train_model(
        model_corrupted,
        num_epochs=num_epochs,
        training_data=corrupted_dataset,
        lr=0.001,
    )
    save_model(model_corrupted, train_loss, val_loss, model_name="model_corrupted")
else:
    train_loss, val_loss = load_model(model_corrupted, model_name="model_corrupted")

In [None]:
plot_train_val_loss(train_loss, val_loss)

In [None]:
model_score = get_f1_score_on_test_set(model_corrupted)
print(model_score)

In [None]:
calculate_influences = True

if calculate_influences:
    corrupted_train_x, corrupted_train_y = get_model_io(
        corrupted_dataset["normalized_images"], corrupted_dataset["labels"]
    )
    influences = compute_influences(
        model=model_corrupted,
        loss=ce_loss,
        x=corrupted_train_x,
        y=corrupted_train_y,
        x_test=val_x,
        y_test=val_y,
        hessian_regularization=1e-3,
        inversion_method="cg",
        influence_type="up",
    )
    save_results(influences, file_name="influences_corrupted.pkl")
else:
    influences = load_results(file_name="influences_corrupted.pkl")

In [None]:
label = 100
avg_corrupted_influences = np.mean(influences, axis=0)
img_with_selected_label = corrupted_dataset["labels"] == label
if_selected_label = avg_corrupted_influences[img_with_selected_label]
imges_same_label = corrupted_dataset["images"][img_with_selected_label].values
plot_top_bottom_if_images(if_selected_label, imges_same_label, num_to_plot=3)

In [None]:
for label in labels_to_keep:
    avg_influences_series = pd.Series(avg_corrupted_influences)
    class_influences = avg_influences_series[corrupted_dataset["labels"] == label]
    corrupted_infl = class_influences[
        class_influences.index.isin(corrupted_indices[label])
    ]
    non_corrupted_infl = class_influences[
        ~class_influences.index.isin(corrupted_indices[label])
    ]
    plt.hist(non_corrupted_infl, label="non corrupted data", density=True, alpha=0.7)
    plt.hist(corrupted_infl, label="corrupted data", density=True, alpha=0.7)
    plt.legend()
    plt.show()
    print(
        f"Average influence of corrupted points for {label=}: ",
        np.mean(corrupted_infl),
    )
    print(
        f"Average influence of other points for {label=}: ",
        np.mean(non_corrupted_infl),
    )