# Influence functions in Computer vision

## Imports

In [None]:
%load_ext autoreload

In [None]:
%autoreload
%matplotlib inline

import matplotlib.pyplot as plt
from pathlib import Path
from cloudpickle import pickle as pkl
import numpy as np
import torch
from pydvl.influence.general import compute_influences
from torch.optim import lr_scheduler, SGD
import torch
import torch.nn as nn
from torchvision.models import resnet18
from pydvl.utils.dataset import load_preprocess_imagenet
from pydvl.influence.model_wrappers import TorchModel
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score

In [None]:
n_classes_to_keep = 10
train_ds, val_ds, test_ds = load_preprocess_imagenet(
    train_size=0.8, test_size=0.1, keep_labels=list(range(n_classes_to_keep))
)

In [None]:
len(train_ds["labels"])

In [None]:
for image in train_ds["images"][:3]:
    plt.imshow(image)
    plt.show()

In [None]:
model_ft = resnet18(weights=True)

for param in model_ft.parameters():
    param.requires_grad = False

# Finetune Final few layers to adjust for tiny imagenet input
model_ft.avgpool = nn.AdaptiveAvgPool2d(1)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, n_classes_to_keep)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_ft.to(device)

train_model = False
model_data_path = Path().resolve().parent / "data/imgnet_model"
ce_loss = nn.CrossEntropyLoss()
optimizer = SGD(model_ft.parameters(), lr=0.01, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)


def save_model(path, model, train_loss, val_loss):
    torch.save(model.state_dict(), path / "model_weights.pth")
    with open(path / "train_val_loss.pkl", "wb") as file:
        pkl.dump([train_loss, val_loss], file)


def load_model(path, model):
    model.load_state_dict(torch.load(path / "model_weights.pth"))
    with open(model_data_path / "train_val_loss.pkl", "rb") as file:
        train_loss, val_loss = pkl.load(file)
    return train_loss, val_loss


if train_model:
    num_epochs = 5
    train_loss, val_loss = TorchModel(model=model_ft).fit(
        x_train=torch.stack(train_ds["normalized_images"]),
        y_train=train_ds["labels"],
        x_val=torch.stack(val_ds["normalized_images"]),
        y_val=val_ds["labels"],
        loss=ce_loss,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=num_epochs,
        batch_size=1000,
    )
    save_model(model_data_path, model_ft, train_loss, val_loss)
else:
    load_model(model_data_path, model_ft)

In [None]:
_, ax = plt.subplots()
ax.plot(train_loss, label="Train")
ax.plot(val_loss, label="Val")
ax.legend()
plt.show()

In [None]:
pred_y_test = np.argmax(
    model_ft(torch.stack(test_ds["normalized_images"])).detach(), axis=1
)

cm = confusion_matrix(test_ds["labels"], pred_y_test)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot();

In [None]:
f1_score(test_ds["labels"], pred_y_test, average="weighted")

In [None]:
influences = compute_influences(
    model_ft,
    ce_loss,
    x=torch.stack(train_ds["normalized_images"][400:600]),
    y=train_ds["labels"][400:600],
    x_test=torch.stack(val_ds["normalized_images"][:300]),
    y_test=val_ds["labels"][:300],
    hessian_regularization=0.01,
    inversion_method="cg",
    influence_type="up",
)

In [None]:
avg_influences = np.mean(influences, axis=0)

In [None]:
plt.hist(avg_influences)

In [None]:
for label in range(n_classes_to_keep):
    plt.hist(
        avg_influences[np.array(train_ds["labels"][400:600]) == label], label=label
    )
plt.legend()

In [None]:
avg_influences[np.argsort(avg_influences)[:10]]

In [None]:
img_idx = 20
plt.imshow(val_ds["images"][img_idx])
print(val_ds["labels"][img_idx])

In [None]:
for label in range(n_classes_to_keep):
    if label == 5:
        continue
    plt.hist(
        influences[20][np.array(train_ds["labels"][200:400]) == label], label=label
    )
plt.legend()

In [None]:
np.array(train_ds["labels"][400:600])[np.argsort(influences[8])[-10:]]

In [None]:
for img_idx in np.argsort(avg_influences)[:10]:
    print(avg_influences[img_idx])
    print(train_ds["labels"][400:600][img_idx])
    plt.imshow(train_ds["images"][400:600][img_idx])
    plt.show()

In [None]:
for idx, img in enumerate(train_ds["images"][:200]):
    if train_ds["labels"][:200][idx] == 5:
        plt.imshow(img)
        plt.show()

In [None]:
influences[8][(influences[8] > 1)]

In [None]:
np.asarray(train_ds["labels"][:200])[~(influences[8] > 1)] == 5

In [None]:
image_n = 40
for i in range(20):
    print(val_target[:100:5][i])
    plt.imshow(result[i][image_n][0])
    print(np.mean(result[i][image_n][0]))
    plt.colorbar()
    plt.show()

In [None]:
plt.imshow(processed_tiny_imagenet["original_image"][:500:10][image_n])
print(processed_tiny_imagenet["labels"][:500:10][image_n])