# Model Evaluation - Confusion Matrix
Evaluates the trained 3D U-Net model on the test set.

In [1]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn.metrics import confusion_matrix

2025-12-15 18:11:55.625855: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-15 18:11:55.658660: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-12-15 18:11:56.283174: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
# =============================================================================
# Define Loss Functions (needed for model loading)
# =============================================================================

CLASS_WEIGHTS = [0.1, 1, 20.0]

def dice_coefficient_per_class(y_true, y_pred, class_idx):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_c = y_true[..., class_idx]
    y_pred_c = y_pred[..., class_idx]
    y_true_f = tf.keras.backend.flatten(y_true_c)
    y_pred_f = tf.keras.backend.flatten(y_pred_c)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f))

def weighted_dice_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    total_loss = 0.0
    for class_idx, weight in enumerate(CLASS_WEIGHTS):
        dice = dice_coefficient_per_class(y_true, y_pred, class_idx)
        total_loss += weight * (1 - dice)
    return total_loss / sum(CLASS_WEIGHTS)

def dice_coefficient(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f))

def dice_liver(y_true, y_pred):
    return dice_coefficient_per_class(y_true, y_pred, 1)

def dice_tumor(y_true, y_pred):
    return dice_coefficient_per_class(y_true, y_pred, 2)

print("Loss functions defined.")

Loss functions defined.


In [3]:
# =============================================================================
# Setup Test Files
# =============================================================================

DATA_DIR = 'preprocessed_patches_v2'
NUM_CLASSES = 3
SEED = 42
BATCH_SIZE = 4

# Get test files (same split as training)
all_files = sorted([os.path.join(DATA_DIR, f) for f in os.listdir(DATA_DIR) if f.endswith('.npz')])
np.random.seed(SEED)
indices = np.random.permutation(len(all_files))
train_end = int(len(all_files) * 0.70)
val_end = train_end + int(len(all_files) * 0.15)
test_files = [all_files[i] for i in indices[val_end:]]

print(f"Total files: {len(all_files)}")
print(f"Test set: {len(test_files)} files ({len(test_files) * 20} patches)")

Total files: 131
Test set: 21 files (420 patches)


In [4]:
# =============================================================================
# Load Best Model + Warmup
# =============================================================================

MODEL_PATH = 'checkpoints/best_model_v2_augment.keras'

print(f"Loading model from {MODEL_PATH}...")
model = tf.keras.models.load_model(
    MODEL_PATH,
    custom_objects={
        'weighted_dice_loss': weighted_dice_loss,
        'dice_coefficient': dice_coefficient,
        'dice_liver': dice_liver,
        'dice_tumor': dice_tumor
    }
)
print("Model loaded successfully!")
print(f"Model parameters: {model.count_params():,}")

# Warmup prediction to trigger XLA compilation (this takes 1-2 minutes)
print("\nWarming up model (XLA compilation - this takes 1-2 min)...")
warmup_start = time.time()
dummy_input = np.zeros((1, 128, 128, 128, 1), dtype=np.float32)
_ = model.predict(dummy_input, verbose=0)
print(f"Warmup done in {time.time()-warmup_start:.1f}s - ready for evaluation!")

Loading model from checkpoints/best_model_v2_augment.keras...


I0000 00:00:1765840317.097091  280811 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9064 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4080 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.9


Model loaded successfully!
Model parameters: 12,708,315

Warming up model (XLA compilation - this takes 1-2 min)...


2025-12-15 18:11:58.334416: E tensorflow/core/util/util.cc:131] oneDNN supports DT_BFLOAT16 only on platforms with AVX-512. Falling back to the default Eigen-based implementation if present.
2025-12-15 18:11:58.340392: I external/local_xla/xla/service/service.cc:163] XLA service 0x7c80a402dee0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-12-15 18:11:58.340406: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 4080 Laptop GPU, Compute Capability 8.9
2025-12-15 18:11:58.352536: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-12-15 18:11:58.471221: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91700


