In [1]:
import os
import sys
import json
from pathlib import Path
from typing import List

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.utils import to_categorical

import nibabel as nib
import SimpleITK as sitk
from scipy.ndimage import rotate as scipy_rotate

from PIL import Image, ImageOps, ImageFilter
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report


tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

2025-12-17 00:53:09.360592: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-17 00:53:09.389911: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-12-17 00:53:10.006654: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
def double_conv_block_3d(x, n_filters, kernel_size=3):
    x = layers.Conv3D(n_filters, kernel_size, padding="same", kernel_initializer="he_normal")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv3D(n_filters, kernel_size, padding="same", kernel_initializer="he_normal")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    return x


def downsample_block_3d(x, n_filters, dropout_rate=0.3):
    f = double_conv_block_3d(x, n_filters)
    p = layers.MaxPool3D(pool_size=(2, 2, 2))(f)
    p = layers.Dropout(dropout_rate)(p)
    return f, p


def upsample_block_3d(x, skip_features, n_filters, dropout_rate=0.5):
    x = layers.Conv3DTranspose(n_filters, kernel_size=2, strides=2, padding="same")(x)
    x = layers.concatenate([x, skip_features])
    x = layers.Dropout(dropout_rate)(x)
    x = double_conv_block_3d(x, n_filters)
    return x


def build_3d_unet(input_shape=(128, 128, 128, 1), num_classes=3, base_filters=32):
    inputs = layers.Input(shape=input_shape)
    
    f1, p1 = downsample_block_3d(inputs, base_filters)        
    f2, p2 = downsample_block_3d(p1, base_filters * 2)        
    f3, p3 = downsample_block_3d(p2, base_filters * 4)        
    f4, p4 = downsample_block_3d(p3, base_filters * 8)
    
    bottleneck = double_conv_block_3d(p4, base_filters * 16)  
    
    u1 = upsample_block_3d(bottleneck, f4, base_filters * 8)  
    u2 = upsample_block_3d(u1, f3, base_filters * 4)         
    u3 = upsample_block_3d(u2, f2, base_filters * 2)         
    u4 = upsample_block_3d(u3, f1, base_filters)            
    

    outputs = layers.Conv3D(num_classes, kernel_size=1, padding="same", activation="softmax")(u4)
    
    model = models.Model(inputs, outputs, name="3D-UNet")
    return model

In [3]:
DATA_DIR = 'preprocessed_patches_v2' 
NUM_CLASSES = 3
SEED = 42

TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15

all_files = sorted([os.path.join(DATA_DIR, f) for f in os.listdir(DATA_DIR) if f.endswith('.npz')])

np.random.seed(SEED)
indices = np.random.permutation(len(all_files))

train_end = int(len(all_files) * TRAIN_RATIO)
val_end = train_end + int(len(all_files) * VAL_RATIO)

train_files = [all_files[i] for i in indices[:train_end]]
val_files = [all_files[i] for i in indices[train_end:val_end]]
test_files = [all_files[i] for i in indices[val_end:]]

In [None]:
def augment_rotation_3d(volume, segmentation, max_angle=30):
    angle_x = np.random.uniform(-max_angle, max_angle)
    angle_y = np.random.uniform(-max_angle, max_angle)
    angle_z = np.random.uniform(-max_angle, max_angle)
    
    vol_rotated = scipy_rotate(volume, angle_z, axes=(0, 1), reshape=False, order=1, mode='constant', cval=0)
    vol_rotated = scipy_rotate(vol_rotated, angle_y, axes=(0, 2), reshape=False, order=1, mode='constant', cval=0)
    vol_rotated = scipy_rotate(vol_rotated, angle_x, axes=(1, 2), reshape=False, order=1, mode='constant', cval=0)
    
    seg_rotated = scipy_rotate(segmentation, angle_z, axes=(0, 1), reshape=False, order=0, mode='constant', cval=0)
    seg_rotated = scipy_rotate(seg_rotated, angle_y, axes=(0, 2), reshape=False, order=0, mode='constant', cval=0)
    seg_rotated = scipy_rotate(seg_rotated, angle_x, axes=(1, 2), reshape=False, order=0, mode='constant', cval=0)
    
    return vol_rotated, seg_rotated


def augment_gamma(volume, gamma_range=(0.7, 1.5)):
    gamma = np.random.uniform(gamma_range[0], gamma_range[1])
    return np.power(np.clip(volume, 0, 1), gamma)


def augment_gaussian_noise(volume, sigma_range=(0, 0.05)):
    sigma = np.random.uniform(sigma_range[0], sigma_range[1])
    noise = np.random.normal(0, sigma, volume.shape)
    return np.clip(volume + noise, 0, 1)


def augment_brightness(volume, delta_range=(-0.1, 0.1)):
    delta = np.random.uniform(delta_range[0], delta_range[1])
    return np.clip(volume + delta, 0, 1)


