## Import

In [10]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import medmnist
from medmnist import INFO, Evaluator

## Visualisation
`dataset`: gives description of dataset\
`dataset[n][0]` the image of patient n\
`dataset[n][1]` binary classification class: `[0]` = normal, `[1]` = pneumonia

In this section, we import the 'test' section of the dataset for visualisation. **Note:** None of these variables used after this block.

In [None]:
from medmnist import PneumoniaMNIST
dataset = PneumoniaMNIST(split='test',download=False)

In [None]:
normal=0
pneumonia=0
for image in dataset: 
    if image[1] == [0]: 
        normal+=1
    elif image[1] == [1]: 
        pneumonia+=1

plt.bar(('normal','pneumonia'),(normal,pneumonia))
plt.title('Number of pneumonia-positive patients in dataset')
plt.show()

plt.imshow(dataset[np.random.randint(len(dataset))][0],cmap='gray') #show random patient

## Dataset Variables

`INFO` method of `medmnist` provides all information about the dataset, type=dictionary. We extract some of the information from this for later use in the training process. 

Hyperparameters -- crucial to control for good results:
1. `NUM_EPOCHS`: number of times the neural network is trained on the entire dataset.
2. `BATCH_SIZE`: number of images before parameters of the NN are updates.
3. `lr`: learning rate, controls how much the network's parameters are adjusted based on the errors during training.

In [11]:
data_flag = 'pneumoniamnist'
download = False

NUM_EPOCHS = 3 #no. of times the NN is trained on the entire dataset
BATCH_SIZE = 32 #no. of images before parameters are updated
lr = 0.001 

info = INFO[data_flag]
task = info['task'] #binary classification or multi-classification?
n_channels = info['n_channels'] #colour channels
n_classes = len(info['label']) #number of classes

DataClass = getattr(medmnist, info['python_class'])

## Preprocessing
1. Define data transformer to transform PIL images into a tensor format, and apply a normalisation for better performance. 
2. Use the `DataClass` method defined above to create training and test dataset.
3. `DataLoader`: Split data into batches and shuffle to avoid overfitting.

In [17]:
#1. 
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# 2. 
train_dataset = DataClass(split='train', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# 3. 
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

TypeError: 'DataLoader' object is not subscriptable

## CNN Structure

First, define layers of CNN. Sequential Structure. 

1. `nn.Conv2d:` Applies 2D convolutions to extract features from images.
2. `nn.BatchNorm2d:` Normalizes input to stabilize training.
3. `nn.ReLU:` Type of activation function: introduces non-linearity for better learning.
4. `nn.MaxPool2d:` Downsamples feature maps to reduce parameters and computational cost.

Class `Net` inherits all functions from PyTorch `nn.module`. 

Next, define the optimiser and loss function.
1. **Loss Function**: A loss function measures the difference between the model's predictions and the actual labels. Common choice is Cross-Entropy Loss -- penalizes the model for making incorrect predictions, considering the probability distribution of the predicted classes. 
2. **Optimiser:**  updates the model's internal parameters (weights and biases) based on the calculated loss.

In [13]:
class Net(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Net, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU())

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.layer3 = nn.Sequential(
            nn.Conv2d(16, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU())
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.layer5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.fc = nn.Sequential(
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = Net(in_channels=n_channels, num_classes=n_classes) #build model
    
# define loss function and optimizer
criterion = nn.CrossEntropyLoss()
    
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

## Training
*How does the training process work?*
1. Model takes a batch of images and their corresponding labels.
2. It predicts class probabilities for each image.
3. The loss function calculates the difference between these predictions and the true labels.
4. The gradients of the loss function with respect to each model parameter are computed (backpropagation).
5. The chosen optimizer uses these gradients to update the model's parameters in a direction that reduces the loss.
6. This process repeats for each batch of data over multiple epochs (complete passes through the entire dataset).
*Code*: 
For each epoch: 
1. Set model to traning mode
2. Refer back to `train_loader` variable: this is the training dataset split into a number of each batches. Code loops through each batch. `inputs` = input images, `targets` = image label
3. For each batch, set optimiser to zero. Feed inputs through the network, and evaluate the loss function. Perform back propagation then update parameters.

In [19]:
for epoch in range(NUM_EPOCHS):
    model.train() #set model to 'training mode'
    for inputs, targets in tqdm(train_loader): #tqdm is progress bar. 
        # forward + backward + optimize
        optimizer.zero_grad() # reset gradients
        outputs = model(inputs) # send inputs through the network, calculate output label
        targets = targets.squeeze().long()
        loss = criterion(outputs, targets) #compare predicted outputs to targets using loss function
        loss.backward() #back propagation 
        optimizer.step() #update parameters

100%|██████████| 148/148 [00:02<00:00, 52.76it/s]
100%|██████████| 148/148 [00:02<00:00, 53.76it/s]
100%|██████████| 148/148 [00:02<00:00, 56.70it/s]


## Quality Evaluation

In [20]:
def test(split):
    model.eval() #set model to 'evaluation' mode
    y_true = torch.tensor([])
    y_score = torch.tensor([])
    
    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad(): #deactivate gradient for efficiency.
        for inputs, targets in data_loader: #loop through batch
            outputs = model(inputs) #prediction
            
            #process target and outputs
            targets = targets.squeeze().long()
            outputs = outputs.softmax(dim=-1)
            targets = targets.float().resize_(len(targets), 1)
            
            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.numpy()
        y_score = y_score.detach().numpy()
        
        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)
    
        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))

        
test('test')

test  auc: 0.966  acc:0.816
