# Classification of Diseased and Healthy 3D Coronary Artery Shapes Using MONAI (PyTorch) (Minimal Reproducible Example)
NOTE:   
- The dataset is very small (40 shapes), the model is made to be trained in +- 20 min on a CPU, thus results are not expected to be good.<br>
- The Notebook can however be used to create a better classifier on larger datasets.<br>
- You can download entire large datasets soon on COSCINE, or create your own with the 'search_and_download_by_name' method, see the PointCloudCompletor_AEmodel, or GettingStarted in the [Samples](https://github.com/GLARKI/MedShapeNet2.0/tree/main/Samples)

## Tested on
- Windows 11, VScode (Jupyter Notebook), GeForce RTX 3050, (CUDA 12.3 - guess not important), environment below.
- A node of the Work Cluster (IKIM - Essen -> KITE project): Dockerized/Kubernetes environment CPU: min(8) max(32), Memory Gi min(16) max(128), GPU: single NVIDIA MIG 20GB.

## Steps (do step 1 manually)
1. Create the virtual environment and activate it within this notebook.
    - conda create -n your_env python=3.10 ipykernel.
    - Install the Pytorch Version, required for your system.
    - Select the env for this ipynb (in vscode the upper right corner 'select kernel' or cntrl+p >python:select interpreter) - You may need to restart the notebook/vscode.
2. PIP installs
3. Imports
4. Search the datasets, and download/inspect the ASOCA dataset.
5. Create labels and numpy masks for training purposes, visualize shapes
6. Create Model
7. Prep data for training (split in validation and training set)
8. Compile and Train the model
9. Run inference
<br><br>
10. Cleaning (manually)
    - Remove the environment: conda remove --name your_env --all
    - Optional: Remove the folders 'model_and_plots', 'msn_downloads'

# Step 2 & 3. PIP installs and Imports

### PIP installs

In [1]:
# install additional requirements
# Note do manually: conda create -n your_env python=3.10 ipykernel # then activate the kernel in the jypnb

# medshapenet to interact with the database
!pip install MedShapeNet # Version  0.1.20

# For notebook functionality
#!pip install jupyter # Version 1.1.1

# for plotting purposes
!pip install matplotlib # Version 3.9.2
!pip install plotly # Version 5.24.1

!pip install monai

^C


In [None]:
!pip list

Package              Version
-------------------- -----------
appnope              0.1.2
argon2-cffi          23.1.0
argon2-cffi-bindings 21.2.0
asttokens            2.0.5
certifi              2024.8.30
cffi                 1.17.1
charset-normalizer   3.3.2
comm                 0.2.1
contourpy            1.3.0
cycler               0.12.1
debugpy              1.6.7
decorator            5.1.1
executing            0.8.3
filelock             3.16.1
fire                 0.7.0
fonttools            4.54.1
fsspec               2024.9.0
idna                 3.10
ipykernel            6.28.0
ipython              8.27.0
jedi                 0.19.1
Jinja2               3.1.4
jupyter_client       8.6.0
jupyter_core         5.7.2
kiwisolver           1.4.7
MarkupSafe           2.1.5
matplotlib           3.9.2
matplotlib-inline    0.1.6
MedShapeNet          0.1.25
minio                7.2.9
monai                1.3.2
mpmath               1.3.0
nest-asyncio         1.6.0
networkx             3.3
numpy 

### Imports

In [1]:
# Imports

# Connection with database MedShapeNet and transformations of 3D files
from MedShapeNet import MedShapeNet as msn
from MedShapeNet import Transformations

# Handle json files (labels)
import json

# Handle paths and operating systems
from pathlib import Path
import os 

# Handle numpy arrays
import numpy as np

#
import torch

# To get a Model
import monai

# for progress bars
from tqdm import tqdm

# load the STL files
from stl import mesh

import plotly.graph_objects as go
import matplotlib.pyplot as plt

# 
from tqdm import tqdm




        This message only displays once when importing MedShapeNet for the first time.

        MedShapeNet API is under construction, more functionality will come soon!

        For information use MedShapeNet.msn_help().
        Alternatively, check the GitHub Page: https://github.com/GLARKI/MedShapeNet2.0

        PLEASE CITE US If you used MedShapeNet API for your (research) project:
        
        @article{li2023medshapenet,
        title={MedShapeNet--A Large-Scale Dataset of 3D Medical Shapes for Computer Vision},
        author={Li, Jianning and Pepe, Antonio and Gsaxner, Christina and Luijten, Gijs and Jin, Yuan and Ambigapathy, Narmada, and others},
        journal={arXiv preprint arXiv:2308.16139},
        year={2023}
        }

        PLEASE USE the def dataset_info(self, bucket_name: str) to find the proper citation alongside MedShapeNet when utilizing a dataset for your resarch project.
        


In [5]:
# Create a MedShapeNet object
msn_instance = msn()

# Print a list of the datasets available
list_of_datasets = msn_instance.datasets(True)

Connection to MinIO server successful.

Download directory already exists at: /Users/jana/Desktop/MedShapeNet2.0-main/Samples/ClassificationWIthMonaiAndTensorflow/msn_downloads
________________________________________________________________________________
1. medshapenetcore/3DTeethSeg
2. medshapenetcore/ASOCA
3. medshapenetcore/AVT
4. medshapenetcore/AutoImplantCraniotomy
5. medshapenetcore/CoronaryArteries
6. medshapenetcore/FLARE
7. medshapenetcore/FaceVR
8. medshapenetcore/KITS
9. medshapenetcore/PULMONARY
10. medshapenetcore/SurgicalInstruments
11. medshapenetcore/ThoracicAorta_Saitta
12. medshapenetcore/ToothFairy
________________________________________________________________________________


# Step 4. Search the datasets, and download/inspect the ASOCA dataset.

### List available datasets, download the ASOCA dataset (part of medshapenetcore), display dataset info (Citation and Licence)

In [7]:
# Create a MedShapeNet object
msn_instance = msn()

# Print a list of the datasets available
list_of_datasets = msn_instance.datasets(True)

# Download the ASOCA dataset as STL (original)
msn_instance.download_dataset(dataset_name='medshapenetcore/ASOCA', download_dir=None, num_threads=10, print_output=False)

# NOTE: Licence and citations are printed automatically when the dataset is downloaded, but can also be requested.
msn_instance.dataset_info('medshapenetcore/ASOCA')

Connection to MinIO server successful.

Download directory already exists at: /Users/jana/Desktop/MedShapeNet2.0-main/Samples/ClassificationWIthMonaiAndTensorflow/msn_downloads
________________________________________________________________________________
1. medshapenetcore/3DTeethSeg
2. medshapenetcore/ASOCA
3. medshapenetcore/AVT
4. medshapenetcore/AutoImplantCraniotomy
5. medshapenetcore/CoronaryArteries
6. medshapenetcore/FLARE
7. medshapenetcore/FaceVR
8. medshapenetcore/KITS
9. medshapenetcore/PULMONARY
10. medshapenetcore/SurgicalInstruments
11. medshapenetcore/ThoracicAorta_Saitta
12. medshapenetcore/ToothFairy
________________________________________________________________________________


Downloading files:   0%|          | 0/43 [00:03<?, ?it/s]


### Load a list of the files within the dataset

In [6]:
# Store all SLT files in a list
msn_instance = msn()
stl_files = msn_instance.dataset_files('medshapenetcore/ASOCA', file_extension='.stl', print_output=True)

Connection to MinIO server successful.

Download directory already exists at: /Users/jana/Desktop/MedShapeNet2.0-main/Samples/ClassificationWIthMonaiAndTensorflow/msn_downloads
________________________________________________________________________________
Files and overview of dataset: medshapenetcore/ ASOCA
 
File: ASOCA/0_CoronaryArtery.stl
File: ASOCA/10_CoronaryArtery.stl
File: ASOCA/11_CoronaryArtery.stl
File: ASOCA/12_CoronaryArtery.stl
File: ASOCA/13_CoronaryArtery.stl
File: ASOCA/14_CoronaryArtery.stl
File: ASOCA/15_CoronaryArtery.stl
File: ASOCA/16_CoronaryArtery.stl
File: ASOCA/17_CoronaryArtery.stl
File: ASOCA/18_CoronaryArtery.stl
File: ASOCA/19_CoronaryArtery.stl
File: ASOCA/1_CoronaryArtery.stl
File: ASOCA/20_CoronaryArtery.stl
File: ASOCA/21_CoronaryArtery.stl
File: ASOCA/22_CoronaryArtery.stl
File: ASOCA/23_CoronaryArtery.stl
File: ASOCA/24_CoronaryArtery.stl
File: ASOCA/25_CoronaryArtery.stl
File: ASOCA/26_CoronaryArtery.stl
File: ASOCA/27_CoronaryArtery.stl
File: AS

# Step 5. Create labels and numpy masks for training purposes, visualize shapes

### Load labels

In [7]:
# Load the JSON data into a Python list (of dictionaries)
path_labels = msn_instance.download_dir / Path('ASOCA') / Path('asoca_labels.json')

try:
    with open(path_labels, 'r') as file:
        labels_data = json.load(file)
        print("Keys in the JSON data:", labels_data[0].keys())
except FileNotFoundError:
    print("Error: The specified labels file does not exist. Check the file path and if there is a json file inside the dataset, e.g. msn_instance.dataset_info(dataset_name)")
except Exception as e:
    print(f"An unexpected error occurred: {e}")

# Inspect the labels
print(f'\nTotal entries: {len(labels_data)} Example of content for first entry: {labels_data[0]}')

Keys in the JSON data: dict_keys(['shape_id', 'label', 'disease'])

Total entries: 40 Example of content for first entry: {'shape_id': 0, 'label': 1, 'disease': 'Pathological'}


In [8]:
bool_create_labels_and_stacked_numpy_masks = True
# create msn_tranform_instance from class Transformation to transform STL files to binary masks
msn_transform_instance = Transformations()

# Path to STL files
path_stls=msn_instance.download_dir / Path('ASOCA') 

### Convert STL dataset to binary masks (Can be skipped)
- But this automates the selection/preperation of test-, validation-, training- sets later.
- The other cells "### Create binary masks and labels for training" is multithreaded conversion of the STL to Binary mask with MedShapenet.Transformations().dataset_to_binary_masks()

In [None]:
# You can skip this: 
# To not run the unneeded cells and make the notebook runnable in one go.
bool_create_labels_and_stacked_numpy_masks = True

# create msn_tranform_instance from class Transformation to transform STL files to binary masks
msn_transform_instance = Transformations()

# Path to STL files
path_stls=msn_instance.download_dir / Path('ASOCA') 

# create and save binary mask for each stl and store it in msn_downloads/asoca
with tqdm(total=len(list(path_stls.glob("*.stl"))), desc="Processing STL files", unit="files") as pbar:
    for stl_file in path_stls.glob("*.stl"):
        msn_transform_instance.stl_to_binary_mask(stl_file, save_npz=True)
        pbar.update(1)

Processing STL files: 100%|██████████| 40/40 [01:18<00:00,  1.95s/files]


In [9]:
# You can skip this: The step 'Create binary masks and labels for training the classifier' does it all at once and multithreaded.

# Create binar mask as a one larg numpy, create labels -> for training

# define path to mask
path_masks = path_stls / Path('binary_masks')

# create list for binary_masks and labels
binary_masks = []
labels = []
shape_ids = []

# from the json file (loaded as labels_data in earlier step) -> map shape_id to label
shape_id_to_label = {entry['shape_id']: entry['label'] for entry in labels_data}

# loop over all paths
for path_mask in path_masks.glob('*.npz'):
    # # add mask to binary_masks_data
    mask = np.load(path_mask)
    binary_masks.append(mask['binary_mask'])

    # get corresponding label and add to labels
    shape_id = int(path_mask.stem.split('_')[0])
    shape_ids.append(shape_id) # if we want to call a specific shape and id later.
    label = shape_id_to_label[shape_id]
    labels.append(label)

# create the stacked binary mask numpy
binary_masks_data = np.stack(binary_masks, axis=0)

# print overview of results
print("Shape of the combined binary masks: ", binary_masks_data.shape, " , with ", len(labels),' labels.')


Shape of the combined binary masks:  (40, 256, 256, 256)  , with  40  labels.


In [10]:
# create indices for the test / validation / training sets

# Total number of labels
total_indices = len(labels)

# Calculate sizes for each set
validation_size = round(total_indices * 0.10)  # 10% for validation
training_size = total_indices - validation_size  # 90% Remaining for training

# Get the indices
indices = list(range(total_indices))

# Shuffle indices to randomize the selection (optional)
np.random.seed(7)
np.random.shuffle(indices)

# Split indices
validation_indices = indices[:validation_size]  # First 10% for validation
training_indices = indices[validation_size:]    # Remaining 90% for training

# Print the indices and corresponding labels
print("Validation indices:", validation_indices, "with labels:", [labels[i] for i in validation_indices])
print("Training indices: ", training_indices, ", \nwith labels:", [labels[i] for i in training_indices],'\n')
print(f'All labels: {labels} \nLabels -> 1: Pathological | 0: healthy')


Validation indices: [17, 37, 34, 18] with labels: [0, 0, 1, 0]
Training indices:  [32, 1, 22, 2, 9, 36, 29, 21, 13, 27, 5, 15, 20, 24, 11, 0, 12, 30, 16, 39, 6, 7, 31, 26, 10, 33, 38, 8, 35, 14, 28, 23, 19, 3, 25, 4] , 
with labels: [1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1] 

All labels: [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0] 
Labels -> 1: Pathological | 0: healthy


### Create binary masks and labels for training 
- *(multithreaded approach using MedShapeNet's Transformations().dataset_to_binary_masks() method)*
- Should be skipped if the cells from 'Convert STL dataset to binary masks' were executed.

In [11]:
if not bool_create_labels_and_stacked_numpy_masks:
    # create one large numpy binary mask representing all shapes, and a list of their names in correct order
    msn_transform_instance = Transformations()
    binary_masks_data, file_names_binary_masks = msn_transform_instance.dataset_to_binary_masks(dataset='medshapenetcore/ASOCA', num_threads=5, grid_size=(256, 256, 256))

    # print shape of masks and lenght of file names for confirmation
    print("\nShape of the combined binary masks: ", binary_masks_data.shape, " , with length of file names: ", len(file_names_binary_masks),'\n')

    # create labels that correspond to the binary masks by mapping shape_id to label
    shape_id_to_label = {entry['shape_id']: entry['label'] for entry in labels_data}

    # Initialize an empty list labels and order these to match entries of binary_masks
    labels =[]
    for filename in file_names_binary_masks:
        # Extract shape_id from the filename
        shape_id = int(filename.split('_')[0])  

        # Get the corresponding label, defaulting to None if shape_id not found
        label = shape_id_to_label.get(shape_id, None)
        labels.append(label)

    # print the labels
    print(f'labels in corresponding order: {labels} \n1: Pathological | 0: healthy')
else:
    pass


In [12]:
# to manually check and make indices for validation- and training- set.
if not bool_create_labels_and_stacked_numpy_masks:
    # Check which labels to use for train/validation/test set. 
    print(labels) 
    print(labels[31:35]) # Validation
    print(labels[10:14]) # Test
else:
    pass

### Visualation of 3D shapes (binary mask / STL)

In [13]:
# function to plot a binary mask as a pointcloud
def plot_binary_mask_scatter(binary_mask):
    """Plots a binary mask as a 3D scatter plot using Plotly."""

    # Get coordinates of the 1s in the binary mask
    x, y, z = np.where(binary_mask)

    # Create and show the scatter plot
    go.Figure(data=go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=2, color='rgba(138, 0, 0, 0.8)', opacity=0.8))).update_layout(
        scene=dict(xaxis_title='X Axis', yaxis_title='Y Axis', zaxis_title='Z Axis', aspectmode='data'),
        title='3D Scatter Plot of Binary Mask', width=800, height=800
    ).show()

