In [5]:
# Load library
import sys
import os
PROJ_DIR = os.path.realpath(os.path.dirname(os.path.dirname(os.path.abspath(''))))
sys.path.append(os.path.join(PROJ_DIR,'src'))

import xai_faithfulness_experiments_lib_edits as ff

In [59]:
DATASET_NAME = 'mnist'
MODEL_NAME = 'ood-mean'

# Load dataset
import torch
import torchvision

batch_size = 256

MNIST_PATH = os.path.join(PROJ_DIR, 'data', 'mnist')

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(MNIST_PATH, train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST(MNIST_PATH, train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

In [45]:
import torch
import numpy as np
class RandomMasker(torch.nn.Module):
    def __init__(self, masking_value:torch.tensor):
        super().__init__()
        self._masking_value = masking_value
    def forward(self, x): # Assumes inputs are (batch_size, num_vars)
        selection_levels = torch.rand((x.shape[0], 1)) # A different selection level for each element of the batch
        while len(selection_levels.shape) < len(x.shape):
            selection_levels = selection_levels.unsqueeze(dim=-1)
        selected_pixels = torch.le(torch.rand(x.shape), selection_levels) # A different selection level for each element of the batch
        return x * selected_pixels + self._masking_value * ~selected_pixels

# Zeros
masking_value = np.zeros(example_data.shape[1:])
# Mean
if MODEL_NAME == 'ood-mean':
    masking_value[:] = 0.1307
masker = RandomMasker(torch.tensor(masking_value).float())

In [57]:
# Declare classifier
MODEL_LABEL_NUM = 10
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')
#https://nextjournal.com/gkoehler/pytorch-mnist
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, MODEL_LABEL_NUM)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x#F.softmax(x)

network = MNISTClassifier().to(device)

Using cpu


In [58]:
MODEL_LR = 1e-2
loss = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(network.parameters(), lr=MODEL_LR)#, weight_decay=1e-3)
MODEL_EPOCHS = 10

_, (x_test, y_test) = next(enumerate(test_loader))

if MODEL_NAME == 'untrained':
    MODEL_EPOCHS = 0 # Untrained
for epoch in range(MODEL_EPOCHS):
    for batch_idx, (x_train, y_train) in enumerate(train_loader):
        if 'ood' in MODEL_NAME:
            x_train_tensor_masked = masker(x_train) # A different set of RandomMasks for each batch
        else:
            x_train_tensor_masked = x_train # Unmasked
        optimizer.zero_grad()
        
        preds = network(x_train_tensor_masked)
        label_onehot = torch.zeros(y_train.shape[0], MODEL_LABEL_NUM)
        label_onehot.scatter_(1, y_train.unsqueeze(1), 1)
        loss_value = loss(preds, label_onehot)
        loss_value.backward()        
        optimizer.step()

        train_accuracy = (preds.argmax(dim=1) == y_train).float().mean() 

        test_preds = network.forward(x_test)        
        test_accuracy = (test_preds.argmax(dim=1) == y_test).float().mean() 
        print(f'Epoch {epoch}/{MODEL_EPOCHS} - Loss: {loss_value.item():.4f} - Train accuracy: {train_accuracy:.2f} - Test accuracy: {test_accuracy:.2f}')
        #if test_accuracy > 0.6: # Undertrained
        #    break

network.eval()
test_accuracy = 0
num_batches = 0
for batch_idx, (x_test, y_test) in enumerate(train_loader):
    num_batches += 1
    test_preds = network.forward(x_test)        
    test_accuracy += (test_preds.argmax(dim=1) == y_test).float().mean()
print((test_accuracy.item()) / num_batches)

Epoch 0/10 - Loss: 0.7140 - Train accuracy: 0.12 - Test accuracy: 0.11
Epoch 0/10 - Loss: 0.5387 - Train accuracy: 0.11 - Test accuracy: 0.09
Epoch 0/10 - Loss: 0.5310 - Train accuracy: 0.12 - Test accuracy: 0.09
Epoch 0/10 - Loss: 0.4603 - Train accuracy: 0.09 - Test accuracy: 0.11
Epoch 0/10 - Loss: 0.4320 - Train accuracy: 0.10 - Test accuracy: 0.08
Epoch 0/10 - Loss: 0.4437 - Train accuracy: 0.09 - Test accuracy: 0.13
Epoch 0/10 - Loss: 0.4318 - Train accuracy: 0.06 - Test accuracy: 0.12
Epoch 0/10 - Loss: 0.4044 - Train accuracy: 0.11 - Test accuracy: 0.09
Epoch 0/10 - Loss: 0.4052 - Train accuracy: 0.11 - Test accuracy: 0.10
Epoch 0/10 - Loss: 0.4036 - Train accuracy: 0.12 - Test accuracy: 0.16
Epoch 0/10 - Loss: 0.3851 - Train accuracy: 0.11 - Test accuracy: 0.14
Epoch 0/10 - Loss: 0.3900 - Train accuracy: 0.14 - Test accuracy: 0.11
Epoch 0/10 - Loss: 0.3925 - Train accuracy: 0.14 - Test accuracy: 0.09
Epoch 0/10 - Loss: 0.3758 - Train accuracy: 0.09 - Test accuracy: 0.09
Epoch 

In [61]:
# Save model
torch.save(network.state_dict(), os.path.join(PROJ_DIR,'assets','models',f'{DATASET_NAME}-{MODEL_NAME}-mlp.pth'))

import json
MODELS_PATH = os.path.join(PROJ_DIR,'assets','models')
with open(os.path.join(MODELS_PATH, 'model-accuracies.json')) as fIn:
    models = json.load(fIn)
models[f'{DATASET_NAME}-{MODEL_NAME}'] = test_accuracy.item()
with open(os.path.join(MODELS_PATH, 'model-accuracies.json'), 'w') as fOut:
    json.dump(models, fOut)