In [None]:
# Mount Google Drive so we can access data
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Install idx2numpy package for extracting data
!pip install idx2numpy



In [None]:
# Import packages
import os
import json
import gzip
import torch
import torchvision
import numpy as np 
import pandas as pd

import idx2numpy
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as transforms
import torchvision.transforms as trans
import matplotlib.pyplot as plt

from collections import OrderedDict 

In [None]:
def load_one_dataset(path):
    '''
    Convenience function to load a single dataset
    '''
    f = gzip.open(path, 'rb')
    data = torch.from_numpy(idx2numpy.convert_from_file(f).astype('float64'))
    f.close()
    
    return(data)


def load_all_datasets(train_imgs, train_labs, test_imgs, test_labs, batch_size):
    '''
    Load training as well as test images here
    '''
    train_images = load_one_dataset(train_imgs).type(torch.float32)/255.0
    train_labels = load_one_dataset(train_labs).type(torch.long)
    train = list(zip(train_images, train_labels))

    test_images = load_one_dataset(test_imgs).type(torch.float32)/255.0
    test_labels = load_one_dataset(test_labs).type(torch.long)
    test = list(zip(test_images, test_labels))
    
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return(train_loader, test_loader)

In [None]:
def add_noise(img, quadrants):
  '''
  Randomly remove 1 or 2 quadrants
  from the input image.
  '''
  # Get the number of quadrants to erase
  n_quads_to_erase = np.random.choice([1, 2])

  # Get which quadrants to erase
  quads_to_erase = np.random.choice([1, 2, 3, 4], size = n_quads_to_erase)

  # Create a copy of the image
  noisy_img = img.clone()

  # Now erase the quadrants
  for quad in quads_to_erase:
    noisy_img = transforms.erase(noisy_img, *quadrants[quad])
  
  # Return statement
  return(noisy_img)

In [None]:
def plot_image(img):
  '''
  Take an image stored as a Torch
  tensor and display it in the notebook
  '''
  # Display the img
  plt.imshow(img.numpy(), cmap= 'gray')

In [None]:
def test_noise():
  '''
  Test the noise function
  '''
  # Just for testing out noise function
  data_dir = '/content/drive/MyDrive/data'

  # Set paths
  paths = {
        'train_imgs': os.path.join(data_dir, 'train-images-idx3-ubyte.gz'),
        'train_labs': os.path.join(data_dir, 'train-labels-idx1-ubyte.gz'),
        'test_imgs': os.path.join(data_dir,'t10k-images-idx3-ubyte.gz'),
        'test_labs': os.path.join(data_dir,'t10k-labels-idx1-ubyte.gz')
  }

  # Load datasets
  train_loader, test_loader = load_all_datasets(**paths, batch_size = 32)

  # Get the next batch from the train loader
  images, labels = iter(train_loader).next()

  # Store the quadrant definitions: move this into training loop later
  quadrants = {
      
      1: [0, 0, 14, 14, 255], 
      2: [0, 14, 14, 14, 255],
      3: [14, 0, 14, 14, 255],
      4: [14, 14, 14, 14, 255],
  }

  # Take the first image in the batch
  img = images[0]

  # Get noisy image
  noisy_img = add_noise(img, quadrants)

  # Return statement
  return(img, noisy_img)

In [None]:
class DenoisingEncoder(nn.Module):
  
  def __init__(self, encoder_units, decoder_units, input_dim, output_dim):
    
    # Conventional super-class declaration
    super(DenoisingEncoder,self).__init__()

    # Initialize lists to store layers
    encoder = []
    decoder = []

    # Add input and output dimensions to layer list for encoder
    self.encoder_units = [input_dim] + encoder_units
    self.decoder_units = [encoder_units[-1]] + decoder_units + [output_dim]

    # Compute the total no. of layers for the encoder/decoder
    self.encoder_layers = len(self.encoder_units)
    self.decoder_layers = len(self.decoder_units)

    # Append the hidden layers for the encoder
    for i in range(1, self.encoder_layers):
      
      # Add linear layer
      layer = ('Linear{}'.format(i), nn.Linear(self.encoder_units[i-1], self.encoder_units[i]))
      activation = ('RELU{}'.format(i), nn.ReLU(True))
      
      # Append
      encoder.append(layer)
      encoder.append(activation)
    
    # Append the hidden layers for the decoder
    for i in range(1, self.decoder_layers - 1):
      
      # Add the layers
      layer = ('Linear{}'.format(i), nn.Linear(self.decoder_units[i-1], self.decoder_units[i]))
      activation = ('RELU{}'.format(i), nn.ReLU(True))
      
      # Append to the lists
      decoder.append(layer)
      decoder.append(activation)

    # Create final output layer
    i = self.decoder_layers - 1
    layer = ('Linear{}'.format(i), nn.Linear(self.decoder_units[i-1], self.decoder_units[i]))
    activation = ('Sigmoid{}'.format(i), nn.Sigmoid())
    
    # Append to decoder list
    decoder.append(layer)
    decoder.append(activation)
    
    # Wrap this in a container and declare the encoder/decoder
    self.encoder = nn.Sequential(OrderedDict(encoder))
    self.decoder = nn.Sequential(OrderedDict(decoder))
    
  def forward(self,x):
    
    # First encode the noisy image and then decode
    x=self.encoder(x)
    x=self.decoder(x)
    
    return x