# function to plot an stl as pointcloud
def plot_stl_file(stl_file):
    """Plots the vertices of an STL file as a 3D scatter plot."""

    # Load the STL file and get vertices
    vertices = mesh.Mesh.from_file(stl_file).vectors.reshape(-1, 3)

    # Create and show the scatter plot
    go.Figure(data=go.Scatter3d(x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2],
                                  mode='markers', marker=dict(size=2, color='rgba(138, 0, 0, 0.8)', opacity=0.8))
              ).update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
                              title='3D Scatter Plot of STL File', width=800, height=800).show()

# Plot a shape from the binary masks
idx = 0
plot_binary_mask_scatter(binary_masks_data[idx])

# Plot the same shape from the downloaded STL files
stl_path = Path('msn_downloads') / Path('ASOCA') /  Path(f'{shape_ids[idx]}_CoronaryArtery.stl')
plot_stl_file(stl_path)


# Step 6. Create Model

In [14]:
model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2)

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

### Check is we use CPU or GPU

In [15]:
# Check if CUDA is available
if torch.cuda.is_available():
    device = 'cuda:0'
    print("CUDA Built:", torch.cuda.is_available())
else:
    device = 'cpu'
    print('Cuda is not available.')

# Step 7. Prep data for training (split in validation and training set)

