# Segmentation
---
Segmentation is the process of using a deep learning neural network to produce a pixel-wise map that clusters pixels of the same class together. In this submodule, we will be using a neural network with the U-Net architecture to produce a segmentation map and show the ability for this network to segment the different layers of skin.

## Why segmentation?
---
In the previous submodules, we demonstrated how a deep learning neural network can bin images into distinct groups based on features within the image. However, these network architectures only allow for a _single_ classification output and lack any _spatial information_. This limitation may not be ideal for large images that contains more than one class within the image. 

Consider a self-driving car that is powered by artificial intelligence. Self-driving cars have a large amount of cameras, sensors, etc. that are used to inform the onboard computer about the surrounding environment. When a self-driving car acquires a snapshot of its surroundings (which happens ~30 times a second), the computer must evaluate the collected information and decide if the vehicle is going to accelerate, brake, turn, etc. One important step of this process is to identify its surrounding environment, which is difficult to do with a single output using a traditional neural network.

Alternatively, when an image is taken, a trained neural network can be scanned, or _convolved_, across an entire image to eventually build up a spatial classification map of its surroundings. This type of process has previously been used in many applications, including segmenting neuronal membranes in biomedical images [<a href="#Reference1" class="intrnllnk">1</a>]. Overall, however, this approach is relatively time consuming, and can not be completed before the next snapshot arrives for a self-driving car.

A more elegant approach is to generate a _fully convolutional_ neural network. As a brief refresher from the classification submodule: deep learning neural networks _encode_ information from the input image, typically with a _max pooling_ layer, which takes the maximum of an output map within a small region. This process in turn decreases the size of the output map by half while keeping the most important features. When this process is done multiple times, spatial information from the input image is encoded and used to predict a class from the image. However, once features from an image are encoded, they can also be _decoded_, and by performing the same number of encoding and decoding steps, a pixel-wise classification map can be produced. When this network architecture is drawn out, the shape of the encoding/decoding steps results in a "U" shape, resulting in the network being called a "U-Net" [<a href="#Reference2" class="intrnllnk">2</a>]. 

## Required packages
---
For this submodule, we will be using the following packages:
- `PyTorch` (torch) and `torchvision` for neural network generation and training
- `tqdm` for progress bars
- `imageio` for data loading
- `matplotlib` and `numpy` for data visualization

Please note that after running the next code section, the following warning will appear and can be ignored:
> /opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
  
This section will also install `jupyterquiz` if it is not already installed.

In [None]:
# Import all of the necessary packages
import torch 
import torch.nn as nn
import torchvision.transforms as T
import torchvision
from tqdm import tqdm
import imageio.v3 as iio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
import os
import os.path as osp
from IPython.display import clear_output

!pip install jupyterquiz
from jupyterquiz import display_quiz

**NOTE: Due to the size of the U-Net network, training _WILL_ take a very long time without utilizing a GPU. A notebook configured with a GPU is _HIGHLY RECOMMENDED_ for this submodule!**

In preparation for this submodule, we will initialize a GPU and mount the bucket that contains the image data used in this submodule. This section of code will create a new folder named `segmentation_bucket`, and populate it will all of the image data needed for this submodule.

To get all of the image data, the `gcsfuse` command is called, which pulls a specific folder from a _bucket_, which is a specific location on the Google Cloud Platform where the data are stored. This is similar to downloading a public dataset as seen in previous submodules, but is particularly useful for larger datasets.

In [None]:
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu");

# Prepare the folder system that will be used for mounting portions of a bucket
for folderList in ['training','validation','testing']:
    for subList in ['images','masks']:
        os.system('mkdir -p /home/jupyter/segmentation_bucket/%s/%s' % (folderList, subList))

# Mount subfolders of the main bucket
rootBucket = 'nosi-uams-alml/segmentation_data_small';
saveDir = '/home/jupyter/segmentation_bucket/';

for folderList in ['training','validation','testing']:
    for subList in ['images','masks']:
        os.system('gcsfuse --only-dir %s/%s/%s nigms-sandbox %s/%s/%s >/dev/null 2>&1' % (rootBucket,folderList,subList,saveDir,folderList,subList));

## Creating a U-Net architecture
---
The original U-Net architecture is made up of 4 encoding and decoding blocks, with each consisting of two convolutional layers. Feature map encoding is performed with max pooling operations, and an upscaling operation is used in the decoding path. Finally, to add more context during the decoding path, the feature map from the corresponding encoding block is added to the decoding block. This architecture ultimately forms a "U" shape. Here is a schematic of the network architecture [<a href="#Reference2" class="intrnllnk">2</a>]:

<p style="text-align: center"> <img src="images/unet.png" alt="Unet architecture"/> </p>

The following code block generates a U-Net architecture.

