# LAVA

This notebook explores the use of LAVA for data valuation.

<div class="alert alert-info">

If you are reading this in the documentation, some boilerplate has been omitted for convenience.

</div>

## Imports and setup

In [None]:
%load_ext autoreload

In [None]:
%autoreload
%matplotlib inline

import logging
import os
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.models.resnet import ResNet, BasicBlock, resnet18, ResNet18_Weights
from torchvision.utils import make_grid
from tqdm.auto import tqdm

from support.common import (
    plot_sample_images,
    plot_losses,
)
from support.torch import (
    TrainingManager,
    MODEL_PATH,
    new_resnet_model,
)
from support.types import Losses

logging.basicConfig(level=logging.DEBUG)

plt.rcParams["figure.figsize"] = (7, 7)
plt.rcParams["font.size"] = 12
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 10
plt.rcParams["axes.facecolor"] = (1, 1, 1, 0)
plt.rcParams["figure.facecolor"] = (1, 1, 1, 0)

random_state = 42
np.random.seed(random_state)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
%autoreload
from pydvl.ot.lava import LAVA
from pydvl.utils.dataset import Dataset

## Loading and preprocessing the dataset

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = CIFAR10(root="/tmp/cifar10", train=True, download=True, transform=transform)
valset = CIFAR10(root="/tmp/cifar10", train=False, download=True, transform=transform)
classes = trainset.classes

In [None]:
trainset = Subset(trainset, np.random.randint(low=0, high=len(trainset), size=100))
valset = Subset(valset, np.random.randint(low=0, high=len(valset), size=100))

In [None]:
trainloader = DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)
valloader = DataLoader(valset, batch_size=4, shuffle=True, num_workers=0)

Let's take a closer look at a few image samples

In [None]:
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
imshow(make_grid(images))

## Model definition and training

We now train a model on the validation data (This is the same as in the paper) in order to use it as a feature extractor.

In [None]:
model = ResNet(BasicBlock, [1, 1, 1, 1], num_classes=10)
num_params = sum(p.numel() for p in model.parameters())
print(f"Model has {num_params} parameters")

In [None]:
mgr = TrainingManager(
    "model_lava_cifar10",
    model,
    nn.CrossEntropyLoss(),
    valloader,
    trainloader,
    MODEL_PATH,
    device=DEVICE,
)
# Set use_cache=False to retrain the model
train_loss, val_loss = mgr.train(n_epochs=10, use_cache=False)

In [None]:
plot_losses(Losses(train_loss, val_loss))

The confusion matrix and $F_1$ score look good, especially considering the low resolution of the images and their complexity (they contain different objects)

In [None]:
y_test = []
y_pred = []

for inputs, targets in tqdm(valloader, total=len(valloader)):
    y_test.append(targets.cpu().numpy().ravel())
    inputs = inputs.to(DEVICE)
    pred = np.argmax(model(inputs).cpu().detach().numpy(), axis=1).ravel()
    y_pred.append(pred)


y_test = np.concatenate(y_test)
y_pred = np.concatenate(y_pred)
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
disp.plot();

## Feature Extraction

We now remove the last layer in order to use the model as a feature extractor

In [None]:
model.fc = torch.nn.Identity()

In [None]:
x_train = []
y_train = []

for inputs, targets in tqdm(trainloader, total=len(trainloader)):
    y_train.append(targets.cpu().numpy().ravel())
    inputs = inputs.to(DEVICE)
    pred = model(inputs).cpu().detach().numpy()
    x_train.append(pred)

x_train = np.concatenate(x_train)
y_train = np.concatenate(y_train)

In [None]:
x_test = []
y_test = []

for inputs, targets in tqdm(valloader, total=len(valloader)):
    y_test.append(targets.cpu().numpy().ravel())
    inputs = inputs.to(DEVICE)
    pred = model(inputs).cpu().detach().numpy()
    x_test.append(pred)

x_test = np.concatenate(x_test)
y_test = np.concatenate(y_test)

# Computing Values

In [None]:
dataset = Dataset(x_train, y_train, x_test, y_test)

In [None]:
all_values = []

for regularization in [1, 0.5, 0.1, 0.01]:
    for lambda_ in [3.0, 1.0, 0.1, 0.01, 0]:
        print(f"{regularization=}, {lambda_=}")
        lava = LAVA(
            dataset,
            inner_ot_method="exact",
            regularization=regularization,
            lambda_=lambda_,
        )
        values = lava.compute_values()
        all_values.append(values)

In [None]:
values_df = pd.DataFrame(np.stack(all_values).T)

In [None]:
values_df.plot.boxplot();

In [None]:
regularization = 1.0
lambda_ = 1.0
lava = LAVA(
    dataset, inner_ot_method="gaussian", regularization=regularization, lambda_=lambda_
)
values = lava.compute_values()

In [None]:
feature_cost = lava._compute_feature_cost()
plt.boxplot(feature_cost.ravel());

In [None]:
lava._compute_gaussian_label_distances()

In [None]:
lava._compute_exact_label_distances()

In [None]:
plt.hist(values)

## Pre-Trained Model

What if we use a pre-trained model instead?

In [None]:
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = torch.nn.Identity()