In [16]:
# Split the dataset 10% for testing, 10% for validation, 90% for training
if bool_create_labels_and_stacked_numpy_masks:
    # Create validation set from specified indices
    x_val = np.array([binary_masks_data[i] for i in validation_indices])
    y_val = np.array([labels[i] for i in validation_indices])

    # Create training set from specified indices
    x_train = np.array([binary_masks_data[i] for i in training_indices])
    y_train = np.array([labels[i] for i in training_indices])
    
else:
    print('You have to check manually which labels fit best from earlier cells - alternatively you can just go with this')
    # 40 samples, so we take 4 (10%) samples for validation.
    x_val = np.array(binary_masks_data[31:35]) # shapes
    y_val = np.array(labels[31:35]) # labels

    # 40 samples, so we take 36 (90%) samples for training.
    # Create the training set by excluding indices 31 to 34
    x_train = np.concatenate((binary_masks_data[:10], binary_masks_data[14:31], binary_masks_data[35:]), axis=0)  # Shapes
    y_train = np.concatenate((labels[:10], labels[14:31], labels[35:]), axis=0)

# print
print(f'\nShape of the validation set: {x_val.shape}')
print(f'Shape of the training set: {x_train.shape}')


Shape of the validation set: (4, 256, 256, 256)
Shape of the training set: (36, 256, 256, 256)