Warmup done in 5.2s - ready for evaluation!


I0000 00:00:1765840323.241902  280944 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


In [5]:
# =============================================================================
# Compute Confusion Matrix
# =============================================================================

print("Computing confusion matrix on test set...")
print(f"Processing {len(test_files)} files ({len(test_files)*20} patches)...\n")

cm = np.zeros((NUM_CLASSES, NUM_CLASSES), dtype=np.int64)
total_patches = 0
start_time = time.time()

for file_idx, filepath in enumerate(test_files):
    file_start = time.time()
    data = np.load(filepath)
    patches = data['patches'].astype(np.float32) / 255.0
    segs = data['segmentations']
    
    # Process in batches
    for i in range(0, len(patches), BATCH_SIZE):
        x = patches[i:i+BATCH_SIZE][..., np.newaxis]
        y_true = segs[i:i+BATCH_SIZE]
        
        pred = model.predict(x, verbose=0)
        y_pred = np.argmax(pred, axis=-1)
        
        # Accumulate confusion matrix
        cm += confusion_matrix(y_true.flatten(), y_pred.flatten(), labels=[0, 1, 2])
        total_patches += len(x)
    
    file_time = time.time() - file_start
    total_time = time.time() - start_time
    eta = (total_time / (file_idx + 1)) * (len(test_files) - file_idx - 1)
    print(f"  [{file_idx+1:2d}/{len(test_files)}] {os.path.basename(filepath)}: {file_time:.1f}s (ETA: {eta:.0f}s)")

print(f"\n{'='*50}")
print(f"DONE! Total time: {time.time()-start_time:.1f}s")
print(f"Total voxels evaluated: {cm.sum():,}")

Computing confusion matrix on test set...
Processing 21 files (420 patches)...



2025-12-15 18:12:07.304204: E external/local_xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng0{} for conv (bf16[4,24,128,128,128]{4,3,2,1,0}, u8[0]{0}) custom-call(bf16[4,24,128,128,128]{4,3,2,1,0}, bf16[24,24,3,3,3]{4,3,2,1,0}, bf16[24]{0}), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]} is taking a while...
2025-12-15 18:12:07.425005: E external/local_xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.121013414s
Trying algorithm eng0{} for conv (bf16[4,24,128,128,128]{4,3,2,1,0}, u8[0]{0}) custom-call(bf16[4,24,128,128,128]{4,3,2,1,0}, bf16[24,24,3,3,3]{4,3,2,1,0}, bf16[24]{0}), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=bf012_oi012

  [ 1/21] volume_052.npz: 18.2s (ETA: 363s)
  [ 2/21] volume_021.npz: 2.8s (ETA: 199s)
  [ 3/21] volume_002.npz: 2.7s (ETA: 142s)
  [ 4/21] volume_023.npz: 2.7s (ETA: 112s)
  [ 5/21] volume_103.npz: 2.7s (ETA: 93s)
  [ 6/21] volume_099.npz: 2.7s (ETA: 79s)
  [ 7/21] volume_116.npz: 2.7s (ETA: 69s)
  [ 8/21] volume_087.npz: 2.7s (ETA: 60s)
  [ 9/21] volume_119.npz: 2.7s (ETA: 53s)
  [10/21] volume_074.npz: 2.7s (ETA: 47s)
  [11/21] volume_086.npz: 2.6s (ETA: 41s)
  [12/21] volume_082.npz: 2.6s (ETA: 36s)
  [13/21] volume_121.npz: 2.7s (ETA: 31s)
  [14/21] volume_130.npz: 2.7s (ETA: 26s)
  [15/21] volume_020.npz: 2.7s (ETA: 22s)
  [16/21] volume_060.npz: 2.7s (ETA: 18s)
  [17/21] volume_071.npz: 2.6s (ETA: 14s)
  [18/21] volume_106.npz: 2.6s (ETA: 11s)
  [19/21] volume_014.npz: 2.6s (ETA: 7s)
  [20/21] volume_092.npz: 2.7s (ETA: 3s)
  [21/21] volume_102.npz: 2.6s (ETA: 0s)

