In [None]:
import os
import sys
import json
from pathlib import Path
from typing import List

# Set XLA flag to use fallback convolution algorithms (prevents OOM during autotuning)
os.environ['XLA_FLAGS'] = '--xla_gpu_strict_conv_algorithm_picker=false'

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import cv2
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.utils import to_categorical

import nibabel as nib
import SimpleITK as sitk
from scipy.ndimage import rotate as scipy_rotate

from PIL import Image, ImageOps, ImageFilter
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report


tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')

In [None]:
# =============================================================================
# 3D U-Net for Liver CT Segmentation
# =============================================================================
# Input: (128, 128, 128, 1) - single channel CT patch
# Output: (128, 128, 128, num_classes) - segmentation mask
# Classes: 0=background, 1=liver, 2=tumor

def double_conv_block_3d(x, n_filters, kernel_size=3):
    """Two consecutive 3D convolutions with BatchNorm and ReLU."""
    x = layers.Conv3D(n_filters, kernel_size, padding="same", kernel_initializer="he_normal")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    x = layers.Conv3D(n_filters, kernel_size, padding="same", kernel_initializer="he_normal")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)
    return x


def downsample_block_3d(x, n_filters, dropout_rate=0.3):
    """Encoder block: double conv + max pool + dropout."""
    f = double_conv_block_3d(x, n_filters)
    p = layers.MaxPool3D(pool_size=(2, 2, 2))(f)
    p = layers.Dropout(dropout_rate)(p)
    return f, p


def upsample_block_3d(x, skip_features, n_filters, dropout_rate=0.5):
    """Decoder block: upsample + concat skip + double conv."""
    x = layers.Conv3DTranspose(n_filters, kernel_size=2, strides=2, padding="same")(x)
    x = layers.concatenate([x, skip_features])
    x = layers.Dropout(dropout_rate)(x)
    x = double_conv_block_3d(x, n_filters)
    return x


def build_3d_unet(input_shape=(128, 128, 128, 1), num_classes=3, base_filters=32):
    """
    Build a 3D U-Net model for volumetric segmentation.
    
    Parameters:
    - input_shape: (D, H, W, C) - default (128, 128, 128, 1)
    - num_classes: Number of output classes (3: background, liver, tumor)
    - base_filters: Starting number of filters (doubles each level)
    
    Architecture (base_filters=32):
        Encoder: 32 -> 64 -> 128 -> 256
        Bottleneck: 512
        Decoder: 256 -> 128 -> 64 -> 32
        
    Returns:
    - model: Keras Model
    """
    inputs = layers.Input(shape=input_shape)
    
    # Encoder path
    f1, p1 = downsample_block_3d(inputs, base_filters)        
    f2, p2 = downsample_block_3d(p1, base_filters * 2)        
    f3, p3 = downsample_block_3d(p2, base_filters * 4)        
    f4, p4 = downsample_block_3d(p3, base_filters * 8)
    
    # Bottleneck
    bottleneck = double_conv_block_3d(p4, base_filters * 16)  
    
    # Decoder path
    u1 = upsample_block_3d(bottleneck, f4, base_filters * 8)  
    u2 = upsample_block_3d(u1, f3, base_filters * 4)         
    u3 = upsample_block_3d(u2, f2, base_filters * 2)         
    u4 = upsample_block_3d(u3, f1, base_filters)            
    

    outputs = layers.Conv3D(num_classes, kernel_size=1, padding="same", activation="softmax")(u4)
    
    model = models.Model(inputs, outputs, name="3D-UNet")
    return model

In [4]:
# =============================================================================
# Get File List and Split into Train/Val/Test
# =============================================================================

DATA_DIR = 'preprocessed_patches_v2'  # Updated to use tumor-centered patches
NUM_CLASSES = 3
SEED = 42

TRAIN_RATIO = 0.70
VAL_RATIO = 0.15
TEST_RATIO = 0.15

