In [1]:
# Written by Jay Jaewon Yoo (UofT Student Number 1002939671)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np

In [3]:
class stacker_dataset(Dataset):
    def __init__(self, data_to_stack, labels):
        super().__init__()
        # elements of data_to_stack are all (dataset size, m, n) where images are m x n

        self.num_data = len(data_to_stack[0])
        self.train_data = torch.cat([torch.unsqueeze(data, dim=1) for data in data_to_stack], dim=1)
            # dataset size x N x m x n where N is the number of inputs to stack
        self.labels = labels

    def __len__(self):
        return self.num_data

    def __getitem__(self, idx):
        return self.train_data[idx], self.labels[idx]

In [4]:
def stacker_layer(in_channels, out_channels, kernel_size=3, stride=1, \
                  padding=1, bias=True, batchnorm=True, activation='relu'):
    # Convolution layer that maintains shape 
        # with optional activation layer and batchnorm
    # Use stride = 1, kernel = 3, padding = 1 for convenience
    # Activation argument is one of 'relu', 'sigmoid', or 'none

    layers = []
    # Adding convolutional layer
    layers.append(nn.Conv2d(in_channels=in_channels, \
                            out_channels=out_channels, \
                            kernel_size=kernel_size, \
                            stride=stride, \
                            padding=padding, \
                            bias=bias))

    # Adding batchnorm
    if batchnorm:
        layers.append(nn.BatchNorm2d(out_channels))

    # Adding activation
    if activation == 'relu':
        layers.append(nn.ReLU())
    elif activation == 'sigmoid':
        layers.append(nn.Sigmoid())
    elif activation == 'none':
        pass
    else:
        assert False, "Invalid activation function."
        
    return nn.Sequential(*layers)

def check_power_2(val):
    # Checks if input val is a power of 2

    return (val & (val - 1) == 0) and val != 0

class stacker(nn.Module):
    def __init__(self, num_input_types, initial_channels=64, encode_channels=16):
        # initial_channels is the initial channels to convolve to
        # encode channels is the number of channels to decode to

        super().__init__()

        assert check_power_2(initial_channels) and \
        check_power_2(encode_channels), \
        "initial_channels and encode_channels must be powers of 2."

        num_encode_decode_layers = np.log2(initial_channels // encode_channels).astype(int)
        
        self.net = []

        # First layer
        self.net.append(stacker_layer(in_channels=num_input_types, \
                                      out_channels=initial_channels, \
                                      activation='relu', \
                                      batchnorm=True))

        # Encoding
        for layer_idx in range(0, num_encode_decode_layers):
            self.net.append(stacker_layer(in_channels=initial_channels // (2 ** layer_idx), \
                                          out_channels=initial_channels // (2 ** (layer_idx + 1)), \
                                          activation='relu', \
                                          batchnorm=True))
        
        # Decoding
        # Note that no batchnorm in the final decoding layer
        for layer_idx in range(0, num_encode_decode_layers):
            self.net.append(stacker_layer(in_channels=encode_channels * (2 ** layer_idx), \
                                          out_channels=encode_channels * (2 ** (layer_idx + 1)), \
                                          activation='relu', \
                                          batchnorm=layer_idx != num_encode_decode_layers - 1))

        # Final layer
        self.net.append(stacker_layer(in_channels=initial_channels, \
                                      out_channels=1, \
                                      activation='sigmoid', \
                                      batchnorm=False))
        
        self.net = nn.Sequential(*self.net)
    
    def forward(self, concatenated_inputs):
        output = self.net(concatenated_inputs)
        return output.squeeze()

In [5]:
# Hyperparameters
batch_size = 2
num_epochs = 20
learning_rate = 1e-4
summary_epoch_interval = 10

In [6]:
# Testing with 5 input types using images of 300x350
dummy_data = [torch.rand(50, 300, 350) for idx in range(5)] # Change
dummy_labels = torch.rand(50, 300, 350) # Change

# Preparing dataloader
num_input_types = len(dummy_data)
dataset = stacker_dataset(dummy_data, dummy_labels)
dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True, num_workers=0)

In [7]:
stacker_model = stacker(num_input_types=num_input_types)
stacker_model.cuda()

stacker(
  (net): Sequential(
    (0): Sequential(
      (0): Conv2d(5, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (4): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
   

In [8]:
optimizer = torch.optim.Adam(lr=learning_rate, params=stacker_model.parameters())
mse_loss_function = nn.MSELoss()

for epoch in range(1, num_epochs + 1):
    cumulated_loss = 0

    for model_input, labels in dataloader:
        model_input = model_input.cuda()
        labels = labels.cuda()
        model_output = stacker_model(model_input)
        loss = torch.sqrt(mse_loss_function(model_output, labels))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        cumulated_loss += loss.item()
        
    if not epoch % summary_epoch_interval:
        print("Epoch %d, Total loss %0.6f" % (epoch, cumulated_loss))

Epoch 10, Total loss 7.218555
Epoch 20, Total loss 7.212712
