# Lightweight LB-FCNN on MNIST

### Import Libraries

In [None]:
!pip install torchstat

In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from io import StringIO
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from torchstat import stat
from torchsummary import summary

### Define Variables & Helper Functions

In [None]:
# Set seed for random number generation to create reproducible results
random_seed = 5
torch.manual_seed(random_seed)

### Load Data

In [None]:
training_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.Grayscale(3), # convert from 1 grayscale channel to 3 RGB channels
        transforms.ToTensor()] 
    )
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.Grayscale(3), # convert from 1 grayscale channel to 3 RGB channels
        transforms.ToTensor()] 
    )
)

### Define Model

In [None]:
# Model hyperparameters; epochs & batch size from Magboo & Abu
epochs = 3 # 100
batch_size = 16

In [None]:
# Image channels - 3 for R, G, B feature maps
channels = 3

In [None]:
# Depthwise Separable Convolution
class DSConv(nn.Module):
    # Define layers
    def __init__(self, kernel_size):
        super(DSConv, self).__init__()
        self.depthwise_conv = nn.Conv2d(in_channels = channels, out_channels = channels, padding = 'same',
                                        kernel_size = kernel_size, bias = False, groups = channels)
        self.pointwise_conv = nn.Conv2d(in_channels = channels, out_channels = 1, 
                                        kernel_size = 1, bias = False)

    # Apply layers
    def forward(self, x):
        x = F.leaky_relu(self.depthwise_conv(x))
        x = F.leaky_relu(self.pointwise_conv(x))
        
        return x

In [None]:
# Multiscale Depthwise Separable Convolution module
class MDSConv(nn.Module):
    # Define layers
    def __init__(self):
        super(MDSConv, self).__init__()
        self.conv = nn.Conv2d(in_channels = channels, out_channels = channels, kernel_size = 1, bias = False)
        self.norm1 = nn.BatchNorm2d(num_features = channels)
        self.norm2 = nn.BatchNorm2d(num_features = 1)
        self.ds_conv1 = DSConv(kernel_size = 3)
        self.ds_conv2 = DSConv(kernel_size = 5)
        self.ds_conv3 = DSConv(kernel_size = 7)
        
    # Apply layers
    def forward(self, x):
        x = F.leaky_relu(self.conv(x))
        x = self.norm1(x)
        
        # Depthwise separable convolution with 3x3 kernel
        x1 = self.ds_conv1(x)
        x1 = self.norm2(x1)
        
        # Depthwise separable convolution with 5x5 kernel
        x2 = self.ds_conv2(x)
        x2 = self.norm2(x2)
        
        # Depthwise separable convolution with 7x7 kernel
        x3 = self.ds_conv3(x)
        x3 = self.norm2(x3)
        
        x = torch.concat((x1, x2, x3), dim = 1)
        x = self.norm1(x)
        x = F.leaky_relu(self.conv(x))
        
        return x

In [None]:
# Residual Connection module
class ResConnection(nn.Module):
    # Define layers
    def __init__(self):
        super(ResConnection, self).__init__()
        self.conv = nn.Conv2d(in_channels = channels, out_channels = channels, kernel_size = 1, bias = False)
        self.norm = nn.BatchNorm2d(num_features = channels)

    # Apply layers
    def forward(self, x):
        x = F.leaky_relu(self.conv(x))
        x = self.norm(x)
        
        return x

In [None]:
# The main building block of LB-FCNN light architecture
class LBFCNNLightBlock(nn.Module):
    # Define layers
    def __init__(self):
        super(LBFCNNLightBlock, self).__init__()
        self.mdsc = MDSConv()
        self.rc = ResConnection()
        self.conv = nn.Conv2d(in_channels = channels, out_channels = channels, kernel_size = 1, bias = False)
        self.norm = nn.BatchNorm2d(num_features = channels)
        
    # Apply layers
    def forward(self, x):
        x_mdsc = self.mdsc(x)
        x_rc = self.rc(x)
        x = torch.add(x_mdsc, x_rc)
        x = F.leaky_relu(self.conv(x))
        x = self.norm(x)
        
        return x