# Get all .npz files
all_files = sorted([os.path.join(DATA_DIR, f) for f in os.listdir(DATA_DIR) if f.endswith('.npz')])
print(f"Found {len(all_files)} volume files in {DATA_DIR}/")

# Shuffle and split
np.random.seed(SEED)
indices = np.random.permutation(len(all_files))

train_end = int(len(all_files) * TRAIN_RATIO)
val_end = train_end + int(len(all_files) * VAL_RATIO)

train_files = [all_files[i] for i in indices[:train_end]]
val_files = [all_files[i] for i in indices[train_end:val_end]]
test_files = [all_files[i] for i in indices[val_end:]]

print(f"\nSplit (by patient):")
print(f"  Train: {len(train_files)} files ({len(train_files)*20} patches)")
print(f"  Val:   {len(val_files)} files ({len(val_files)*20} patches)")
print(f"  Test:  {len(test_files)} files ({len(test_files)*20} patches)")

Found 131 volume files in preprocessed_patches_v2/

Split (by patient):
  Train: 91 files (1820 patches)
  Val:   19 files (380 patches)
  Test:  21 files (420 patches)


In [None]:
# =============================================================================
# Online Data Augmentation Functions
# =============================================================================

def augment_rotation_3d(volume, segmentation, max_angle=15):
    """
    Apply random 3D rotation to volume and segmentation.
    Rotates around a random axis by a random angle within ±max_angle degrees.
    """
    angle_x = np.random.uniform(-max_angle, max_angle)
    angle_y = np.random.uniform(-max_angle, max_angle)
    angle_z = np.random.uniform(-max_angle, max_angle)
    
    vol_rotated = scipy_rotate(volume, angle_z, axes=(0, 1), reshape=False, order=1, mode='constant', cval=0)
    vol_rotated = scipy_rotate(vol_rotated, angle_y, axes=(0, 2), reshape=False, order=1, mode='constant', cval=0)
    vol_rotated = scipy_rotate(vol_rotated, angle_x, axes=(1, 2), reshape=False, order=1, mode='constant', cval=0)
    
    seg_rotated = scipy_rotate(segmentation, angle_z, axes=(0, 1), reshape=False, order=0, mode='constant', cval=0)
    seg_rotated = scipy_rotate(seg_rotated, angle_y, axes=(0, 2), reshape=False, order=0, mode='constant', cval=0)
    seg_rotated = scipy_rotate(seg_rotated, angle_x, axes=(1, 2), reshape=False, order=0, mode='constant', cval=0)
    
    return vol_rotated, seg_rotated


def augment_gamma(volume, gamma_range=(0.7, 1.5)):
    gamma = np.random.uniform(gamma_range[0], gamma_range[1])
    return np.power(np.clip(volume, 0, 1), gamma)


def augment_gaussian_noise(volume, sigma_range=(0, 0.05)):
    sigma = np.random.uniform(sigma_range[0], sigma_range[1])
    noise = np.random.normal(0, sigma, volume.shape)
    return np.clip(volume + noise, 0, 1)


def augment_brightness(volume, delta_range=(-0.1, 0.1)):
    delta = np.random.uniform(delta_range[0], delta_range[1])
    return np.clip(volume + delta, 0, 1)


def apply_augmentation(volume, segmentation, augment=True):
    if not augment:
        return volume, segmentation
    
    if np.random.random() < 0.5:
        volume, segmentation = augment_rotation_3d(volume, segmentation, max_angle=15)
    
    if np.random.random() < 0.5:
        volume = augment_gamma(volume, gamma_range=(0.7, 1.5))
    
    if np.random.random() < 0.5:
        volume = augment_gaussian_noise(volume, sigma_range=(0, 0.05))
    
    if np.random.random() < 0.3:
        volume = augment_brightness(volume, delta_range=(-0.1, 0.1))
    
    return volume, segmentation


# =============================================================================
# File-Based Data Generator (Original Working Version)
# =============================================================================

