# Complete notebook with code to generate all the results

In [1]:
from __future__ import division, print_function, absolute_import
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from math import ceil
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
import pickle 
import scipy.ndimage
from PIL import Image as PILImage
from skimage.color import rgb2gray

## Helper functions

In [35]:
# borrowed from the original paper of Li et al. (2018)
def makedirs(path):
    '''
    if path does not exist in the file system, create it
    '''
    if not os.path.exists(path):
        os.makedirs(path)

def list_of_norms(X):
    '''
    X is a list of vectors X = [x_1, ..., x_n], we return
        [d(x_1, x_1), d(x_2, x_2), ... , d(x_n, x_n)], where the distance
    function is the squared euclidean distance.
    '''
    return torch.sum(torch.pow(X, 2), dim=1)

## Create necessary directories

In [36]:
# data folder
data_folder = os.path.join(os.getcwd(), "data")
makedirs(data_folder)

# Various model folders
model_original = os.path.join(os.getcwd(), "models", "mnist_original")
model_cifar = os.path.join(os.getcwd(), "models", "cifar")
model_mnist_color = os.path.join(os.getcwd(), "models", "mnist_color")
model_mnist_rgb2gray = os.path.join(os.getcwd(), "models", "mnist_rgb2gray")

model_folders_list = [model_original, model_cifar, model_mnist_color, model_mnist_rgb2gray]

# Image folder in every model folder
for model_folder in model_folders_list:
    makedirs(os.path.join(model_folder, "img"))

## Choose the dataset here

In [37]:
# uncomment one of the assignments below
# WHICH_DATA_FLAG = "mnist_original"
#WHICH_DATA_FLAG = "cifar"
#WHICH_DATA_FLAG = "mnist_color"
WHICH_DATA_FLAG = "mnist_rgb2gray"


## Parameters dependent on dataset, no input needed

In [38]:
batch_size = 250
image_size = 28
n_input_channels = 1
if WHICH_DATA_FLAG == "mnist_original":
    model_folder = model_original
    model_filename = "mnist_original"
elif WHICH_DATA_FLAG == "mnist_color":
    model_folder = model_mnist_color
    model_filename = "mnist_color"
    n_input_channels = 3
elif WHICH_DATA_FLAG == "mnist_rgb2gray":
    model_folder = model_mnist_rgb2gray
    model_filename = "mnist_rgb2gray"
elif WHICH_DATA_FLAG == "cifar":
    model_folder = model_cifar
    model_filename = "cifar"
    n_input_channels = 3
    image_size = 32

img_folder = os.path.join(model_folder, "img")

## Function to transform original MNIST to colored MNIST or rgb2gray MNIST

In [39]:
def color_dataset(raw_data, to_gray=False):
    N = len(raw_data)
    if to_gray:
        n_channels = 1
    else:
        n_channels = 3
    
    raw_data = raw_data.view(N, 28, 28, 1)
    
    try:
        lena = PILImage.open('./resources/lena.png')    
    except:
        print("Lena image could not be found, please check ./resources/lena.png")
        return 1
            

    # Extend to RGB
    data_rgb = np.concatenate([raw_data, raw_data, raw_data], axis=3)
    
    # Make binary
    data_binary = (data_rgb > 0.5)
    data_color = np.zeros((N, 28, 28, n_channels))
    
    for i in range(N):
        # Take a random crop of the Lena image (background)
        x_c = np.random.randint(0, lena.size[0] - 28)
        y_c = np.random.randint(0, lena.size[1] - 28)
        image = lena.crop((x_c, y_c, x_c + 28, y_c + 28))
        image = np.asarray(image) # / 255.0 REMOVED DIVISION HERE TO MAKE EVERY DATASET EQUAL

        ## COPIED IMAGE BECAUSE "READ-ONLY" ERROR
        new_image = image.copy()
        # Change color distribution # SHOULD THIS ONLY HAPPEN IF TO_GRAY == FALSE?
        for j in range(3):
            new_image[:, :, j] = (new_image[:, :, j] + np.random.uniform(0, 1)) / 2.0

        # Invert the colors at the location of the number
        new_image[data_binary[i]] = 1 - new_image[data_binary[i]]
        if to_gray:
            data_color[i] = np.reshape(rgb2gray(new_image), (28, 28, 1))
        else:
            data_color[i] = new_image
        
    return torch.from_numpy(data_color)

## Load the datasets

