## Imports

In [62]:
from __future__ import division, print_function, absolute_import
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from math import ceil

## Helper functions
Helper functions borrowed from original paper by Li et al. 

In [220]:
def makedirs(path):
    '''
    if path does not exist in the file system, create it
    '''
    if not os.path.exists(path):
        os.makedirs(path)

def list_of_distances(X, Y):
    '''
    Given a list of vectors, X = [x_1, ..., x_n], and another list of vectors,
    Y = [y_1, ... , y_m], we return a list of vectors
            [[d(x_1, y_1), d(x_1, y_2), ... , d(x_1, y_m)],
             ...
             [d(x_n, y_1), d(x_n, y_2), ... , d(x_n, y_m)]],
    where the distance metric used is the sqared euclidean distance.
    The computation is achieved through a clever use of broadcasting.
    '''
    XX = torch.reshape(list_of_norms(X), shape=(-1, 1))
    YY = torch.reshape(list_of_norms(Y), shape=(1, -1))
    output = XX + YY - 2 * torch.mm(X, Y.t())

    return output

def list_of_norms(X):
    '''
    X is a list of vectors X = [x_1, ..., x_n], we return
        [d(x_1, x_1), d(x_2, x_2), ... , d(x_n, x_n)], where the distance
    function is the squared euclidean distance.
    '''
    return torch.sum(torch.pow(X, 2), dim=1)

def print_and_write(str, file):
    '''
    print str to the console and also write it to file
    '''
    print(str)
    file.write(str + '\n')

## Create necessary folders

In [17]:
# data folder
makedirs('./data/mnist')

# Models folder
model_folder = os.path.join(os.getcwd(), "saved_model", "mnist_model", "mnist_cae_1")
makedirs(model_folder)

# Image folder
img_folder = os.path.join(model_folder, "img")
makedirs(img_folder)

# Model filename
model_filename = "mnist_cae"

## Dataset - Pytorch
#### <font color='red'>Double check the normalization mean and stdev for dataset</font>
#### <font color='red'>Double check parameters Dataloader (e.g. shuffle on or off, different batch sizes for train/valid/test)</font>

In [23]:
# Transforms to perform on loaded dataset. Normalize around mean 0.1307 and std 0.3081 for optimal pytorch results. 
# source: https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/4
transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,),(0.3081,))])

# Load datasets into reproduction/data/mnist. Download if data not present. 
mnist_train = DataLoader(torchvision.datasets.MNIST('./data/mnist', train=True, download=True, transform=transforms))

mnist_train_data = mnist_train.dataset.data
mnist_train_targets = mnist_train.dataset.targets

# first 55000 examples for training
x_train = mnist_train_data[0:55000]
y_train = mnist_train_targets[0:55000]

# 5000 examples for validation set
x_valid = mnist_train_data[55000:60000]
y_valid = mnist_train_targets[55000:60000]

# 10000 examples in test set

mnist_test = DataLoader(torchvision.datasets.MNIST('./data/mnist', train=False, download=True, transform=transforms))

x_test = mnist_test.dataset.data
y_test = mnist_test.dataset.targets

train_data = TensorDataset(x_train, y_train)
valid_data = TensorDataset(x_valid, y_valid)
test_data = TensorDataset(x_test, y_test)

batch_size = 250

# Datasets in DataLoader, can be used in iteration: for (x, y) in train_dl...
# Check parameters
train_dl = DataLoader(train_data, batch_size=batch_size, drop_last=False, shuffle=False)
valid_dl = DataLoader(valid_data, batch_size=batch_size, drop_last=False, shuffle=False)
test_dl = DataLoader(test_data, batch_size=batch_size, drop_last=False, shuffle=False)

## Parameters

In [4]:
# COPIED FROM THE ORIGINAL IMPLEMENTATION
# training parameters
learning_rate = 0.002
training_epochs = 1500

# frequency of testing and saving
test_display_step = 100   # how many epochs we do evaluate on the test set once
save_step = 50            # how frequently do we save the model to disk

# elastic deformation parameters
sigma = 4
alpha = 20

# lambda's are the ratios between the four error terms
lambda_class = 20
lambda_ae = 1 # autoencoder
lambda_1 = 1 # push prototype vectors to have meaningful decodings in pixel space
lambda_2 = 1 # cluster training examples around prototypes in latent space


input_height = input_width =  28    # MNIST data input shape 
n_input_channel = 1     # the number of color channels; for MNIST is 1.
input_size = input_height * input_width * n_input_channel   # 784
n_classes = 10