class VolumeGenerator(tf.keras.utils.Sequence):
    """
    Loads patches from .npz files with configurable batch size.
    Each file contains: patches (20, 128, 128, 128), segmentations (20, 128, 128, 128)
    """
    
    def __init__(self, files, batch_size=2, num_classes=3, shuffle=True, augment=False):
        self.files = files
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.augment = augment
        
        self.batch_indices = []
        for file_idx, filepath in enumerate(files):
            data = np.load(filepath)
            n_patches = len(data['patches'])
            for start in range(0, n_patches, batch_size):
                self.batch_indices.append((file_idx, start))
        
        self.on_epoch_end()
    
    def __len__(self):
        return len(self.batch_indices)
    
    def __getitem__(self, idx):
        file_idx, patch_start = self.batch_indices[self.indices[idx]]
        data = np.load(self.files[file_idx])
        
        patch_end = patch_start + self.batch_size
        
        x = data['patches'][patch_start:patch_end].astype(np.float32) / 255.0
        y = data['segmentations'][patch_start:patch_end]
        
        if self.augment:
            x_aug = []
            y_aug = []
            for i in range(len(x)):
                vol_aug, seg_aug = apply_augmentation(x[i], y[i], augment=True)
                x_aug.append(vol_aug)
                y_aug.append(seg_aug)
            x = np.array(x_aug, dtype=np.float32)
            y = np.array(y_aug)
        
        x = x[..., np.newaxis]
        y = to_categorical(y, num_classes=self.num_classes)
        
        return x, y
    
    def on_epoch_end(self):
        self.indices = np.arange(len(self.batch_indices))
        if self.shuffle:
            np.random.shuffle(self.indices)


# Create generators
BATCH_SIZE = 2

train_gen = VolumeGenerator(train_files, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES, shuffle=True, augment=True)
val_gen = VolumeGenerator(val_files, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES, shuffle=False, augment=False)
test_gen = VolumeGenerator(test_files, batch_size=BATCH_SIZE, num_classes=NUM_CLASSES, shuffle=False, augment=False)

print(f"Batch size: {BATCH_SIZE}")
print(f"Train: {len(train_gen)} batches (with augmentation: rotation ±15°, gamma, noise)")
print(f"Val:   {len(val_gen)} batches (no augmentation)")
print(f"Test:  {len(test_gen)} batches (no augmentation)")

In [None]:
# =============================================================================
# Combined Loss Function: Dice + Focal + Tversky
# =============================================================================

# Class weights: [background, liver, tumor]
CLASS_WEIGHTS = [0.1, 1.0, 30.0]  # Increased tumor weight

# Loss combination weights
LOSS_WEIGHTS = {
    'dice': 1.0,
    'focal': 1.0,
    'tversky': 1.0,
}

# Tversky parameters (alpha > beta favors recall over precision)
TVERSKY_ALPHA = 0.3  # Weight for false positives
TVERSKY_BETA = 0.7   # Weight for false negatives (higher = better recall)

# Focal loss parameters
FOCAL_GAMMA = 2.0    # Focus on hard examples (higher = more focus)


def dice_coefficient_per_class(y_true, y_pred, class_idx, smooth=1e-6):
    """Dice coefficient for a single class."""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_c = y_true[..., class_idx]
    y_pred_c = y_pred[..., class_idx]
    y_true_f = tf.keras.backend.flatten(y_true_c)
    y_pred_f = tf.keras.backend.flatten(y_pred_c)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)


def weighted_dice_loss(y_true, y_pred):
    """Weighted dice loss - higher weight for tumor class."""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    total_loss = 0.0
    for class_idx, weight in enumerate(CLASS_WEIGHTS):
        dice = dice_coefficient_per_class(y_true, y_pred, class_idx)
        total_loss += weight * (1 - dice)
    
    return total_loss / sum(CLASS_WEIGHTS)