# Step 8. Compile and Train the model

In [21]:
max_epochs = 1

for epoch in range(max_epochs):
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    train_index = np.random.permutation(range(len(x_train)))
    for i in train_index:                                    
        torch.cuda.empty_cache()
        inputs = torch.tensor(np.expand_dims(np.expand_dims(x_train[i],axis=0),axis=0),dtype=torch.float32).to(device)
        labels = torch.tensor(np.expand_dims(y_train[i],axis=0),dtype=torch.float32).type(torch.LongTensor).to(device)
        torch.cuda.empty_cache()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
        print(f"train_loss: {loss.item():.4f}",end='\r')

epoch 1/1
train_loss: 0.4331

# Step 9. Run inference

In [20]:
model.eval()
with torch.no_grad():
    for i in range(len(x_val)):
        torch.cuda.empty_cache()
        inputs = torch.tensor(np.expand_dims(np.expand_dims(x_val[i],axis=0),axis=0),dtype=torch.float32).to(device)
        labels = torch.tensor(np.expand_dims(y_val[i],axis=0),dtype=torch.float32).type(torch.LongTensor).to(device)
        torch.cuda.empty_cache()
        outputs = torch.argmin(torch.nn.LogSoftmax(dim=1)(model(inputs)))

        print(f"Prediction: {outputs}, Label: {y_val[i]}")

torch.Size([1, 2])
Prediction: 0, Label: 0
torch.Size([1, 2])
Prediction: 0, Label: 0
torch.Size([1, 2])
Prediction: 0, Label: 1
torch.Size([1, 2])
Prediction: 0, Label: 0


# Step 10 Cleaning (manually)
- Remove the environment: conda remove --name your_env --all
- Optional: Remove the folders 'model_and_plots', 'msn_downloads'