In [None]:
import os

os.environ['KMP_DUPLICATE_LIB_OK']='True' # only potentially necessary if you have MacBook with M1/M2 chip

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import copy
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data import Dataset

In [None]:
# noisy MNIST dataset
class NoisyMNIST(Dataset):
    # constructor
    '''
    train: (bool) True for training data, False for test/validation data
    noise_strength: (float) variance (strength) of the additive noise, keep between [0, 0.5]
    '''
    def __init__(self, train, noise_strength):
        mnist_data = torchvision.datasets.MNIST('./', train=train, download=True)
        self.images = (mnist_data.data/255).float() # convert to [0, 1] range and make float
        self.images = self.images.unsqueeze(1) # make shape (N_images, 1, H , W)
        self.noise_strength = noise_strength
        self.noise = torch.randn(self.images.shape)*self.noise_strength
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        noisy_image = torch.clip(self.images[idx]+self.noise[idx], min=0, max=1)
        return noisy_image, self.images[idx].squeeze(0)

# make training set
noise_strength = 0.3
noisy_mnist_training = NoisyMNIST(True, noise_strength)
# make test set
noisy_mnist_validation = NoisyMNIST(False, noise_strength)

In [None]:
# look at some (noisy, true image) pairs
plt.figure(figsize=(12,6))
for i in range(4):
    idx = np.random.choice(np.arange(len(noisy_mnist_training)))
    noisy_image, true_image = noisy_mnist_training[idx]
    plt.subplot(2, 4, i+1)
    plt.imshow(noisy_image.squeeze(0).numpy(), 'gray')
    plt.axis(False)
    plt.subplot(2, 4, i+5)
    plt.imshow(true_image.squeeze(0).numpy(), 'gray')
    plt.axis(False)

In [None]:
# model code
class Autoencoder(nn.Module):
    # constructor
    def __init__(self, channel_widths, image_height, image_width, nonlinearity=nn.ReLU()):
        super(Autoencoder, self).__init__()
        assert len(channel_widths) >= 2, "channel_widths should be at least length-2"
        self.image_height = image_height
        self.image_width = image_width
        layers = []
        # make encoder side
        for i in range(len(channel_widths)-1):
            layers.append(nn.Conv2d(channel_widths[i], channel_widths[i+1],
                                    kernel_size=5, padding=2, stride=2, bias=False)) # conv layer
            layers.append(nonlinearity) # non-linearity
            
        # make decoder side
        for i in range(len(channel_widths)-1, 1, -1):
            layers.append(nn.Conv2d(channel_widths[i], channel_widths[i-1],
                                    kernel_size=5, padding=2, stride=1, bias=False)) # conv layer
            layers.append(nn.Upsample(scale_factor=2, mode='nearest')) # upscale image by factor of 2
            layers.append(nonlinearity)
        # compose the encoder and decoder back-to-back
        self.backbone = nn.Sequential(*layers)
        # make sure output matches size of input image
        self.final_upsample = nn.Upsample(size=(self.image_height, self.image_width), mode='nearest')
        # final convolution layer to give 1 output channel like the input image
        # this layer also serves to provide the final estimate of the pixel values
        self.final_conv = nn.Conv2d(channel_widths[1], 1,
                                    kernel_size=5, padding=2, stride=1, bias=False)
    # forward pass
    def forward(self, x):
        features = self.backbone(x) # get feature maps
        resized_features = self.final_upsample(features) #(batch_size, N_feature_maps, H, W)
        pixel_predictions = self.final_conv(resized_features) # (batch_size, 1, H, W)
        pixel_predictions = pixel_predictions.squeeze(1) # batch_size, H, W
        return pixel_predictions

In [None]:
# example of creating autoencoder model
channel_widths = [1, 8, 16] # must start with 1
image_height, image_width = 28, 28
model = Autoencoder(channel_widths, image_height, image_width)
print(model)

