In [None]:
# %% Deep learning - Section 16.157
#    Autoencoders with tied weights

# This code pertains a deep learning course provided by Mike X. Cohen on Udemy:
#   > https://www.udemy.com/course/deeplearning_x
# The "base" code in this repository is adapted (with very minor modifications)
# from code developed by the course instructor (Mike X. Cohen), while the
# "exercises" and the "code challenges" contain more original solutions and
# creative input from my side. If you are interested in DL (and if you are
# reading this statement, chances are that you are), go check out the course, it
# is singularly good.


In [29]:
# %% Libraries and modules
import numpy               as np
import matplotlib.pyplot   as plt
import torch
import torch.nn            as nn
import seaborn             as sns
import copy
import torch.nn.functional as F
import pandas              as pd
import scipy.stats         as stats
import sklearn.metrics     as skm
import time
import sys

from torch.utils.data                 import DataLoader,TensorDataset
from sklearn.model_selection          import train_test_split
from google.colab                     import files
from torchsummary                     import summary
from scipy.stats                      import zscore
from sklearn.decomposition            import PCA
from IPython                          import display
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')
plt.style.use('default')


In [None]:
# %% A brief aside on Linear and Parameter classes

# Input and weights with nn.Parameter()
X  = torch.rand(10,50)
W1 = nn.Parameter(torch.rand(128,50))

# Print some info
print(W1)
print()

print(W1.shape)
print(W1.t().shape)
print()

# Compute an output
Y = X@W1.t()
print(Y.shape),print(),print()

# Input and weights with nn.Linear()
W2 = nn.Linear(128,50)
print()

# Print some info
print(W2)
print()

print(W2.weight.shape)
print(W2.weight.t().shape)
print()

# Compute an output
Y = X @ (W2.weight)
print(Y.shape)

In [None]:
# %% Notice the swapped order

# With nn.Parameter() the size of W is [output,input], while with nn.Linear we
# have [input,output]; once more, freaking Python chaos
print(W1.shape)
print(W2.weight.shape)


In [28]:
# %% See all attributes of Linear class

dir(nn.Linear)

# When Linear uses Parameter the order is swapped
??nn.Linear.forward
??nn.Linear.__init__


In [30]:
# %% Data

# Load data
data = np.loadtxt(open('sample_data/mnist_train_small.csv','rb'),delimiter=',')

# Split labels from data
labels = data[:,0]
data   = data[:,1:]

# Normalise data (original range is (0,255))
data_norm = data / np.max(data)

# Convert to tensor
data_tensor = torch.tensor(data_norm).float()
labels_tensor = torch.tensor(labels).long()


In [37]:
# %% Model class

def gen_model():

    class mnist_AE(nn.Module):
        def __init__(self):
            super().__init__()

            # Architecture (tied weights with nn.Parameter(); initialise some
            # random weights)
            self.input  = nn.Linear(784,128)
            self.encode = nn.Parameter(torch.randn(50,128))
            #self.encode = nn.Linear(128, 50)
            #self.mid    = nn.Linear( 50,128)
            self.decode = nn.Linear(128,784)

        # Forward propagation (with tied weights)
        def forward(self,x):

            # Normal first step
            x = F.relu(self.input(x))

            # Second pass (direct multiplication)
            x = x.t()
            x = F.relu(self.encode @ x)

            # Mirror decoding layer (transpose back)
            x = F.relu(self.encode.t() @ x)
            x = x.t()

            # Normal final step
            x = torch.sigmoid(self.decode(x))

            return x

    # Generate model instance
    ANN = mnist_AE()

    # Loss function
    loss_fun = nn.MSELoss()

    # Optimizer
    optimizer = torch.optim.Adam(ANN.parameters(),lr=0.001)

    return ANN,loss_fun,optimizer


In [None]:
# %% Test on some data

ANN,loss_fun,optimizer = gen_model()

X = data_tensor[:5,:]
yHat = ANN(X)

print(X.shape)
print(yHat.shape)


In [44]:
# %% Function to train the model

def train_model(ANN,loss_fun,optimizer):

    # Parameters, inizialise vars
    num_epochs = 10000
    losses     = []

    # Loop over epochs (no minibatch loop)
    for epoch_i in range(num_epochs):

        # Select only a random subset of images
        random_i = np.random.choice(data_tensor.shape[0],size=32)
        X        = data_tensor[random_i,:]

        # Forward propagation and loss
        yHat = ANN(X)
        loss = loss_fun(yHat,X)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Loss in this epoch
        losses.append(loss.item())

    return losses,ANN


In [61]:
# %% Train and fit

ANN,loss_fun,optimizer = gen_model()
losses,ANN             = train_model(ANN,loss_fun,optimizer)


In [None]:
# %% Plotting

phi = (1 + np.sqrt(5)) / 2
fig = plt.figure(figsize=(phi*5,5))

plt.plot(losses,'-')
plt.xlabel('Epochs')
plt.ylabel('Model loss')
plt.title('Model loss over epochs\n(tied weights)')

