Implementing an Autoencoder in PyTorch
===

This is the PyTorch equivalent of my previous article on implementing an autoencoder in TensorFlow 2.0, which you may read [here](https://towardsdatascience.com/implementing-an-autoencoder-in-tensorflow-2-0-5e86126e9f7)

First, to install PyTorch, you may use the following pip command,

```
$ pip install torch torchvision
```

The `torchvision` package contains the image data sets that are ready for use in PyTorch.

More details on its installation through [this guide](https://pytorch.org/get-started/locally/) from [pytorch.org](pytorch.org).

## Setup

We begin by importing our dependencies.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

import scipy

## PCASL angio model

Define functions relevant for the model

In [None]:
# Calculate the attenuation due to RF pulses in a CAPRIA-style acquisition.
#
# Tom Okell, June 2022
#
# R = CAPRIAAttenuation(t,t0,Alpha)
#
# where t is an array of timepoints separated by the TR, t0 is the start of
# imaging and Alpha is an array of flip angles of size(t).

def CAPRIAAttenuation(t,t0,Alpha):

    # Initialise
    R = np.zeros(shape=t.shape)
    R[0] = 1.0; 

    # Calculate attenuation due to each previous RF pulse
    for ii in range(1,len(t)):
        if t[ii] > t0:
            R[ii] = R[ii-1]*np.cos(np.deg2rad(Alpha[ii-1]))  # Attenuation
        else:
            R[ii] = 1.0

    return R

In [None]:
# Test the attenuation function
TR = 9.0 # ms
N = 20
t = np.linspace(0,N*TR,N+1)
t0 = 2*TR
Alpha = np.ones(shape=t.shape)*30
R = CAPRIAAttenuation(t,t0,Alpha)

plt.plot(t,R)

In [None]:
# Returns the gamma inc function but first zeros in the elements of X which
# are negative
def togammainc(X,A):
    X[X<0] = 0.0
    A[A<0] = 0.0

    return scipy.special.gammainc(A,X)

# This function calculates a theoretical intensities for a dynamic angio
# voxel given the times, t, for a set of parameter values where tau is the
# labelling duration in ms, T1 is that of blood, Alpha is the flip angle
# (assuming spoiled GRE, in degs), TR is the repetition time, A is a scaling factor, delta_t is the
# arrival time in ms from the labelling plane to the voxel, s is the
# sharpness and p the time to peak of a gamma variate dispersion kernel.
# It is assumed that all the blood sees all the RF pulses (relevant for 3D
# acquisitions with the bottom edge of the FOV close to the labelling plane)

def CAPRIAAngioSigAllRFAnalytic(t,tau,T1b,Alpha,delta_t,s,p,t0):
    
    # Define arrays for below
    t = np.array(t)
    a = np.array(1+p*s)
    
    # Calculate the RF attenuation term
    R = CAPRIAAttenuation(t,t0,Alpha)
  
    # Calculate the modified parameters for the integral
    sprime = s + 1.0/T1b
  
    # Calculate the scaling factor
    SF = 2 * np.exp(-delta_t/T1b) * (s/sprime)**a
  
    # Calculate the incomplete gamma integrals    
    G = togammainc(sprime*(t-delta_t),a) - togammainc(sprime*(t-delta_t-tau),a)
    #print('G:',G)
    
    # Calculate a scaling for the excitation
    E = np.sin(np.deg2rad(Alpha))
    
    # Output the complete result
    S = SF * R * G * E
    
    #print('S:',S.shape)
        
    return S

In [None]:
# Test Angio signal
TR = 9.0e-3 # ms
T = 2016.0e-3 
N = int(np.round(T/TR))
t = np.linspace(TR,T,N)
Alpha = np.ones(shape=t.shape)*6
tau = 1.8
t0 = tau
T1b = 1.65
delta_t = 0.7
s = 5
p = 100e-3

S = CAPRIAAngioSigAllRFAnalytic(t,tau,T1b,Alpha,delta_t,s,p,t0)

plt.plot(t,S)

## Define a flip angle schedule function


In [None]:
# Calculate flip angle schedules for CAPRIA acquisitions
# 
# Tom Okell, June 2022
#
# Usage:
#   Alpha = CalcCAPRIAFAs(FAMode,FAParams,t,t0)
#
# Required inputs:
#   FAMode      = 'CFA', 'Quadratic' or 'Maintain'
#   FAParams    = For CFA:          a scalar that defines the constant
#                                   flip angle in degrees.
#                 For Quadratic:    the flip angle varies quadratically
#                                   between FAParams(1) and FAParams(2) in
#                                   degrees.
#                 For Maintain:     Uses a backwards recursive formula to
#                                   maintain the signal at a constant level
#                                   (i.e. magnetisation loss in the
#                                   previous TR is counteracted by a higher
#                                   flip angle in the next TR). In this
#                                   case FAParams(1) defines the final flip
#                                   angle at the end of the readout.
#   t           = the time array to be simulated in s (assumes separation by TR)
#   t0          = the time at which imaging commences (s)

def CalcCAPRIAFAs(FAMode,FAParams,t,t0):

    # Initialise
    Alpha = np.zeros(t.shape)
    Idx = (t >=  t0)
    N = sum(Idx); # Number of pulses played out

    # CFA (FAParams = FA)
    if FAMode.upper() == 'CFA':
        Alpha[Idx] = FAParams[0] 

    # VFA quadratic (FAParams = [FAMin FAMax])
    elif FAMode.upper() == 'QUADRATIC':    
        Alpha[Idx] = FAParams[0] + (FAParams[1]-FAParams[0])*(range(N)/(N-1))**2;   

    # VFA Maintain (FAParams = FAMax)
    elif FAMode.upper() == 'MAINTAIN':
        raise Exception('Maintain not yet implemented')

    # Unknown
    else:
        raise Exception('Unknown FAMode!')


    return Alpha

# Test
FAMode = 'Quadratic'
FAParams = [2,9]
TR = 9.0e-3 # ms
T = 2016.0e-3 
t0 = tau
N = int(np.round((T-t0)/TR))
t = np.linspace(t0,T,N)

Alpha = CalcCAPRIAFAs(FAMode,FAParams,t,t0)

plt.plot(t,Alpha)

## Dataset

Run the model for a range of physiological parameters

In [None]:
# Set sequence parameters
TR = 9.0e-3 # ms
T = 2016.0e-3 
tau = 1.8
T1b = 1.65
t0 = tau
N = int(np.round(T/TR))
t = np.linspace(t0,t0+T,N)
FAMode = 'Quadratic'
FAParams = [2,9]
Alpha = CalcCAPRIAFAs(FAMode,FAParams,t,t0)
#plt.plot(t,Alpha)
#plt.plot(t,CAPRIAAttenuation(t,t0,Alpha))

# Physio params
delta_ts = np.linspace(0.1,1.8,30)
ss = np.linspace(1,100,30)
ps = np.linspace(1e-3,500e-3,30)

# Initialise output
S = np.zeros((len(delta_ts)*len(ss)*len(ps),len(t)))

# Loop through the parameters
ii = 0
for delta_t in delta_ts:
    for s in ss:
        for p in ps:
            S[ii,:] = CAPRIAAngioSigAllRFAnalytic(t,tau,T1b,Alpha,delta_t,s,p,t0)
            plt.plot(t,S[ii,:])
            ii = ii + 1

# Define the dataset class

In [None]:
# Borrowed from DL workshop code
from torch.utils.data import Dataset
class numpy_dataset(Dataset):  # Inherit from Dataset class
    def __init__(self, data, target, transform=None):
        self.data = torch.from_numpy(data).float()
        self.target = torch.from_numpy(target).float()
        self.transform = transform #Â This is where you can add augmentations

    def __getitem__(self, index):
        x = self.data[index]
        y = self.target[index]

        if self.transform:
            x = self.transform(x)

        return x, y

    def __len__(self):
        return len(self.data)

# Define the dataset and data loader

In [None]:
batch_size = 1000

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = numpy_dataset(S,S)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

## Autoencoder

Define an autoencoder class with fully connected layers and one intermediate layer for both its encoder and decoder components.

In [None]:
class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        
        # Encoder: input -> intermediate
        print('Creating encoder_hidden_layer with in_features=',kwargs["input_shape"],',out_features=',kwargs["intermediate_features"])
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=kwargs["intermediate_features"]
        )

        print('Creating encoder_hidden_layer2 with in_features=',kwargs["intermediate_features"],',out_features=',kwargs["latent_features"])
        self.encoder_hidden_layer2 = nn.Linear(
            in_features=kwargs["intermediate_features"], out_features=kwargs["latent_features"]
        )

        print('Creating encoder_output_layer with in_features=',kwargs["latent_features"],', out_features=',kwargs["latent_features"])
        self.encoder_output_layer = nn.Linear(
            in_features=kwargs["latent_features"], out_features=kwargs["latent_features"]
        )
        
        # Decoder
        print('Creating decoder_hidden_layer with in_features=',kwargs["latent_features"],', out_features=',kwargs["latent_features"])        
        self.decoder_hidden_layer = nn.Linear(
            in_features=kwargs["latent_features"], out_features=kwargs["latent_features"]
        )

        print('Creating decoder_hidden_layer2 with in_features=',kwargs["latent_features"],', out_features=',kwargs["intermediate_features"])        
        self.decoder_hidden_layer2 = nn.Linear(
            in_features=kwargs["latent_features"], out_features=kwargs["intermediate_features"]
        )
        
        print('Creating decoder_output_layer with in_features=',kwargs["intermediate_features"],',out_features=',kwargs["input_shape"])        
        self.decoder_output_layer = nn.Linear(
            in_features=kwargs["intermediate_features"], out_features=kwargs["input_shape"]
        )

    def forward(self, features):
        # Encoder
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)
        activation = self.encoder_hidden_layer2(activation)
        activation = torch.relu(activation)
        code = self.encoder_output_layer(activation)
        code = torch.sigmoid(code)
        
        # Decoder
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_hidden_layer2(activation)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.sigmoid(activation)
        return reconstructed

