<a href='https://ai.meng.duke.edu'> = <img align="left" style="padding-top:10px;" src=https://storage.googleapis.com/aipi_datasets/Duke-AIPI-Logo.png>

# 3D Classification of Medical Images
In this example notebook we are demonstrate how to develop a classification model on 3-dimensional medical image files.  We will use the open-source [MONAI framework](https://monai.io/index.html) to help us.  MonAI is a community-supported PyTorch-based framework for deep learning in healthcare imaging.

Our objective in this exercise will be to try to classify MR images into male/female.  THe dateset we will be using is a very small subset of images from the [IXI Dataset](https://brain-development.org/ixi-dataset/) which contains ~600 MR images from healthy subjects collected from three different hospitals in London.  The images are in [NIFTI](https://radiopaedia.org/articles/nifti-file-format?lang=us) format.

**Notes:**
- This notebook must be run on GPU

**References:**
- This notebook is based on one of the [MONAI tutorials](https://github.com/Project-MONAI/tutorials).  Please see the other tutorials for additional examples of how to use the framework for various medical imaging tasks


In [None]:
# Run this cell only if working in Colab
# Connects to any needed files from GitHub and Google Drive
import os

# Remove Colab default sample_data
!rm -r ./sample_data

# Clone GitHub files to colab workspace
repo_name = "AIPI540-Deep-Learning-Applications" # Enter repo name
git_path = 'https://github.com/AIPI540/AIPI540-Deep-Learning-Applications.git'
!git clone "{git_path}"

# Install dependencies from requirements.txt file
#!pip install -r "{os.path.join(repo_name,'requirements.txt')}"

# Change working directory to location of notebook
notebook_dir = '2_computer_vision/CNNs'
path_to_notebook = os.path.join(repo_name,notebook_dir)
%cd "{path_to_notebook}"
%ls

In [3]:
import urllib
import zipfile

import matplotlib.pyplot as plt
import torch
import numpy as np

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import DataLoader, ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate90, Resize, ScaleIntensity, EnsureType

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)

torch:  1.10 ; cuda:  1.10.1


## Data preparation
### Download data
We are going to use a small subset of the [IXI Dataset](https://brain-development.org/ixi-dataset/).  Note that this data is made available under the Creative Commons [CC BY-SA 3.0 license](https://creativecommons.org/licenses/by-sa/3.0/legalcode).  Start by running the cell below to download the needed image files.

In [5]:
# Download the data
if not os.path.exists('./data'):
    os.mkdir('./data')
if not os.path.exists('./data/ixi'):
    url = 'https://storage.googleapis.com/aipi540-datasets/ixi.zip'
    destfile = 'data/ixi.zip'
    urllib.request.urlretrieve(url,filename=destfile)
    #Unzip file to path
    zip_ref = zipfile.ZipFile(destfile, 'r')
    zip_ref.extractall('data/')
    zip_ref.close()

### Add labels
Since we are working with a very small subsample we will manually input the labels.

In [6]:
images = [
        "IXI314-IOP-0889-T1.nii.gz",
        "IXI249-Guys-1072-T1.nii.gz",
        "IXI609-HH-2600-T1.nii.gz",
        "IXI173-HH-1590-T1.nii.gz",
        "IXI020-Guys-0700-T1.nii.gz",
        "IXI342-Guys-0909-T1.nii.gz",
        "IXI134-Guys-0780-T1.nii.gz",
        "IXI577-HH-2661-T1.nii.gz",
        "IXI066-Guys-0731-T1.nii.gz",
        "IXI130-HH-1528-T1.nii.gz",
        "IXI607-Guys-1097-T1.nii.gz",
        "IXI175-HH-1570-T1.nii.gz",
        "IXI385-HH-2078-T1.nii.gz",
        "IXI344-Guys-0905-T1.nii.gz",
        "IXI409-Guys-0960-T1.nii.gz",
        "IXI584-Guys-1129-T1.nii.gz",
        "IXI253-HH-1694-T1.nii.gz",
        "IXI092-HH-1436-T1.nii.gz",
        "IXI574-IOP-1156-T1.nii.gz",
        "IXI585-Guys-1130-T1.nii.gz",
    ]

datapath = os.path.join('data','ixi')
images = [os.sep.join([datapath, f]) for f in images]

# Create binary labels for man/woman classification
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)

### Create PyTorch Datasets from data and load DataLoaders

In [7]:
# Define transforms
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), EnsureType()])
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])

# Create training Dataset and DataLoader using first 10 images
train_ds = ImageDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

# Create validation Dataset and DataLoader using the rest of the images
val_ds = ImageDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

# Set up dict for dataloaders
dataloaders = {'train':train_loader,'val':val_loader}

# Store size of training and validation sets
dataset_sizes = {'train':len(train_ds),'val':len(val_ds)}

As we see above, the inputs in a PyTorch DataLoader are of shape [N,C,H,W,D) where:  
- N = batch size  
- C = number of channels (1 in this case for grayscale)  
- H = image height  
- W = image width
- D = image depth

### Define our model architecture
We will used a pre-trained ResNet18 model in this example.  However, we will need to make a couple changes to it:  
- Since we have 2 output classes instead of 1000 (the ImageNet default), we need to replace the final fully connected layer in the network with a new layer that has 2 output units (not 1000)  
- Since our images are black/white, we have a single input channel rather than 3.  Therefore we will replace the first layer in the network with a new convolutional layer that has `in_channels = 1`

In [8]:
# Load a pre-trained DenseNet121
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

### Train the model

In [None]:
def train_model(model, criterion, optimizer, dataloaders, device, num_epochs=5):

    model = model.to(device) # Send model to GPU if available

    iter_num = {'train':0,'val':0} # Track total number of iterations

    best_metric = -1

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Get the input images and labels, and send to GPU if available
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the weight gradients
                optimizer.zero_grad()

                # Forward pass to get outputs and calculate loss
                # Track gradient only for training data
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backpropagation to get the gradients with respect to each weight
                    # Only if in train
                    if phase == 'train':
                        loss.backward()
                        # Update the weights
                        optimizer.step()

                # Convert loss into a scalar and add it to running_loss
                running_loss += loss.item() * inputs.size(0)
                # Track number of correct predictions
                running_corrects += torch.sum(preds == labels.data)

                # Iterate count of iterations
                iter_num[phase] += 1

            # Calculate and display average loss and accuracy for the epoch
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # Save weights if accuracy is best
            if phase=='val':
                if epoch_acc > best_metric:
                    torch.save(model.state_dict,'models/3d_classification_model.pth')
                    print('Saved best new model')

    print(f'Training complete. Best validation set accuracy was {best_metric}')
    
    return

In [None]:
# Use cross-entropy loss function
criterion = torch.nn.CrossEntropyLoss()
# loss_function = torch.nn.BCEWithLogitsLoss()  # also works with this data

# Use Adam adaptive optimizer
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# Train the model
epochs=5
train_model(model, criterion, optimizer, dataloaders, device, num_epochs=epochs)