DONE! Total time: 71.4s
Total voxels evaluated: 880,803,840


In [None]:
# =============================================================================
# Plot Confusion Matrix
# =============================================================================

class_names = ['Background', 'Liver', 'Tumor']
cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True) * 100

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
im1 = axes[0].imshow(cm, cmap='Blues')
axes[0].set_title('Confusion Matrix (Counts)', fontsize=14)
axes[0].set_xlabel('Predicted', fontsize=12)
axes[0].set_ylabel('True', fontsize=12)
axes[0].set_xticks(range(3))
axes[0].set_yticks(range(3))
axes[0].set_xticklabels(class_names)
axes[0].set_yticklabels(class_names)
for i in range(3):
    for j in range(3):
        axes[0].text(j, i, f'{cm[i,j]:,}', ha='center', va='center', fontsize=10)
cbar1 = plt.colorbar(im1, ax=axes[0])
cbar1.set_label('Voxel Count', fontsize=11)

# Normalized (percentages)
im2 = axes[1].imshow(cm_norm, cmap='Blues', vmin=0, vmax=100)
axes[1].set_title('Confusion Matrix (% by True Class)', fontsize=14)
axes[1].set_xlabel('Predicted', fontsize=12)
axes[1].set_ylabel('True', fontsize=12)
axes[1].set_xticks(range(3))
axes[1].set_yticks(range(3))
axes[1].set_xticklabels(class_names)
axes[1].set_yticklabels(class_names)
for i in range(3):
    for j in range(3):
        axes[1].text(j, i, f'{cm_norm[i,j]:.1f}%', ha='center', va='center', fontsize=10)
cbar2 = plt.colorbar(im2, ax=axes[1])
cbar2.set_label('Percentage (%)', fontsize=11)

plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nConfusion matrix saved to confusion_matrix.png")

In [None]:
# =============================================================================
# Print Per-Class Metrics
# =============================================================================

print("="*60)
print("PER-CLASS METRICS")
print("="*60)
print(f"{'Class':<12} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Dice':>10}")
print("-"*60)

for i, name in enumerate(class_names):
    tp = cm[i, i]
    fp = cm[:, i].sum() - tp
    fn = cm[i, :].sum() - tp
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
    
    print(f"{name:<12} {precision:>10.4f} {recall:>10.4f} {f1:>10.4f} {dice:>10.4f}")

print("="*60)

