In [23]:
import torch
import torchvision
import os
from pathlib import Path
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, RandomSampler, Subset
from torchvision import transforms
from torch.nn import MSELoss
from captum.attr import IntegratedGradients

from lfxai.models.images import AutoEncoderMnist, EncoderMnist, DecoderMnist
from lfxai.models.pretext import Identity, RandomNoise
from lfxai.explanations.features import attribute_auxiliary
from lfxai.explanations.examples import SimplEx, InfluenceFunctions

In [24]:
random_seed: int = 1
batch_size: int = 200
dim_latent: int = 4
n_epochs: int = 100
subtrain_size: int = 1000

# Initialize seed and device
torch.random.manual_seed(random_seed)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load MNIST
data_dir = Path.cwd() / "data/mnist"
train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=True)
test_dataset = torchvision.datasets.MNIST(data_dir, train=False, download=True)
train_transform = transforms.Compose([transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_dataset.transform = train_transform
test_dataset.transform = test_transform
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False
)

# Initialize encoder, decoder and autoencoder wrapper
pert = RandomNoise()
encoder = EncoderMnist(encoded_space_dim=dim_latent)
decoder = DecoderMnist(encoded_space_dim=dim_latent)
autoencoder = AutoEncoderMnist(encoder, decoder, dim_latent, pert)
encoder.to(device)
decoder.to(device)
autoencoder.to(device)

# Train the denoising autoencoder
save_dir = Path.cwd() / "results/mnist/consistency_examples"
if not save_dir.exists():
    os.makedirs(save_dir)
if not (save_dir / (autoencoder.name + ".pt")).exists():
    autoencoder.fit(
        device, train_loader, test_loader, save_dir, n_epochs, checkpoint_interval=10
    )
else:      
    autoencoder.load_state_dict(
        torch.load(save_dir / (autoencoder.name + ".pt")), strict=False
    )
autoencoder.train().to(device)

AutoEncoderMnist(
  (encoder): EncoderMnist(
    (encoder_cnn): Sequential(
      (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): ReLU(inplace=True)
      (5): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
      (6): ReLU(inplace=True)
    )
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (encoder_lin): Sequential(
      (0): Linear(in_features=288, out_features=128, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=128, out_features=4, bias=True)
    )
  )
  (decoder): DecoderMnist(
    (decoder_lin): Sequential(
      (0): Linear(in_features=4, out_features=128, bias=True)
      (1): ReLU(inplace=True)
      (2): Linear(in_features=128, out_features=288, bias=True)
      (3): ReLU(inplace=True)
    )
    (unflatten): Unflatten(dim

In [31]:
len(train_dataset), len(test_dataset)

(60000, 10000)

In [30]:
idx_subtrain = [
    torch.nonzero(train_dataset.targets == (n % 10))[n // 10].item()
    for n in range(subtrain_size)
]

idx_subtest = [
    torch.nonzero(test_dataset.targets == (n % 10))[n // 10].item()
    for n in range(subtrain_size)
]

train_subset = Subset(train_dataset, idx_subtrain)
test_subset = Subset(test_dataset, idx_subtest)
subtrain_loader = DataLoader(train_subset)
subtest_loader = DataLoader(test_subset)
labels_subtrain = torch.cat([label for _, label in subtrain_loader])
labels_subtest = torch.cat([label for _, label in subtest_loader])

# Create a training set sampler with replacement for computing influence functions
recursion_depth = 100
train_sampler = RandomSampler(
    train_dataset, replacement=True, num_samples=recursion_depth * batch_size
)
train_loader_replacement = DataLoader(
    train_dataset, batch_size, sampler=train_sampler
)

# Fitting explainers, computing the metric and saving everything
mse_loss = torch.nn.MSELoss()
explainer_list = [
    InfluenceFunctions(autoencoder, mse_loss, save_dir / "if_grads"),
]

frac_list = [0.05, 0.1, 0.2, 0.5, 0.7, 1.0]
n_top_list = [int(frac * len(idx_subtrain)) for frac in frac_list]
results_dict = {}
for explainer in explainer_list:
    attribution = explainer.attribute_loader(
        device,
        subtrain_loader,
        subtest_loader,
        train_loader_replacement=train_loader_replacement,
        recursion_depth=recursion_depth,
    )
    autoencoder.load_state_dict(
        torch.load(save_dir / (autoencoder.name + ".pt")), strict=False
    )
    results_dict[str(explainer)] = attribution

                                                        

KeyboardInterrupt: 

In [32]:
idx_subtrain = [
    torch.nonzero(train_dataset.targets == (n % 10))[n // 10].item()
    for n in range(subtrain_size)
]

idx_subtest = [
    torch.nonzero(test_dataset.targets == (n % 10))[n // 10].item()
    for n in range(subtrain_size)
]

In [44]:
test_dataset

Dataset MNIST
    Number of datapoints: 10000
    Root location: c:\Users\albac\Desktop\PostGrad\Cambridge\minitaskVDSLab\data\mnist
    Split: Test

In [35]:
len(idx_subtrain), len(idx_subtest)

(1000, 1000)

In [None]:
idx_subtrain = [
    torch.nonzero(train_dataset.targets == (n % 10))[n // 10].item()
    for n in range(subtrain_size)
]

idx_subtest = [
    torch.nonzero(test_dataset.targets == (n % 10))[n // 10].item()
    for n in range(subtrain_size)
]

train_subset = Subset(train_dataset, idx_subtrain)
test_subset = Subset(test_dataset, idx_subtest)
subtrain_loader = DataLoader(train_subset)
subtest_loader = DataLoader(test_subset)
labels_subtrain = torch.cat([label for _, label in subtrain_loader])
labels_subtest = torch.cat([label for _, label in subtest_loader])

# Create a training set sampler with replacement for computing influence functions
recursion_depth = 100
train_sampler = RandomSampler(
    train_dataset, replacement=True, num_samples=recursion_depth * batch_size
)
train_loader_replacement = DataLoader(
    train_dataset, batch_size, sampler=train_sampler
)

# Fitting explainers, computing the metric and saving everything
mse_loss = torch.nn.MSELoss()
explainer_list = [
    InfluenceFunctions(autoencoder, mse_loss, save_dir / "if_grads"),
]

frac_list = [0.05, 0.1, 0.2, 0.5, 0.7, 1.0]
n_top_list = [int(frac * len(idx_subtrain)) for frac in frac_list]
results_dict = {}
for explainer in explainer_list:
    attribution = explainer.attribute_loader(
        device,
        subtrain_loader,
        subtest_loader,
        train_loader_replacement=train_loader_replacement,
        recursion_depth=recursion_depth,
    )
    autoencoder.load_state_dict(
        torch.load(save_dir / (autoencoder.name + ".pt")), strict=False
    )
    results_dict[str(explainer)] = attribution