def focal_loss(y_true, y_pred, gamma=FOCAL_GAMMA):
    """
    Focal loss for handling class imbalance.
    Focuses training on hard-to-classify examples.
    
    FL(p) = -alpha * (1-p)^gamma * log(p)
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    # Clip predictions to prevent log(0)
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
    
    # Calculate focal weight: (1 - p)^gamma
    focal_weight = tf.pow(1 - y_pred, gamma)
    
    # Cross entropy
    ce = -y_true * tf.math.log(y_pred)
    
    # Apply class weights
    class_weights_tensor = tf.constant(CLASS_WEIGHTS, dtype=tf.float32)
    class_weights_tensor = class_weights_tensor / tf.reduce_sum(class_weights_tensor)
    
    # Weighted focal loss
    focal = focal_weight * ce * class_weights_tensor
    
    return tf.reduce_mean(tf.reduce_sum(focal, axis=-1))


def tversky_loss(y_true, y_pred, alpha=TVERSKY_ALPHA, beta=TVERSKY_BETA, smooth=1e-6):
    """
    Tversky loss for controlling precision/recall tradeoff.
    
    Tversky Index = TP / (TP + alpha*FP + beta*FN)
    
    - alpha = beta = 0.5 -> equivalent to Dice
    - alpha < beta -> penalizes FN more (improves recall)
    - alpha > beta -> penalizes FP more (improves precision)
    
    For tumor segmentation, we want higher recall, so beta > alpha.
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    total_loss = 0.0
    for class_idx, weight in enumerate(CLASS_WEIGHTS):
        y_true_c = y_true[..., class_idx]
        y_pred_c = y_pred[..., class_idx]
        
        y_true_f = tf.keras.backend.flatten(y_true_c)
        y_pred_f = tf.keras.backend.flatten(y_pred_c)
        
        # True positives, false positives, false negatives
        tp = tf.keras.backend.sum(y_true_f * y_pred_f)
        fp = tf.keras.backend.sum((1 - y_true_f) * y_pred_f)
        fn = tf.keras.backend.sum(y_true_f * (1 - y_pred_f))
        
        # Tversky index
        tversky_idx = (tp + smooth) / (tp + alpha * fp + beta * fn + smooth)
        total_loss += weight * (1 - tversky_idx)
    
    return total_loss / sum(CLASS_WEIGHTS)


def combo_loss(y_true, y_pred):
    """
    Combined loss: Dice + Focal + Tversky
    
    This combination provides:
    - Dice: Overall segmentation quality
    - Focal: Focus on hard examples (small tumors)
    - Tversky: Better recall for minority class (tumors)
    """
    dice = weighted_dice_loss(y_true, y_pred)
    focal = focal_loss(y_true, y_pred)
    tversky = tversky_loss(y_true, y_pred)
    
    total = (LOSS_WEIGHTS['dice'] * dice + 
             LOSS_WEIGHTS['focal'] * focal + 
             LOSS_WEIGHTS['tversky'] * tversky)
    
    return total / sum(LOSS_WEIGHTS.values())


# =============================================================================
# Metrics
# =============================================================================

def dice_coefficient(y_true, y_pred):
    """Overall dice coefficient."""
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection + 1e-6) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + 1e-6)


def dice_liver(y_true, y_pred):
    """Dice coefficient for liver (class 1)."""
    return dice_coefficient_per_class(y_true, y_pred, 1)


def dice_tumor(y_true, y_pred):
    """Dice coefficient for tumor (class 2)."""
    return dice_coefficient_per_class(y_true, y_pred, 2)


print("Loss functions defined:")
print(f"  - Dice Loss (weight: {LOSS_WEIGHTS['dice']})")
print(f"  - Focal Loss (weight: {LOSS_WEIGHTS['focal']}, gamma: {FOCAL_GAMMA})")
print(f"  - Tversky Loss (weight: {LOSS_WEIGHTS['tversky']}, alpha: {TVERSKY_ALPHA}, beta: {TVERSKY_BETA})")
print(f"\nClass weights: Background={CLASS_WEIGHTS[0]}, Liver={CLASS_WEIGHTS[1]}, Tumor={CLASS_WEIGHTS[2]}")

