<a href="https://colab.research.google.com/github/yandexdataschool/MLatImperial2021/blob/master/08_lab/autoencoder_seminar.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import scipy as sp
import scipy.misc
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

%matplotlib inline

In [None]:
import tensorflow as tf
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = (X_train / 255).astype('float32')
X_test  = (X_test  / 255).astype('float32')

In [None]:
device = torch.device("cuda:0")

In [None]:
def preprocess_data(X, y, classification):
  X_preprocessed = torch.tensor(X, dtype=torch.float).unsqueeze(1)
  if classification:
    y_preprocessed = torch.tensor(y, dtype=torch.long)
  else:
    y_preprocessed = torch.tensor(y).unsqueeze(1)
  return X_preprocessed.to(device), y_preprocessed.to(device)

def get_batches(X, y, batch_size, shuffle=False, classification=False):
  if shuffle:
    shuffle_ids = np.random.permutation(len(X))
    X = X[shuffle_ids].copy()
    y = y[shuffle_ids].copy()
  for i_picture in range(0, len(X), batch_size):
    # Get batch and preprocess it:
    batch_X = X[i_picture:i_picture + batch_size]
    batch_y = y[i_picture:i_picture + batch_size]
    
    # 'return' the batch (see the link above to
    # better understand what 'yield' does)
    yield preprocess_data(batch_X, batch_y, classification)  

In [None]:
from IPython.display import clear_output