## Set up training parameters

Set our seed and other configurations for reproducibility. We set the batch size, the number of training epochs, and the learning rate.

In [None]:
epochs = 1000
learning_rate = 5e-3

In [None]:
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Before using our defined autoencoder class, we have the following things to do:
    1. We configure which device we want to run on.
    2. We instantiate an `AE` object.
    3. We define our optimizer.
    4. We define our reconstruction loss.

In [None]:
#  use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# create a model from `AE` autoencoder class
# load it to the specified device, either gpu or cpu
model = AE(input_shape=S.shape[1],intermediate_features=round((S.shape[1]+10)/2),latent_features=4).to(device)

# create an optimizer object
# Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# mean-squared error loss
criterion = nn.MSELoss()

We train our autoencoder for our specified number of epochs.

In [None]:
for epoch in range(epochs):
    loss = 0
    for batch_features, _ in train_loader:
        # reshape mini-batch data to [N, S.shape[1]] matrix
        # load it to the active device
        batch_features = batch_features.view(-1, S.shape[1]).to(device)
        
        # reset the gradients back to zero
        # PyTorch accumulates gradients on subsequent backward passes
        optimizer.zero_grad()
        
        # compute reconstructions
        outputs = model(batch_features)
        
        # compute training reconstruction loss
        train_loss = criterion(outputs, batch_features)
        
        # compute accumulated gradients
        train_loss.backward()
        
        # perform parameter update based on current gradients
        optimizer.step()
        
        # add the mini-batch training loss to epoch loss
        loss += train_loss.item()
    
    # compute the epoch training loss
    loss = loss / len(train_loader)
    
    # display the epoch training loss
    print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))

Let's extract some test examples to reconstruct using our trained autoencoder.

In [None]:
test_dataset = numpy_dataset(S,S)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=10, shuffle=True
)

test_examples = None

with torch.no_grad():
    for batch_features in test_loader:
        batch_features = batch_features[0]
        test_examples = batch_features.view(-1, S.shape[1])
        reconstruction = model(test_examples)
        break

## Visualize Results

Let's try to reconstruct some examples using our trained autoencoder.

In [None]:
with torch.no_grad():
    number = 10
    plt.figure(figsize=(20, 4))
    for index in range(number):
        # display original
        ax = plt.subplot(2, number, index + 1)
        plt.plot(t,test_examples[index].numpy().reshape(S.shape[1]))

        # display reconstruction
        ax = plt.subplot(2, number, index + 1 + number)
        plt.plot(t,reconstruction[index].numpy().reshape(S.shape[1]))

    plt.show()