def apply_augmentation(volume, segmentation, augment=True):
    if not augment:
        return volume, segmentation
    if np.random.random() < 0.5:
        volume, segmentation = augment_rotation_3d(volume, segmentation, max_angle=30)
    # Intensity augmentations
    if np.random.random() < 0.3:
        volume = augment_gamma(volume, gamma_range=(0.7, 1.5))
    
    if np.random.random() < 0.1:
        volume = augment_gaussian_noise(volume, sigma_range=(0, 0.05))
    
    if np.random.random() < 0.3:
        volume = augment_brightness(volume, delta_range=(-0.1, 0.1))
    
    return volume, segmentation


class VolumeGenerator(tf.keras.utils.Sequence):
    
    def __init__(self, files, batch_size=2, num_classes=3, shuffle=True, augment=False):
        self.files = files
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.augment = augment
        
        self.batch_indices = []
        for file_idx, filepath in enumerate(files):
            data = np.load(filepath)
            n_patches = len(data['patches'])
            for start in range(0, n_patches, batch_size):
                self.batch_indices.append((file_idx, start))
        
        self.on_epoch_end()
    
    def __len__(self):
        return len(self.batch_indices)
    
    def __getitem__(self, idx):
        file_idx, patch_start = self.batch_indices[self.indices[idx]]
        data = np.load(self.files[file_idx])
        
        patch_end = patch_start + self.batch_size
        
        x = data['patches'][patch_start:patch_end].astype(np.float32) / 255.0
        y = data['segmentations'][patch_start:patch_end]
        
        if self.augment:
            x_aug = []
            y_aug = []
            for i in range(len(x)):
                vol_aug, seg_aug = apply_augmentation(x[i], y[i], augment=True)
                x_aug.append(vol_aug)
                y_aug.append(seg_aug)
            x = np.array(x_aug, dtype=np.float32)
            y = np.array(y_aug)
        
        x = x[..., np.newaxis]
        y = to_categorical(y, num_classes=self.num_classes)
        
        return x, y
    
    def on_epoch_end(self):
        self.indices = np.arange(len(self.batch_indices))
        if self.shuffle:
            np.random.shuffle(self.indices)


BATCH_SIZE = 4

train_gen = VolumeGenerator(train_files, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES, shuffle=True, augment=True)
val_gen = VolumeGenerator(val_files, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES, shuffle=False, augment=False)
test_gen = VolumeGenerator(test_files, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES, shuffle=False, augment=False)

In [5]:
CLASS_WEIGHTS = [0.1, 0.8, 20.0]


def dice_coefficient_per_class(y_true, y_pred, class_idx, smooth=1e-6):
    """Dice coefficient for a single class."""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_c = y_true[..., class_idx]
    y_pred_c = y_pred[..., class_idx]
    y_true_f = tf.keras.backend.flatten(y_true_c)
    y_pred_f = tf.keras.backend.flatten(y_pred_c)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)


def weighted_dice_loss(y_true, y_pred):
    """Weighted dice loss - higher weight for tumor class."""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    total_loss = 0.0
    for class_idx, weight in enumerate(CLASS_WEIGHTS):
        dice = dice_coefficient_per_class(y_true, y_pred, class_idx)
        total_loss += weight * (1 - dice)
    
    return total_loss / sum(CLASS_WEIGHTS)

def dice_coefficient(y_true, y_pred):
    """Overall dice coefficient."""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1e-6) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + 1e-6)


def dice_liver(y_true, y_pred):
    """Dice coefficient for liver (class 1)."""
    return dice_coefficient_per_class(y_true, y_pred, 1)


def dice_tumor(y_true, y_pred):
    """Dice coefficient for tumor (class 2)."""
    return dice_coefficient_per_class(y_true, y_pred, 2)


print("Loss function: Weighted Dice Loss")
print(f"Class weights: Background={CLASS_WEIGHTS[0]}, Liver={CLASS_WEIGHTS[1]}, Tumor={CLASS_WEIGHTS[2]}")

Loss function: Weighted Dice Loss
Class weights: Background=0.1, Liver=0.8, Tumor=20.0


In [6]:
model = build_3d_unet(input_shape=(128, 128, 128, 1), num_classes=NUM_CLASSES, base_filters=24)

# Load pretrained weights
PRETRAINED_PATH = 'checkpoints/best_model_v2_24_final.keras'
print(f"Loading pretrained weights from {PRETRAINED_PATH}...")
model.load_weights(PRETRAINED_PATH)
print("Weights loaded successfully!")

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=weighted_dice_loss,
    metrics=[dice_coefficient, dice_liver, dice_tumor]
)

model.summary()

I0000 00:00:1765950811.034270    6356 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9102 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4080 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.9


Loading pretrained weights from checkpoints/best_model_v2_24_final.keras...
Weights loaded successfully!