class Logger:
  def __init__(self):
    self.train_loss_batch = []
    self.train_loss_epoch = []
    self.test_loss_batch = []
    self.test_loss_epoch = []
    self.train_batches_per_epoch = 0
    self.test_batches_per_epoch = 0
    self.epoch_counter = 0

  def fill_train(self, loss):
    self.train_loss_batch.append(loss)
    self.train_batches_per_epoch += 1

  def fill_test(self, loss):
    self.test_loss_batch.append(loss)
    self.test_batches_per_epoch += 1

  def finish_epoch(self):
    self.train_loss_epoch.append(np.mean(
        self.train_loss_batch[-self.train_batches_per_epoch:]
    ))
    self.test_loss_epoch.append(np.mean(
        self.test_loss_batch[-self.test_batches_per_epoch:]
    ))
    self.train_batches_per_epoch = 0
    self.test_batches_per_epoch = 0
    
    clear_output()
  
    print("epoch #{} \t train_loss: {:.8} \t test_loss: {:.8}".format(
              self.epoch_counter,
              self.train_loss_epoch[-1],
              self.test_loss_epoch [-1]
          ))
    
    self.epoch_counter += 1

    plt.figure(figsize=(11, 5))

    plt.subplot(1, 2, 1)
    plt.plot(self.train_loss_batch, label='train loss')
    plt.xlabel('# batch iteration')
    plt.ylabel('loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(self.train_loss_epoch, label='average train loss')
    plt.plot(self.test_loss_epoch , label='average test loss' )
    plt.legend()
    plt.xlabel('# epoch')
    plt.ylabel('loss')
    plt.show();

In [None]:
class Reshape(torch.nn.Module):
  def __init__(self, *shape):
    super(Reshape, self).__init__()
    self.shape = shape

  def forward(self, x):
    return x.reshape(x.shape[0], *self.shape)

In [None]:
def create_encoder():
    return torch.nn.Sequential(
    nn.Conv2d(1, 16, 3, padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(2), # 14x14

    nn.Conv2d(16, 32, 3, padding=1),
    nn.LeakyReLU(),
    nn.MaxPool2d(2), # 7x7

    nn.Conv2d(32, 64, 3), # 5x5
    nn.LeakyReLU(),
    nn.Conv2d(64, 128, 3), # 3x3
    nn.LeakyReLU(),
    nn.Conv2d(128,256, 3), # 1x1
    nn.LeakyReLU(),
    nn.Conv2d(256, 32, 1),

    Reshape(32)
  )

def create_decoder():
    return nn.Sequential(
    Reshape(32, 1, 1),

    nn.ConvTranspose2d(32, 256, 3, dilation=2), # 2x2
    nn.LeakyReLU(),

    nn.ConvTranspose2d(256, 128, 3, dilation=2), # 4x4
    nn.LeakyReLU(),

    nn.ConvTranspose2d(128, 64, 3, dilation=2), # 8x8
    nn.LeakyReLU(),

    nn.ConvTranspose2d(64, 32, 3, dilation=2), # 16x16
    nn.LeakyReLU(),

    nn.ConvTranspose2d(32, 16,3, dilation=2), # 28x28
    nn.LeakyReLU(),
    nn.ConvTranspose2d(16, 3,3, dilation=1), 
    nn.LeakyReLU(),
    nn.ConvTranspose2d(3, 1,3, dilation=2), 
    nn.LeakyReLU(),
    nn.ConvTranspose2d(1, 1,2, dilation=1),
    nn.Sigmoid()
  )


encoder = create_encoder()
decoder = create_decoder()

autoencoder = torch.nn.Sequential(
  encoder,
  decoder
).to(device)


optimiser = torch.optim.Adam(autoencoder.parameters(), lr=0.003)
loss_function = torch.nn.functional.mse_loss
num_epochs = 20
batch_size = 256

In [None]:
def fit(model, loss_function, optimizer, _X_train, _y_train, _X_test, _y_test, num_epochs, batch_size, classification=False):
  logger = Logger()

  for i_epoch in range(num_epochs):
    model.train() # setting the model to training mode
    for batch_X, batch_y in get_batches(_X_train, _y_train,
                                        batch_size=batch_size, shuffle=True, classification=classification):
      predictions = model(batch_X) # compute the predictions
      loss = loss_function(predictions, batch_y) # compute the loss
      logger.fill_train(loss.item())

      model.zero_grad() # zero the gradients
      loss.backward() # compute new gradients
      optimizer.step() # do an optimization step

    # Now, let's evaluate on the test part:
    model.eval() # setting the model to evaluatioin mode
    for batch_X, batch_y in get_batches(_X_test, _y_test,
                                        batch_size=batch_size, classification=classification):
      loss = loss_function(model(batch_X), batch_y)
      logger.fill_test(loss.item())
    
    logger.finish_epoch()

In [None]:
fit(autoencoder, loss_function, optimiser, X_train, X_train, X_test, X_test, num_epochs, batch_size, classification=False)

In [None]:
X_test[:10].reshape(28, 280)

plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(X_test[:10], (1,0,2)).reshape(28, 280), cmap='Greys')
plt.axis('off');

plt.figure(figsize=(10, 10))
encoder_reconstruction = autoencoder(torch.tensor(X_test[:10]).unsqueeze(1).to(device)).cpu().detach()[:, 0, ...]

plt.imshow(np.transpose(encoder_reconstruction, (1,0,2)).reshape(28, 280), cmap='Greys')
plt.axis('off');

Now, lets make a classifier

In [None]:
for param in encoder.parameters():
  param.requires_grad_(False)

classifier = nn.Sequential(
    encoder,
    #nn.ReLU(),
    nn.Linear(32, 10),
    #nn.ReLU(),
    #nn.Linear(10, 10)
    ).to(device)

optimiser = torch.optim.Adam(classifier.parameters(), lr=0.005)
loss_function = torch.nn.functional.cross_entropy
num_epochs = 70
batch_size = 256


fit(classifier, loss_function, optimiser, X_train[:300], y_train[:300], X_test, y_test, num_epochs, batch_size, classification=True)

In [None]:
## Test accuracy
def get_accuracy(model, X, y):
  return (torch.argmax(model(torch.tensor(X).unsqueeze(1).to(device)), dim=1).cpu().detach().numpy() == y).mean()

print(get_accuracy(classifier, X_test, y_test))
print(get_accuracy(classifier, X_train[:300], y_train[:300]))

In [None]:
encoder = create_encoder()

for param in encoder.parameters():
  param.requires_grad_(True)

classifier = nn.Sequential(
    encoder,
    #nn.ReLU(),
    nn.Linear(32, 10),
    #nn.ReLU(),
    #nn.Linear(10, 10)
    ).to(device)

optimiser = torch.optim.Adam(classifier.parameters(), lr=0.005)
loss_function = torch.nn.functional.cross_entropy
num_epochs = 70
batch_size = 256


fit(classifier, loss_function, optimiser, X_train[:300], y_train[:300], X_test, y_test, num_epochs, batch_size, classification=True)

What do we observe on the training curve?

In [None]:
print(get_accuracy(classifier, X_test, y_test))
print(get_accuracy(classifier, X_train[:300], y_train[:300]))

Semi-supervised

In [None]:
X_train_labeled, X_train_unlabeled = X_train[:300], X_train[300:]
y_train_labeled = y_train[:300]

In [None]:
def gen_untrained(batch_size):
  ids = np.arange(len(X_train_unlabeled))
  np.random.shuffle(ids)
  for i in range(0, len(X_train_unlabeled), batch_size):
    yield X_train_unlabeled[ids][i:i+batch_size]

In [None]:
unlabeled_generator = gen_untrained(256)

Remember, what we want to do here is to create a class, that do two things: it acts both like a Autoencoder and classifier, so it should give you two outputs - a reconstructed image and classification probability vector

In [None]:
class UnsupervisedAE(nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.classifier = nn.Linear(32, 10)

  def forward(self, X):
    # <YOUR CODE>
    return x_reco, x_class

Define our losses

In [None]:
unsup_ae = UnsupervisedAE(create_encoder(), create_decoder()).to(device)


optimiser = torch.optim.Adam(unsup_ae.parameters(), lr=0.003)

mse_loss = torch.nn.functional.mse_loss
ce_loss = torch.nn.functional.cross_entropy

In [None]:
N_EPOCHS = 100
BATCH_SIZE = 16

LAMBDA = 0.3
history_ae = []
history_cl = []
history_tot = []
for i_epoch in range(N_EPOCHS):
  print("Working on ep #", i_epoch)
  ids = np.arange(len(X_train_labeled))
  np.random.shuffle(ids)

  for i_image in range(0, len(X_train_labeled), BATCH_SIZE):
    X_batch = torch.tensor(X_train_labeled[ids][i_image:i_image + BATCH_SIZE]).unsqueeze(1).to(device)
    y_batch = torch.tensor(y_train_labeled[ids][i_image:i_image + BATCH_SIZE], dtype=torch.long).to(device)
    try:
      X_batch_unlabled = torch.tensor(unlabeled_generator.__next__()).unsqueeze(1).to(device)
    except StopIteration:
      unlabeled_generator = gen_untrained(256)
      X_batch_unlabled = torch.tensor(unlabeled_generator.__next__()).unsqueeze(1).to(device)

    epoch_ae_loss = 0
    epoch_cl_loss = 0
    epoch_total_loss = 0


    # So, here we need to do two things: predict reconstructed image and our MSE loss on the UNLABELED dataset
    reco_image, _ = unsup_ae(X_batch_unlabled)
    ae_loss = mse_loss(reco_image, X_batch_unlabled)

    # here, we want to predict the classification loss of the labeled data
    _, class_preds = unsup_ae(X_batch)
    cass_loss = ce_loss(class_preds, y_batch)

    # And here we just want to make the sum of the losses with some regularisation coefficient
    loss = cass_loss + LAMBDA * ae_loss

    loss.backward()
    optimiser.step()
    unsup_ae.zero_grad()

    epoch_ae_loss += ae_loss.item()
    epoch_cl_loss += cass_loss.item()
    epoch_total_loss += loss.item()
  history_ae.append(epoch_ae_loss)
  history_cl.append(epoch_cl_loss)
  history_tot.append(epoch_total_loss)

  if i_epoch % 1 == 0:
    clear_output(wait=True)
    plt.figure(figsize=(12, 8))
    plt.plot(history_ae, label='ae loss')
    plt.plot(history_cl, label='cl loss')
    plt.plot(history_tot, label='total')
    plt.legend()
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show();

In [None]:
history_tot[-1]

In [None]:
## Test accuracy
def get_accuracy(model, X, y):
  return (torch.argmax(model(torch.tensor(X).unsqueeze(1).to(device))[1], dim=1).cpu().detach().numpy() == y).mean()

print(get_accuracy(unsup_ae, X_test, y_test))
#print(get_accuracy(classifier, X_train[:300], y_train[:300]))