In [40]:
def get_data(data_flag):
    # check if flag is correct
    assert data_flag in ["mnist_original", "cifar", "mnist_color", "mnist_rgb2gray"]
    
    if data_flag == "cifar":
        # extract cifar data
        data = get_cifar10(os.path.join(data_folder, "cifar"), one_hot=False)

        # extract training,validation and test data
        X_train, Y_train = torch.from_numpy(data['train'].images), torch.from_numpy(data['train'].labels)
        X_validation, Y_validation = torch.from_numpy(data['validation'].images), torch.from_numpy(data['validation'].labels)
        X_test, Y_test = torch.from_numpy(data['test'].images), torch.from_numpy(data['test'].labels)

        # create datasets
        train_data = TensorDataset(X_train, Y_train)
        valid_data = TensorDataset(X_validation, Y_validation)
        test_data = TensorDataset(X_test, Y_test)
        return train_data, valid_data, test_data
        
    # Transforms to perform on loaded dataset. Normalize around mean 0.1307 and std 0.3081 for optimal pytorch results. 
    # source: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/4
    transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,),(0.3081,))])

    # Load datasets into reproduction/data/mnist
    mnist_train = DataLoader(torchvision.datasets.MNIST(os.path.join(data_folder, "mnist_original"), train=True, download=True, transform=transforms))
    mnist_test = DataLoader(torchvision.datasets.MNIST(os.path.join(data_folder, "mnist_original"), train=False, download=True, 
                                                       transform=transforms))
        
    mnist_train_data = mnist_train.dataset.data
    mnist_train_targets = mnist_train.dataset.targets
    
    x_test = mnist_test.dataset.data
    y_test = mnist_test.dataset.targets
    
    if data_flag == "mnist_original":
        x_train = mnist_train_data[0:55000]
        y_train = mnist_train_targets[0:55000]

        x_valid = mnist_train_data[55000:60000]
        y_valid = mnist_train_targets[55000:60000]
        
        train_data = TensorDataset(x_train, y_train)
        valid_data = TensorDataset(x_valid, y_valid)
        test_data = TensorDataset(x_test, y_test)
        
        return train_data, valid_data, test_data

    to_gray = (data_flag == "mnist_rgb2gray")
    x_train_color = color_dataset(mnist_train_data, to_gray)
    x_test_color = color_dataset(x_test, to_gray)

    # ADDED THIS
    mnist_train = TensorDataset(x_train_color, mnist_train_targets)
    mnist_test = TensorDataset(x_test_color, y_test)

    # first 55000 examples for training
    x_train = mnist_train[0:55000][0]
    y_train = mnist_train[0:55000][1]
    # y_train = mnist_train_targets[0:55000]

    # 5000 examples for validation set
    x_valid = mnist_train[55000:60000][0]
    y_valid = mnist_train[55000:60000][1]

    # 10000 examples in test set
    x_test = mnist_test[:][0]
    y_test = mnist_test[:][1]

    train_data = TensorDataset(x_train, y_train)
    valid_data = TensorDataset(x_valid, y_valid)
    test_data = TensorDataset(x_test, y_test)
    
    return train_data, valid_data, test_data

In [41]:
train_data, valid_data, test_data = get_data(WHICH_DATA_FLAG)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /Users/TomLotze/Documents/Artificial Intelligence/Year1/FACT/FACT/Reproduction/data/mnist_original/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting /Users/TomLotze/Documents/Artificial Intelligence/Year1/FACT/FACT/Reproduction/data/mnist_original/MNIST/raw/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /Users/TomLotze/Documents/Artificial Intelligence/Year1/FACT/FACT/Reproduction/data/mnist_original/MNIST/raw/train-labels-idx1-ubyte.gz


113.5%

Extracting /Users/TomLotze/Documents/Artificial Intelligence/Year1/FACT/FACT/Reproduction/data/mnist_original/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /Users/TomLotze/Documents/Artificial Intelligence/Year1/FACT/FACT/Reproduction/data/mnist_original/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting /Users/TomLotze/Documents/Artificial Intelligence/Year1/FACT/FACT/Reproduction/data/mnist_original/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /Users/TomLotze/Documents/Artificial Intelligence/Year1/FACT/FACT/Reproduction/data/mnist_original/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting /Users/TomLotze/Documents/Artificial Intelligence/Year1/FACT/FACT/Reproduction/data/mnist_original/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [42]:
# plt.imshow(train_data[4][0].detach().numpy().reshape(32, 32, 3))

## Model parameters

In [43]:
# COPIED FROM THE ORIGINAL IMPLEMENTATION
# training parameters
learning_rate = 0.002
training_epochs = 1500

# frequency of testing and saving
test_display_step = 5    # how many epochs we do evaluate on the test set once, default 100
save_step = 50            # how frequently do we save the model to disk

# elastic deformation parameters
sigma = 4
alpha = 20

# lambda's are the ratios between the four error terms
lambda_class = 20
lambda_ae = 1 # autoencoder
lambda_1 = 1 # push prototype vectors to have meaningful decodings in pixel space
lambda_2 = 1 # cluster training examples around prototypes in latent space


input_height = input_width =  28    # MNIST data input shape  
input_size = input_height * input_width * n_input_channels   # 784
n_classes = 10

