<img src=https://brand.uark.edu/_resources/images/UA_Logo_Horizontal.jpg width="400" height="96">

###_Biomedical Image Analysis & Artificial Intelligence_

# Notebook 3.3 Convolutional Neural Networks
---
##### The purpose of this notebook is to motivate and discuss the use of convolutional neural networks (CNNs) for biomedical applications.



### Required packages
---
##### **_NOTE: This notebook is accelerated with the use of GPU hardware, but is not required. please refer to notebook 2.4_ParallelProcessing if you need a refresher on how to do this._**
##### **_Run this code chunk first. If you encounter an error when trying to run code chunks in this notebook, then first try re-running this chunk._**


In [None]:
# Import all of the necessary packages
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
from tqdm.notebook import tqdm
from IPython.display import clear_output
from collections import OrderedDict

In [None]:
# Load the MNIST digits dataset
import os

# Check to see if the training file exists...
filename = "/content/MNIST/processed/training.pt"
if not os.path.isfile(filename):
  # ...and if it does not, then download it,...
  !gdown --id 1jNqLo0LBYrPcGzaDBBA60sqGlclS_nX_

  # ...check to see if the destination path exists,...
  if not os.path.isdir("/content/MNIST/processed/"):
    os.makedirs("/content/MNIST/processed/")

  # ...and move the file.
  os.rename("/content/training.pt",filename)

# Check to see if the test file exists...
filename = "/content/MNIST/processed/test.pt"
if not os.path.isfile(filename):
  # ...and ff it does not, then download it,...
  !gdown --id 178t2_ul2VvnAHuClPhlE27qHkpE5I8S1

  # ...and move the file.
  os.rename("/content/test.pt",filename)

# Create the dataset
mnist_data = torchvision.datasets.MNIST(root='/content/', transform=T.ToTensor())
clear_output()

# Convolutional Neural Networks
---
##### In the previous notebook, we saw that a deep learning MLP model can be trained to predict a number based on an image of that handwritten digit. This can be very useful for many applications, but an MLP does have an important downfall: spatial features.
##### At the end of the last notebook, we saw that the weights of a trainined MLP tended to be much larger in the center of the input image where the digits are centered, and we asked these two questions:
1. Do you think we would get different results?
2. How do you think we can get around this issue?

##### To answer the first question, we would expect to get mostly incorrect predictions for the digit in the image, since the network was not trained on images where the digit lies near the edge on the input image.
##### One work around for this issue is to _augment_ the dataset by translating (or moving around), rotating, or even scaling the digits in the images. However, there are an infinite number of combinations that we could do for augmentation, resulting in an MLP that will not always be correct.
##### Another thing we can do, however, is _scan_ the entire image and _extract features_ that are attributed to each digit.  As we learned about in the previous notebooks, this can be accomplished with images through _convolution_. Using convolution would result in a more robust network that could handle many more cases where handwritten digits or other features are located away from the center of the image. These unique features in the image can be highlighted through the use of many small _kernels_ (or _filters_) that are initially randomly chosen, but we could update and adjust via deep learning. Additionally, features can be _pooled_ together to decide which features matter most, which is done in a _pooling layer_. This kind of feature extraction network is called a _convolutional neural network (CNN)_.
##### The code chunk below show an example of one convolutional _layer_ that can be used in a CNN.

In [None]:
# Define some parameters for a convolutional layer
#@title Parameters for the convolutional layer. {run:"auto"}
number_of_filters = 9 #@param {type:"slider", min:1, max:20, step:1}
filter_size = 5 #@param {type:"slider", min:3, max:11, step:2}

# Get the input image
input_image = mnist_data.data[0,:,:].unsqueeze(0).unsqueeze(0).float()

# Define the single convolutional layer
conv_layer = nn.Conv2d(in_channels=1,out_channels=number_of_filters, kernel_size=filter_size)
conv_layer.requires_grad_(False)
weights = conv_layer.weight

# Show the weights (filters)
fig = plt.figure(figsize=(10,10))
ncols = 5
if weights.shape[0] % ncols == 0: 
  nrows = weights.shape[0] // ncols
