<a href="https://colab.research.google.com/github/BSniegowski/ML-uni_course/blob/main/lab/09autoencoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
%matplotlib inline

import torch

from torch import nn
from torch.utils.data import Subset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Lambda, Compose

import matplotlib.pyplot as plt

import numpy as np

from sklearn.decomposition import PCA

def plot_dataset(train_data, model):
    view_data = train_data.data[:5].view(-1, 28*28) / 255.
    _, decoded_data = model.forward(train_data.data[:5].view(-1, 784).float().cuda() / 255.)
    decoded_data = decoded_data.cpu().detach().numpy()

    n_rows = 2 if decoded_data is not None else 1
    n_cols = len(view_data)
    plt.suptitle("Reconstruction")
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2, n_rows * 2))

    if decoded_data is not None:
        for i in range(n_cols):
            axes[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray')
            axes[0][i].set_xticks(())
            axes[0][i].set_yticks(())

        for i in range(n_cols):
            axes[1][i].clear()
            axes[1][i].imshow(np.reshape(decoded_data[i], (28, 28)), cmap='gray')
            axes[1][i].set_xticks(())
            axes[1][i].set_yticks(())

    else:
        for i in range(n_cols):
            axes[i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray')
            axes[i].set_xticks(())
            axes[i].set_yticks(())

    plt.show()

def plot_pca(data, model):
    labels = data.classes
    plt.suptitle("Reduction of latent space")
    _ = plt.figure(figsize=(10, 6))
    pca = PCA(2)

    z = model.encode(train_data.data.view(-1, 784).float().cuda())
    reduced_z = pca.fit_transform(z.detach().cpu().numpy())

    for class_idx in range(10):
        indices = (data.targets == class_idx)
        plt.scatter(
            reduced_z[indices, 0], reduced_z[indices, 1],
            s=2., label=labels[class_idx])

    plt.legend()
    plt.show()


torch.manual_seed(1337)
batch_size = 128
transforms = Compose([ToTensor(), Lambda(lambda x: x.flatten())])

# Mnist dataset
train_data = MNIST(root='.',
                   train=True,
                   transform=transforms,
                   download=True) # change to false if you already have the data

# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



In [4]:
class AutoEncoder(torch.nn.Module):
    def __init__(self, latent_dim):

        super(AutoEncoder, self).__init__()

        self.latent_dim = latent_dim

        D = latent_dim
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(784, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, D),
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(D, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 784),
            torch.nn.Sigmoid()
        )

    def decode(self, encoded):
        return self.decoder(encoded)

    def encode(self, x):
        return self.encoder(x)

    def forward(self, x):
        # encode and decode
        encoded = self.encode(x)
        decoded = self.decode(encoded)
        return encoded, decoded

In [None]:
# Hyper Parameters
epochs = 25
LR = 5e-3         # learning rate

# prepare original data for plotting

autoencoder = AutoEncoder(latent_dim=10).cuda()

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
rec_loss_fn = torch.nn.MSELoss()

for epoch in range(epochs):
    epoch_losses = []  # For logging purposes
    for step, (x, y) in enumerate(train_loader):
        x = x.cuda()
        encoded, decoded = autoencoder(x)
        loss_val = rec_loss_fn(decoded, x) # calculate loss
        optimizer.zero_grad()        # clear gradients for this training step
        loss_val.backward()          # backpropagation, compute gradients
        optimizer.step()             # apply gradients

        epoch_losses.append(loss_val.item())

    print(f'Epoch: {epoch}  |  train loss: {np.mean(epoch_losses):.4f}')

    if epoch % 10 == 0:
        plot_dataset(train_data, autoencoder)
        plot_pca(train_data, autoencoder)

In [6]:
labeled_data = Subset(train_data, range(100))
labeled_loader = torch.utils.data.DataLoader(dataset=labeled_data, batch_size=32, shuffle=True)

test_data = MNIST(root='.', 
                   train=False, 
                   transform=transforms,    
                   download=True) # change to false if you already have the data
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=5000, shuffle=True)

In [11]:
class BaselineModel(nn.Module):
  def __init__(self):
    super(BaselineModel, self).__init__()
    self.net = nn.Sequential(
        nn.Linear(784, 128),
        nn.ReLU(),
        nn.Linear(128,64),
        nn.ReLU(),
        nn.Linear(64,32),
        nn.ReLU(),
        nn.Linear(32, 10)
    )
    
  def forward(self, x):
    x = self.net(x)
    return x