# Network Parameters
n_prototypes = 15         # the number of prototypes
n_layers = 4

# height and width of each layers' filters
f_1 = 3
f_2 = 3
f_3 = 3
f_4 = 3

# stride size in each direction for each of the layers
s_1 = 2
s_2 = 2
s_3 = 2
s_4 = 2

# number of feature maps in each layer
n_map_1 = 32
n_map_2 = 32
n_map_3 = 32
n_map_4 = 10

# the shapes of each layer's filter
# [out channel, in_channel, 3, 3]
filter_shape_1 = [n_map_1, n_input_channels, f_1, f_1]
filter_shape_2 = [n_map_2, n_map_1, f_2, f_2]
filter_shape_3 = [n_map_3, n_map_2, f_3, f_3]
filter_shape_4 = [n_map_4, n_map_3, f_4, f_4]

# strides for each layer (changed to tuples)
stride_1 = [s_1, s_1]
stride_2 = [s_2, s_2]
stride_3 = [s_3, s_3]
stride_4 = [s_4, s_4]


## Model construction

In [44]:
class Encoder(nn.Module):
    '''Encoder'''
    def __init__(self):
        super(Encoder, self).__init__()
        
        # height and width of each layers' filters
        f_1 = 3
        f_2 = 3
        f_3 = 3
        f_4 = 3
        
        # define layers
        self.enc_l1 = nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=2, padding=0)
        self.enc_l2 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=0)
        self.enc_l3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=0)
        self.enc_l4 = nn.Conv2d(32, 10, kernel_size=3, stride=2, padding=0)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        
    def pad_image(self, img):
        ''' Takes an input image (batch) and pads according to Tensorflows SAME padding'''
        input_h = img.shape[2]
        input_w = img.shape[3]
        stride = 2 
        filter_h = 3
        filter_w = 3

        output_h = int(ceil(float(input_h)) / float(stride))
        output_w = output_h

        if input_h % stride == 0:
            pad_height = max((filter_h - stride), 0)
        else:
            pad_height = max((filter_h - (input_h % stride), 0))

        pad_width = pad_height

        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        padded_img = torch.zeros(img.shape[0], img.shape[1], input_h + pad_height, input_w + pad_width)
        padded_img[:,:, pad_top:-pad_bottom, pad_left:-pad_right] = img

        return padded_img
        
    def forward(self, x):
        pad_x = self.pad_image(x)
        el1 = self.relu(self.enc_l1(pad_x))
        
        pad_el1 = self.pad_image(el1)
        el2 = self.relu(self.enc_l2(pad_el1))
    
        pad_el2 = self.pad_image(el2)
        el3 = self.relu(self.enc_l3(pad_el2))
        
        pad_el3 = self.pad_image(el3)
        el4 = self.relu(self.enc_l4(pad_el3))
        
        return el4
        

class nn_prototype(nn.Module):
    '''Model'''
    def __init__(self, n_prototypes=15, n_layers=4, n_classes=10):
        super().__init__()
        
        self.encoder = Encoder()
        
        # initialize prototype - currently not in correct spot
        
        # changed this for the colored mnist, from 40 to 160, the new shape would be 250*10*4*4
        n_features = 40 # size of encoded x - 250 x 10 x 2 x 2
        self.prototype_replacement = nn.Linear(n_features, n_prototypes)
        self.last_layer = nn.Linear(n_prototypes,10)
        self.relu = nn.ReLU()

    
    def forward(self, x):
    
        #encoder step
        enc_x = self.encoder(x)
        
        x = enc_x.view(enc_x.shape[0], -1)
    
    
        x = self.prototype_replacement(x)
        x = self.relu(x)
        
        # classification layer
        logits = self.last_layer(x)
        
        # Softmax to prob dist not needed as cross entropy loss is used?
        
        return logits
        

## Cost function

In [45]:
def loss_function(logits, Y):
    return F.cross_entropy(logits, Y, reduction="mean")

## Accuracy

In [46]:
def compute_acc(logits, labels):
    batch_size = labels.shape[0]
    predictions = logits.argmax(dim=1)
    total_correct = torch.sum(predictions == labels).item()
    accuracy = total_correct / batch_size
    
    return(accuracy)

## Training loop

In [48]:
model = nn_prototype(15,4,10)
batch_size_ = 250

# get validation and test set
valid_dl = DataLoader(valid_data, batch_size=5000, drop_last=False, shuffle=False)
test_dl = DataLoader(test_data, batch_size=10000, drop_last=False, shuffle=False)


# initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# initialize storage for results
train_accs = []
train_losses = []
test_accs = []
test_losses = []
valid_accs = []
valid_losses = []

