Preamble: Import Libraries, Download Dataset, Preprocess Data

In [None]:
%matplotlib qt
import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # sometimes needed to register 3D
import numpy as np
from matplotlib import colors
import h5py
from matplotlib.widgets import Slider




from medmnist import AdrenalMNIST3D

# Or NoduleMNIST3D, AdrenalMNIST3D, etc.
# Change these to match which dataset you've been assigned


train_dataset = AdrenalMNIST3D(split='train', size=28, download=True)
trainx = []
trainy = []

val_dataset = AdrenalMNIST3D(split='val', size=28, download=True)
valx = []
valy = []

for i in range(len(train_dataset)):
    trainx.append(train_dataset[i][0])
    trainy.append(train_dataset[i][1])

for i in range(len(val_dataset)):
    valx.append(val_dataset[i][0])
    valy.append(val_dataset[i][1])

#validation
trainx_tensor = tf.convert_to_tensor(trainx, dtype=tf.float16)
trainx_tensor = np.transpose(trainx_tensor, (0,2,3,4,1))

trainy_tensor = tf.convert_to_tensor(trainy, dtype=tf.float16)



#validation
valx_tensor = tf.convert_to_tensor(valx, dtype=tf.float16)
valx_tensor = np.transpose(valx_tensor, (0,2,3,4,1))

valy_tensor = tf.convert_to_tensor(valy, dtype=tf.float16)


# float16 doesn't run any faster on the 4090s, but it cuts memory usage in half!





2025-12-09 13:00:25.974628: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-12-09 13:00:37.664413: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
2025-12-09 13:00:37.667195: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:84] Allocation of 52157952 exceeds 10% of free system memory.


Preamble: Understand and Visualize Data

In [2]:
fig = plt.figure()

vol = np.squeeze(trainx[1], axis=0)     # shape (28, 28, 28)

ax = fig.add_subplot(111, projection='3d')

filled = vol > 0

# Create RGBA array
norm = colors.Normalize(vmin=vol.min(), vmax=vol.max())

# Pick a colormap: 'viridis', 'plasma', 'inferno', 'magma', etc.
cmap = plt.cm.viridis

# cmap(norm(vol)) gives an RGBA array of shape (28, 28, 28, 4)
facecolors = cmap(norm(vol))

# --- Use magnitude as alpha (0 → transparent, 1 → opaque) ---
# If vol is already in [0, 1], clip is fine; otherwise you can reuse norm(vol)
alpha = np.clip(vol, 0, 1)
facecolors[..., 3] = alpha

# Optionally: make truly empty voxels fully transparent (even if they exist in facecolors)
facecolors[~filled, 3] = 0.0

# --- Plot ---
ax.voxels(filled, facecolors=facecolors)

ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')
ax.set_zlabel('Z-axis')
plt.title('Voxel colors and transparency by magnitude')
plt.show()

In [3]:
# Middle indices along each axis
i_mid = vol.shape[0] // 2  # axial index
j_mid = vol.shape[1] // 2  # coronal index
k_mid = vol.shape[2] // 2  # sagittal index

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Axial: slice along axis 0
axes[0].imshow(vol[i_mid, :, :], cmap='gray')
axes[0].set_title(f'Axial (i={i_mid})')
axes[0].axis('off')

# Coronal: slice along axis 1
axes[1].imshow(vol[:, j_mid, :], cmap='gray')
axes[1].set_title(f'Coronal (j={j_mid})')
axes[1].axis('off')

# Sagittal: slice along axis 2
axes[2].imshow(vol[:, :, k_mid], cmap='gray')
axes[2].set_title(f'Sagittal (k={k_mid})')
axes[2].axis('off')

plt.tight_layout()
plt.show()

In [4]:
num_slices = vol.shape[0]
rows, cols = 7, 4

fig, axes = plt.subplots(rows, cols, figsize=(10, 18))

for i, ax in enumerate(axes.flat):
    if i < num_slices:
        ax.imshow(vol[i], cmap='gray')
        ax.set_title(f"Slice {i}")
        ax.axis('off')
    else:
        ax.axis('off')

plt.tight_layout()
plt.show()

In [5]:

# Initial slice index (MUST BE ON A SLICE THAT HAS ACTUAL DATA IN IT (no idea why))
init_idx = 16


fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.2)  # leave space at bottom for slider

# Show initial slice
im = ax.imshow(vol[init_idx], cmap='gray')
ax.set_title(f"Slice {init_idx}")
ax.axis('off')

# Slider axis: [left, bottom, width, height] in figure coordinates
ax_slider = fig.add_axes([0.2, 0.05, 0.6, 0.03])

# Slider: from 0 to num_slices - 1
slider = Slider(
    ax=ax_slider,
    label='Slice',
    valmin=0,
    valmax=vol.shape[0] - 1,
    valinit=init_idx,
    valstep=1,          # step in whole-number slices
)

# Update function
def update(val):
    idx = int(slider.val)
    im.set_data(vol[idx])
    ax.set_title(f"Slice {idx}")
    fig.canvas.draw_idle()

slider.on_changed(update)

plt.show()

Our dataset is a collection of 1,584 left and right adrenal glands taken from 792 patients. They are marked as either healthy or having an adrenal mass, indicating a binary classification task. The models have dimensions of 28x28x28 taken from 64mmx64mmx64mm volumes. The training/validation/test datasets have quantities of 1,188/98/298.

In [4]:
model = tf.keras.Sequential(
    layers = [
    tf.keras.layers.Input(shape=(28,28,28,1)),
    tf.keras.layers.Conv3D(16, 3, activation="relu"),
    tf.keras.layers.Conv3D(16, 2, activation="relu"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(1, activation="sigmoid")
    ]
)
print(trainx_tensor[0].shape)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
history = model.fit(trainx_tensor, trainy_tensor, validation_data=(valx_tensor, valy_tensor), epochs=3, batch_size=12)
model.save_weights("model.ckpt.weights.h5")


(28, 28, 28, 1)
Epoch 1/3
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 469ms/step - accuracy: 0.7710 - loss: 0.5116 - val_accuracy: 0.8061 - val_loss: 0.4525
Epoch 2/3
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 471ms/step - accuracy: 0.8569 - loss: 0.3224 - val_accuracy: 0.8571 - val_loss: 0.4332
Epoch 3/3
[1m99/99[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 481ms/step - accuracy: 0.9604 - loss: 0.1406 - val_accuracy: 0.8163 - val_loss: 0.4713


[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 138ms/step - accuracy: 0.7919 - loss: 0.9995


[0.9995102882385254, 0.791946291923523]