<a href='https://ai.meng.duke.edu'> = <img align="left" style="padding-top:10px;" src=https://storage.googleapis.com/aipi_datasets/Duke-AIPI-Logo.png>

# 3D Classification of Medical Images
In this example notebook we will demonstrate how to develop a classification model which uses 3-dimensional medical image files as inputs.  We will use the open-source [MONAI framework](https://monai.io/index.html, a community-supported PyTorch-based framework for deep learning in healthcare imaging, which contains some nice functionality for working with 3D medical images.

Our objective in this exercise will be to try to classify MRI images into male/female.  The dateset we will be using is a very small subset of images from the [IXI Dataset](https://brain-development.org/ixi-dataset/) which contains ~600 MRI images from healthy subjects collected from three different hospitals in London.  The images are in the [NIFTI](https://radiopaedia.org/articles/nifti-file-format?lang=us) format.

**Notes:**
- This notebook should be run on GPU, but can also be run on CPU in ~10 minutes if using a small number of training epochs for demonstration purposes

**References:**
- This notebook is based on one of the [MONAI tutorials](https://github.com/Project-MONAI/tutorials).  Please see the other tutorials for additional examples of how to use the framework for various medical imaging tasks


In [None]:
# Run this cell only if working in Colab
# Connects to any needed files from GitHub and Google Drive
import os

# Remove Colab default sample_data
!rm -r ./sample_data

# Clone GitHub files to colab workspace
repo_name = "AIPI540-Deep-Learning-Applications" # Enter repo name
git_path = 'https://github.com/AIPI540/AIPI540-Deep-Learning-Applications.git'
!git clone "{git_path}"

# Install dependencies from requirements.txt file
!pip install -r "{os.path.join(repo_name,'requirements.txt')}"

# Change working directory to location of notebook
notebook_dir = '2_computer_vision/CNNs'
path_to_notebook = os.path.join(repo_name,notebook_dir)
%cd "{path_to_notebook}"
%ls

In [2]:
import urllib
import zipfile
import os

import matplotlib.pyplot as plt
import torch
import numpy as np

import monai
from monai.data import DataLoader, ImageDataset
from monai.transforms import AddChannel, Compose, RandRotate90, Resize, ScaleIntensity, EnsureType

pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)

torch:  1.10 ; cuda:  1.10.1


## Data preparation
### Download data
We are going to use a small subset of the [IXI Dataset](https://brain-development.org/ixi-dataset/).  Note that this data is made available under the Creative Commons [CC BY-SA 3.0 license](https://creativecommons.org/licenses/by-sa/3.0/legalcode).  Start by running the cell below to download the needed image files.

In [3]:
# Download the data
if not os.path.exists('./data'):
    os.mkdir('./data')
if not os.path.exists('./data/ixi'):
    url = 'https://storage.googleapis.com/aipi540-datasets/ixi.zip'
    destfile = 'data/ixi.zip'
    urllib.request.urlretrieve(url,filename=destfile)
    #Unzip file to path
    zip_ref = zipfile.ZipFile(destfile, 'r')
    zip_ref.extractall('data/')
    zip_ref.close()

### Add labels
Since we are working with a very small subsample we will manually input the labels.

In [4]:
images = [
        "IXI314-IOP-0889-T1.nii.gz",
        "IXI249-Guys-1072-T1.nii.gz",
        "IXI609-HH-2600-T1.nii.gz",
        "IXI173-HH-1590-T1.nii.gz",
        "IXI020-Guys-0700-T1.nii.gz",
        "IXI342-Guys-0909-T1.nii.gz",
        "IXI134-Guys-0780-T1.nii.gz",
        "IXI577-HH-2661-T1.nii.gz",
        "IXI066-Guys-0731-T1.nii.gz",
        "IXI130-HH-1528-T1.nii.gz",
        "IXI607-Guys-1097-T1.nii.gz",
        "IXI175-HH-1570-T1.nii.gz",
        "IXI385-HH-2078-T1.nii.gz",
        "IXI344-Guys-0905-T1.nii.gz",
        "IXI409-Guys-0960-T1.nii.gz",
        "IXI584-Guys-1129-T1.nii.gz",
        "IXI253-HH-1694-T1.nii.gz",
        "IXI092-HH-1436-T1.nii.gz",
        "IXI574-IOP-1156-T1.nii.gz",
        "IXI585-Guys-1130-T1.nii.gz",
    ]

datapath = os.path.join('data','ixi')
images = [os.sep.join([datapath, f]) for f in images]

# Create binary labels for man/woman classification
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)

### Create PyTorch Datasets from data and load DataLoaders
Rather than using PyTorch's base implementations of the Dataset and DataLoader, we will use MONAI's DataSet and DataLoader which are wrappers around their PyTorch equivalents which add some additional functionality.

In [11]:
# Define transforms
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), RandRotate90(), EnsureType()])
val_transforms = Compose([ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])

