In [None]:

import numpy as np
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate
from sklearn.model_selection import train_test_split

# Set paths
PREPROCESSED_DIR = Path("./experiments/preprocessed")
MODEL_DIR = Path("./experiments/models/unet3d")
MODEL_DIR.mkdir(parents=True, exist_ok=True)

# Load preprocessed MRI data
mri_data = np.load(PREPROCESSED_DIR / "mri_data.npy")
mri_masks = np.load(PREPROCESSED_DIR / "mri_masks.npy")

print("MRI data shape:", mri_data.shape)
print("MRI masks shape:", mri_masks.shape)

# Split train/test
X_train, X_test, y_train, y_test = train_test_split(
    mri_data, mri_masks, test_size=0.2, random_state=42
)

print("Train shape:", X_train.shape, y_train.shape)
print("Test shape:", X_test.shape, y_test.shape)


In [None]:

def unet3d(input_shape=(64,64,64,1), num_classes=1):
    inputs = Input(shape=input_shape)
    
    # Encoder
    c1 = Conv3D(32, (3,3,3), activation='relu', padding='same')(inputs)
    c1 = Conv3D(32, (3,3,3), activation='relu', padding='same')(c1)
    p1 = MaxPooling3D((2,2,2))(c1)
    
    c2 = Conv3D(64, (3,3,3), activation='relu', padding='same')(p1)
    c2 = Conv3D(64, (3,3,3), activation='relu', padding='same')(c2)
    p2 = MaxPooling3D((2,2,2))(c2)
    
    c3 = Conv3D(128, (3,3,3), activation='relu', padding='same')(p2)
    c3 = Conv3D(128, (3,3,3), activation='relu', padding='same')(c3)
    
    # Decoder
    u2 = UpSampling3D((2,2,2))(c3)
    u2 = concatenate([u2, c2])
    c4 = Conv3D(64, (3,3,3), activation='relu', padding='same')(u2)
    c4 = Conv3D(64, (3,3,3), activation='relu', padding='same')(c4)
    
    u1 = UpSampling3D((2,2,2))(c4)
    u1 = concatenate([u1, c1])
    c5 = Conv3D(32, (3,3,3), activation='relu', padding='same')(u1)
    c5 = Conv3D(32, (3,3,3), activation='relu', padding='same')(c5)
    
    outputs = Conv3D(num_classes, (1,1,1), activation='sigmoid')(c5)
    
    model = Model(inputs, outputs)
    return model

input_shape = X_train.shape[1:]
model = unet3d(input_shape=input_shape)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()


In [None]:

batch_size = 1
epochs = 5

history = model.fit(
    X_train, y_train,
    validation_split=0.2,
    batch_size=batch_size,
    epochs=epochs,
    verbose=1
)

# Save trained model
model.save(MODEL_DIR / "unet3d_model")
print(f"3D U-Net model saved to {MODEL_DIR / 'unet3d_model'}")


In [None]:

import matplotlib.pyplot as plt

# Evaluate on test set
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

# Visualize predictions on a test sample
sample_idx = 0
sample_vol = X_test[sample_idx:sample_idx+1]
sample_mask = y_test[sample_idx]

pred_mask = model.predict(sample_vol)[0,...,0]

slice_idx = 32
plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(sample_vol[0,:,:,slice_idx,0], cmap='gray')
plt.title("MRI Slice")
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow(sample_mask[:,:,slice_idx,0], cmap='gray')
plt.title("Ground Truth Mask")
plt.axis('off')

plt.subplot(1,3,3)
plt.imshow(pred_mask[:,:,slice_idx], cmap='gray')
plt.title("Predicted Mask")
plt.axis('off')

plt.show()