In [None]:
def train(model, criterion, training_data, validation_data,
          training_indices, validation_indices, config, verbose=False):
    # unpack configuration parameters
    lr = config['lr'] # learning rate
    n_epochs = config['n_epochs'] # number of passes (epochs) through the training data
    batch_size = config['batch_size']
    
    # set up optimizer and loss function
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-4)
    
    # set up dataloaders
    train_sampler = torch.utils.data.SubsetRandomSampler(training_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(validation_indices)
    trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size, sampler=train_sampler)
    valloader = torch.utils.data.DataLoader(validation_data, batch_size=batch_size, sampler=val_sampler)
    
    # training loop
    train_losses = []
    val_losses = []
    for n in range(n_epochs):
        # set model to training mode (unnecessary for this model, but good practice)
        model.train()
        epoch_loss = 0
        for images, targets in trainloader:
            optimizer.zero_grad() # zero out gradients
            class_logits = model(images)
            loss = criterion(class_logits, targets)
            loss.backward() # backpropagate to compute gradients
            optimizer.step() # update parameters using stochastic gradient descent
            # update epoch statistics
            epoch_loss += loss.item() # batch loss
            
        # validation
        epoch_loss /= len(trainloader)
        val_loss = validate(model, valloader, criterion)
        val_loss /= len(valloader)
        
        # log epoch information
        train_losses.append(epoch_loss)
        val_losses.append(val_loss)
        if verbose and (n+1) % (int(n_epochs/20)) == 0:
            print('Epoch {}/{}: (Train) Loss = {:.4e}, (Val) Loss = {:.4e}'.format(
                   n+1,
                   n_epochs,
                   epoch_loss,
                   val_loss))
        
        
    return (np.array(train_losses),
            np.array(val_losses),
            model)
        
def validate(model, dataloader, criterion):
    val_loss = 0
    # set model to eval mode (again, unnecessary here but good practice)
    model.eval()
    # don't compute gradients since we are not updating the model, saves a lot of computation
    with torch.no_grad():
        for images, targets in dataloader:
            class_logits = model(images)
            loss = criterion(class_logits, targets)
            val_loss += loss.item()
    return val_loss

In [None]:
# data
# make training set
noise_strength = 0.2
noisy_mnist_training = NoisyMNIST(True, noise_strength)
# make test set
noisy_mnist_validation = NoisyMNIST(False, noise_strength)
# you can try playing around with more or less training data
N_training_examples = 500
N_validation_examples = 500
random_seed = 1 # random seed for reproducibility
np.random.seed(random_seed)
training_indices = np.random.choice(np.arange(len(noisy_mnist_training)), size=N_training_examples)
validation_indices = np.random.choice(np.arange(len(noisy_mnist_validation)), size=N_validation_examples)

# configuration parameters, you can play around with these
config = {'lr': 1e-1,
          'n_epochs': 300,
          'batch_size': 100}

loss_function = 'mse' # mse (mean-square error) or mae (mean-absolute error)
if loss_function == 'mse':
    criterion = nn.MSELoss()
else:
    criterion = nn.L1Loss()

# model
channel_widths = [1, 16, 16] # must start with a 1 and be at least length--2
image_height, image_width = 28, 28
model = Autoencoder(channel_widths, image_height, image_width)

# train
verbose = True # print metrics during training, False for no printing
train_losses, val_losses, trained_model = train(model,
                                                criterion,
                                                noisy_mnist_training,
                                                noisy_mnist_validation,
                                                training_indices,
                                                validation_indices,
                                                config,
                                                verbose=verbose)

In [None]:
# plot training/validation loss and accuracy over training time
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.semilogy(train_losses, color='royalblue')
plt.xlabel('Epoch')
plt.title('Training loss')
plt.grid(True)
plt.subplot(122)
plt.semilogy(val_losses, color='royalblue')
plt.xlabel('Epoch')
plt.title('Validation loss')
plt.grid(True)

In [None]:
# look at some (noisy, true image, denoised image) results
with torch.no_grad():
    N = 5
    plt.figure(figsize=(12,6))
    for i in range(N):
        idx = np.random.choice(np.arange(len(noisy_mnist_training)))
        noisy_image, true_image = noisy_mnist_training[idx]
        denoised_image = model(noisy_image.unsqueeze(0)).squeeze(0)
        plt.subplot(3, N, i+1)
        plt.title('Noisy image')
        plt.imshow(noisy_image.squeeze(0).numpy(), 'gray')
        plt.axis(False)
        plt.subplot(3, N, i+1+N)
        plt.title('True image')
        plt.imshow(true_image.squeeze(0).numpy(), 'gray')
        plt.axis(False)
        plt.subplot(3, N, i+1+2*N)
        plt.title('De-noised image')
        plt.imshow(denoised_image.squeeze(0).numpy(), 'gray')
        plt.axis(False)