In [7]:
EPOCHS = 60
CHECKPOINT_DIR = 'checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Custom callback to track learning rate
class LearningRateLogger(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
        logs['learning_rate'] = lr

callbacks = [
    # Save best model based on validation tumor dice
    tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(CHECKPOINT_DIR, 'best_model_tumor_dice.keras'),
        monitor='val_dice_tumor',
        mode='max',
        save_best_only=True,
        verbose=1
    ),
    # Save latest checkpoint every epoch
    tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(CHECKPOINT_DIR, 'latest_checkpoint.keras'),
        save_best_only=False,
        verbose=0
    ),
    # Reduce learning rate when val_dice_tumor plateaus
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_dice_tumor',
        mode='max',
        factor=0.5,
        patience=12,
        min_lr=1e-7,
        verbose=1
    ),
    # Early stopping if no improvement
    tf.keras.callbacks.EarlyStopping(
        monitor='val_dice_tumor',
        mode='max',
        patience=25,
        restore_best_weights=True,
        verbose=1
    ),
    # Track learning rate
    LearningRateLogger()
]

history = model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1
)

# Save final model and history
model.save('final_model.keras')
print("\nFinal model saved to final_model.keras")

history_dict = {key: [float(v) for v in values] for key, values in history.history.items()}
with open('training_history.json', 'w') as f:
    json.dump(history_dict, f, indent=2)
print("Training history saved to training_history.json")

  self._warn_if_super_not_called()
2025-12-17 00:53:32.827248: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91700


Epoch 1/60


2025-12-17 00:53:39.135907: I external/local_xla/xla/service/service.cc:163] XLA service 0x78795c058e10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-12-17 00:53:39.135919: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 4080 Laptop GPU, Compute Capability 8.9
2025-12-17 00:53:39.268724: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-12-17 00:53:43.724648: E external/local_xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng0{} for conv (bf16[4,24,128,128,128]{4,3,2,1,0}, u8[0]{0}) custom-call(bf16[4,24,128,128,128]{4,3,2,1,0}, bf16[24,24,3,3,3]{4,3,2,1,0}, bf16[24]{0}), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cud

[1m455/455[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 931ms/step - dice_coefficient: 0.9193 - dice_liver: 0.8098 - dice_tumor: 0.5211 - loss: 0.4658
Epoch 1: val_dice_tumor improved from None to 0.42323, saving model to checkpoints/best_model_tumor_dice.keras
[1m455/455[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m514s[0m 988ms/step - dice_coefficient: 0.9208 - dice_liver: 0.8170 - dice_tumor: 0.5186 - loss: 0.4679 - val_dice_coefficient: 0.9463 - val_dice_liver: 0.8686 - val_dice_tumor: 0.4232 - val_loss: 0.5571 - learning_rate: 1.0000e-04
Epoch 2/60
[1m455/455[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 945ms/step - dice_coefficient: 0.9256 - dice_liver: 0.8199 - dice_tumor: 0.4986 - loss: 0.4870
Epoch 2: val_dice_tumor did not improve from 0.42323
[1m455/455[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m454s[0m 997ms/step - dice_coefficient: 0.9227 - dice_liver: 0.8201 - dice_tumor: 0.5047 - loss: 0.4811 - val_dice_coefficient: 0.9369 - val_dice_liver: 

In [None]:
# Plot training history
with open('training_history.json', 'r') as f:
    history_data = json.load(f)

epochs_range = range(1, len(history_data['loss']) + 1)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Loss
axes[0, 0].plot(epochs_range, history_data['loss'], 'b-', label='Train Loss')
axes[0, 0].plot(epochs_range, history_data['val_loss'], 'r-', label='Val Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training vs Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Plot 2: Tumor Dice (most important)
axes[0, 1].plot(epochs_range, history_data['dice_tumor'], 'b-', label='Train Tumor Dice')
axes[0, 1].plot(epochs_range, history_data['val_dice_tumor'], 'r-', label='Val Tumor Dice')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Dice Score')
axes[0, 1].set_title('Tumor Dice Score (Primary Metric)')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Plot 3: Liver Dice
axes[1, 0].plot(epochs_range, history_data['dice_liver'], 'b-', label='Train Liver Dice')
axes[1, 0].plot(epochs_range, history_data['val_dice_liver'], 'r-', label='Val Liver Dice')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Dice Score')
axes[1, 0].set_title('Liver Dice Score')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Plot 4: Learning Rate
if 'learning_rate' in history_data:
    axes[1, 1].plot(epochs_range, history_data['learning_rate'], 'g-', marker='o', markersize=3)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_title('Learning Rate Schedule')
    axes[1, 1].set_yscale('log')
    axes[1, 1].grid(True)
else:
    axes[1, 1].text(0.5, 0.5, 'No LR data', ha='center', va='center', transform=axes[1, 1].transAxes)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

# Print best metrics
best_val_tumor_idx = np.argmax(history_data['val_dice_tumor'])
print(f"\nBest Validation Tumor Dice: {history_data['val_dice_tumor'][best_val_tumor_idx]:.4f} at epoch {best_val_tumor_idx + 1}")
print(f"  - Train Tumor Dice: {history_data['dice_tumor'][best_val_tumor_idx]:.4f}")
print(f"  - Val Liver Dice: {history_data['val_dice_liver'][best_val_tumor_idx]:.4f}")
print(f"  - Val Loss: {history_data['val_loss'][best_val_tumor_idx]:.4f}")


Best Validation Tumor Dice: 0.4849 at epoch 52
  - Train Tumor Dice: 0.6500
  - Val Liver Dice: 0.9127
  - Val Loss: 0.4963


: 