# Network Parameters
n_prototypes = 15         # the number of prototypes
n_layers = 4

# height and width of each layers' filters
f_1 = 3
f_2 = 3
f_3 = 3
f_4 = 3

# stride size in each direction for each of the layers
s_1 = 2
s_2 = 2
s_3 = 2
s_4 = 2

# number of feature maps in each layer
n_map_1 = 32
n_map_2 = 32
n_map_3 = 32
n_map_4 = 10

# the shapes of each layer's filter
# [out channel, in_channel, 3, 3]
filter_shape_1 = [n_map_1, n_input_channel, f_1, f_1]
filter_shape_2 = [n_map_2, n_map_1, f_2, f_2]
filter_shape_3 = [n_map_3, n_map_2, f_3, f_3]
filter_shape_4 = [n_map_4, n_map_3, f_4, f_4]

# strides for each layer (changed to tuples)
stride_1 = [s_1, s_1]
stride_2 = [s_2, s_2]
stride_3 = [s_3, s_3]
stride_4 = [s_4, s_4]


## Initialize encoder and decoder

In [5]:
std_weights = 0.01

weights = {
    'enc_f1': nn.Parameter(std_weights * torch.randn(filter_shape_1,
                                           dtype=torch.float32)),
    'enc_f2': nn.Parameter(std_weights * torch.randn(filter_shape_2,
                                           dtype=torch.float32)), 
    'enc_f3': nn.Parameter(std_weights * torch.randn(filter_shape_3,
                                           dtype=torch.float32)), 
    'enc_f4': nn.Parameter(std_weights * torch.randn(filter_shape_4,
                                           dtype=torch.float32)), 
    'dec_f4': nn.Parameter(std_weights * torch.randn(filter_shape_4,
                                           dtype=torch.float32)), 
    'dec_f3': nn.Parameter(std_weights * torch.randn(filter_shape_3,
                                           dtype=torch.float32)), 
    'dec_f2': nn.Parameter(std_weights * torch.randn(filter_shape_2,
                                           dtype=torch.float32)),
    'dec_f1': nn.Parameter(std_weights * torch.randn(filter_shape_1,
                                           dtype=torch.float32)),
}


biases = {
    'enc_b1': nn.Parameter(torch.zeros([n_map_1], dtype=torch.float32)),
    'enc_b2': nn.Parameter(torch.zeros([n_map_2], dtype=torch.float32)),
    'enc_b3': nn.Parameter(torch.zeros([n_map_3], dtype=torch.float32)),
    'enc_b4': nn.Parameter(torch.zeros([n_map_4], dtype=torch.float32)),
    'dec_b4': nn.Parameter(torch.zeros([n_map_3], dtype=torch.float32)),
    'dec_b3': nn.Parameter(torch.zeros([n_map_2], dtype=torch.float32)),
    'dec_b2': nn.Parameter(torch.zeros([n_map_1], dtype=torch.float32)),
    'dec_b1': nn.Parameter(torch.zeros([n_input_channel], dtype=torch.float32)),
}

last_layer = {
    'w': nn.Parameter(torch.randn([n_prototypes, n_classes],
                                       dtype=torch.float32))
}


### Print shapes of all parameters

In [6]:
# Printing shapes of all parameters
print("weights")
for weight in weights.keys():
    print(weight, weights[weight].shape)
print("biases")
for b in biases.keys():
    print(b, biases[b].shape)
print("last_layer")
print(last_layer['w'].shape)

weights
enc_f1 torch.Size([32, 1, 3, 3])
enc_f2 torch.Size([32, 32, 3, 3])
enc_f3 torch.Size([32, 32, 3, 3])
enc_f4 torch.Size([10, 32, 3, 3])
dec_f4 torch.Size([10, 32, 3, 3])
dec_f3 torch.Size([32, 32, 3, 3])
dec_f2 torch.Size([32, 32, 3, 3])
dec_f1 torch.Size([32, 1, 3, 3])
biases
enc_b1 torch.Size([32])
enc_b2 torch.Size([32])
enc_b3 torch.Size([32])
enc_b4 torch.Size([10])
dec_b4 torch.Size([32])
dec_b3 torch.Size([32])
dec_b2 torch.Size([32])
dec_b1 torch.Size([1])
last_layer
torch.Size([15, 10])


## Layer functions
#### <font color='red'>Fix the stride and padding parameters, check if filter in tf is same as weight in pt</font>
Padding discussion pytorch: https://github.com/pytorch/pytorch/issues/3867