In [8]:
def show_results(logs):
  f, ax = plt.subplots(1, 2, figsize=(16, 5))
  ax[0].plot(logs['train_accuracy'], color='C%s' % i, linestyle='--', label='train')
  ax[0].plot(logs['test_accuracy'], color='C%s' % i, label='test')
  ax[0].set_xlabel('epochs')
  ax[0].set_ylabel('accuracy')
  ax[0].legend()

  ax[1].plot(logs['train_loss'], color='C%s' % i, linestyle='--', label='train')
  ax[1].plot(logs['test_loss'], color='C%s' % i, label='test')
  ax[1].set_xlabel('epochs')
  ax[1].set_ylabel('loss')
  ax[1].legend()

In [35]:
def train(n_epochs, model, logs):
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  # Hiperparameters
  learning_rate = 0.05
  momentum = 0.1

  loss_fn = torch.nn.functional.cross_entropy
  optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

  correct, numel = 0, 0

  # Training loop
  for i in range(n_epochs):
    model.train()
    for x, y in labeled_loader:
      x = x.to(device)
      y = y.to(device)
      optimizer.zero_grad()
      output = model(x)
      y_pred = torch.argmax(output, dim=1)
      correct += torch.sum(y_pred == y).item()
      numel += train_loader.batch_size
      loss = loss_fn(output, y)
      loss.backward()
      optimizer.step()

    logs['train_loss'].append(loss.item())
    logs['train_accuracy'].append(correct / numel)
    correct, numel = 0, 0

    model.eval()
    with torch.no_grad():
      for x_test, y_test in test_loader:
        x_test = x_test.to(device)
        y_test = y_test.to(device)
        output = model(x_test)
        y_pred = torch.argmax(output, dim=1)
        correct += torch.sum(y_pred == y_test).item()
        numel += test_loader.batch_size
      loss = loss_fn(output, y_test)

    logs['test_loss'].append(loss.item())
    logs['test_accuracy'].append(correct / numel)
    print('test acc', logs['test_accuracy'][-1])
    correct, numel = 0, 0

  # print('max test acc: ', max(logs['test_accuracy']))

In [49]:
model = BaselineModel()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

logs = {'train_loss': [], 'test_loss': [], 'train_accuracy': [], 'test_accuracy': []}

train(n_epochs=20, model=model, logs=logs)

test acc 0.1135
test acc 0.1135
test acc 0.1135
test acc 0.1135
test acc 0.1135
test acc 0.1315
test acc 0.1478
test acc 0.1294
test acc 0.1248
test acc 0.1161
test acc 0.1174
test acc 0.1158
test acc 0.1434
test acc 0.1679
test acc 0.163
test acc 0.1765
test acc 0.1759
test acc 0.1957
test acc 0.1933
test acc 0.2059


In [46]:
class ModelWithEncoder(nn.Module):
  def __init__(self, latent_dim):
    super(ModelWithEncoder, self).__init__()
    D = latent_dim
    self.encoder = autoencoder.encoder
    self.net = nn.Sequential(
        nn.Linear(D, 128),
        nn.ReLU(),
        nn.Linear(128,64),
        nn.ReLU(),
        nn.Linear(64,32),
        nn.ReLU(),
        nn.Linear(32, 10)
    )
    
  def forward(self, x):
    x = self.encoder(x)
    x = self.net(x)
    return x

In [50]:
model2 = ModelWithEncoder(latent_dim=10)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model2.to(device)


logs2 = {'train_loss': [], 'test_loss': [], 'train_accuracy': [], 'test_accuracy': []}

train(n_epochs=20, model=model2, logs=logs2)

test acc 0.2548
test acc 0.1222
test acc 0.3225
test acc 0.3559
test acc 0.3966
test acc 0.3871
test acc 0.5152
test acc 0.4101
test acc 0.4913
test acc 0.4723
test acc 0.3587
test acc 0.5253
test acc 0.4813
test acc 0.5994
test acc 0.6152
test acc 0.408
test acc 0.3109
test acc 0.4433
test acc 0.5791
test acc 0.6162
