[![Licence](https://img.shields.io/badge/license-MIT-blue)](https://opensource.org/license/mit/)

# MONAI Core on AWS Workshop

Setup notebook environment using "PyTorch 1.12 Python 3.8 GPU optimized" Kernel with "g4dn.4xlarge" instance type. If you do not have enough quota limit to start the instance, please [follow the instructure](https://docs.aws.amazon.com/sagemaker/latest/dg/canvas-requesting-quota-increases.html) to raise quota limit for ml.g4dn.2xlarge notebook first.

This 3D classification model training notebook is based on this existing [DenseNet Training MONAI tutorial](https://github.com/Project-MONAI/tutorials/blob/main/3d_classification/densenet_training_array.ipynb)
## Download and install libraries

In [None]:
%%sh
pip install -q --upgrade pip
pip install -q --upgrade boto3 botocore
pip install -q tqdm nibabel pydicom numpy pathlib2 pylibjpeg-openjpeg joblib
pip install -q "itk>=5.3rc4" "itkwidgets[all]>=1.0a23" "itk-io" "monai-weekly[nibabel, matplotlib, tqdm]"

### Import Libraries and Step up Amazon HealthLake Imaging Client

In [None]:
import json
import logging 
import boto3
import io
import sys
import time
import os
import pandas as pd
import sagemaker
from sagemaker import get_execution_role
from botocore.exceptions import ClientError
import torch
import monai
from openjpeg import decode
from src.Api import MedicalImaging 
import warnings
warnings.filterwarnings('ignore')
logging.basicConfig( level="INFO" )
# logging.basicConfig( level="DEBUG" )


s3 = boto3.client('s3')
medicalimaging = MedicalImaging()

account_id = boto3.client("sts").get_caller_identity()["Account"]
session = sagemaker.session.Session()
region = boto3.Session().region_name
bucket = sagemaker.Session().default_bucket()
role = f"arn:aws:iam::{account_id}:role/HealthImagingImportJobRole"  ## use this role if you have deployed the CloudFormation template described above
# role = get_execution_role()                ## use this role if you want to use SageMaker Execution role to import image into AWS HealthImaging
print(f"S3 Bucket is {bucket}")
print(f"IAM role for image import job is {role}")

In [None]:
%store -r

## Prepare Labels

We are preparing for a binary classification annotations of the synthetic Brain MRIs. The positive label for Alzheimer's disease come from the Coherent clinical information.

In [None]:
train_imagesets = {}
val_imagesets = {}

alzheimers_studies = pd.read_csv('src/coherent_alzheimers_studyinstanceuid.csv')
positive_list = alzheimers_studies['studyinstanceuid'].to_list()
counter=0
for imagesetid in imageSetIds.keys():
    counter+=1
    json_dicom_header = medicalimaging.getMetadata(datastoreId, imagesetid)
    studyinstanceuid = json_dicom_header['Study']['DICOM']['StudyInstanceUID']
    if counter<200:  ## train:validation ratio as 2:1
        if studyinstanceuid in positive_list:
            train_imagesets[imagesetid] = torch.tensor([[0.,1.]])
        else:
            train_imagesets[imagesetid] = torch.tensor([[1.,0.]])
    else:
        if studyinstanceuid in positive_list:
            val_imagesets[imagesetid] = torch.tensor([[0.,1.]])
        else:
            val_imagesets[imagesetid] = torch.tensor([[1.,0.]])

## Prepare Tensors

Two helper functions are used to retrieve frame IDs for a given ImageSet in AWS HealthImaging datastore.

In [None]:
from joblib import Parallel, delayed
from monai.data import NumpyReader
import time


def getImageFrameIds(datastoreId, imagesetId):
    json_dicom_header = medicalimaging.getMetadata(datastoreId, imagesetId)
    frameIds = []
    for series in json_dicom_header["Study"]["Series"]:
        for instances in json_dicom_header["Study"]["Series"][series]["Instances"]:
            for frame in json_dicom_header["Study"]["Series"][series]["Instances"][instances]["ImageFrames"]:
                frameId = frame["ID"]
                frameIds.append(frameId)
    return {'imagesetId': imagesetId, 'frameIds': frameIds}


def getRescaledPixels(datastoreId,  imagesetId, frameId):
    pixel = medicalimaging.getFramePixels(datastoreId,  imagesetId, frameId)
    ## rescale to min 0 and max 1
    if pixel.min() < pixel.max():
        pixel -= pixel.min()
        pixel = (pixel / pixel.max())
    return pixel

Retrieve the imageset frameIds for both training and validation datasets

In [None]:
tic =time.time()
train_imagesets_frameids = Parallel(n_jobs=-1, backend='threading')(delayed(getImageFrameIds)(datastoreId, imagesetId) for imagesetId in list(train_imagesets.keys()) )
toc = time.time()
print(f"time to retrieve train set metadata: {toc - tic:0.4f} seconds")

tic =time.time()
val_imagesets_frameids = Parallel(n_jobs=-1, backend='threading')(delayed(getImageFrameIds)(datastoreId, imagesetId) for imagesetId in list(val_imagesets.keys()) )
toc = time.time()
print(f"time to retrieve validation set metadata: {toc - tic:0.4f} seconds")

In [None]:
from monai.data import NumpyReader
import numpy as np
image_reader = NumpyReader()

def getImageTensors(datastoreId, imagesetId, frameIds):
    pixels = Parallel(n_jobs=-1, backend='threading')(delayed(getRescaledPixels)(datastoreId,  imagesetId, f) for f in frameIds) 
    img_data, meta_data = image_reader.get_data(pixels)
    inputTensor = torch.tensor(np.expand_dims(img_data, 0)).float()
    return torch.unsqueeze(inputTensor,0)
    

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device is: {device}")
# Create DenseNet121, CrossEntropyLoss and Adam optimizer
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)

loss_function = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
max_epochs = 3

In [None]:
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    
    for te in train_imagesets_frameids:
        step += 1
        inputTensor = getImageTensors(datastoreId, te['imagesetId'], te['frameIds'])
        inputs, labels = inputTensor.to(device), train_imagesets[te['imagesetId']].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_imagesets) // 1
        print(f"{step}/{epoch_len}, train_loss: {epoch_loss:.4f}")

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()

        num_correct = 0.0
        metric_count = 0
        for ve in train_imagesets_frameids:
            inputTensor = inputTensor = getImageTensors(datastoreId, ve['imagesetId'], ve['frameIds'])
            val_images, val_labels = inputTensor.to(device), val_imagesets[ve['imagesetId']].to(device)
            with torch.no_grad():
                val_outputs = model(val_images)
                value = torch.eq(val_outputs.argmax(dim=1), val_labels.argmax(dim=1))
                metric_count += len(value)
                num_correct += value.sum().item()

        metric = num_correct / metric_count
        metric_values.append(metric)

        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), "best_metric_model_classification3d_array.pth")
            print("saved new best metric model")

        print(f"Current epoch: {epoch+1} current accuracy: {metric:.4f} ")
        print(f"Best accuracy: {best_metric:.4f} at epoch {best_metric_epoch}")

print(f"Training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")