In [1]:
# Written by Jay Jaewon Yoo (UofT Student Number 1002939671)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np

In [3]:
class autoencoder_dataset(Dataset):
    def __init__(self, data, labels):
        super().__init__()
        # data and labels are both (dataset size, m, n) where images are m x n

        self.num_data = len(data)
        self.train_data = torch.unsqueeze(data, dim=1)
        self.labels = labels

    def __len__(self):
        return self.num_data

    def __getitem__(self, idx):
        return self.train_data[idx], self.labels[idx]

In [4]:
def autoencoder_layer(in_channels, out_channels, kernel_size=3, stride=1, \
                  padding=1, bias=True, batchnorm=True, activation='relu', \
                  upsample=None):
    # Convolution layer that maintains shape 
        # with optional activation layer and batchnorm
    # Use stride = 1, kernel = 3, padding = 1 for convenience
    # Activation argument is one of 'relu', 'sigmoid', 'leaky_relu', or 'none

    layers = []

    # Upsampling
    if upsample is not None:
        layers.append(nn.Upsample(scale_factor=upsample))

    # Adding convolutional layer
    layers.append(nn.Conv2d(in_channels=in_channels, \
                            out_channels=out_channels, \
                            kernel_size=kernel_size, \
                            stride=stride, \
                            padding=padding, \
                            bias=bias))

    # Adding batchnorm
    if batchnorm:
        layers.append(nn.BatchNorm2d(out_channels))

    # Adding activation
    if activation == 'relu':
        layers.append(nn.ReLU())
    elif activation == 'sigmoid':
        layers.append(nn.Sigmoid())
    elif activation == 'leaky_relu':
        layers.append(nn.LeakyReLU())
    elif activation == 'none':
        pass
    else:
        assert False, "Invalid activation function."
        
    return nn.Sequential(*layers)

class autoencoder(nn.Module):
    def __init__(self, num_hidden_channels=64):
        super().__init__()
        
        self.net = []

        self.net.append(autoencoder_layer(in_channels=1, \
                                      out_channels=num_hidden_channels, \
                                      activation='leaky_relu', \
                                      batchnorm=True, \
                                      upsample=None))
        self.net.append(autoencoder_layer(in_channels=num_hidden_channels, \
                                      out_channels=num_hidden_channels, \
                                      activation='leaky_relu', \
                                      batchnorm=False, \
                                      upsample=None))
        self.net.append(nn.MaxPool2d(2)) # Pool to half of shape
        self.net.append(autoencoder_layer(in_channels=num_hidden_channels, \
                                      out_channels=num_hidden_channels, \
                                      activation='leaky_relu', \
                                      batchnorm=True, \
                                      upsample=None))
        self.net.append(autoencoder_layer(in_channels=num_hidden_channels, \
                                      out_channels=num_hidden_channels, \
                                      activation='leaky_relu', \
                                      batchnorm=False, \
                                      upsample=2))
        self.net.append(autoencoder_layer(in_channels=num_hidden_channels, \
                                      out_channels=1, \
                                      activation='sigmoid', \
                                      batchnorm=False, \
                                      upsample=None))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, concatenated_inputs):
        output = self.net(concatenated_inputs)
        return output.squeeze()

In [5]:
# Hyperparameters
batch_size = 2
num_epochs = 20
learning_rate = 1e-4
summary_epoch_interval = 10

In [6]:
# Testing with 5 input types using images of 300x350
dummy_data = torch.rand(50, 300, 350) # Change
dummy_labels = torch.rand(50, 300, 350) # Change

# Preparing dataloader
num_input_types = len(dummy_data)
dataset = autoencoder_dataset(dummy_data, dummy_labels)
dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True, num_workers=0)

In [7]:
autoencoder_model = autoencoder()
autoencoder_model.cuda()

autoencoder(
  (net): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
    )
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (4): Sequential(
      (0): Upsample(scale_factor=2.0, mode=nearest)
      (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): LeakyReLU(negative_slope=0.01)
    )
    (5): Sequential(
      (0): Conv2d(64, 1, kernel_size=(3, 3),

In [8]:
optimizer = torch.optim.Adam(lr=learning_rate, params=autoencoder_model.parameters())
mse_loss_function = nn.MSELoss()

for epoch in range(1, num_epochs + 1):
    cumulated_loss = 0

    for model_input, labels in dataloader:
        model_input = model_input.cuda()
        labels = labels.cuda()
        model_output = autoencoder_model(model_input)
        loss = torch.sqrt(mse_loss_function(model_output, labels))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        cumulated_loss += loss.item()
        
    if not epoch % summary_epoch_interval:
        print("Epoch %d, Total loss %0.6f" % (epoch, cumulated_loss))

Epoch 10, Total loss 7.211551
Epoch 20, Total loss 7.202534
