# MNIST

In [1]:
import sage
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.datasets as dsets

In [2]:
# Load train set
train = dsets.FashionMNIST('../data', train=True, download=True)
imgs = train.data.reshape(-1, 784) / 255.0
labels = train.targets

# Shuffle and split into train and val
inds = torch.randperm(len(train))
imgs = imgs[inds]
labels = labels[inds]
val, Y_val = imgs[:6000], labels[:6000]
train, Y_train = imgs[6000:], labels[6000:]

# Load test set
test = dsets.FashionMNIST('../data', train=False, download=True)
test, Y_test = test.data.reshape(-1, 784) / 255.0, test.targets

# Move test data to numpy
test_np = test.cpu().data.numpy()
Y_test_np = Y_test.cpu().data.numpy()

In [3]:
device = torch.device('cuda', 0)
model = torch.load('trained_models/mnist mlp.pt')
model = model.to(device)
model = nn.Sequential(model, nn.Softmax(dim=1))

In [4]:
print(test_np,Y_test_np)

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] [7 2 1 ... 4 5 6]


In [5]:
# Setup and calculate
imputer = sage.MarginalImputer(model, test_np[:128])
estimator = sage.PermutationEstimator(imputer, 'cross entropy')
sage_values = estimator(test_np, Y_test_np, batch_size=512, thresh=0.07)

Setting up imputer for PyTorch model, assuming that any necessary output activations are applied properly. If not, please set up nn.Sequential with nn.Sigmoid or nn.Softmax


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))

it:0
StdDev Ratio = 0.1391 (Converge at 0.0700)
it:1
StdDev Ratio = 0.0757 (Converge at 0.0700)
it:2
StdDev Ratio = 0.0617 (Converge at 0.0700)
Detected convergence



In [6]:
sage_values.save('results/mnist_sage.pkl')