In [1]:
import torch
import torch.nn as nn
import numpy as np
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cuda")

In [3]:
# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./data/',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='./data/',
                              train=False,
                              transform=transforms.ToTensor())

In [4]:
# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

In [5]:
# visualize an image
imgs, labels = next(iter(train_loader))
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(imgs[i].reshape(28, 28), cmap=plt.cm.binary)
    plt.xlabel(labels[i])
plt.show()

In [6]:
def plot_subplot(img, idx, h=5, w=4, label=None):
    plt.subplot(h, w, idx)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(img, cmap=plt.cm.binary)
    if label:
        plt.xlabel(label)

In [20]:
# visualize some results    
def test(model):
    imgs, labels = next(iter(test_loader))
    results = model(Variable(imgs.view(-1,28*28).to(device))).cpu().detach().numpy()
    plt.figure(figsize=(10, 10))
    for i in range(10):
        # normal image
        j = np.random.randint(64)
        plot_subplot(imgs[j].reshape(28, 28), 2*i+1, label=labels[j])
        # after autoenc
        plot_subplot(results[j].reshape(28, 28), 2*i+2, label=labels[j])
    plt.show()

In [69]:
# define a class autoencoder, with the following amounts of neurons for the layers: 28*28 -> h1 -> h2 -> latSize -> h2 -> h1 -> 28*28, where h1,h2 and latSize are variables
class Autoencoder(nn.Module):
    def __init__(self, h1=128, h2=48, h3=12, latSz=3):
        super(Autoencoder, self).__init__()
        self.latSz = latSz
        self.encoder = nn.Sequential(
            nn.Linear(28*28, h1),
            nn.Sigmoid(),
            nn.Linear(h1,h2),
            nn.Sigmoid(),
            nn.Linear(h2,h3),
            nn.Sigmoid(),
            nn.Linear(h3, latSz),
            nn.Sigmoid()
        )
        self.decoder = nn.Sequential(
            nn.Linear(latSz, h3),
            nn.Sigmoid(),
            nn.Linear(h3, h2),
            nn.Sigmoid(),
            nn.Linear(h2,h1),
            nn.Sigmoid(),
            nn.Linear(h1,28*28),
            nn.Sigmoid()
        )
    def forward(self,X,print_lat_space=False):
        X = self.encoder(X)
        if print_lat_space:
            print(X.numpy())
        return self.decoder(X)
    
    def plot_latSpace(self):
        lat_inputs = torch.FloatTensor(np.eye(self.latSz)).to(device)
        outs = self.decoder(lat_inputs).cpu().detach().numpy()
        w = 4
        h = (self.latSz+w-1)//w
        plt.figure(figsize=(10, 10))
        for i in range(self.latSz):
            plot_subplot(outs[i].reshape(28,28), i+1, h, w)

In [70]:
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optim = torch.optim.Adam(model.parameters(), lr=0.005)

In [71]:
def train(model, optim, epochs=20):
    losses = []
    for epoch in range(1,1+epochs):
        losses2 = []
        for i, (images, _) in enumerate(train_loader):
            inputImgs = Variable(images.view(-1,28*28).to(device))
            outputs = model(inputImgs)
            loss = criterion(inputImgs, outputs)
            optim.zero_grad()
            loss.backward()
            optim.step()
            losses2.append(loss.item())
        losses.append(np.mean(losses2))
        print(f"Epoch {epoch}: Train Loss: {np.mean(losses2)}")

In [72]:
# for visualizing epoch by epoch how the result looks like
test(model)
train(model, optim, epochs=1)

In [73]:
train(model, optim, epochs=20)

In [93]:
test(model)

In [75]:
model.plot_latSpace()

In [89]:
def print_latent_vectors(model):
    imgs, labels = next(iter(test_loader))
    for i in range(10):
        # normal image
        j = np.random.randint(64)
        #result = model(Variable(imgs[j].view(-1,28*28).to(device)), print_lat_space=True).cpu().detach().numpy()
        lat_vec = model.encoder(Variable(imgs[j].view(-1,28*28).to(device)))
        print(lat_vec.cpu().detach().numpy())
        result = model.decoder(Variable(lat_vec).to(device)).cpu().detach().numpy()
        plt.imshow(result.reshape(28, 28), cmap="gray")
        plt.show()

In [91]:
print_latent_vectors(model)

In [104]:
def plot_lat_detailed(model):
    inputs = np.array([(1,1,1),(2,1,0),(2,0,1),(1,2,0),(1,0,2),(0,2,1),(0,1,2),(3,0,0),(0,3,0),(0,0,3)])
    imputs = inputs/3
    lat_inputs = torch.FloatTensor(inputs).to(device)
    outs = model.decoder(lat_inputs).cpu().detach().numpy()
    w = 4
    n = inputs.shape[0]
    h = (n+w-1)//w
    plt.figure(figsize=(10, 10))
    for i in range(n):
        plot_subplot(outs[i].reshape(28,28), i+1, h, w, label=np.array2string(inputs[i]))

In [105]:
plot_lat_detailed(model)

In [108]:
# try different models
hyperparams = [
    (128,64,32,16),
    (128,64,24,12),
    (128,48,16,10),
    (128,48,16,8),
    (128,48,16,6),
    (128,48,12,4),
    (128,48,12,3)
]
models = []
for hyps in hyperparams:
    print(hyps)
    model = Autoencoder(hyps[0], hyps[1], hyps[2], hyps[3]).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=0.005)
    train(model, optim, epochs=15)
    test(model)
    model.plot_latSpace()
    models.append(model)

In [None]:
# autoencoders with Sigmoid functions in every layer have significantly worse performance, but are much more interpretable
# (here is an autoencoder with ReLU, I had done some tests with it, but I didn't run them again to include them here)
class Autoencoder(nn.Module):
    def __init__(self, h1=128, h2=32, latSz=10):
        super(Autoencoder, self).__init__()
        self.latSz = latSz
        self.encoder = nn.Sequential(
            nn.Linear(28*28, h1),
            nn.ReLU(True),
            nn.Linear(h1,h2),
            nn.ReLU(True),
            nn.Linear(h2, latSz)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latSz, h2),
            nn.ReLU(True),
            nn.Linear(h2,h1),
            nn.ReLU(True),
            nn.Linear(h1,28*28),
            nn.Sigmoid()
        )
    def forward(self,X,print_lat_space=False):
        X = self.encoder(X)
        if print_lat_space:
            print(X)
        return self.decoder(X)
    
    def plot_latSpace(self):
        lat_inputs = torch.FloatTensor(np.eye(self.latSz)).to(device)
        outs = self.decoder(lat_inputs).cpu().detach().numpy()
        w = 4
        h = (self.latSz+w-1)//w
        plt.figure(figsize=(10, 10))
        for i in range(self.latSz):
            plot_subplot(outs[i].reshape(28,28), i+1, h, w)