In [2]:
from __future__ import division, print_function, absolute_import
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
import pickle 


In [3]:
class Encoder(nn.Module):
    '''Encoder'''
    def __init__(self):
        super(Encoder, self).__init__()
        
        # height and width of each layers' filters
        f_1 = 3
        f_2 = 3
        f_3 = 3
        f_4 = 3
        
        # define layers
        self.enc_l1 = nn.Conv2d(n_input_channel, 32, kernel_size=3, stride=2, padding=0)
        self.enc_l2 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=0)
        self.enc_l3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=0)
        self.enc_l4 = nn.Conv2d(32, 10, kernel_size=3, stride=2, padding=0)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        
    def pad_image(self, img):
        ''' Takes an input image (batch) and pads according to Tensorflows SAME padding'''
        input_h = img.shape[2]
        input_w = img.shape[3]
        stride = 2 
        filter_h = 3
        filter_w = 3

        output_h = int(ceil(float(input_h)) / float(stride))
        output_w = output_h

        if input_h % stride == 0:
            pad_height = max((filter_h - stride), 0)
        else:
            pad_height = max((filter_h - (input_h % stride), 0))

        pad_width = pad_height

        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left

        padded_img = torch.zeros(img.shape[0], img.shape[1], input_h + pad_height, input_w + pad_width)
        padded_img[:,:, pad_top:-pad_bottom, pad_left:-pad_right] = img

        return padded_img
        
    def forward(self, x):
        pad_x = self.pad_image(x)
        el1 = self.relu(self.enc_l1(pad_x))
        
        pad_el1 = self.pad_image(el1)
        el2 = self.relu(self.enc_l2(pad_el1))
    
        pad_el2 = self.pad_image(el2)
        el3 = self.relu(self.enc_l3(pad_el2))
        
        pad_el3 = self.pad_image(el3)
        el4 = self.relu(self.enc_l4(pad_el3))
        
        return el4
        
class Decoder(nn.Module):
    '''Decoder'''
    def __init__(self):
        super(Decoder, self).__init__()
        # height and width of each layers' filters
        f_1 = 3
        f_2 = 3
        f_3 = 3
        f_4 = 3

        # define layers
        self.dec_l4 = nn.ConvTranspose2d(10, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec_l3 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=0) # the output padding here should be 1 if the images are 32x32
        self.dec_l2 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec_l1 = nn.ConvTranspose2d(32, n_input_channel, kernel_size=3, stride=2, padding=1, output_padding=1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, enc_x):
        dl4 = self.relu(self.dec_l4(enc_x))
        dl3 = self.relu(self.dec_l3(dl4))
        dl2 = self.relu(self.dec_l2(dl3))
        decoded_x = self.sigmoid(self.dec_l1(dl2))
        
        return decoded_x


class nn_prototype(nn.Module):
    '''Model'''
    def __init__(self, n_prototypes=15, n_layers=4, n_classes=10):
        super().__init__()
        
        self.encoder = Encoder()
        self.decoder = Decoder()
        
        # initialize prototype - currently not in correct spot
        
        # changed this for the colored mnist, from 40 to 160, the new shape would be 250*10*4*4
        n_features = 40 # size of encoded x - 250 x 10 x 2 x 2
        self.prototype_feature_vectors = nn.Parameter(torch.empty(size=(n_prototypes, n_features), 
                                                                  dtype=torch.float32).uniform_())
        
        self.last_layer = nn.Linear(n_prototypes,10)
        
    def list_of_distances(self, X, Y):
        '''
        Given a list of vectors, X = [x_1, ..., x_n], and another list of vectors,
        Y = [y_1, ... , y_m], we return a list of vectors
                [[d(x_1, y_1), d(x_1, y_2), ... , d(x_1, y_m)],
                 ...
                 [d(x_n, y_1), d(x_n, y_2), ... , d(x_n, y_m)]],
        where the distance metric used is the sqared euclidean distance.
        The computation is achieved through a clever use of broadcasting.
        '''
        XX = torch.reshape(self.list_of_norms(X), shape=(-1, 1))
        YY = torch.reshape(self.list_of_norms(Y), shape=(1, -1))
        output = XX + YY - 2 * torch.mm(X, Y.t())

        return output

    def list_of_norms(self, X):
        '''
        X is a list of vectors X = [x_1, ..., x_n], we return
            [d(x_1, x_1), d(x_2, x_2), ... , d(x_n, x_n)], where the distance
        function is the squared euclidean distance.
        '''
        return torch.sum(torch.pow(X, 2), dim=1)
    
    def forward(self, x):
        
        #print("Shape of input x", x.shape)
        
        #encoder step
        enc_x = self.encoder(x)
        
        #print("Shape of encoded x", enc_x.shape)
        
        #decoder step
        dec_x = self.decoder(enc_x)
        
        #print("shape of decoded x", dec_x.shape)
        
        # hardcoded input size (not needed, shape already correct)
        # dec_x = dec_x.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
        
        # flatten encoded x to compute distance with prototypes
        n_features = enc_x.shape[1] * enc_x.shape[2] * enc_x.shape[3]
        feature_vectors_flat = torch.reshape(enc_x, shape=[-1, n_features])
        
        #print("Shape of flattened feature vectors", feature_vectors_flat.shape)
        
        # distance to prototype
        prototype_distances = self.list_of_distances(feature_vectors_flat, self.prototype_feature_vectors)
        
        # distance to feature vectors
        feature_vector_distances = self.list_of_distances(self.prototype_feature_vectors, feature_vectors_flat)
        
        # classification layer
        logits = self.last_layer(prototype_distances)
        
        # Softmax to prob dist not needed as cross entropy loss is used?
        
        return dec_x, logits, feature_vector_distances, prototype_distances
        