Blogpost: https://mmuratarat.github.io/2019-01-17/implementing-padding-schemes-of-tensorflow-in-python

### Padding helper
Based on blogpost above. 

In [107]:
def pad_img(img):
    ''' Takes an input image (batch) and pads according to Tensorflows SAME padding'''
    input_h = img.shape[2]
    input_w = img.shape[3]
    stride = 2 
    filter_h = 3
    filter_w = 3
    
    output_h = int(ceil(float(input_h)) / float(stride))
    output_w = output_h
    
    if input_h % stride == 0:
        pad_height = max((filter_h - stride), 0)
    else:
        pad_height = max((filter_h - (input_h % stride), 0))

    pad_width = pad_height
    
    pad_top = pad_height // 2
    pad_bottom = pad_height - pad_top
    pad_left = pad_width // 2
    pad_right = pad_width - pad_left
    
    padded_img = torch.zeros(img.shape[0], img.shape[1], input_h + pad_height, input_w + pad_width)
    padded_img[:,:, pad_top:-pad_bottom, pad_left:-pad_right] = img

    return padded_img

In [199]:
def conv_layer(img, kernel, bias, strides, padding="VALID", nonlinearity = nn.ReLU()):
    img = pad_img(img)
    conv = F.conv2d(img, kernel, bias=bias, stride=strides, padding=padding)
    out = nonlinearity(conv)
    return out

#### stride must be tuple for torch, is a list in tf
#### padding is different, tf uses same/valid, torch a int or list of ints
#### is the filter the same ass weights argument for conv2d?

# tensorflow's conv2d_transpose needs to know the shape of the output
def deconv_layer(img, kernel, bias, strides, padding="VALID", nonlinearity=nn.ReLU(), out_padding = None):
#     img = pad_img(img)
    deconv = F.conv_transpose2d(img, kernel, bias=bias, stride=strides, padding=padding, output_padding=out_padding)
    out = nonlinearity(deconv)
    return out

# def deconv_layer(img, kernel, bias, strides, padding="VALID", nonlinearity=nn.ReLU(), out_size = None):
#     in_channel = img.shape[1]
#     out_channel = kernel.shape[1]
#     k_size = (kernel.shape[2], kernel.shape[3])
    
#     img = pad_img(img)
    
#     upsample = nn.ConvTranspose2d(in_channel, out_channel, k_size, strides, padding=(0,0))
    
#     out = upsample(img, output_size=out_size)
    

def fc_layer(input, weight, bias, nonlinearity=nn.ReLU()):
    return nonlinearity(torch.mm(input, weight) + bias)

## Model construction

In [200]:
# Dummy tensor for debugging
# X = torch.empty(batch_size, n_input_channel, input_width, input_height)
X,Y = next(iter(train_dl))
X = X.view(250,1,28,28).float()
# X in shape 250x28x28 (Batch x H x W)
print(X.shape)
# Y in shape 250 (B). Needs to be converted to one-hot for reproduction of paper

torch.Size([250, 1, 28, 28])


### Encoder

In [201]:
PADDING_FLAG = (0,0)
# eln means the output of the nth layer of the encoder
print('X: ', X.shape)
el1 = conv_layer(X, weights['enc_f1'], biases['enc_b1'], stride_1, padding=PADDING_FLAG)
print('EL1: ',el1.shape)
el2 = conv_layer(el1, weights['enc_f2'], biases['enc_b2'], stride_2, padding=PADDING_FLAG)
print('EL2: ',el2.shape)
el3 = conv_layer(el2, weights['enc_f3'], biases['enc_b3'], stride_3, padding=PADDING_FLAG)
print('EL3: ',el3.shape)
el4 = conv_layer(el3, weights['enc_f4'], biases['enc_b4'], stride_4, padding=PADDING_FLAG)
print('EL4: ',el4.shape)

input_shape = list(X.shape)
l1_shape = list(el1.shape)
l2_shape = list(el2.shape)
l3_shape = list(el3.shape)
l4_shape = list(el4.shape)

flatten_size = l4_shape[1] * l4_shape[2] * l4_shape[3]
n_features = flatten_size

# feature vectors is the flattened output of the encoder
feature_vectors = torch.reshape(el4, shape=[-1, flatten_size])
#print(feature_vectors.shape)
# initialize the prototype feature vectors
prototype_feature_vectors = nn.Parameter(torch.empty(size=
                                        [n_prototypes, n_features],
                                        dtype=torch.float32).uniform_())