In [None]:
# =============================================================================
# Build and Compile Model (Fine-tuning from pre-trained weights)
# =============================================================================

model = build_3d_unet(input_shape=(128, 128, 128, 1), num_classes=NUM_CLASSES, base_filters=24)

# Load pre-trained weights from ORIGINAL best model (not the overwritten one!)
PRETRAINED_PATH = 'checkpoints/best_model_v2_augment.keras'  # Original 65% dice tumor model
print(f"Loading pre-trained weights from {PRETRAINED_PATH}...")
model.load_weights(PRETRAINED_PATH)
print("Weights loaded successfully!")

# Compile with combo loss and lower learning rate for fine-tuning
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),  # Lower LR for fine-tuning
    loss=combo_loss,  # Combined: Dice + Focal + Tversky
    metrics=[dice_coefficient, dice_liver, dice_tumor]
)

print(f"\nFine-tuning with Combo Loss (Dice + Focal + Tversky)")
print(f"Learning rate: 1e-5 (reduced for fine-tuning)")
model.summary()

In [None]:
# =============================================================================
# Train Model with Crash Protection
# =============================================================================

EPOCHS = 30
CHECKPOINT_DIR = 'checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(CHECKPOINT_DIR, 'best_model_combo_loss.keras'),
        monitor='val_dice_tumor',
        mode='max',
        save_best_only=True,
        verbose=1
    ),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(CHECKPOINT_DIR, 'latest_checkpoint.keras'),
        save_best_only=False,
        verbose=0
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_dice_tumor',
        mode='max',
        patience=10,
        restore_best_weights=True,
        verbose=1
    )
]

history = None
training_completed = False

try:
    print("Starting training with combo loss...")
    print(f"Best model will be saved to: {CHECKPOINT_DIR}/best_model_combo_loss.keras")
    print("-" * 60)
    
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=EPOCHS,
        callbacks=callbacks,
        verbose=1
    )
    training_completed = True
    print("\nTraining completed successfully!")
    
except KeyboardInterrupt:
    print("\n" + "=" * 60)
    print("TRAINING INTERRUPTED BY USER")
    print("=" * 60)
    print("Saving emergency checkpoint...")
    model.save(os.path.join(CHECKPOINT_DIR, 'interrupted_model.keras'))
    print(f"Model saved to {CHECKPOINT_DIR}/interrupted_model.keras")
    
except Exception as e:
    print("\n" + "=" * 60)
    print(f"TRAINING CRASHED: {type(e).__name__}")
    print("=" * 60)
    print(f"Error: {e}")
    print("\nAttempting emergency save...")
    try:
        model.save(os.path.join(CHECKPOINT_DIR, 'crash_recovery_model.keras'))
        print(f"Model saved to {CHECKPOINT_DIR}/crash_recovery_model.keras")
    except Exception as save_error:
        print(f"Emergency save failed: {save_error}")

print("\n" + "=" * 60)
print("CHECKPOINT SUMMARY")
print("=" * 60)
if os.path.exists(CHECKPOINT_DIR):
    for f in sorted(os.listdir(CHECKPOINT_DIR)):
        fpath = os.path.join(CHECKPOINT_DIR, f)
        size_mb = os.path.getsize(fpath) / 1024**2
        print(f"  {f}: {size_mb:.1f} MB")

In [None]:
# =============================================================================
# Save Final Model
# =============================================================================

# Check if training completed and save final model
if training_completed:
    try:
        model.save('final_model.keras')
        print("Final model saved to final_model.keras")
        
        # Also save training history
        if history is not None:
            history_dict = {key: [float(v) for v in values] for key, values in history.history.items()}
            with open('training_history.json', 'w') as f:
                json.dump(history_dict, f, indent=2)
            print("Training history saved to training_history.json")
    except Exception as e:
        print(f"Error saving final model: {e}")
        print("Best model should be available in checkpoints/best_model.keras")
