
Classification of Diseased and Healthy 3D Coronary Artery Shapes Using Tenforflow (Minimal Reproducible Example)

In [None]:
!pip install MedShapeNetCore
!pip install tensorflow[and-cuda]

download the dataset

In [None]:
!python -m MedShapeNetCore download ASOCA

import the necessary packages

In [4]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from MedShapeNetCore.MedShapeNetCore import MyDict,MSNLoader,MSNVisualizer,MSNSaver,MSNTransformer

load and prepare the dataset

In [9]:
msn_loader=MSNLoader()
ASOCA_DATA=msn_loader.load('ASOCA')
shape_data=ASOCA_DATA['mask']
shape_labels=ASOCA_DATA['labels']
print(shape_data.shape)
print(shape_labels)
x_train=np.expand_dims(shape_data, axis=4)
y_train=shape_labels

current dataset: ./medshapenetcore_npz/medshapenetcore_ASOCA.npz
available keys in the dataset: ['mask', 'point', 'mesh', 'labels']
(40, 256, 256, 256)
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


construct a classification model

In [13]:
def get_model(width=256, height=256, depth=256):
    """Build a 3D convolutional neural network model."""

    inputs = keras.Input((width, height, depth, 1))

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu",padding='same',strides=(2, 2, 2))(inputs)
    #x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu",padding='same')(x)
    #x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu",padding='same',strides=(2, 2, 2))(x)
    #x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)


    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu",padding='same')(x)
    #x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)


    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu",padding='same')(x)
    #x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(units=64, activation="relu")(x)
    x = layers.Dropout(0.4)(x)

    outputs = layers.Dense(units=1, activation="sigmoid")(x)

    # Define the model.
    model = keras.Model(inputs, outputs, name="3dcnn")
    return model


# Build model.
model = get_model(width=256, height=256, depth=256)
model.summary()

Model: "3dcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 256, 256, 256,    0         
                             1)]                                 
                                                                 
 conv3d_10 (Conv3D)          (None, 128, 128, 128, 6   1792      
                             4)                                  
                                                                 
 batch_normalization_10 (Ba  (None, 128, 128, 128, 6   256       
 tchNormalization)           4)                                  
                                                                 
 conv3d_11 (Conv3D)          (None, 128, 128, 128, 6   110656    
                             4)                                  
                                                                 
 batch_normalization_11 (Ba  (None, 128, 128, 128, 6   256   

combine and train model

In [None]:
# compile model
initial_learning_rate = 0.0001
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
model.compile(
    loss="binary_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
    metrics=["acc"],
)
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    "3d_image_classification.h5", save_best_only=False
)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)
epochs = 200
# training

model.fit(
    x_train,
    y_train,
    validation_split=0.20,
    epochs=epochs,
    shuffle=True,
    verbose=1,
    callbacks=[checkpoint_cb, early_stopping_cb],
    #callbacks=[checkpoint_cb],
)