# Model Evaluation - Confusion Matrix
Evaluates the trained 3D U-Net model on the test set.

In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import confusion_matrix

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

In [None]:
# =============================================================================
# Define Loss Functions (needed for model loading)
# =============================================================================

CLASS_WEIGHTS = [0.1, 1, 20.0]

def dice_coefficient_per_class(y_true, y_pred, class_idx):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_c = y_true[..., class_idx]
    y_pred_c = y_pred[..., class_idx]
    y_true_f = tf.keras.backend.flatten(y_true_c)
    y_pred_f = tf.keras.backend.flatten(y_pred_c)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f))

def weighted_dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    total_loss = 0.0
    for class_idx, weight in enumerate(CLASS_WEIGHTS):
        dice = dice_coefficient_per_class(y_true, y_pred, class_idx)
        total_loss += weight * (1 - dice)
    return total_loss / sum(CLASS_WEIGHTS)

def dice_coefficient(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f))

def dice_liver(y_true, y_pred):
    return dice_coefficient_per_class(y_true, y_pred, 1)

def dice_tumor(y_true, y_pred):
    return dice_coefficient_per_class(y_true, y_pred, 2)

print("Loss functions defined.")

In [None]:
# =============================================================================
# Setup Test Files
# =============================================================================

DATA_DIR = 'preprocessed_patches_v2'
NUM_CLASSES = 3
SEED = 42
BATCH_SIZE = 4

# Get test files (same split as training)
all_files = sorted([os.path.join(DATA_DIR, f) for f in os.listdir(DATA_DIR) if f.endswith('.npz')])
np.random.seed(SEED)
indices = np.random.permutation(len(all_files))
train_end = int(len(all_files) * 0.70)
val_end = train_end + int(len(all_files) * 0.15)
test_files = [all_files[i] for i in indices[val_end:]]

print(f"Total files: {len(all_files)}")
print(f"Test set: {len(test_files)} files ({len(test_files) * 20} patches)")

In [None]:
# =============================================================================
# Load Best Model + Warmup
# =============================================================================

MODEL_PATH = 'checkpoints/best_model.keras'

print(f"Loading model from {MODEL_PATH}...")
model = tf.keras.models.load_model(
    MODEL_PATH,
    custom_objects={
        'weighted_dice_loss': weighted_dice_loss,
        'dice_coefficient': dice_coefficient,
        'dice_liver': dice_liver,
        'dice_tumor': dice_tumor
    }
)
print("Model loaded successfully!")
print(f"Model parameters: {model.count_params():,}")

# Warmup prediction to trigger XLA compilation (this takes 1-2 minutes)
print("\nWarming up model (XLA compilation - this takes 1-2 min)...")
warmup_start = time.time()
dummy_input = np.zeros((1, 128, 128, 128, 1), dtype=np.float32)
_ = model.predict(dummy_input, verbose=0)
print(f"Warmup done in {time.time()-warmup_start:.1f}s - ready for evaluation!")

In [None]:
# =============================================================================
# Compute Confusion Matrix
# =============================================================================

print("Computing confusion matrix on test set...")
print(f"Processing {len(test_files)} files ({len(test_files)*20} patches)...\n")

cm = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)
total_patches = 0
start_time = time.time()

for file_idx, filepath in enumerate(test_files):
    file_start = time.time()
    data = np.load(filepath)
    patches = data['patches'].astype(np.float32) / 255.0
    segs = data['segmentations']
    
    # Process in batches
    for i in range(0, len(patches), BATCH_SIZE):
        x = patches[i:i+BATCH_SIZE][..., np.newaxis]
        y_true = segs[i:i+BATCH_SIZE]
        
        pred = model.predict(x, verbose=0)
        y_pred = np.argmax(pred, axis=-1)
        
        # Accumulate confusion matrix
        cm += confusion_matrix(y_true.flatten(), y_pred.flatten(), labels=[0, 1, 2])
        total_patches += len(x)
    
    file_time = time.time() - file_start
    total_time = time.time() - start_time
    eta = (total_time / (file_idx + 1)) * (len(test_files) - file_idx - 1)
    print(f"  [{file_idx+1:2d}/{len(test_files)}] {os.path.basename(filepath)}: {file_time:.1f}s (ETA: {eta:.0f}s)")

print(f"\n{'='*50}")
print(f"DONE! Total time: {time.time()-start_time:.1f}s")
print(f"Total voxels evaluated: {cm.sum():,}")

In [None]:
# =============================================================================
# Plot Confusion Matrix
# =============================================================================

class_names = ['Background', 'Liver', 'Tumor']
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True) * 100

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
im1 = axes[0].imshow(cm, cmap='Blues')
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14)
axes[0].set_xlabel('Predicted', fontsize=12)
axes[0].set_ylabel('True', fontsize=12)
axes[0].set_xticks(range(3))
axes[0].set_yticks(range(3))
axes[0].set_xticklabels(class_names)
axes[0].set_yticklabels(class_names)
for i in range(3):
    for j in range(3):
        axes[0].text(j, i, f'{cm[i,j]:,}', ha='center', va='center', fontsize=10)
plt.colorbar(im1, ax=axes[0])

# Normalized (percentages)
im2 = axes[1].imshow(cm_norm, cmap='Blues', vmin=0, vmax=100)
axes[1].set_title('Confusion Matrix (% by True Class)', fontsize=14)
axes[1].set_xlabel('Predicted', fontsize=12)
axes[1].set_ylabel('True', fontsize=12)
axes[1].set_xticks(range(3))
axes[1].set_yticks(range(3))
axes[1].set_xticklabels(class_names)
axes[1].set_yticklabels(class_names)
for i in range(3):
    for j in range(3):
        axes[1].text(j, i, f'{cm_norm[i,j]:.1f}%', ha='center', va='center', fontsize=10)
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nConfusion matrix saved to confusion_matrix.png")

In [None]:
# =============================================================================
# Print Per-Class Metrics
# =============================================================================

print("="*60)
print("PER-CLASS METRICS")
print("="*60)
print(f"{'Class':<12} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Dice':>10}")
print("-"*60)

for i, name in enumerate(class_names):
    tp = cm[i, i]
    fp = cm[:, i].sum() - tp
    fn = cm[i, :].sum() - tp
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
    
    print(f"{name:<12} {precision:>10.4f} {recall:>10.4f} {f1:>10.4f} {dice:>10.4f}")

print("="*60)

# Overall accuracy
accuracy = np.trace(cm) / cm.sum()
print(f"\nOverall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")