In [None]:
def train(encoder_units, decoder_units, 
          epochs=100, batch_size=8, 
          input_dim = 784, output_dim = 784, 
          lr = 0.01, momentum= 0.09, weight_decay = 0,  
          data_dir = '/content/drive/MyDrive/data'):
    '''
    This is the main training loop
    '''
    # Set device
    if torch.cuda.is_available():
      device = torch.device("cuda")
    else:
      device = torch.device("cpu")
    
    # Set paths to datasets
    paths = {
        'train_imgs': os.path.join(data_dir, 'train-images-idx3-ubyte.gz'),
        'train_labs': os.path.join(data_dir, 'train-labels-idx1-ubyte.gz'),
        'test_imgs': os.path.join(data_dir,'t10k-images-idx3-ubyte.gz'),
        'test_labs': os.path.join(data_dir,'t10k-labels-idx1-ubyte.gz')
    }

    # Store the quadrant definitions: move this into training loop later
    quadrants = {
      
      1: [0, 0, 14, 14, 0], 
      2: [0, 14, 14, 14, 0],
      3: [14, 0, 14, 14, 0],
      4: [14, 14, 14, 14, 0],
    }

    # Load datasets
    train_loader, test_loader = load_all_datasets(**paths, batch_size = batch_size)
    
    # Set parameters
    net = DenoisingEncoder(encoder_units, decoder_units, input_dim, output_dim)
    
    # Send net object to device memory
    net.to(device)
    
    # We use the cross-entropy loss
    criterion = nn.MSELoss()

    # We use mini-batch stochastic gradient descent with momentum
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=momentum, 
                                            weight_decay=0)

    # Store results here
    results = {
      'train_loss': [], 
      'val_loss': [], 
      }

    # Loop over the dataset multiple times
    for epoch in range(epochs):  
        
        # Initialize running loss
        running_loss = 0.0
        running_accuracy = 0.0

        # Initialize the validation running loss
        val_running_loss = 0.0
        val_running_accuracy = 0.0
        
        # Iterate through data now
        for i, data in enumerate(train_loader):
            
            # Get the inputs: data is a list of [inputs, labels]
            clean_images, _ = data
            
            # Initialize container for noisy images
            noisy_images = []

            # Now get noisy images
            for img in clean_images: 
              noisy_images.append(add_noise(img, quadrants))
            
            # Convert noisy image list to Torch tensor
            noisy_images = torch.stack(noisy_images, dim =0)
            
            # Flatten noisy images
            noisy_images=noisy_images.view(noisy_images.size(0),-1).type(torch.FloatTensor)
            
            # Flatten clean images
            flat_clean_imgs = clean_images.view(clean_images.size(0),-1).type(torch.FloatTensor)

            # Send the inputs and labels to the memory of the device
            noisy_images, flat_clean_imgs = noisy_images.to(device), flat_clean_imgs.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward
            pred_images = net(noisy_images)
            
            # Calculate loss
            loss = criterion(pred_images, flat_clean_imgs)

            # Backward
            loss.backward()
            
            # Optimize
            optimizer.step()

            # Add to running loss
            running_loss += loss.item()
        
        # Loop through the validation data
        for j, data in enumerate(test_loader):
          
          # No need to calculate gradients for validation set
          with torch.no_grad():

              # Get the inputs: data is a list of [inputs, labels]
              val_clean_images, _ = data
            
              # Initialize container for noisy images
              val_noisy_images = []

              # Now get noisy images
              # Flatten clean images for loss calculation
              for img in val_clean_images: 
                val_noisy_images.append(add_noise(img, quadrants))
            
              # Convert noisy image list to Torch tensor
              # Flatten noisy validation images
              val_noisy_images = torch.stack(val_noisy_images, dim =0)
              val_noisy_images=val_noisy_images.view(val_noisy_images.size(0),-1).type(torch.FloatTensor)
              
              # Flatten clean validation images
              val_flat_clean_imgs = val_clean_images.view(val_clean_images.size(0),-1).type(torch.FloatTensor)

              # Send the inputs and labels to the memory of the device
              val_noisy_images, val_flat_clean_imgs = val_noisy_images.to(device), val_flat_clean_imgs.to(device)

              # Send the data item through the network to get output
              val_pred_images = net(val_noisy_images)

              # Compute the loss
              # Add to running loss
              val_loss = criterion(val_pred_images, val_flat_clean_imgs)
              val_running_loss += val_loss.item()
        
        # Rescale the training and validation perfomance metrics
        running_loss = (running_loss*batch_size)/len(train_loader)
        
        # Rescale the validation loss
        val_running_loss = (val_running_loss*batch_size)/len(test_loader)
        
        # Append to the results tracker
        results['train_loss'].append(np.float(running_loss))
        results['val_loss'].append(np.float(val_running_loss))

        # Make print message format string
        msg = "Epoch:{} | Training Loss:{} | Validation Loss: {}" "\n"

        # Print performance
        print(msg.format(epoch, running_loss, val_running_loss))
        
    # Print message
    print('Done training...')
    
    # Return statement
    return(results)