deconv_batch_size = torch.eye(feature_vectors.shape[0])
# this is necessary for prototype images evaluation
reshape_feature_vectors = torch.reshape(feature_vectors, shape=[-1, l4_shape[1],
   l4_shape[2], l4_shape[3]])

X:  torch.Size([250, 1, 28, 28])
EL1:  torch.Size([250, 32, 14, 14])
EL2:  torch.Size([250, 32, 7, 7])
EL3:  torch.Size([250, 32, 4, 4])
EL4:  torch.Size([250, 10, 2, 2])


### Decoder
#### <font color='red'> Padding needs to be fixed. Currently 'hacked'</font>

In [206]:
print('Reshape_feature_vectors :', reshape_feature_vectors.shape)
print(pad_img(reshape_feature_vectors).shape)
dl4 = deconv_layer(reshape_feature_vectors, weights['dec_f4'], biases['dec_b4'],
                   strides=stride_4, padding=1, out_padding = 1)
# want to dl4 to have shape 250 x 32 x 4 x 4
print('DL4: ',dl4.shape)

dl3 = deconv_layer(dl4, weights['dec_f3'], biases['dec_b3'],
                   strides=stride_3, padding=1, out_padding=0)
# want to dl3 to have shape 250 x 32 x 7 x 7
print('DL3: ',dl3.shape)

dl2 = deconv_layer(dl3, weights['dec_f2'], biases['dec_b2'],
                   strides=stride_2, padding=1, out_padding=1)
# want to dl2 to have shape 250 x 32 x 14 x 14
print('DL2: ',dl2.shape)

dl1 = deconv_layer(dl2, weights['dec_f1'], biases['dec_b1'],
                   strides=stride_1, padding=1, out_padding=1,
                   nonlinearity=nn.Sigmoid())
# want to dl4 to have shape 250 x 1 x 28 x 28
print('DL1: ',dl1.shape)

Reshape_feature_vectors : torch.Size([250, 10, 2, 2])
torch.Size([250, 10, 3, 3])
DL4:  torch.Size([250, 32, 4, 4])
DL3:  torch.Size([250, 32, 7, 7])
DL2:  torch.Size([250, 32, 14, 14])
DL1:  torch.Size([250, 1, 28, 28])


In [232]:
'''
X_decoded is the decoding of the encoded feature vectors in X;
we reshape it to match the shape of the training input
X_true is the correct output for the autoencoder
'''

X_decoded = torch.reshape(dl1, shape=(-1, input_size))
X_true = X.view(-1, input_size)
print(X_decoded.shape)
print(X_true.shape)

torch.Size([250, 784])
torch.Size([250, 784])


## Prototype distances

In [233]:
'''
prototype_distances is the list of distances from each x_i to every prototype
in the latent space
feature_vector_distances is the list of distances from each prototype to every x_i
in the latent space
'''
prototype_distances = list_of_distances(feature_vectors,
                                        prototype_feature_vectors)
prototype_distances = torch.Tensor(prototype_distances)
feature_vector_distances = list_of_distances(prototype_feature_vectors,
                                             feature_vectors)
feature_vector_distances = torch.Tensor(feature_vector_distances)

# the logits are the weighted sum of distances from prototype_distances
logits = torch.mm(prototype_distances, last_layer['w'])
probability_distribution = F.softmax(logits, dim=1)

## Cost function

In [244]:
'''
the error function consists of 4 terms, the autoencoder loss,
the classification loss, and the two requirements that every feature vector in
X look like at least one of the prototype feature vectors and every prototype
feature vector look like at least one of the feature vectors in X.
'''
ae_error = torch.mean(list_of_norms(X_decoded - X_true))
class_error = F.cross_entropy(logits, Y)
error_1 = torch.mean(torch.min(feature_vector_distances, axis=1)[0])
error_2 = torch.mean(torch.min(prototype_distances, axis = 1)[0])

# total_error is the our minimization objective
total_error = lambda_class * class_error +\
              lambda_ae * ae_error + \
              lambda_1 * error_1 + \
              lambda_2 * error_2

print(total_error)

tensor(5550998., grad_fn=<AddBackward0>)


## Accuracy

In [245]:
# accuracy is not the classification error term; it is the percentage accuracy
correct_prediction = tf.equal(tf.argmax(logits, 1),
                              tf.argmax(Y, 1),
                              name='correct_prediction')
accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype=tf.float32),
                          name='accuracy')

NameError: name 'tf' is not defined