In [None]:
# Create LB-FCNN light model
class LBFCNNLight(nn.Module):
    # Define layers
    def __init__(self):
        super().__init__()
        self.mdsc = MDSConv()
        self.lbfcnn_block = LBFCNNLightBlock()
        self.pool = nn.Conv2d(in_channels = channels, out_channels = channels, 
                              kernel_size = 2, stride = 2, bias = False)
        self.fc = nn.Linear(in_features = channels, out_features = 10)
        # TODO: Fix pool; paper says kernel size=3, but kernel=3 results in tensor shapes that don't match the paper diagram
        # Some other value somewhere must be excess by 1
        
        # TODO: Fix feature maps here don't match feature maps in paper diagram
        
    # Apply layers
    def forward(self, x):
        x = self.mdsc(x)
        x = self.pool(x)
        x = self.lbfcnn_block(x)
        x = self.pool(x)
        x = self.lbfcnn_block(x)
        x = self.pool(x)
        x = self.lbfcnn_block(x)
        x = self.pool(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(input = x, start_dim = 1)
        x = self.fc(x)
        x = F.softmax(input = x, dim = 1)
        
        return x

In [None]:
# Allocate tensors to the device used for computation
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Performing torch operations on {device} device")

In [None]:
# Inspect model shapes per layer
model = LBFCNNLight().to(device)
summary(model, (3, 646, 220))

### Train Model

In [None]:
# Evaluation metrics
results_accuracy = []
results_precision = []
results_sensitivity = []
results_specificity = []
results_f1 = []
train_losses = []
test_losses = []

In [None]:
# Train model for k folds, with e epochs each 
trainloader = DataLoader(training_data, batch_size=batch_size)
testloader = DataLoader(test_data, batch_size=batch_size)
    
train_losses.append([])
test_losses.append([])
    
# Instantiate model
model = LBFCNNLight().to(device)
    
# Define criterion (function used to compute loss) and optimizer for model
criterion = nn.NLLLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
    
# ---------- Run for x epochs on training set and test set ----------
for epoch in range(epochs):
        
    for phase in ['train', 'test']:
        if phase == 'train':
            model.train()
            dataloader = trainloader
        else:
            model.eval()
            dataloader = testloader
            
        running_loss = 0.0
        running_corrects = 0
        total_batch_count = 0

        for inputs, labels in dataloader:
            # Get the inputs; data is a list of [images, labels]
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            if phase == 'train':
                # Backpropagation
                optimizer.zero_grad() # Reset gradients from previous passes
                loss.backward() # Compute gradients using derivative of loss
                optimizer.step() # Update values using gradients
                
            total_batch_count += 1 # Increment number of finished batches
            running_loss += loss.item() # Add batch loss to current epoch loss

        running_loss /= total_batch_count
            
        if phase == 'train':
            train_losses[0].append(running_loss)
            print(f"Epoch {epoch+1}/{epochs} Training Loss: {running_loss}")
        else:
            test_losses[0].append(running_loss)
            print(f"Epoch {epoch+1}/{epochs} Test Loss: {running_loss}")
                  
# ---------- Get performance metrics for this fold ----------
correct = 0
incorrect = 0
total = 0
    
model.eval()
with torch.no_grad():
    for batch_index, batch_data in enumerate(testloader):
        # Get the inputs; data is a list of [images, labels]
        images, labels = batch_data
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        predictions = model(images)

        # Set total and correct
        _, predicted = torch.max(input=predictions, dim=1) # Get list of predicted classes
            
        # Get how many classes there were in this batch
        total += labels.size(0)
            
        # Get true positive, true negative, false positive, and false negative counts
        for index in range(len(labels)):
            correct = correct+1 if (predicted[index] == labels[index]) else correct
            incorrect = incorrect+1 if (predicted[index] != labels[index]) else incorrect

print(f"correct: {correct}, incorrect: {incorrect}, total: {total}")
    
# Get evaluation metrics
accuracy = correct/total if total != 0 else 0
print(f"Accuracy: {accuracy}")
results_accuracy.append(accuracy)

### Save Model

In [None]:
# Print tensor sizes per layer in model
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

In [None]:
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

In [None]:
torch.save(model.state_dict(), 'weights/lbfcnn_weights.h5')

### Evaluate Model

In [None]:
# Plot train losses
ax = plt.axes()
for fold_losses in train_losses:
    plt.plot(fold_losses)
plt.title('Training Loss Evaluation')
plt.xlabel('Epoch')
plt.ylabel('Magnitude')
plt.show()

In [None]:
# Plot test losses
ax = plt.axes()
for fold_losses in test_losses:
    plt.plot(fold_losses)
plt.title('Test Loss Evaluation')
plt.xlabel('Epoch')
plt.ylabel('Magnitude')
plt.show()

In [None]:
# Get performance metrics by calculcating average of metrics across all folds
print("Final Performance Metrics")
print(f"Accuracy: {np.mean(results_accuracy)}")

### Show Predictions on Sample Images

In [None]:
sample_img_names = ["0000-0-A.tif", "0163-0-P.tif", "0198-0-A.tif"]
sample_img_paths = [(data_dir + "/" + sample_img_names[i]) for i in range(len(sample_img_names))]
sample_classes = [sample_img_names[i][5] for i in range(len(sample_img_names))]
sample_ground_truths = ["No Metastasis" if sample_classes[i] == 0 else "Metastasis" for i in range(len(sample_img_names))]
sample_imgs = [Image.open(img_path).convert('RGB') for img_path in sample_img_paths]
sample_imgs_show = [Image.open(img_path) for img_path in sample_img_paths]

In [None]:
validation_batch = torch.stack([preprocess(img).to(device) for img in sample_imgs])

In [None]:
sample_preds= model(validation_batch).detach().cpu().data.numpy()
sample_preds

In [None]:
fig, axs = plt.subplots(1, len(sample_imgs_show), figsize=(20, 5))
for i, img in enumerate(sample_imgs_show):
    ax = axs[i]
    ax.axis('off')
    ax.set_title("Prediction: {:.0f}% No Metastasis, {:.0f}% Metastasis \n Ground Truth: {}"
                 .format(100*sample_preds[i,0], 100*sample_preds[i,1], sample_ground_truths[i]))
    ax.imshow(img)

In [None]:
# NOTE: Due to lack of library support for CUDA, this will error if using CUDA
# Get computational complexity
model = model.to("cpu")
stat(model, (3, 646, 220))