In [None]:
encoder_units = [512, 256, 128, 64]
decoder_units = [128, 256, 512]
lr = 0.7
momentum = 0.99
batch_size=64
train(encoder_units = encoder_units, 
      decoder_units = decoder_units, 
      lr = lr, 
      momentum = momentum, 
      batch_size = batch_size)

Epoch:0 | Training Loss:4.400472933549617 | Validation Loss: 2.664925663334549

Epoch:1 | Training Loss:2.3291591166941594 | Validation Loss: 2.1533975350628993

Epoch:2 | Training Loss:1.9344924315969065 | Validation Loss: 1.8280896031932465

Epoch:3 | Training Loss:1.7463820988435481 | Validation Loss: 1.6809110831303202

Epoch:4 | Training Loss:1.6578196712902613 | Validation Loss: 1.617053161760804

Epoch:5 | Training Loss:1.5753121121860008 | Validation Loss: 1.5708209572324328

Epoch:6 | Training Loss:1.5230939613222314 | Validation Loss: 1.5320665244084255

Epoch:7 | Training Loss:1.4892876296917767 | Validation Loss: 1.4726838852949202

Epoch:8 | Training Loss:1.4504721077012102 | Validation Loss: 1.446104490832918

Epoch:9 | Training Loss:1.4069776593495025 | Validation Loss: 1.4079469844793817

Epoch:10 | Training Loss:1.3930564839194324 | Validation Loss: 1.3880532448458824

Epoch:11 | Training Loss:1.3747770641404173 | Validation Loss: 1.369812119538617

Epoch:12 | Training

{'train_loss': [4.400472933549617,
  2.3291591166941594,
  1.9344924315969065,
  1.7463820988435481,
  1.6578196712902613,
  1.5753121121860008,
  1.5230939613222314,
  1.4892876296917767,
  1.4504721077012102,
  1.4069776593495025,
  1.3930564839194324,
  1.3747770641404173,
  1.336608231957279,
  1.330042612323883,
  1.3075802820577804,
  1.2880706707678877,
  1.2711993126726862,
  1.253846832175753,
  1.2471241117921719,
  1.2454916231794906,
  1.2276326776948818,
  1.216615198517659,
  1.1992682273835262,
  1.194727134602919,
  1.1999855635008578,
  1.2032718612059856,
  1.1818903061880994,
  1.1727543294048512,
  1.1667524077999059,
  1.167631498150734,
  1.1654906853048532,
  1.1543885955551285,
  1.147919415410902,
  1.137895145777192,
  1.1226476450591707,
  1.128743597566446,
  1.1261036674351073,
  1.1322430649927175,
  1.1016626389168982,
  1.091081514986339,
  1.1015383614532983,
  1.0951728832238772,
  1.087632057763366,
  1.0857154800693618,
  1.0789207442482907,
  1.0644