plt.savefig('figure47_tied_weights.png')
plt.show()
files.download('figure47_tied_weights.png')


In [None]:
# %% Plotting

X = data_tensor[:5,:]
yHat = ANN(X)

phi = (1 + np.sqrt(5)) / 2
fig,axs = plt.subplots(2,5,figsize=(1.5*phi*5,5))

for i in range(5):
    axs[0,i].imshow(X[i,:].view(28,28).detach() ,cmap='gray')
    axs[1,i].imshow(yHat[i,:].view(28,28).detach() ,cmap='gray')
    axs[0,i].set_xticks([]), axs[0,i].set_yticks([])
    axs[1,i].set_xticks([]), axs[1,i].set_yticks([])

plt.suptitle('Post-training performance')

plt.savefig('figure48_tied_weights.png')
plt.show()
files.download('figure48_tied_weights.png')


In [None]:
# %% Add noise

# Get a small set of images and add uniform noise to simulate a noisy input
X       = data_tensor[:10,:]
X_noise = X + torch.rand_like(X)/4

# clip at 1 to maintain normalisation
X_noise[X_noise>1] = 1

# Plotting
phi = (1 + np.sqrt(5)) / 2
fig,axs = plt.subplots(2,5,figsize=(1.5*phi*5,5))

for i in range(5):
    axs[0,i].imshow(X[i,:].view(28,28).detach() ,cmap='gray')
    axs[1,i].imshow(X_noise[i,:].view(28,28).detach() ,cmap='gray')
    axs[0,i].set_xticks([]), axs[0,i].set_yticks([])
    axs[1,i].set_xticks([]), axs[1,i].set_yticks([])

plt.suptitle('Noisy data')

plt.savefig('figure49_tied_weights.png')
plt.show()
files.download('figure49_tied_weights.png')


In [None]:
# %% Run the model on simulated noisy data

# Model pass
Y = ANN(X_noise)

# Plotting
phi = (1 + np.sqrt(5)) / 2
fig,axs = plt.subplots(3,10,figsize=(1.5*phi*5,5))

for i in range(10):
    axs[0,i].imshow(X[i,:].view(28,28).detach(),cmap='gray')
    axs[1,i].imshow(X_noise[i,:].view(28,28).detach(),cmap='gray')
    axs[2,i].imshow(Y[i,:].view(28,28).detach(),cmap='gray')
    axs[0,i].set_xticks([]), axs[0,i].set_yticks([])
    axs[1,i].set_xticks([]), axs[1,i].set_yticks([])
    axs[2,i].set_xticks([]), axs[2,i].set_yticks([])

plt.suptitle('Reconstruction of noisy data')

plt.savefig('figure50_tied_weights.png')
plt.show()
files.download('figure50_tied_weights.png')


In [52]:
# %% Exercise 1
#    The network we built here is not a truly mirrored network: We tied the encoder/decoder layers, but left the input
#    and output layers separate. That's not wrong or bad or anything; it's just a choice. Modify the code to create
#    a truly mirrored network, where all decoding layers are tied to their corresponding encoding layers.

# Model class
def gen_model():

    class mnist_AE(nn.Module):
        def __init__(self):
            super().__init__()

            # Architecture (tied weights with nn.Parameter())
            self.input  = nn.Parameter(torch.randn(128,784))
            self.encode = nn.Parameter(torch.randn(50,128))

        # Forward propagation (with tied weights)
        def forward(self,x):

            # Encoder part
            x = x.t()
            x = F.relu(self.input @ x)
            x = F.relu(self.encode @ x)

            # Decoder (mirrored) part
            x = F.relu(self.encode.t() @ x)
            x = torch.sigmoid(self.input.t() @ x)
            x = x.t()

            return x

    # Generate model instance
    ANN = mnist_AE()

    # Loss function
    loss_fun = nn.MSELoss()

    # Optimizer
    optimizer = torch.optim.Adam(ANN.parameters(),lr=0.001)

    return ANN,loss_fun,optimizer


In [60]:
# %% Exercise 2
#    You don't need to use nn.Parameter; you can still accomplish what we did by using nn.Linear and extracting the
#    weights matrices. Rewrite the code to use nn.Linear instead of nn.Parameter.

# Model class
def gen_model():

    class mnist_AE(nn.Module):
        def __init__(self):
            super().__init__()

            # Architecture (tied weights with nn.Parameter())
            self.input  = nn.Linear(784,128)
            self.encode = nn.Linear(128, 50, bias=False)
            self.decode = nn.Linear(128,784)

        def forward(self,x):

            # First layer
            x = F.relu(self.input(x))

            # Encoder
            x = F.relu(self.encode(x))

            # Decoder (mirrored)
            x = F.relu(x @ self.encode.weight)

            # Output layer
            x = torch.sigmoid(self.decode(x))

            return x

    # Generate model instance
    ANN = mnist_AE()

    # Loss function
    loss_fun = nn.MSELoss()

    # Optimizer
    optimizer = torch.optim.Adam(ANN.parameters(),lr=0.001)

    return ANN,loss_fun,optimizer