else: 
  nrows = (weights.shape[0] // ncols) + 1

count = 0
dim_size = weights.shape[2]
full_image = np.zeros((dim_size*nrows,dim_size*ncols))
for i in range(nrows):
  plt.plot([-0.5,dim_size*ncols],[i*filter_size-0.5,i*filter_size-0.5],'r-')
  for j in range(ncols):
    plt.plot([j*filter_size-0.5,j*filter_size-0.5],[-0.5,dim_size*nrows],'r-')
    if count < weights.shape[0]:
      full_image[i*dim_size:i*dim_size+dim_size,j*dim_size:j*dim_size+dim_size] = weights[count,0,:,:].cpu()
      count+=1
plt.plot([-0.5,dim_size*ncols],[nrows*filter_size-0.5,nrows*filter_size-0.5],'r-')
plt.plot([ncols*filter_size-0.5,ncols*filter_size-0.5],[-0.5,dim_size*nrows],'r-')
plt.imshow(full_image)
plt.gca().set_title('Layer 1 filters',fontsize='x-large');
plt.gca().set_axis_off()

### Fully connected layers
---
##### Although convolutional layers are very powerful at extracting features in an image, they are unable to predict a classification like an MLP can. To complete a CNN the predicts a digit based on the features that were extracted, we can add activation functions after each convolutional layer, as well as an MLP network to the end of the feature extraction portion of a CNN.
##### Similar to the previous notebook, the model is first initialized with random weights, meaning that initially the output will demonstrate that the model is not confident at guessing the digit.

In [None]:
# Define the network architecture
filter_size = 3
number_of_filters = 32

test_model = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(in_channels=1,out_channels=number_of_filters, kernel_size=filter_size)),
    ('relu1',nn.ReLU()),
    ('mp1',nn.MaxPool2d(kernel_size=2)),
    ('flat',nn.Flatten(start_dim=1,end_dim=-1)),
    ('fc1',nn.Linear(5408,100)),
    ('relu2',nn.ReLU()),
    ('fc2',nn.Linear(100,10)),
]))
clear_output()

# Get the first image in the MNIST dataset
input_image = mnist_data.data[0,:,:].unsqueeze(0).unsqueeze(0).float() / 255.

# Compute the forward pass of the image
output = test_model.forward(input_image)

# Get the probabilities by passing the output through a softmax
prob = torch.softmax(output,1)*100

# Show the image
fig = plt.Figure()
img = plt.imshow(input_image.squeeze(0).squeeze(0))
ax = plt.gca()
ax.set_title('Input image',fontsize='x-large');
plt.show()

# Generate a report
print('\nNumber - Probability [%]\n')
for i in range(10):
  print(str(i) + '   -    ' + str(prob[0,i].detach().numpy()) + '\n')

### Training a CNN
---
##### The _training loop_ for a CNN has exactly the same steps as for a MLP:
1. Forward pass
2. Loss calculation
3. Back-propogation
4. Gradient descent

##### For a CNN, the weights from an MLP are synonymous with each _pixel_ of each filter in a convolutional layer, and can be adjusted in a similar fashion.

##### Since the CNN is trained on a single image, the expected output from this code chunk is for the CNN to have slightly more confidence in the correct digit. You can also see that the filters in the CNN have changed.

In [None]:
# Define which loss function we will be using
loss = nn.CrossEntropyLoss()

# Get the correct label for the image
input_label = mnist_data.targets[0].unsqueeze(0)

# Calculate the loss based on the output from the model
model_loss = loss(output,input_label)

# Back-propogate the loss function through the network
model_loss.backward()

# Specify the optimizer (or gradient descent algorithm)
optimizer = torch.optim.Adam(test_model.parameters())

# Gradient descent iteration
optimizer.step()

# Compute the forward pass of the image
output = test_model.forward(input_image)

# Get the probabilities by passing the output through a softmax
prob = torch.softmax(output,1)*100

# Show the image
fig = plt.Figure()
img = plt.imshow(input_image.squeeze(0).squeeze(0))
ax = plt.gca()
ax.set_title('Input image',fontsize='x-large');
plt.show()

# Generate a report
print('\nNumber - Probability [%]\n')
for i in range(10):
  print(str(i) + '   -    ' + str(prob[0,i].detach().numpy()) + '\n')

# Making CNNs deep
---
##### In general, the more layers to a model, the more _deep_ the model is said to be. For CNNs, a deeper convolutional layer will be able to extract _higher level_ (more descriptive) features. 
#####  In the code chunk below, we train a CNN with 3 convolutional layers on the digits dataset like before. **_A GPU is highly recommended for training a CNN, but should only take a few minutes without a GPU._**

In [None]:
#@title --- Hidden code (double-click to show code) ---

# Give some context of what is going on
print('Preparing the network and dataset...')

# Define the network architecture
model = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(in_channels=1,out_channels=32, kernel_size=3)),
    ('relu1',nn.ReLU()),
    ('mp1',nn.MaxPool2d(kernel_size=2)),
    ('flat',nn.Flatten(start_dim=1,end_dim=-1)),
    ('fc1',nn.Linear(5408,100)),
    ('relu2',nn.ReLU()),
    ('fc2',nn.Linear(100,10)),
]))
try:
  model = model.cuda()
  canUseGPU = True