In [4]:
# load the model
loaded_model = torch.load("./saved_model/gray_mnist_model_color28_20_10_1_1"+"/gray_mnist_cae_color28_20_10_1_1_epoch_30.pt", 
                         map_location=torch.device('cpu'))


  "type " + container_type.__name__ + ". It won't be checked "
  "type " + container_type.__name__ + ". It won't be checked "
  "type " + container_type.__name__ + ". It won't be checked "


In [51]:
print("& "+" & ".join([str(i) for i in list(range(10))])+ " & min index\\\\\\hline")
for i, row in enumerate(loaded_model.last_layer.weight.t()):
    #img = plt.imread("./saved_model/gray_mnist_model_color28_20_10_1_1/img/prototypes_epoch_30/"+str(i)+".png")
    #plt.imshow(img)
    #plt.axis("off")
    #plt.show()
    row = np.round(row.detach().numpy(), 2)
    print("\\includegraphics[width=0.035\\textwidth]{Images/%d.png}"%i+" & "+" & ".join([str(i) for i in row]), "&", np.argmin(row), "\\\\\\hline")
    #print(row, np.argmin(row), np.min(row))

& 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & min index\\\hline
\includegraphics[width=0.035\textwidth]{Images/0.png} & 0.36 & 0.33 & -0.2 & -0.72 & 0.67 & -0.13 & 0.81 & -0.89 & 0.3 & 0.14 & 7 \\\hline
\includegraphics[width=0.035\textwidth]{Images/1.png} & -0.22 & -0.3 & -1.58 & -0.22 & 0.28 & 0.77 & 0.12 & 0.16 & 0.47 & 0.04 & 2 \\\hline
\includegraphics[width=0.035\textwidth]{Images/2.png} & 0.14 & -0.32 & 0.15 & -0.23 & 0.28 & 0.47 & 0.35 & 0.16 & -1.64 & 0.08 & 8 \\\hline
\includegraphics[width=0.035\textwidth]{Images/3.png} & -1.03 & 0.47 & -0.25 & 0.48 & 0.23 & -0.31 & 0.18 & 0.07 & -0.2 & 0.49 & 0 \\\hline
\includegraphics[width=0.035\textwidth]{Images/4.png} & 0.2 & -0.51 & 0.36 & 0.28 & -1.24 & 0.3 & -0.23 & 0.06 & 0.24 & 0.12 & 4 \\\hline
\includegraphics[width=0.035\textwidth]{Images/5.png} & -0.03 & -0.08 & -0.24 & -0.08 & 0.16 & -0.15 & -0.84 & 0.55 & 0.23 & 0.16 & 6 \\\hline
\includegraphics[width=0.035\textwidth]{Images/6.png} & -0.99 & 0.49 & 0.02 & 0.37 & -0.12 & 0.19 & 

In [16]:
# load the model
original = torch.load("./saved_model/mnist_model/mnist_cae_epoch_30.pt"+"", 
                         map_location=torch.device('cpu'))

for i, row in enumerate(original.last_layer.weight.t()):
    #img = plt.imread("./saved_model/mnist_model/img/prototypes_epoch_30/"+str(i)+".png")
    #plt.imshow(img)
    #plt.show()
    row = np.round(row.detach().numpy(), 2)
    print(row, np.argmin(row), np.min(row))

[ 0.64  0.37 -0.56  0.29 -1.15  0.45  0.36  0.54  0.02  0.03] 4 -1.15
[ 0.53  0.35  0.49 -0.48  0.21 -1.3  -0.2   0.24 -0.01 -0.37] 5 -1.3
[ 0.27 -0.53 -0.17  0.13  0.08 -0.13  0.2  -0.02  0.42 -0.54] 9 -0.54
[-1.63  0.42  0.04  0.44  0.16  0.11 -0.07  0.27  0.3  -0.25] 0 -1.63
[ 0.15  0.06 -0.16 -1.29  0.45 -0.06  0.62  0.02  0.34  0.11] 3 -1.29
[ 0.64 -0.36 -0.05 -0.26 -0.81  0.33 -0.06 -0.14  0.72  0.63] 4 -0.81
[-0.09  0.44  0.44  0.09 -0.85  0.02  0.24 -0.3  -0.14 -1.54] 9 -1.54
[ 0.01 -0.41 -0.23  0.38  0.67 -0.59 -0.15 -0.08 -0.03  0.45] 5 -0.59
[-0.12  0.23 -0.14  0.15  0.48  0.36  0.5  -1.42  0.51  0.14] 7 -1.42
[-0.35  0.17 -0.7  -0.74  1.09  0.2  -0.04  0.38 -0.82  0.46] 8 -0.82
[ 0.03 -0.77  0.34  0.17 -0.31  0.33  0.18 -0.06 -0.78 -0.26] 8 -0.78
[-0.05 -0.63  0.44  0.02  0.19 -0.12 -0.24  0.19 -0.13  0.16] 1 -0.63
[-0.3  -0.39  0.24  0.03 -0.33  0.38 -0.32 -0.35  0.05  0.4 ] 1 -0.39
[ 0.33  0.37  0.02  0.21 -0.22  0.24  0.06  0.38 -1.39  0.13] 8 -1.39
[-0.22  0.06  0.    0