In [None]:
def double_conv(in_channels, out_channels):
    # This corresponds to the two convolution layers that make up each encoding/decoding block
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class UNet(nn.Module):

    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class

        # This is just initializing each encoding/decoding block, as well as the max pooling layer that will be used
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, x):
        #--- Encoding Path ---#
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        
        x = self.maxpool(conv3) # Encoding

        x = self.dconv_down4(x)
        
        #--- Decoding Path ---#
        x = nn.functional.interpolate(x, scale_factor=2, mode='nearest') #  Decoding
     
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = nn.functional.interpolate(x, scale_factor=2, mode='nearest')
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = nn.functional.interpolate(x, scale_factor=2, mode='nearest')
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)

        out = self.conv_last(x)

        return out

<div class="alert alert-block alert-info" style="color:black"> <b>EXERCISE</b> The default U-Net architecture performs encoding (and decoding) 4 times. Following the pattern of the network architecture, try adding an additional encoding and decoding steps and see how that affects the output of the network. </div>

## Seeing U-Net in action
---
To verify that the network is correctly producing the expected output, we can initialize the network with random weights and biases, apply the network to an image, and observe the output.

In [None]:
# Load a sample image from the imageio package
# In this case, an example image of immunohistochemistry (IHC) staining is used
img = iio.imread("images/immunohistochemistry.png")

# Create a figure to show the original image and the network output
fig = plt.figure(figsize=(10,10))

# Initial Image
ax = fig.add_subplot(1, 2, 1)
ax.set_title('Original Image')
ax.axis('off')
plot1 = plt.imshow(img)

# Initialize the U-Net network model with just a single class. This will not matter since it is just to verify the network output
model = UNet(n_class=1)

# Convert the test image to the expected input type for the network
img_torch = torch.tensor(img / 255).permute(2,0,1).unsqueeze(0).float()

# Pass the input image into the network
output = model(img_torch)

# Show the output image
ax = fig.add_subplot(1, 2, 2)
ax.set_title('Output Image from U-Net')
ax.axis('off')
plot2 = plt.imshow(output.squeeze().detach().numpy(),cmap='gray',clim=(torch.min(output),torch.max(output)))

# Print the height and width of the input and output image
print("Shape of input image: ",np.shape(np.mean(img,axis=2)))
print("Shape of output image: ",np.shape(output.squeeze().detach().numpy()))

Although the actual values of the outputted image do not mean anything, we can clearly see that the output of the network is the same size as the input image. This means that we can use this network architecture to semantically segment an image.

<div class="alert alert-block alert-info"> <b>Knowledge Check</b> </div>

In [None]:
!pip install jupyterquiz --quiet
from jupyterquiz import display_quiz
display_quiz('../quiz_files/submodule_03/kc1.json')