# Overall accuracy
accuracy = np.trace(cm) / cm.sum()
print(f"\nOverall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

PER-CLASS METRICS
Class         Precision     Recall         F1       Dice
------------------------------------------------------------
Background       0.9807     0.9662     0.9734     0.9734
Liver            0.8702     0.9397     0.9036     0.9036
Tumor            0.8618     0.5269     0.6539     0.6539

Overall Accuracy: 0.9501 (95.01%)


: 

In [None]:
# =============================================================================
# Visualization: Original Patch, Ground Truth, Predicted Mask
# =============================================================================

def visualize_predictions(model, test_files, num_patients=5, slice_idx=64):
    """
    Visualize predictions for multiple patients.
    
    Parameters:
    - model: Trained model
    - test_files: List of test file paths
    - num_patients: Number of patients (rows) to display
    - slice_idx: Which slice of the 3D volume to display (default: middle slice)
    """
    num_patients = min(num_patients, len(test_files))
    
    fig, axes = plt.subplots(num_patients, 3, figsize=(12, 4 * num_patients))
    
    # Handle single patient case
    if num_patients == 1:
        axes = axes.reshape(1, -1)
    
    # Column titles
    col_titles = ['Original Patch', 'Ground Truth', 'Predicted Mask']
    
    # Color map for segmentation: 0=black (bg), 1=green (liver), 2=red (tumor)
    seg_cmap = plt.cm.colors.ListedColormap(['black', 'green', 'red'])
    
    for row, filepath in enumerate(test_files[:num_patients]):
        # Load data
        data = np.load(filepath)
        patches = data['patches'].astype(np.float32) / 255.0
        segs = data['segmentations']
        
        # Use first patch from this patient (or one with tumor if available)
        patch_idx = 0
        for i in range(len(segs)):
            if np.any(segs[i] == 2):  # Find patch with tumor
                patch_idx = i
                break
        
        # Get patch and segmentation
        patch = patches[patch_idx]
        gt = segs[patch_idx]
        
        # Predict
        x = patch[np.newaxis, ..., np.newaxis]
        pred = model.predict(x, verbose=0)
        pred_mask = np.argmax(pred[0], axis=-1)
        
        # Get the slice (use middle or specified slice)
        s = min(slice_idx, patch.shape[0] - 1)
        
        # Find slice with most tumor content if available
        tumor_counts = [np.sum(gt[i] == 2) for i in range(gt.shape[0])]
        if max(tumor_counts) > 0:
            s = np.argmax(tumor_counts)
        
        # Plot original patch
        axes[row, 0].imshow(patch[s], cmap='gray', vmin=0, vmax=1)
        axes[row, 0].set_ylabel(f'Patient {row + 1}\n{os.path.basename(filepath)[:15]}...', fontsize=10)
        axes[row, 0].set_xticks([])
        axes[row, 0].set_yticks([])
        
        # Plot ground truth
        axes[row, 1].imshow(patch[s], cmap='gray', vmin=0, vmax=1, alpha=0.5)
        axes[row, 1].imshow(gt[s], cmap=seg_cmap, vmin=0, vmax=2, alpha=0.5)
        axes[row, 1].set_xticks([])
        axes[row, 1].set_yticks([])
        
        # Plot prediction
        axes[row, 2].imshow(patch[s], cmap='gray', vmin=0, vmax=1, alpha=0.5)
        axes[row, 2].imshow(pred_mask[s], cmap=seg_cmap, vmin=0, vmax=2, alpha=0.5)
        axes[row, 2].set_xticks([])
        axes[row, 2].set_yticks([])
        
        # Add column titles on first row
        if row == 0:
            for col, title in enumerate(col_titles):
                axes[row, col].set_title(title, fontsize=12, fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='black', label='Background'),
        Patch(facecolor='green', label='Liver'),
        Patch(facecolor='red', label='Tumor')
    ]
    fig.legend(handles=legend_elements, loc='lower center', ncol=3, fontsize=10, 
               bbox_to_anchor=(0.5, -0.02))
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.08)
    plt.savefig('prediction_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Visualization saved to prediction_visualization.png")


# Run visualization
print("Generating visualization for test patients...")
visualize_predictions(model, test_files, num_patients=5)

Generating visualization for test patients...


In [None]:
# =============================================================================
# Visualization: Single Patient - Multiple Slices (Axial View)
# =============================================================================

def visualize_single_patient_slices(model, test_files, patient_idx=0, num_slices=8):
    """
    Visualize multiple slices from a single patient.
    
    Parameters:
    - model: Trained model
    - test_files: List of test file paths
    - patient_idx: Which patient to visualize
    - num_slices: Number of slices (rows) to display
    
    Arrangement: Upper rows = superior slices (toward head)
                 Lower rows = inferior slices (toward feet)
    """
    # Load patient data
    filepath = test_files[patient_idx]
    data = np.load(filepath)
    patches = data['patches'].astype(np.float32) / 255.0
    segs = data['segmentations']
    
    # Find patch with tumor (for more interesting visualization)
    patch_idx = 0
    max_tumor = 0
    for i in range(len(segs)):
        tumor_count = np.sum(segs[i] == 2)
        if tumor_count > max_tumor:
            max_tumor = tumor_count
            patch_idx = i
    
    patch = patches[patch_idx]
    gt = segs[patch_idx]
    
    # Predict
    x = patch[np.newaxis, ..., np.newaxis]
    pred = model.predict(x, verbose=0)
    pred_mask = np.argmax(pred[0], axis=-1)
    
    # Get slice indices evenly distributed through the volume
    # In CT: higher slice index = superior (head), lower = inferior (feet)
    depth = patch.shape[0]
    slice_indices = np.linspace(depth - 1, 0, num_slices, dtype=int)  # Superior to inferior
    
    # Create figure
    fig, axes = plt.subplots(num_slices, 3, figsize=(10, 3 * num_slices))
    
    # Column titles
    col_titles = ['Original Patch', 'Ground Truth', 'Predicted Mask']
    
    # Color map for segmentation
    seg_cmap = plt.cm.colors.ListedColormap(['black', 'green', 'red'])
    
    for row, s in enumerate(slice_indices):
        # Plot original patch
        axes[row, 0].imshow(patch[s], cmap='gray', vmin=0, vmax=1)
        axes[row, 0].set_ylabel(f'Slice {s}\n({"Superior" if row < num_slices//2 else "Inferior"})', 
                                 fontsize=9)
        axes[row, 0].set_xticks([])
        axes[row, 0].set_yticks([])
        
        # Plot ground truth overlay
        axes[row, 1].imshow(patch[s], cmap='gray', vmin=0, vmax=1, alpha=0.5)
        axes[row, 1].imshow(gt[s], cmap=seg_cmap, vmin=0, vmax=2, alpha=0.5)
        axes[row, 1].set_xticks([])
        axes[row, 1].set_yticks([])
        
        # Plot prediction overlay
        axes[row, 2].imshow(patch[s], cmap='gray', vmin=0, vmax=1, alpha=0.5)
        axes[row, 2].imshow(pred_mask[s], cmap=seg_cmap, vmin=0, vmax=2, alpha=0.5)
        axes[row, 2].set_xticks([])
        axes[row, 2].set_yticks([])
        
        # Add column titles on first row
        if row == 0:
            for col, title in enumerate(col_titles):
                axes[row, col].set_title(title, fontsize=12, fontweight='bold')
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='black', label='Background'),
        Patch(facecolor='green', label='Liver'),
        Patch(facecolor='red', label='Tumor')
    ]
    fig.legend(handles=legend_elements, loc='lower center', ncol=3, fontsize=10,
               bbox_to_anchor=(0.5, -0.02))
    
    # Add main title
    patient_name = os.path.basename(filepath)
    fig.suptitle(f'Patient: {patient_name}\nPatch {patch_idx} (Tumor voxels: {max_tumor:,})\n↑ Superior (Head)  |  ↓ Inferior (Feet)', 
                 fontsize=12, fontweight='bold', y=1.02)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.05, top=0.93)
    plt.savefig('single_patient_slices.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Visualization saved to single_patient_slices.png")
    print(f"Showing {num_slices} slices from depth 0-{depth-1}")


# Run visualization for first test patient
print("Generating slice visualization for single patient...")
visualize_single_patient_slices(model, test_files, patient_idx=2, num_slices=8)

In [None]:
# =============================================================================
# 3D Visualization: Isosurface Rendering of Liver and Tumor
# =============================================================================

from skimage import measure
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

def visualize_3d_segmentation(model, test_files, patient_idx=0, downsample=2):
    """
    Visualize 3D isosurface rendering of liver and tumor segmentation.
    
    Parameters:
    - model: Trained model
    - test_files: List of test file paths
    - patient_idx: Which patient to visualize
    - downsample: Factor to downsample volume for faster rendering (default=2)
    """
    # Load patient data
    filepath = test_files[patient_idx]
    data = np.load(filepath)
    patches = data['patches'].astype(np.float32) / 255.0
    segs = data['segmentations']
    
    # Find patch with most tumor content
    patch_idx = 0
    max_tumor = 0
    for i in range(len(segs)):
        tumor_count = np.sum(segs[i] == 2)
        if tumor_count > max_tumor:
            max_tumor = tumor_count
            patch_idx = i
    
    patch = patches[patch_idx]
    gt = segs[patch_idx]
    
    # Predict
    x = patch[np.newaxis, ..., np.newaxis]
    pred = model.predict(x, verbose=0)
    pred_mask = np.argmax(pred[0], axis=-1)
    
    # Downsample for faster rendering
    gt_ds = gt[::downsample, ::downsample, ::downsample]
    pred_ds = pred_mask[::downsample, ::downsample, ::downsample]
    
    # Create figure with 2 subplots (GT and Prediction)
    fig = plt.figure(figsize=(16, 7))
    
    titles = ['Ground Truth', 'Prediction']
    volumes = [gt_ds, pred_ds]
    
    for idx, (title, vol) in enumerate(zip(titles, volumes)):
        ax = fig.add_subplot(1, 2, idx + 1, projection='3d')
        
        # Extract and plot liver surface (class 1)
        liver_mask = (vol >= 1).astype(np.float32)  # Liver includes tumor region
        if liver_mask.sum() > 0:
            try:
                verts, faces, _, _ = measure.marching_cubes(liver_mask, level=0.5)
                mesh = Poly3DCollection(verts[faces], alpha=0.3, linewidths=0)
                mesh.set_facecolor('green')
                mesh.set_edgecolor('darkgreen')
                ax.add_collection3d(mesh)
            except:
                pass  # Skip if no surface found
        
        # Extract and plot tumor surface (class 2)
        tumor_mask = (vol == 2).astype(np.float32)
        if tumor_mask.sum() > 0:
            try:
                verts, faces, _, _ = measure.marching_cubes(tumor_mask, level=0.5)
                mesh = Poly3DCollection(verts[faces], alpha=0.9, linewidths=0)
                mesh.set_facecolor('red')
                mesh.set_edgecolor('darkred')
                ax.add_collection3d(mesh)
            except:
                pass  # Skip if no surface found
        
        # Set axis properties
        ax.set_xlim(0, vol.shape[0])
        ax.set_ylim(0, vol.shape[1])
        ax.set_zlim(0, vol.shape[2])
        ax.set_xlabel('X (Depth)', fontsize=10)
        ax.set_ylabel('Y (Height)', fontsize=10)
        ax.set_zlabel('Z (Width)', fontsize=10)
        ax.set_title(f'{title}\nLiver (green), Tumor (red)', fontsize=12, fontweight='bold')
        
        # Set viewing angle
        ax.view_init(elev=20, azim=45)
    
    # Add main title
    patient_name = os.path.basename(filepath)
    fig.suptitle(f'3D Segmentation Visualization\nPatient: {patient_name}, Patch {patch_idx}\n'
                 f'Tumor voxels: {max_tumor:,} | Downsampled {downsample}x for rendering',
                 fontsize=12, fontweight='bold', y=1.02)
    
    plt.tight_layout()
    plt.savefig('3d_visualization.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"3D visualization saved to 3d_visualization.png")
    print(f"Original volume: {gt.shape}, Downsampled: {gt_ds.shape}")


# Run 3D visualization
print("Generating 3D isosurface visualization...")
visualize_3d_segmentation(model, test_files, patient_idx=2, downsample=2)

In [None]:
# =============================================================================
# 3D Visualization: Multi-View Rendering (6 angles)
# =============================================================================

def visualize_3d_multiview(model, test_files, patient_idx=0, downsample=2):
    """
    Visualize 3D segmentation from 6 different viewing angles.
    Shows Ground Truth vs Prediction side by side for each angle.
    
    Parameters:
    - model: Trained model
    - test_files: List of test file paths
    - patient_idx: Which patient to visualize
    - downsample: Factor to downsample volume for faster rendering
    """
    from skimage import measure
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
    
    # Load patient data
    filepath = test_files[patient_idx]
    data = np.load(filepath)
    patches = data['patches'].astype(np.float32) / 255.0
    segs = data['segmentations']
    
    # Find patch with most tumor content
    patch_idx = 0
    max_tumor = 0
    for i in range(len(segs)):
        tumor_count = np.sum(segs[i] == 2)
        if tumor_count > max_tumor:
            max_tumor = tumor_count
            patch_idx = i
    
    patch = patches[patch_idx]
    gt = segs[patch_idx]
    
    # Predict
    x = patch[np.newaxis, ..., np.newaxis]
    pred = model.predict(x, verbose=0)
    pred_mask = np.argmax(pred[0], axis=-1)
    
    # Downsample for faster rendering
    gt_ds = gt[::downsample, ::downsample, ::downsample]
    pred_ds = pred_mask[::downsample, ::downsample, ::downsample]
    
    # Define viewing angles: (elevation, azimuth, name)
    views = [
        (20, 45, 'Front-Right'),
        (20, 135, 'Front-Left'),
        (20, 225, 'Back-Left'),
        (20, 315, 'Back-Right'),
        (90, 0, 'Top (Axial)'),
        (0, 0, 'Front (Coronal)'),
    ]
    
    # Create figure: 6 rows (views) x 2 columns (GT, Pred)
    fig = plt.figure(figsize=(10, 24))
    
    volumes = [('Ground Truth', gt_ds), ('Prediction', pred_ds)]
    
    for view_idx, (elev, azim, view_name) in enumerate(views):
        for col_idx, (title, vol) in enumerate(volumes):
            ax = fig.add_subplot(6, 2, view_idx * 2 + col_idx + 1, projection='3d')
            
            # Extract and plot liver surface
            liver_mask = (vol >= 1).astype(np.float32)
            if liver_mask.sum() > 0:
                try:
                    verts, faces, _, _ = measure.marching_cubes(liver_mask, level=0.5)
                    mesh = Poly3DCollection(verts[faces], alpha=0.3, linewidths=0)
                    mesh.set_facecolor('green')
                    ax.add_collection3d(mesh)
                except:
                    pass
            
            # Extract and plot tumor surface
            tumor_mask = (vol == 2).astype(np.float32)
            if tumor_mask.sum() > 0:
                try:
                    verts, faces, _, _ = measure.marching_cubes(tumor_mask, level=0.5)
                    mesh = Poly3DCollection(verts[faces], alpha=0.9, linewidths=0)
                    mesh.set_facecolor('red')
                    ax.add_collection3d(mesh)
                except:
                    pass
            
            # Set axis properties
            ax.set_xlim(0, vol.shape[0])
            ax.set_ylim(0, vol.shape[1])
            ax.set_zlim(0, vol.shape[2])
            ax.set_xlabel('X', fontsize=8)
            ax.set_ylabel('Y', fontsize=8)
            ax.set_zlabel('Z', fontsize=8)
            ax.view_init(elev=elev, azim=azim)
            
            # Title for top row only
            if view_idx == 0:
                ax.set_title(title, fontsize=11, fontweight='bold')
            
            # View name for first column only
            if col_idx == 0:
                ax.text2D(-0.15, 0.5, view_name, transform=ax.transAxes, 
                         fontsize=10, fontweight='bold', rotation=90,
                         ha='center', va='center')
    
    # Add main title
    patient_name = os.path.basename(filepath)
    fig.suptitle(f'3D Multi-View Segmentation\nPatient: {patient_name}, Patch {patch_idx}\n'
                 f'Green=Liver, Red=Tumor | Tumor voxels: {max_tumor:,}',
                 fontsize=12, fontweight='bold', y=0.995)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.94, left=0.12)
    plt.savefig('3d_multiview.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Multi-view 3D visualization saved to 3d_multiview.png")


# Run multi-view 3D visualization
print("Generating multi-view 3D visualization...")
visualize_3d_multiview(model, test_files, patient_idx=2, downsample=2)