OrganMNIST3D Basic Rundown:

What is OrganMNIST3D?

OrganMNIST3D, or OM3D for short, is a collection of CT scans of 11 various human organs such as the liver, spleen, pancreas. These scans are stored as 28x28x28 3D images sliced along the three major medical axies of axial, coronal, sagittal as grey scale pixel images.

What is the medical issue within OM3D?

Our medical issue is to make an AI that can correctly determine what human organ is being displayed in each CT scan with high accuracy. This data set is one of few that use 3D cross sections of humans to train AI models due to the effort getting these scan models and the ethical issues concerning the aquistion. 

In [1]:


%matplotlib qt
import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # sometimes needed to register 3D
import numpy as np
from matplotlib import colors
from matplotlib.widgets import Slider




from medmnist import OrganMNIST3D

# Or NoduleMNIST3D, AdrenalMNIST3D, etc.
# Change these to match which dataset you've been assigned


train_dataset = OrganMNIST3D(split='train', size=28, download=True)
trainx = []
trainy = []

test_dataset = OrganMNIST3D(split='test', size=28, download=True)
testx = []
testy = []

val_dataset = OrganMNIST3D(split='train', size=28, download=True)
valx = []
valy = []

for i in range(len(train_dataset)):
    trainx.append(train_dataset[i][0])
    trainy.append(train_dataset[i][1])

for i in range(len(test_dataset)):
    testx.append(test_dataset[i][0])
    testy.append(test_dataset[i][1])

for i in range(len(val_dataset)):
    valx.append(val_dataset[i][0])
    valy.append(val_dataset[i][1])


trainx_tensor = tf.convert_to_tensor(trainx, dtype=tf.float16)
trainy_tensor = tf.convert_to_tensor(trainy, dtype=tf.float16)
testx_tensor = tf.convert_to_tensor(testx, dtype=tf.float16)
testy_tensor = tf.convert_to_tensor(testy, dtype=tf.float16)
valx_tensor = tf.convert_to_tensor(valx, dtype=tf.float16)
valy_tensor = tf.convert_to_tensor(valy, dtype=tf.float16)
# float16 doesn't run any faster on the 4090s, but it cuts memory usage in half!



In [2]:
def MyNet():
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(1, 28, 28, 28), name='input'), 
        tf.keras.layers.Conv3D(filters=32, kernel_size=3, input_shape=(1, 28, 28, 28), activation='relu', padding='same', name='conv1', data_format='channels_first'),      
        tf.keras.layers.MaxPool3D(pool_size=2, data_format='channels_first'),
        tf.keras.layers.Conv3D(64, kernel_size=3, padding='same', activation='relu', data_format='channels_first'),
        tf.keras.layers.MaxPool3D(pool_size=2, data_format='channels_first'),
        tf.keras.layers.Conv3D(128, kernel_size=3, padding='same', activation='relu', data_format='channels_first'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(256, activation='relu', name='dense1'),  
        tf.keras.layers.Dropout(0.4),
        tf.keras.layers.Dense(11, activation='softmax', name='dense2') 

    ])
    model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
    return model

In [3]:

model = MyNet()
model.summary()
training_history = model.fit(trainx_tensor, trainy_tensor, epochs=5, validation_data=(valx_tensor, valy_tensor))

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/5
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 3s/step - accuracy: 0.2019 - loss: 2.1611 - val_accuracy: 0.5386 - val_loss: 1.3781
Epoch 2/5
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 2s/step - accuracy: 0.5160 - loss: 1.2691 - val_accuracy: 0.8074 - val_loss: 0.7617
Epoch 3/5
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 2s/step - accuracy: 0.7405 - loss: 0.7657 - val_accuracy: 0.8702 - val_loss: 0.4959
Epoch 4/5
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 2s/step - accuracy: 0.8054 - loss: 0.6163 - val_accuracy: 0.8929 - val_loss: 0.4252
Epoch 5/5
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 2s/step - accuracy: 0.8960 - loss: 0.3783 - val_accuracy: 0.9104 - val_loss: 0.2697