## Training a segmentation network
---
Similar to previous classifcation submodules, the dataset used to train a segmentation network must have an input image as well as some type of label. However, instead of a single label for an input image, a pixel-wise labeled image must be used to determine the accuracy of the network.
For example, here is a labeled image from the [cityscapes dataset](https://www.cityscapes-dataset.com) where the pixels are color-coded as:

<div>
    <table>
        <thead>
            <tr>
                <th>Color</th>
                <th>Label</th>
            </tr>
        </thead>
        <tbody>
            <tr>
                <td> <span style="color:#B06BB0;">purple</span> </td>
                <td> road </td>
            </tr>
            <tr>
                <td> <span style="color:#FF39F5;">pink</span> </td>
                <td> sidewalk </td>
            </tr>
            <tr>
                <td> <span style="color:#FF1A15;">red</span> </td>
                <td> human </td>
            </tr>
            <tr>
                <td> <span style="color:#2024FF;">blue</span> </td>
                <td> car </td>
            </tr>
            <tr>
                <td> <span style="color:#EFF215;">yellow</span> </td>
                <td> traffic sign </td>
            </tr>
            <tr>
                <td> <span style="color:#949695;">light gray</span> </td>
                <td> pole </td>
            </tr>
            <tr>
                <td> <span style="color:#62665F;">dark gray</span> </td>
                <td> building </td>
            </tr>
            <tr>
                <td> <span style="color:#000000;">black</span> </td>
                <td> clutter </td>
            </tr>
            <tr>
                <td> <span style="color:#997112;">dark yellow</span> </td>
                <td> street sign </td>
            </tr>
            <tr>
                <td> <span style="color:#76B1DB;">light blue</span> </td>
                <td> sky </td>
            </tr>
            <tr>
                <td> <span style="color:#781C2B;">dark red</span> </td>
                <td> bicycle </td>
            </tr>
            <tr>
                <td> <span style="color:#7AA12E;">green</span> </td>
                <td> vegetation </td>
            </tr>        
        </tbody>
    </table>

<p style="text-align: center"> <img src="images/tuebingen00.png" alt="labeled image of city" width="1000"/> </p>
    
</div>

## Application of a U-Net architecture
---
As previously stated, segmentation is a powerful tool, particularly for biomedical image datasets. For an example application of segmentation with a U-Net, we will be segmenting images of skin that were acquired using fluorescence microscopy. The dataset we will be using for this application has previously been published [<a href="#Reference3" class="intrnllnk">3</a>,<a href="#Reference4" class="intrnllnk">4</a>].

Skin is made up of three primary layers (_epidermis, dermis, hypodermis_) each with its own structure and function. These layers aggregate to form a protective boundary for the human body.
<p style="text-align: center"> <img src="images/skinlayers.png" alt="layers of skin" width="600" /></p>

<p> For this application, <em>en face</em> images of skin autofluorescence were acquired where the sources of contrast are two molecular cofactors associated with cellular metabolism (<span style="color:#00FF00;">green</span> and <span style="color:#0000FF;">blue</span>, localized to cells), as well as collagen fibers found within the dermal layer (<span style="color:#FF0000;">red</span>). </p>

Within skin, the cells in question are mostly localized in the epidermis, and the relative amount of the two metabolic cofactors can inform us on their cellular function (i.e. oxidative phosphorylation or glucose catabolism). Here is an example of one of these images: 
<p style="text-align: center"> <img src="images/mpmexample_annotated.png" alt="en face image of skin" width="400" /> </p>

Accurate segmentation of the layers of skin in the _en face_ images is important for summarizing the metabolic state that cells are in but this is difficult and tedious to do by hand. For this submodule, a U-Net architecture will be trained to semantically segment different layers and features within skin. 

### Loading and viewing the dataset
---
Typically, in image processing workflows, a lot of future issues can be mitigated by just observing what the data look like prior to any processing. To verify that the images used for training load correctly from the bucket, and to visualize the images, we will choose a small subset of the training dataset to display. Additionally, we will also output the corresponding labeled _ground truth_ image, which is a hand-traced pixel-wise classification map for the image. The goal is to teach the network to generate its own pixel-wise map.

Please note the following code block is output the following warnings:

> /opt/conda/lib/python3.7/site-packages/IPython/core/events.py:89: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  func(*args, **kwargs)
>
> /opt/conda/lib/python3.7/site-packages/IPython/core/pylabtools.py:151: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.
  fig.canvas.print_figure(bytes_io, **kw)
  
These can be ignored.

In [None]:
# Get a list of all of the training images
bucket_loc = '/home/jupyter'
filelist = os.listdir(bucket_loc + '/segmentation_bucket/training/images/')

# Pick 4 images at random
random_int = np.random.randint(low=0, high=len(filelist), size=4)

# Display the images
fig, axs = plt.subplots(2, 4, layout="tight", figsize=(15, 8))
for i in range(4):
    axs[0,i].imshow(iio.imread(bucket_loc + '//segmentation_bucket/training/images/'+filelist[random_int[i]]))
    axs[0,i].set_axis_off()

for i in range(4):
    axs[1,i].imshow(iio.imread(bucket_loc + '//segmentation_bucket/training/masks/'+filelist[random_int[i]]), cmap='jet',clim=(36,216))
    axs[1,i].set_axis_off()

# Set up a colormap for reference
cax = plt.axes([1.01, 0.1, 0.05, 0.8]);
cmap = cm.get_cmap('jet', 6) 
ticks = np.linspace(0,1,13)

cbar = plt.colorbar(mappable = cm.ScalarMappable(norm=None, cmap=cmap), cax=cax);
cbar.ax.set_yticks(ticks[1:len(ticks)-1:2]);
cbar.ax.set_yticklabels(['epidermis','dermis','granulation','keratin','hair','background']);

There are 6 different labels in this dataset: epidermis, dermis, granulation, keratin, hair, background. However, for the purpose of this submodule, we will be generalizing these labels, which will simplify the training process. The new labels will be the following: epidermis, dermis, hair, and background.

Here are the same four images, but with the altered classes.

In [None]:
# Setup the new classes
classes = np.array([36,72,108,144,180,216])
actualClass = np.array([1,2,1,1,3,4])

# Display the images
fig, axs = plt.subplots(2, 4, layout="tight", figsize=(15, 8))
for i in range(4):
    axs[0,i].imshow(iio.imread(bucket_loc + '/segmentation_bucket/training/images/'+filelist[random_int[i]]))
    axs[0,i].set_axis_off()

for i in range(4):
    mask = iio.imread(bucket_loc + '/segmentation_bucket/training/masks/'+filelist[random_int[i]])
    
    for j in range(len(classes)):
        mask[np.where(mask==classes[j])] = actualClass[j]  
        
    axs[1,i].imshow(mask, cmap='jet',clim=(1,4))
    axs[1,i].set_axis_off()
    
# Set up a colormap for reference
cax = plt.axes([1.01, 0.1, 0.05, 0.8])
cmap = cm.get_cmap('jet', 4) 
ticks = np.linspace(0,1,9)

cbar = plt.colorbar(mappable = cm.ScalarMappable(norm=None, cmap=cmap), cax=cax);
cbar.ax.set_yticks(ticks[1:len(ticks)-1:2]);
cbar.ax.set_yticklabels(['epidermis','dermis','hair','background']);

### Creating a dataset class
---
Datasets come in all different shapes, sizes, and formats, requiring the end user (us) to load and prepare the data in a meaningful way. The purpose of this `UNetDataset` class is to prepare images as they are needed during the training process. This includes converting the images to a PyTorch tensor, performing random X/Y flipping augmentations, and combining the input image and the corresponding mask so they can be quickly referenced.

In [None]:
class UNetDataset(torch.utils.data.Dataset):
    # Define what will be ran at initialization of the UNetDataset class
    
    def __init__(self, image_array, mask_array):
    # Attach the list of images and ages to the class
    # These are attached to the class so that they can be accessed by other methods in the class
        self.image = image_array
        self.mask = mask_array

    # Initialize a transform that will be used later
        self.tform = T.ToTensor()                       

        # There are two required methods for a class that inherits from torch.utils.data.Dataset:
        # __len__()
        # __getitem__()

    # Return the length of the dataset
    def __len__(self):
        return self.image.shape[3]

    # Return a single image from the dataset, as well as the mask associated with the image
    def __getitem__(self,idx):
        # Return a single variable (dict) that contains both the image and the mask
        out = {
            'image': self.tform(self.image[:,:,:,idx]).float(),
            'mask': self.tform(self.mask[:,:,idx]).long()
        }
        
        # Random vflip and hflip
        if np.random.rand() > 0.5:
            out['image'] = torchvision.transforms.functional.hflip(out['image'])
            out['mask'] = torchvision.transforms.functional.hflip(out['mask'])
            
        if np.random.rand() > 0.5:
            out['image'] = torchvision.transforms.functional.vflip(out['image'])
            out['mask'] = torchvision.transforms.functional.vflip(out['mask'])
 
        return out
    

<div class="alert alert-block alert-info" style="color:black"> <b>EXERCISE</b> In the previous submodule, augmentation was shown to improve the robustness of image classification. Try implementing other types of augmentation aside from horizontal and vertical flipping and assess how these augmentations affect the network accuracy. A list of augmentations can be found <a href=https://pytorch.org/vision/stable/transforms.html#functional-transforms>here</a>. </div>

## Preparing datasets
---
Once the U-Net architecture and dataset class are ready to go, we can load (or _buffer_) all of the images into memory. After loading a particular input image, the image is first _normalized_ such that the intensity values range between 0 and 1, and the annotated masks are adjusted to the limited number of classes as previously described.

Finally, a `UNetDataset` is made for each of the three datasets (training, validation, and testing), and the dataset class is wrapped within a `DataLoader` that allows us to easily modify parameters regarding training such as shuffling and batch size.

In [None]:
# Get a list of all of the training images
filelist = os.listdir(bucket_loc + '/segmentation_bucket/training/images/')
classes = np.array([36,72,108,144,180,216])
actualClass = np.array([1,2,1,1,3,4])

train_images = np.zeros((512,512,3,len(filelist)))
train_masks = np.zeros((512,512,len(filelist)))

for i in tqdm(range(len(filelist)), desc="Loading training set", unit="img"):
    if ".tif" in filelist[i]:
        train_images[:,:,:,i] = iio.imread(bucket_loc + '/segmentation_bucket/training/images/'+filelist[i]) / 255
        train_masks[:,:,i] = iio.imread(bucket_loc + '/segmentation_bucket/training/masks/'+filelist[i])
        
for i in tqdm(range(len(classes)), desc="Processing training masks", unit="class"):
    train_masks[np.where(train_masks==classes[i])] = actualClass[i]-1
    
# Get a list of all of the validation images
filelist = os.listdir(bucket_loc + '/segmentation_bucket/validation/images/')

valid_images = np.zeros((512,512,3,len(filelist)))
valid_masks = np.zeros((512,512,len(filelist)))
for i in tqdm(range(len(filelist)), desc="Loading validation set", unit="img"):
    if ".tif" in filelist[i]:
        valid_images[:,:,:,i] = iio.imread(bucket_loc + '/segmentation_bucket/validation/images/'+filelist[i]) / 255
        valid_masks[:,:,i] = iio.imread(bucket_loc + '/segmentation_bucket/validation/masks/'+filelist[i])

for i in tqdm(range(len(classes)), desc="Processing validation masks", unit="class"):
    valid_masks[np.where(valid_masks==classes[i])] = actualClass[i]-1
    
# Get a list of all of the training images
filelist = os.listdir(bucket_loc + '/segmentation_bucket/testing/images/')

test_images = np.zeros((512,512,3,len(filelist)))
test_masks = np.zeros((512,512,len(filelist)))

for i in tqdm(range(len(filelist)), desc="Loading testing set", unit="img"):
    if ".tif" in filelist[i]:
        test_images[:,:,:,i] = iio.imread(bucket_loc + '/segmentation_bucket/testing/images/'+filelist[i]) / 255
        test_masks[:,:,i] = iio.imread(bucket_loc + '/segmentation_bucket/testing/masks/'+filelist[i])

for i in tqdm(range(len(classes)), desc="Processing testing masks", unit="class"):
    test_masks[np.where(test_masks==classes[i])] = actualClass[i]-1
    
# Generate dataset classes
trainSet = UNetDataset(image_array=train_images, mask_array=train_masks)
validSet = UNetDataset(image_array=valid_images, mask_array=valid_masks)
testSet  = UNetDataset(image_array=test_images,  mask_array=test_masks)

# Create dataloaders for training process
trainLoader = torch.utils.data.DataLoader(dataset=trainSet, shuffle=True, batch_size=1)
validLoader = torch.utils.data.DataLoader(dataset=validSet, shuffle=False, batch_size=1)
testLoader =  torch.utils.data.DataLoader(dataset=testSet, shuffle=False, batch_size=1)

### Verification of dataset class with images
---
To verify that everything is working as expected prior to the training loop, we can use our newly established `UNetDataset` class to load in random images from the training dataset.

In [None]:
# Pick 4 images at random
random_int = np.random.randint(low=0, high=len(filelist), size=4)

# Display the images
fig, axs = plt.subplots(2, 4, layout="tight", figsize=(15, 8))
for i in range(4):
    out = trainSet.__getitem__(random_int[i])
    axs[0,i].imshow(out['image'].permute(2,1,0))
    axs[0,i].set_axis_off()
    
    axs[1,i].imshow(out['mask'].permute(2,1,0),cmap='jet',clim=(0,3))
    axs[1,i].set_axis_off()   

# Set up a colormap for reference
cax = plt.axes([1.01, 0.1, 0.05, 0.8])
cmap = cm.get_cmap('jet', 4) 
ticks = np.linspace(0,1,9)

cbar = plt.colorbar(mappable = cm.ScalarMappable(norm=None, cmap=cmap), cax=cax);
cbar.ax.set_yticks(ticks[1:len(ticks)-1:2]);
cbar.ax.set_yticklabels(['epidermis','dermis','hair','background']);

## Train the U-Net
---

Prior to creating a training loop, we will first create a function that will be used to visualize the network training (`update_plots`).

In [None]:
# Define method for updating plot
def update_plots(ep_range,train_acc,train_loss,valid_acc,valid_loss):
    clear_output(wait=True)

    fig, axs = plt.subplots(2, 1, layout="tight", figsize=(10, 10))
    axs[0].plot(ep_range,train_acc,'ro-',ep_range,valid_acc,'bo-')
    axs[0].set_xlabel('Epoch');
    axs[0].set_ylabel('Accuracy [%]');

    line = axs[1].plot(ep_range,train_loss,'ro-',ep_range,valid_loss,'bo-')
    axs[1].set_xlabel('Epoch');
    axs[1].set_ylabel('Loss');
    axs[1].legend([line[0],line[1]],['training','validation'], loc='upper right',fontsize='x-large')
    plt.show()

Now we can initialize the network, set some training parameters such as the loss function and the optimizer used, as well as define the total number of training epochs. For this submodule, a total of 10 epochs has been chosen due to previous tests with this dataset. In general, the number of epochs is arbitrary and is based on how well the network learns.

In your own work, you can determine the number of epochs by observing the accuracy and loss from the training and validation datasets every epoch as the network is trained. If the number of epochs is too high (typically >50), then _overfitting_, also known as memorization, of the training dataset can occur, resulting in a less robust network. Alternatively, if the number of epochs is too low, then the network does not have enough data to become accurate.

In [None]:
# Initialize the model and set training parameters
model = UNet(n_class=4).to(device);
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
          
# Set the number of training epochs
total_ep = 10

# Set up variable used for plotting
ep_range = np.array(0)

<div class="alert alert-block alert-info"> <b>Knowledge Check</b> </div>

In [None]:
display_quiz('../quiz_files/submodule_03/kc2.json')

Next, we will perform an initial pass of the training and validation dataset with the randomly initialized network, which will help us get a baseline accuracy. For this particular pass of both datasets, the `with torch.no_grad():` line is used to ensure that the network is not learning anything during this step.

In [None]:
#--- Initial check of accuracy ---#
with torch.no_grad():
    
    # Training images
    tot_loss = 0
    tot_acc = 0
    with tqdm(total=len(trainSet), desc=f'Training (0 / {total_ep})', unit='img') as pbar:
        for batch in trainLoader:
            # Get a batch of training images/masks from the data loader and send them to the gpu
            img = batch['image'].to(device)
            mask = batch['mask'].squeeze(1).to(device)

            # Forward pass for the network
            masks_pred = model(img)

            # Calculate loss
            loss = criterion(masks_pred, mask)
            tot_loss += loss.item()

            # See what the network actually predicted for each pixel using the softmax function
            masks_pred = torch.argmax(torch.softmax(masks_pred,1),dim=1).cpu().detach().numpy()
            mask = mask.cpu().detach().numpy()

            # Calculate an accuracy for the image
            acc = np.sum(mask == masks_pred) / np.size(masks_pred)
            tot_acc += acc

            # Update progress bar
            pbar.set_postfix(**{'Loss (batch)': loss.item(),'Accuracy (batch)': acc*100})
            pbar.update(batch['image'].shape[0])
        pbar.set_postfix(**{'Average loss': tot_loss/len(trainSet),'Average accuracy': (tot_acc/len(trainSet))*100})
        
        train_acc = np.array((tot_acc/len(trainSet))*100)
        train_loss = tot_loss/len(trainSet)
        
    # Validation loop
    tot_loss = 0
    tot_acc = 0
    with tqdm(total=len(validSet), desc=f'Validation (0 / {total_ep})', unit='img') as pbar:
        for batch in validLoader:
            # Get a batch of training images/masks from the data loader and send them to the gpu
            img = batch['image'].to(device)
            mask = batch['mask'].squeeze(1).to(device)

            # Forward pass for the network
            masks_pred = model(img)

            # Calculate loss
            loss = criterion(masks_pred, mask)
            tot_loss += loss.item()

            # See what the network actually predicted for each pixel using the softmax function
            masks_pred = torch.argmax(torch.softmax(masks_pred,1),dim=1).cpu().detach().numpy()
            mask = mask.cpu().detach().numpy()

            # Calculate an accuracy for the image
            acc = np.sum(mask == masks_pred) / np.size(masks_pred)
            tot_acc += acc
            
            # Update progress bar
            pbar.set_postfix(**{'Loss (batch)': loss.item(),'Accuracy (batch)': acc*100})
            pbar.update(batch['image'].shape[0])

        pbar.set_postfix(**{'Average loss': tot_loss/len(validSet),'Average accuracy': (tot_acc/len(validSet))*100})
        
        valid_acc = np.array((tot_acc/len(validSet))*100)
        valid_loss = tot_loss/len(validSet)  

# Clear the output and print the initial baseline accuracy
clear_output(wait=True)
print('Initial training accuracy: ' + str(np.round(train_acc,2)) + '%')
print('Initial validation accuracy: ' + str(np.round(valid_acc,2)) + '%')

To help interpret these accuracy numbers, you can imagine that an accuracy of 50% corresponds to the network getting the correct classification for each pixel 50% of the time, which is basically a coin flip. 

Since the network has just been initialized with random numbers, the expected accuracy should be very low (<50%). This provides a good baseline measurement, since the network accuracy after any amount of training should never be lower than this number.

Now we can perform the network training and visualize the performance of the network. It is important to keep in mind that the network is only learning from the training dataset, and thus the `with torch.no_grad():` line is used during the assessment of the validation dataset.

**NOTE: This network will train for 10 epochs. <em><u>If you are using a CPU, this will take about 10 hours to complete</u></em>. Alternatively, this will take about 10 minutes to complete with a GPU.**

In [None]:
#--- Actual training loop ---#
for ep in range(total_ep):

    # Update plots
    update_plots(ep_range, train_acc, train_loss, valid_acc, valid_loss)
    
    # Training loop
    tot_loss = 0
    tot_acc = 0
    with tqdm(total=len(trainSet), desc=f'Training ({ep + 1} / {total_ep})', unit='img') as pbar:
        for batch in trainLoader:
            # Get a batch of training images/masks from the data loader and send them to the gpu
            img = batch['image'].to(device)
            mask = batch['mask'].squeeze(1).to(device)

            # Forward pass for the network
            masks_pred = model(img)

            # Calculate loss
            loss = criterion(masks_pred, mask)
            tot_loss += loss.item()

            # Zero the optimizer, backpropagate, and step the optimizer
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # See what the network actually predicted for each pixel using the softmax function
            masks_pred = torch.argmax(torch.softmax(masks_pred,1),dim=1).cpu().detach().numpy()
            mask = mask.cpu().detach().numpy()

            # Calculate an accuracy for the image
            acc = np.sum(mask == masks_pred) / np.size(masks_pred)
            tot_acc += acc
            
            # Update progress bar
            pbar.set_postfix(**{'Loss (batch)': loss.item(),'Accuracy (batch)': acc*100})
            pbar.update(batch['image'].shape[0])
            
        pbar.set_postfix(**{'Average loss': tot_loss/len(trainSet),'Average accuracy': (tot_acc/len(trainSet))*100})
        
        # Prepare for next update of plots
        ep_range = np.append(ep_range, ep+1)
        train_acc = np.append(train_acc, np.array((tot_acc/len(trainSet))*100))
        train_loss = np.append(train_loss, tot_loss/len(trainSet))
        
    # Validation loop
    tot_loss = 0
    tot_acc = 0
    with tqdm(total=len(validSet), desc=f'Validation ({ep + 1} / {total_ep}', unit='img') as pbar:
        for batch in validLoader:
            with torch.no_grad():
                # Get a batch of training images/masks from the data loader and send them to the gpu
                img = batch['image'].to(device)
                mask = batch['mask'].squeeze(1).to(device)

                # Forward pass for the network
                masks_pred = model(img)

                # Calculate loss
                loss = criterion(masks_pred, mask)
                tot_loss += loss.item()

                # See what the network actually predicted for each pixel using the softmax function
                masks_pred = torch.argmax(torch.softmax(masks_pred,1),dim=1).cpu().detach().numpy()
                mask = mask.cpu().detach().numpy()

                # Calculate an accuracy for the image
                acc = np.sum(mask == masks_pred) / np.size(masks_pred)
                tot_acc += acc
            
                # Update progress bar
                pbar.set_postfix(**{'Loss (batch)': loss.item(),'Accuracy (batch)': acc*100})
                pbar.update(batch['image'].shape[0])
                
        pbar.set_postfix(**{'Average loss': tot_loss/len(validSet),'Average accuracy': (tot_acc/len(validSet))*100})
        
        # Prepare for next update of plots
        valid_acc = np.append(valid_acc, np.array((tot_acc/len(validSet))*100))
        valid_loss = np.append(valid_loss, tot_loss/len(validSet))     

Now that training has completed, we will calculate a final accuracy for the training, validation, and most importantly, the testing data set.

In [None]:
#--- Final check of accuracy ---#
with torch.no_grad():

    # Update plots
    update_plots(ep_range, train_acc, train_loss, valid_acc, valid_loss)
    final_train_acc = 0
    final_train_loss = 0
    final_valid_acc = 0
    final_valid_loss = 0
    final_test_acc = 0
    final_test_loss = 0
    
    #--- Training dataset ---#
    with tqdm(total=len(trainSet), desc=f'Final Training', unit='img') as pbar:
        tot_loss = 0
        tot_acc = 0
    
        for batch in trainLoader:
            # Get a batch of training images/masks from the data loader and send them to the gpu
            img = batch['image'].to(device)
            mask = batch['mask'].squeeze(1).to(device)

            # Forward pass for the network
            masks_pred = model(img)

            # Calculate loss
            loss = criterion(masks_pred, mask)
            tot_loss += loss.item()

            # See what the network actually predicted for each pixel using the softmax function
            masks_pred = torch.argmax(torch.softmax(masks_pred,1),dim=1).cpu().detach().numpy()
            mask = mask.cpu().detach().numpy()

            # Calculate an accuracy for the image
            acc = np.sum(mask == masks_pred) / np.size(masks_pred)
            tot_acc += acc
                
            # Update progress bar
            pbar.set_postfix(**{'Loss (batch)': loss.item(),'Accuracy (batch)': acc*100})
            pbar.update(batch['image'].shape[0])
            
        pbar.set_postfix(**{'Average loss': tot_loss/len(trainSet),'Average accuracy': (tot_acc/len(trainSet))*100})  
        
        # Calculate the final loss and accuracy
        final_train_loss = tot_loss/len(trainSet)
        final_train_acc = (tot_acc/len(trainSet))*100
      
    #--- Validation dataset ---#
    with tqdm(total=len(validSet), desc=f'Final Validation', unit='img') as pbar:
        tot_loss = 0
        tot_acc = 0
        
        for batch in validLoader:
            # Get a batch of training images/masks from the data loader and send them to the gpu
            img = batch['image'].to(device)
            mask = batch['mask'].squeeze(1).to(device)

            # Forward pass for the network
            masks_pred = model(img)

            # Calculate loss
            loss = criterion(masks_pred, mask)
            tot_loss += loss.item()

            # See what the network actually predicted for each pixel using the softmax function
            masks_pred = torch.argmax(torch.softmax(masks_pred,1),dim=1).cpu().detach().numpy()
            mask = mask.cpu().detach().numpy()

            # Calculate an accuracy for the image
            acc = np.sum(mask == masks_pred) / np.size(masks_pred)
            tot_acc += acc
                
            # Update progress bar
            pbar.set_postfix(**{'Loss (batch)': loss.item(),'Accuracy (batch)': acc*100})
            pbar.update(batch['image'].shape[0])
            
        pbar.set_postfix(**{'Average loss': tot_loss/len(validSet),'Average accuracy': (tot_acc/len(validSet))*100})
        
        # Calculate the final loss and accuracy
        final_valid_loss = tot_loss/len(validSet)
        final_valid_acc = (tot_acc/len(validSet))*100
        
    #--- Testing dataset ---#
    with tqdm(total=len(testSet), desc=f'Final Testing', unit='img') as pbar:
        tot_loss = 0
        tot_acc = 0
        
        for batch in testLoader:
            # Get a batch of training images/masks from the data loader and send them to the gpu
            img = batch['image'].to(device)
            mask = batch['mask'].squeeze(1).to(device)

            # Forward pass for the network
            masks_pred = model(img)

            # Calculate loss
            loss = criterion(masks_pred, mask)
            tot_loss += loss.item()

            # See what the network actually predicted for each pixel using the softmax function
            masks_pred = torch.argmax(torch.softmax(masks_pred,1),dim=1).cpu().detach().numpy()
            mask = mask.cpu().detach().numpy()

            # Calculate an accuracy for the image
            acc = np.sum(mask == masks_pred) / np.size(masks_pred)
            tot_acc += acc
            
            # Update progress bar
            pbar.set_postfix(**{'Loss (batch)': loss.item(),'Accuracy (batch)': acc*100})
            pbar.update(batch['image'].shape[0])
            
        pbar.set_postfix(**{'Average loss': tot_loss/len(testSet),'Average accuracy': (tot_acc/len(testSet))*100})  
        
        # Calculate the final loss and accuracy
        final_test_loss = tot_loss/len(testSet)
        final_test_acc = (tot_acc/len(testSet))*100

# Final output containing plots of final accuracy/loss
clear_output(wait=True)

fig, axs = plt.subplots(2, 1, layout="tight", figsize=(10, 10))
axs[0].plot(ep_range,train_acc,'ro-',ep_range,valid_acc,'bo-')
axs[0].plot(total_ep,final_test_acc,'go')
axs[0].set_xlabel('Epoch');
axs[0].set_ylabel('Accuracy [%]');

line = axs[1].plot(ep_range,train_loss,'ro-',ep_range,valid_loss,'bo-',total_ep,final_test_loss,'go-')
axs[1].set_xlabel('Epoch');
axs[1].set_ylabel('Loss');
axs[1].legend([line[0],line[1],line[2]],['training','validation','testing'], loc='upper right',fontsize='x-large')
plt.show()
    
print('Final training accuracy: ' + str(np.round(final_train_acc,2)) + '%')
print('Final validation accuracy: ' + str(np.round(final_valid_acc,2)) + '%')
print('Final testing accuracy: ' + str(np.round(final_test_acc,2)) + '%')

## Check the output
---
After the training is complete, it is a good idea to output some of the images the network produces to check that the network is generating an output that you expect. For the following grid of images, the first row is the input image, the second row is the manually segmented mask, and the third row is the segmented mask generated by the trained CNN.

In [None]:
# Pick 4 images at random
random_int = np.random.randint(low=0, high=len(filelist), size=4)

# Display the images
fig, axs = plt.subplots(3, 4, layout="tight", figsize=(20, 15))
for i in range(4):
    out = trainSet.__getitem__(random_int[i])
    
    img = out['image'].to(device)
    mask = out['mask'].squeeze(1)

    masks_pred = model(img.unsqueeze(0))

    masks_pred = torch.argmax(torch.softmax(masks_pred,1),dim=1).cpu()
    
    axs[0,i].imshow(img.cpu().permute(2,1,0))
    axs[0,i].set_axis_off()
    
    axs[1,i].imshow(mask.permute(2,1,0),cmap='jet',clim=(0,3))
    axs[1,i].set_axis_off()

    axs[2,i].imshow(masks_pred.permute(2,1,0),cmap='jet',clim=(0,3))
    axs[2,i].set_axis_off()

<div class="alert alert-block alert-info" style="color:black"> <b>EXERCISE</b> In the original set of outputs, the network performs relatively well. However, this network can be improved by adjusting training hyperparameters within the loss function, as well as introducing additional training constraints such as schedulers. Try making adjustments to the training loop and assess how they affect the final output. </div>

<div class="alert alert-warning" style="color:black"> <b>CHALLENGE</b> One major drawback to this dataset is that the number of pixels for each class are not balanced (i.e. some classes contains more pixels overall relative to other classes). Within this training dataset, here is the distribution of the classes:
<ul>
    <li> epidermis: 37.64% </li>
    <li> dermis: 18.47% </li>
    <li> hair: 3.05% </li>
    <li> background 40.84% </li>
</ul>

Can you adjust some training parameters to take this distribution into consideration?

<i>HINT:</i> The CrossEntropyLoss function documentation can be found <a href=https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html>here</a>.
</div>

## Conclusion
---
In this submodule, we discussed and explored the use of segmentation. This type of segmentation is a powerful tool for generating pixel-wise classification maps, rather than a single classification for an entire input image. To demonstrate this, we generated a U-Net CNN architecture, and trained the network to segment different regions of skin within fluorescence images.

### References
---
<div> 
    <span id="Reference1"> 
        <p> 
            [1] Cireşan DC, et al.
            <i> Deep neural networks segment neuronal membranes in electron microscopy images</i>.
            2012.
            <a href="https://papers.nips.cc/paper/2012/hash/459a4ddcb586f24efd9395aa7662bc7c-Abstract.html">Link</a> 
        </p> 
    </span>   
    <span id="Reference2"> 
        <p> 
            [2] Ronneberger O, et al. 
            <i> U-Net: Convolutional Networks for Biomedical Image Segmentation</i>.
            2015.
            <a href="https://arxiv.org/abs/1505.04597">Link</a> 
        </p> 
    </span>
</div>
<div>
    <span id="Reference3"> 
        <p> 
            [3] Jones JD, Quinn KP.
            <i> Automated Quantitative Analysis of Wound Histology Using Deep-Learning Neural Networks</i>.
            2020.
            <a href="https://doi.org/10.1016/j.jid.2020.10.010">Link</a> 
        </p> 
    </span>
</div>
<div>
    <span id="Reference4"> 
        <p> 
            [4] Jones JD, et al.
            <i> Quantifying Age-Related Changes in Skin Wound Metabolism Using In Vivo Multiphoton Microscopy</i>.
            2020.
            <a href="https://doi.org/10.1089/wound.2019.1030">Link</a> 
        </p> 
    </span>
</div>