except:
  print('Warning: The GPU has not been enabled. Training may take longer than expected.')
  canUseGPU = False

# Load the MNIST digits dataset
batchsize = 10
data_loader = torch.utils.data.DataLoader(mnist_data,batch_size=batchsize,shuffle=True)

# Specify the number of epochs
number_of_epochs = 3

# Pull out a small subset of the dataset that will be used to benchmarking purposes
test_images = mnist_data.data[[1,3,5,7,9,0,13,15,17,4],:,:].float().unsqueeze(1) / 255.

if canUseGPU:
  test_images = test_images.cuda()

with torch.no_grad():
  test_output = model.forward(test_images).cpu()
  test_prob = torch.transpose(torch.softmax(test_output,1)*100,0,1)

print('...Done')

# Plot the results
fig = plt.figure(figsize=(10,5))
plt.imshow(test_prob)
ax = plt.gca()
plt.xticks([])
secax = ax.secondary_xaxis('top')
secax.set_xticks(np.arange(0,10))
secax.set_xlabel('Predicted digit',fontsize='x-large')
plt.yticks(np.arange(0,10))
ax.set_ylabel('Actual digit',fontsize='x-large')
plt.clim(0,100)
cbar = plt.colorbar()
cbar.set_label('Confidence in prediction [%]',fontsize='x-large')
plt.title('Output from epoch: 0 of ' + str(number_of_epochs),fontsize='x-large')
plt.show()

# Define which loss function we will be using
loss = nn.CrossEntropyLoss()

# Specify the optimizer (or gradient descent algorithm)
optimizer = torch.optim.Adam(model.parameters())

# Go through each epoch
for ep in range(number_of_epochs):

  # Set up a progress bar for training
  with tqdm(total=len(mnist_data), desc=f'Epoch {ep + 1}/{number_of_epochs}', unit='img') as pbar:
    
    # Go through all of the images in the dataset
    for i, data in enumerate(data_loader, 0):
      # Get the current minibatch
      inputs, labels = data

      if canUseGPU:
        inputs = inputs.cuda()
        labels = labels.cuda()

      # Clear gradients (this is just required)
      optimizer.zero_grad()

      # Perform forward pass
      output = model.forward(inputs)

      # Calculate loss
      loss_value = loss(output,labels)

      # Back-propogate
      loss_value.backward()

      # Gradient descent
      optimizer.step()  

      # Update progress bar
      pbar.update(batchsize)

  # Clear the output
  clear_output(wait=True)

  with torch.no_grad():
    test_output = model.forward(test_images).cpu()
    test_prob = torch.softmax(test_output,1)*100

  # Plot the results
  fig = plt.figure(figsize=(10,5))
  plt.imshow(test_prob)
  ax = plt.gca()
  plt.xticks([])
  secax = ax.secondary_xaxis('top')
  secax.set_xticks(np.arange(0,10))
  secax.set_xlabel('Predicted digit',fontsize='x-large')
  plt.yticks(np.arange(0,10))
  ax.set_ylabel('Actual digit',fontsize='x-large')
  plt.clim(0,100)
  cbar = plt.colorbar()
  cbar.set_label('Confidence in prediction [%]',fontsize='x-large')
  plt.title('Output from epoch: '+ str(ep+1) + ' of ' + str(number_of_epochs),fontsize='x-large')
  plt.show()

  #display_filters(model.conv1.weight)

# Show each digit in the small test batch and display the digit predicted with the models confidence
f = plt.figure(figsize=(15,15))

for i in range(10):
  f.add_subplot(2,5,i+1)
  plt.imshow(test_images[i,:,:].squeeze(0).cpu())
  plt.gca().set_title("Predicted: " + str(torch.argmax(test_prob[i,:]).detach().numpy()) + " (" + str(torch.max(test_prob[i,:]).detach().numpy()) + "%)")
f.subplots_adjust(bottom=0.55)

# Finishing up
---
##### **_If you used a GPU and are fininshed with the notebook, please make sure to end your session with Google Colab by selecting:_**
> ### Runtime > Manage sessions
##### **_A window will pop up and you need to locate the current notebook and select:_**
> ### TERMINATE

# Ready for the next notebook?
---
##### You can click [here](https://colab.research.google.com/github/aewoessn/biomedical-image-analysis-and-ai/blob/main/notebooks/4.1_PredictingAgeFromFaces.ipynb) to take you to the next notebook.