# training loop
for epoch in range(training_epochs):
    print("\nEpoch:", epoch)

    # load the training data and reshuffle
    train_dl = DataLoader(train_data, batch_size=batch_size_, drop_last=False, shuffle=True)

    # loop over the batches
    for step, (x, Y) in enumerate(train_dl):
        optimizer.zero_grad()
        
        x_plot = x[0].clone()
        
        x = x.view(x.shape[0], n_input_channels, x.shape[1], x.shape[2]).float()

        #Y = Y.long()
        
        # perform forward pass
        logits = model(x)

        # compute the loss
        total_loss = loss_function(logits, Y)

        # backpropagate over the loss
        total_loss.backward()

        # update the weights
        optimizer.step()

        # compute and save accuracy and loss
        train_accuracy = compute_acc(logits, Y)
        train_accs.append(train_accuracy)
        train_losses.append(total_loss.item())


    # print information after a batch
    print('Last train loss of batch:', total_loss.item())
    print('Train acc on batch:', np.mean(train_accs[-step:]))
    print("Last train acc", train_accuracy)


    if epoch % test_display_step == 0:
        # save model and prototypes
        #torch.save(model, model_folder + "/" + model_filename + "_epoch_" + str(epoch) + '.pt')

        
        # perform testing
        with torch.no_grad():
            for (x_test, y_test) in test_dl:
                x_test = x_test.view(x_test.shape[0], n_input_channels, x_test.shape[1], x_test.shape[2]).float()

                # forward pass
                logits = model(x_test)

                # compute loss and accuracy and save
                test_accuracy = compute_acc(logits, y_test)
                test_loss = loss_function(logits, y_test)
                test_accs.append(test_accuracy)
                test_losses.append(test_loss)

            print('\nTest loss:', test_loss.item())
            print('Test acc:', test_accuracy)

    # validation
    with torch.no_grad():
        for (x_valid, y_valid) in valid_dl:
                x_valid = x_valid.view(x_valid.shape[0], n_input_channels, x_valid.shape[1], x_valid.shape[2]).float()
        
                logits = model(x_valid)

                # compute losses and accuracy and save
                valid_accuracy = compute_acc(logits, y_valid)
                valid_loss = loss_function(logits, y_valid)
                valid_accs.append(valid_accuracy)
                valid_losses.append(valid_loss)

        print('\nValid loss:', valid_loss.item())
        print('Valid acc:', valid_accuracy)
    



Epoch: 0
Last train loss of batch: 0.39582154154777527
Train acc on batch: 0.7039634703196347
Last train acc 0.868

Test loss: 0.3356102406978607
Test acc: 0.9027

Valid loss: 0.26358696818351746
Valid acc: 0.927

Epoch: 1
Last train loss of batch: 0.2350766956806183
Train acc on batch: 0.9190319634703198
Last train acc 0.936

Valid loss: 0.1647043526172638
Valid acc: 0.9504

Epoch: 2
Last train loss of batch: 0.1379500776529312
Train acc on batch: 0.9449497716894978
Last train acc 0.944

Valid loss: 0.12545432150363922
Valid acc: 0.9616

Epoch: 3
Last train loss of batch: 0.18207015097141266
Train acc on batch: 0.9561461187214613
Last train acc 0.944

Valid loss: 0.11830883473157883
Valid acc: 0.965

Epoch: 4
Last train loss of batch: 0.11367779225111008
Train acc on batch: 0.9623378995433789
Last train acc 0.96

Valid loss: 0.10067733377218246
Valid acc: 0.971

Epoch: 5
Last train loss of batch: 0.1305655688047409
Train acc on batch: 0.967324200913242
Last train acc 0.964

Test loss


Valid loss: 0.11296052485704422
Valid acc: 0.983

Epoch: 47
Last train loss of batch: 0.014467325061559677
Train acc on batch: 0.996310502283105
Last train acc 0.992

Valid loss: 0.1142958253622055
Valid acc: 0.9784

Epoch: 48
Last train loss of batch: 0.0016372472746297717
Train acc on batch: 0.9960365296803653
Last train acc 1.0

Valid loss: 0.09959948807954788
Valid acc: 0.9826

Epoch: 49
Last train loss of batch: 0.00225879717618227
Train acc on batch: 0.9972420091324201
Last train acc 1.0

Valid loss: 0.11439211666584015
Valid acc: 0.9804

Epoch: 50
Last train loss of batch: 0.010683768428862095
Train acc on batch: 0.9961461187214612
Last train acc 0.992

Test loss: 0.12147355824708939
Test acc: 0.9788

Valid loss: 0.11164817959070206
Valid acc: 0.9812

Epoch: 51
Last train loss of batch: 0.0036782934330403805
Train acc on batch: 0.9947945205479453
Last train acc 1.0

Valid loss: 0.1186271458864212
Valid acc: 0.9828

Epoch: 52
Last train loss of batch: 0.0036088177002966404
Train

KeyboardInterrupt: 