else:
    print("Training did not complete normally.")
    print("Check checkpoints/ directory for saved models:")
    print("  - best_model.keras (best validation tumor dice)")
    print("  - latest_checkpoint.keras (last completed epoch)")
    print("  - interrupted_model.keras or crash_recovery_model.keras (if crashed)")

In [9]:
# =============================================================================
# Confusion Matrix (Memory-Efficient Incremental Computation)
# =============================================================================

def compute_confusion_matrix_incremental(model, generator, num_classes=3, num_batches=None):
    """
    Compute confusion matrix incrementally to avoid memory issues.
    Instead of storing all predictions (~11GB), accumulates a 3x3 matrix (~72 bytes).
    """
    cm = np.zeros((num_classes, num_classes), dtype=np.int64)
    
    n_batches = num_batches if num_batches else len(generator)
    
    for i in range(n_batches):
        x, y = generator[i]
        pred = model.predict(x, verbose=0)
        
        # Get class labels
        y_true = np.argmax(y, axis=-1).flatten()
        y_pred = np.argmax(pred, axis=-1).flatten()
        
        # Accumulate confusion matrix directly (no storage of all predictions)
        cm += confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
        
        if (i + 1) % 50 == 0:
            print(f"  Processed {i+1}/{n_batches} batches...")
    
    return cm


def plot_confusion_matrix(cm, class_names=['Background', 'Liver', 'Tumor']):
    """Plot confusion matrix with percentages."""
    # Normalize by row (true labels)
    cm_norm = cm.astype('float') / cm.sum(axis=1, keepdims=True) * 100
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Raw counts
    im1 = axes[0].imshow(cm, cmap='Blues')
    axes[0].set_title('Confusion Matrix (Counts)')
    axes[0].set_xlabel('Predicted')
    axes[0].set_ylabel('True')
    axes[0].set_xticks(range(len(class_names)))
    axes[0].set_yticks(range(len(class_names)))
    axes[0].set_xticklabels(class_names)
    axes[0].set_yticklabels(class_names)
    
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            axes[0].text(j, i, f'{cm[i, j]:,}', ha='center', va='center', fontsize=10)
    
    plt.colorbar(im1, ax=axes[0])
    
    # Normalized (percentages)
    im2 = axes[1].imshow(cm_norm, cmap='Blues', vmin=0, vmax=100)
    axes[1].set_title('Confusion Matrix (% by True Class)')
    axes[1].set_xlabel('Predicted')
    axes[1].set_ylabel('True')
    axes[1].set_xticks(range(len(class_names)))
    axes[1].set_yticks(range(len(class_names)))
    axes[1].set_xticklabels(class_names)
    axes[1].set_yticklabels(class_names)
    
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            axes[1].text(j, i, f'{cm_norm[i, j]:.1f}%', ha='center', va='center', fontsize=10)
    
    plt.colorbar(im2, ax=axes[1])
    
    plt.tight_layout()
    plt.savefig('confusion_matrix.png', dpi=150)
    plt.show()
    
    # Print per-class metrics
    print("\nPer-Class Metrics:")
    print("-" * 50)
    for i, name in enumerate(class_names):
        tp = cm[i, i]
        fn = cm[i, :].sum() - tp
        fp = cm[:, i].sum() - tp
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        print(f"{name:12s}: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")
    
    # Print dice-like metrics
    print("\nDice Coefficients (from confusion matrix):")
    print("-" * 50)
    for i, name in enumerate(class_names):
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        dice = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
        print(f"{name:12s}: Dice={dice:.4f}")


# Compute and plot
print("Computing confusion matrix on test set (memory-efficient)...")
print(f"Processing {len(test_gen)} batches...")
cm = compute_confusion_matrix_incremental(model, test_gen, num_classes=NUM_CLASSES)
plot_confusion_matrix(cm)

Computing confusion matrix on test set (memory-efficient)...
Processing 105 batches...


NameError: name 'model' is not defined