# Create training Dataset and DataLoader using first 10 images
train_ds = ImageDataset(image_files=images[:10], labels=labels[:10], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())

# Create validation Dataset and DataLoader using the rest of the images
val_ds = ImageDataset(image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available())

# Set up dict for dataloaders
dataloaders = {'train':train_loader,'val':val_loader}

# Store size of training and validation sets
dataset_sizes = {'train':len(train_ds),'val':len(val_ds)}

im, label = monai.utils.misc.first(train_loader)
print(f'Image type: {type(im)}')
print(f'Input batch shape: {im.shape}')
print(f'Label batch shape: {label.shape}')

Image type: <class 'torch.Tensor'>
Input batch shape: torch.Size([2, 1, 96, 96, 96])
Label batch shape: torch.Size([2])


As we see above, the 3D inputs in our DataLoader are of shape [N,C,H,W,D) where:  
- N = batch size  
- C = number of channels (1 in this case for grayscale)  
- H = image height  
- W = image width
- D = image depth

### Define our model architecture
We will used a pre-trained DenseNet 121 model for this task.

In [12]:
# Load a pre-trained DenseNet121
# We have a signle input channel, and we have 2 output classes
# We set spatial_dims=3 to indicate we want to use the version suitable for 3D input images
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

### Train the model

In [7]:
def train_model(model, criterion, optimizer, dataloaders, device, num_epochs=5):

    model = model.to(device) # Send model to GPU if available

    iter_num = {'train':0,'val':0} # Track total number of iterations

    best_metric = -1

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Get the input images and labels, and send to GPU if available
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the weight gradients
                optimizer.zero_grad()

                # Forward pass to get outputs and calculate loss
                # Track gradient only for training data
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backpropagation to get the gradients with respect to each weight
                    # Only if in train
                    if phase == 'train':
                        loss.backward()
                        # Update the weights
                        optimizer.step()

                # Convert loss into a scalar and add it to running_loss
                running_loss += loss.item() * inputs.size(0)
                # Track number of correct predictions
                running_corrects += torch.sum(preds == labels.data)

                # Iterate count of iterations
                iter_num[phase] += 1

            # Calculate and display average loss and accuracy for the epoch
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # Save weights if accuracy is best
            if phase=='val':
                if epoch_acc > best_metric:
                    best_metric = epoch_acc
                    if not os.path.exists('./models'):
                        os.mkdir('./models')
                    torch.save(model.state_dict,'models/3d_classification_model.pth')
                    print('Saved best new model')

    print(f'Training complete. Best validation set accuracy was {best_metric}')
    
    return

In [8]:
# Use cross-entropy loss function
criterion = torch.nn.CrossEntropyLoss()
# loss_function = torch.nn.BCEWithLogitsLoss()  # also works with this data

# Use Adam adaptive optimizer
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# Train the model
epochs=5
train_model(model, criterion, optimizer, dataloaders, device, num_epochs=epochs)

Epoch 0/4
----------
train Loss: 0.7231 Acc: 0.5000
val Loss: 0.6714 Acc: 0.6000
Saved best new model
Epoch 1/4
----------
train Loss: 0.5776 Acc: 0.7000
val Loss: 0.6721 Acc: 0.6000
Epoch 2/4
----------
train Loss: 0.6822 Acc: 0.7000
val Loss: 0.6755 Acc: 0.6000
Epoch 3/4
----------
train Loss: 0.5370 Acc: 0.7000
val Loss: 0.6817 Acc: 0.5000
Epoch 4/4
----------
train Loss: 0.5039 Acc: 0.8000
val Loss: 0.6553 Acc: 0.6000
Training complete. Best validation set accuracy was 0.6
