Classification of Diseased and Healthy 3D Coronary Artery Shapes Using MONAI (Minimal Reproducible Example)

installation

In [3]:
!pip install monai
!pip install MedShapeNetCore

Collecting monai
  Downloading monai-1.3.0-202310121228-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: monai
Successfully installed monai-1.3.0
Collecting MedShapeNetCore
  Downloading MedShapeNetCore-0.1.0-py3-none-any.whl (6.1 kB)
Collecting fire (from MedShapeNetCore)
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting trimesh (from MedShapeNetCore)
  Downloading trimesh-4.0.9-py3-none-any.whl (689 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m689.0/689.0 kB[0m [31m24.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting SimpleITK (from MedShapeNetCore)
  Downloading SimpleITK-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
[2K     [9

download the dataset

In [4]:
!python -m MedShapeNetCore download ASOCA

downloading...
[################################] 42842/42842 - 00:00:12
download complete...
file directory: ./medshapenetcore_npz/medshapenetcore_ASOCA.npz


import necessay packages

In [5]:
import os
import sys
import monai
import torch
import numpy as np
from MedShapeNetCore.MedShapeNetCore import MyDict,MSNLoader,MSNVisualizer,MSNSaver,MSNTransformer
pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('using:',device)



using: cuda


load and prepare the shape data

In [7]:
msn_loader=MSNLoader()
ASOCA_DATA=msn_loader.load('ASOCA')
shape_data=ASOCA_DATA['mask']
shape_labels=ASOCA_DATA['labels']
print(shape_data.shape)
print(shape_labels)
labels = torch.nn.functional.one_hot(torch.as_tensor(shape_labels).to(torch.int64)).float() #  one hot encoding
print(labels.shape)

current dataset: ./medshapenetcore_npz/medshapenetcore_ASOCA.npz
available keys in the dataset: ['mask', 'point', 'mesh', 'labels']
(40, 256, 256, 256)
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
torch.Size([40, 2])


define the optimizer, loss function and a classification model based on DenseNet

In [8]:
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

train the model for 200 epochs

In [None]:
max_epochs=200
torch.cuda.empty_cache()
import gc
gc.collect()
for epoch in range(max_epochs):
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    for i in range(len(shape_labels)):
        torch.cuda.empty_cache()
        inputs = torch.tensor(np.expand_dims(np.expand_dims(shape_data[i],axis=0),axis=0),dtype=torch.float32).to(device)
        labels = torch.tensor(np.expand_dims(shape_labels[i],axis=0),dtype=torch.float32).type(torch.LongTensor).to(device)
        torch.cuda.empty_cache()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = 20*loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"train_loss: {loss.item():.4f}")