In [3]:
# Medical Image Segmentation for Sunnybrook Cardiac Left Ventricle Dataset
# This notebook implements an enhanced U-Net model for segmenting the left ventricle in cardiac MRI images
# from the Sunnybrook dataset, with improvements to prevent masks from defaulting to zero and enhance image clarity.

# Cell 1: Environment Setup
import os
import numpy as np
import pandas as pd
import pydicom
import cv2
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from skimage.draw import polygon
from tensorflow.keras import mixed_precision
import datetime
import logging

# Configure logging for detailed diagnostics
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Enable mixed precision for TPU performance
mixed_precision.set_global_policy('mixed_float16')

# Enable XLA optimization
tf.config.optimizer.set_jit(True)

# Constants
CSV_PATH = '/kaggle/input/sunnybrook-latest/sunny_brook/scd_patientdata.csv'
DICOM_ROOT = '/kaggle/input/sunnybrook-latest/sunny_brook/dicom'
CONTOUR_ROOT = '/kaggle/input/sunnybrook-latest/sunny_brook/SCD_ManualContours'
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 16
EPOCHS = 33
CLIP_NORM = 1.0
# Cell 1: Add
import psutil
logger.info(f"Memory available: {psutil.virtual_memory().available / 1e9:.2f} GB")


In [4]:
# Cell 2: Data Loading
df = pd.read_csv(CSV_PATH)
logger.info(f"Loaded CSV with {len(df)} entries")
print("CSV Head:\n", df.head(10))

# Map PatientID to OriginalID
contour_mapping = {row['PatientID']: row['OriginalID'] for _, row in df.iterrows()}

def map_to_contour_id(patient_id):
    contour_id = contour_mapping.get(patient_id, patient_id)
    contour_dir = os.path.join(CONTOUR_ROOT, contour_id, "contours-manual", "IRCCI-expert")
    if not os.path.exists(contour_dir):
        logger.warning(f"Contour directory not found for {patient_id} at {contour_dir}")
        return None
    contour_files = [f for f in os.listdir(contour_dir) if f.endswith('-icontour-manual.txt')]
    if not contour_files:
        logger.warning(f"No contour files for {patient_id} in {contour_dir}")
        return None
    return contour_id

# Split patient IDs
patient_ids = df['PatientID'].unique()
train_ids, temp_ids = train_test_split(patient_ids, test_size=0.3, random_state=42)
val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=42)
logger.info(f"Train: {len(train_ids)}, Val: {len(val_ids)}, Test: {len(test_ids)}")
def get_dicom_files(patient_ids, dicom_root):
    dicom_files = []
    for pid in patient_ids:
        patient_dir = os.path.join(dicom_root, pid)
        if not os.path.exists(patient_dir):
            logger.warning(f"DICOM directory not found for {pid}")
            continue
        for root, _, files in os.walk(patient_dir):
            for file in files:
                if file.endswith('.dcm'):
                    dicom_files.append((os.path.join(root, file), pid))
    logger.info(f"Found {len(dicom_files)} DICOM files for {len(patient_ids)} patients")
    return dicom_files

train_files = get_dicom_files(train_ids, DICOM_ROOT)
val_files = get_dicom_files(val_ids, DICOM_ROOT)
test_files = get_dicom_files(test_ids, DICOM_ROOT)


CSV Head:
     PatientID  OriginalID Gender  Age                   Pathology
0  SCD0000101  SC-HF-I-01   Male   53  Heart failure with infarct
1  SCD0000201  SC-HF-I-02   Male   48  Heart failure with infarct
2  SCD0000301  SC-HF-I-04   Male   79  Heart failure with infarct
3  SCD0000401  SC-HF-I-05   Male   45  Heart failure with infarct
4  SCD0000501  SC-HF-I-06   Male   60  Heart failure with infarct
5  SCD0000601  SC-HF-I-07   Male   74  Heart failure with infarct
6  SCD0000701  SC-HF-I-08   Male   46  Heart failure with infarct
7  SCD0000801  SC-HF-I-09   Male   57  Heart failure with infarct
8  SCD0000901  SC-HF-I-10   Male   69  Heart failure with infarct
9  SCD0001001  SC-HF-I-11   Male   55  Heart failure with infarct


**new cell3**

In [5]:
# Cell 3: Debugging Empty Masks
def debug_contour_files(patient_ids):
    for pid in patient_ids:
        contour_id = map_to_contour_id(pid)
        if contour_id:
            contour_dir = os.path.join(CONTOUR_ROOT, contour_id, "contours-manual", "IRCCI-expert")
            if os.path.exists(contour_dir):
                files = [f for f in os.listdir(contour_dir) if f.endswith('-icontour-manual.txt')]
                logger.info(f"Patient {pid}: {len(files)} contour files")
            else:
                logger.warning(f"Contour directory missing for patient {pid}")
        else:
            logger.warning(f"No contour ID mapped for patient {pid}")
debug_contour_files(patient_ids)

In [6]:
# Cell 3: Debugging Empty Masks
def debug_contour_files(patient_ids):
    for pid in patient_ids:
        contour_id = map_to_contour_id(pid)
        if contour_id:
            contour_dir = os.path.join(CONTOUR_ROOT, contour_id, "contours-manual", "IRCCI-expert")
            if os.path.exists(contour_dir):
                files = [f for f in os.listdir(contour_dir) if f.endswith('-icontour-manual.txt')]
                logger.info(f"Patient {pid}: {len(files)} contour files")
            else:
                logger.warning(f"Contour directory missing for patient {pid}")
        else:
            logger.warning(f"No contour ID mapped for patient {pid}")

# Run debugging on all patient IDs
debug_contour_files(patient_ids)

In [7]:
# Cell 4: Preprocessing with Enhanced Image Clarity and Mask Validation
def load_and_preprocess_dicom(file_info):
    try:
        dicom_path, patient_id = file_info
        dicom_path = dicom_path.numpy().decode('utf-8') if isinstance(dicom_path, tf.Tensor) else dicom_path
        patient_id = patient_id.numpy().decode('utf-8') if isinstance(patient_id, tf.Tensor) else patient_id
        
        ds = pydicom.dcmread(dicom_path)
        img = ds.pixel_array.astype(np.float32)
        img = np.nan_to_num(img, nan=0.0, posinf=0.0, neginf=0.0)
        
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        img = clahe.apply(img.astype(np.uint8)).astype(np.float32)
        
        img_max = np.max(img)
        if img_max == 0:
            logger.debug(f"Image {dicom_path} has max value 0")
            return np.zeros(IMAGE_SIZE + (1,), dtype=np.float32), np.zeros(IMAGE_SIZE + (1,), dtype=np.float32), False
        img = img / (img_max + 1e-7)
        img = cv2.resize(img, IMAGE_SIZE, interpolation=cv2.INTER_AREA)
        img = img[..., np.newaxis]
        
        contour_id = map_to_contour_id(patient_id)
        mask = np.zeros(IMAGE_SIZE + (1,), dtype=np.float32)
        valid_mask = False
        if contour_id:
            contour_dir = os.path.join(CONTOUR_ROOT, contour_id, "contours-manual", "IRCCI-expert")
            contour_files = [f for f in os.listdir(contour_dir) if f.endswith('-icontour-manual.txt')]
            if contour_files:
                frame_number = ds.get('FrameNumber', None)
                matched_contour = None
                for contour_file in contour_files:
                    if frame_number and str(frame_number).zfill(4) in contour_file:
                        matched_contour = contour_file
                        break
                if not matched_contour:
                    logger.debug(f"No contour match for {dicom_path}, frame {frame_number}")
                    return np.zeros(IMAGE_SIZE + (1,), dtype=np.float32), np.zeros(IMAGE_SIZE + (1,), dtype=np.float32), False
                contour_path = os.path.join(contour_dir, matched_contour)
                contours = np.loadtxt(contour_path)
                if contours.ndim == 2 and contours.shape[1] == 2 and len(contours) > 3:
                    scale_x = IMAGE_SIZE[0] / ds.Columns
                    scale_y = IMAGE_SIZE[1] / ds.Rows
                    scaled_contours = contours * np.array([scale_x, scale_y])
                    rr, cc = polygon(scaled_contours[:, 1], scaled_contours[:, 0], shape=IMAGE_SIZE)
                    rr = np.clip(rr, 0, IMAGE_SIZE[0] - 1)
                    cc = np.clip(cc, 0, IMAGE_SIZE[1] - 1)
                    mask[rr, cc, 0] = 1.0
                    valid_mask = np.sum(mask) > 0
                else:
                    logger.debug(f"Invalid contour data in {contour_path}")
            else:
                logger.debug(f"No contour files for {dicom_path}")
        else:
            logger.debug(f"No contour ID for {patient_id}")
        
        return img, mask, valid_mask
    except Exception as e:
        logger.error(f"Error processing {dicom_path}: {e}")
        return np.zeros(IMAGE_SIZE + (1,), dtype=np.float32), np.zeros(IMAGE_SIZE + (1,), dtype=np.float32), False

def tf_load_and_preprocess(file_info):
    img, mask, valid = tf.py_function(
        load_and_preprocess_dicom,
        [file_info],
        [tf.float32, tf.float32, tf.bool]
    )
    img.set_shape([224, 224, 1])
    mask.set_shape([224, 224, 1])
    return img, mask, valid

In [8]:
# Cell 5: Data Augmentation with Zoom
def rotate_image_py(image, angle):
    image = image.numpy()
    angle = angle.numpy()
    rows, cols = image.shape[:2]
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
    rotated = cv2.warpAffine(image, M, (cols, rows), flags=cv2.INTER_LINEAR)
    return rotated[..., np.newaxis]

def rotate_mask_py(mask, angle):
    mask = mask.numpy()
    angle = angle.numpy()
    rows, cols = mask.shape[:2]
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
    rotated = cv2.warpAffine(mask, M, (cols, rows), flags=cv2.INTER_NEAREST)
    return rotated[..., np.newaxis]

def augment_tf(image, mask):
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_up_down(image)
        mask = tf.image.flip_up_down(mask)
    angle = tf.random.uniform((), minval=-15, maxval=15)
    image = tf.py_function(rotate_image_py, [image, angle], tf.float32)
    mask = tf.py_function(rotate_mask_py, [mask, angle], tf.float32)
    image.set_shape([224, 224, 1])
    mask.set_shape([224, 224, 1])
    # Tuned zoom augmentation
    zoom = tf.random.uniform((), 0.95, 1.05)  # Reduced range for less distortion
    image = tf.image.resize(image, [int(224 * zoom), int(224 * zoom)])
    image = tf.image.resize_with_crop_or_pad(image, 224, 224)
    mask = tf.image.resize(mask, [int(224 * zoom), int(224 * zoom)], method='nearest')
    mask = tf.image.resize_with_crop_or_pad(mask, 224, 224)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, lower=0.9, upper=1.1)
    image = tf.clip_by_value(image, 0.0, 1.0)
    return image, mask

In [9]:
# Cell 6: Dataset Pipeline with Enhanced Filtering
def create_dataset(dicom_files):
    if not dicom_files:
        logger.warning("No DICOM files provided. Returning empty dataset.")
        return tf.data.Dataset.from_tensor_slices([])
    
    dataset = tf.data.Dataset.from_tensor_slices(dicom_files)
    dataset = dataset.map(tf_load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.filter(lambda img, mask, valid: valid)
    dataset = dataset.map(lambda img, mask, valid: (img, mask), num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.cache()
    dataset = dataset.map(augment_tf, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    
    try:
        sample_count = sum(1 for _ in dataset.unbatch())
        logger.info(f"Dataset created with {sample_count} samples ({sample_count // BATCH_SIZE} batches)")
        if sample_count == 0:
            logger.error("Empty dataset after filtering. Check data and preprocessing.")
    except Exception as e:
        logger.error(f"Error counting samples: {e}")
    
    return dataset

train_dataset = create_dataset(train_files)
val_dataset = create_dataset(val_files)
test_dataset = create_dataset(test_files)

I0000 00:00:1744777398.886481      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1744777398.887170      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


In [10]:
# Temporary Cell: Debug Dataset Size
def count_samples(dataset, name):
    count = sum(1 for _ in dataset.unbatch())
    logger.info(f"{name} dataset has {count} samples")
    return count

logger.info("Counting samples...")
train_count = count_samples(train_dataset, "Training")
val_count = count_samples(val_dataset, "Validation")
test_count = count_samples(test_dataset, "Test")

In [11]:
# Cell 7 (fixed): Data Validation
import tensorflow as tf
import logging
import time
import os
import json
import numpy as np   # ← make sure numpy is available

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def validate_dataset(dataset, name, cache_file, max_samples=None):
    start_time = time.time()
    
    # If we already ran, just load the cached counts
    if os.path.exists(cache_file):
        with open(cache_file, 'r') as f:
            counts = json.load(f)
        logger.info(f"{name} validation (cached): "
                    f"{counts['valid']}/{counts['total']} non-empty masks")
        logger.info(f"{name} validation took {time.time() - start_time:.2f} seconds")
        return counts['valid'], counts['total']
    
    valid_samples = 0   # Python int
    total_samples = 0   # Python int

    # If the user asked to limit to max_samples, compute how many batches that is
    if max_samples is not None:
        # BATCH_SIZE must be defined elsewhere in your notebook
        batches = (max_samples + BATCH_SIZE - 1) // BATCH_SIZE
        dataset_iter = dataset.take(batches)
    else:
        dataset_iter = dataset
    
    # Loop in eager mode, convert mask sums to NumPy, then do Python arithmetic
    for img_batch, mask_batch in dataset_iter:
        # sum over H, W, C → shape [batch_size]
        mask_sums = tf.reduce_sum(mask_batch, axis=[1,2,3]).numpy()
        valid_samples += int(np.sum(mask_sums > 0))
        total_samples += mask_sums.shape[0]
    
    logger.info(f"{name} validation: {valid_samples}/{total_samples} "
                "samples have non-empty masks")
    logger.info(f"{name} validation took {time.time() - start_time:.2f} seconds")
    
    # Cache the result for next time
    with open(cache_file, 'w') as f:
        json.dump({'valid': valid_samples, 'total': total_samples}, f)
    
    return valid_samples, total_samples

logger.info("Starting dataset validation...")
start_total = time.time()

train_valid, train_total = validate_dataset(
    train_dataset, "Training", "train_validation.json"
)
if train_valid < train_total * 0.5:
    logger.warning(f"Training dataset has only {train_valid}/{train_total} "
                   "valid samples. Check mask generation.")

val_valid, val_total = validate_dataset(
    val_dataset, "Validation", "val_validation.json"
)
if val_valid < val_total * 0.5:
    logger.warning(f"Validation dataset has only {val_valid}/{val_total} "
                   "valid samples. Check mask generation.")

test_valid, test_total = validate_dataset(
    test_dataset, "Test", "test_validation.json"
)
if test_valid < test_total * 0.5:
    logger.warning(f"Test dataset has only {test_valid}/{test_total} "
                   "valid samples. Check mask generation.")

logger.info(f"Total validation took {time.time() - start_total:.2f} seconds")


In [24]:
# Cell 8: U-Net Model with Batch Normalization and Class Weighting
def unet_model(input_shape):
    inputs = layers.Input(shape=input_shape)
    def conv_block(x, filters):
        x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
        x = layers.BatchNormalization()(x)
        return x
    
    # Encoder
    c1 = conv_block(inputs, 32)
    p1 = layers.MaxPooling2D((2, 2))(c1)
    c2 = conv_block(p1, 64)
    p2 = layers.MaxPooling2D((2, 2))(c2)
    c3 = conv_block(p2, 128)
    p3 = layers.MaxPooling2D((2, 2))(c3)
    c4 = conv_block(p3, 256)
    p4 = layers.MaxPooling2D((2, 2))(c4)
    c5 = conv_block(p4, 512)
    
    # Decoder
    u6 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = conv_block(u6, 256)
    c6 = layers.Dropout(0.2)(c6)
    u7 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = conv_block(u7, 128)
    c7 = layers.Dropout(0.2)(c7)
    u8 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = conv_block(u8, 64)
    c8 = layers.Dropout(0.2)(c8)
    u9 = layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = conv_block(u9, 32)
    c9 = layers.Dropout(0.2)(c9)
    
    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid', dtype='float32')(c9)
    return tf.keras.Model(inputs=inputs, outputs=outputs)

def dice_coefficient(y_true, y_pred, smooth=1.0):
    y_true_f = tf.cast(y_true, tf.float32)
    y_pred_f = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coefficient(y_true, y_pred)

# Modified weighted combined loss to handle dimensions correctly
class_weights = {0: 1.0, 1: 10.0}
def weighted_combined_loss(y_true, y_pred):
    # Calculate binary cross entropy
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    # Create weights tensor
    weights = y_true * class_weights[1] + (1 - y_true) * class_weights[0]
    # Ensure weights have same shape as bce
    weights = tf.squeeze(weights, axis=-1)  # Remove last dimension if needed
    # Apply weights
    weighted_bce = tf.reduce_mean(bce * weights)
    # Calculate dice loss
    dice = dice_loss(y_true, y_pred)
    return 0.5 * weighted_bce + dice

# Model Compilation
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
    logger.info("Using TPU strategy")
except ValueError:
    strategy = tf.distribute.get_strategy()
    logger.info("Using default strategy (GPU/CPU)")

with strategy.scope():
    model = unet_model((224, 224, 1))
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, clipvalue=1.0)
    model.compile(optimizer=optimizer, 
                 loss=weighted_combined_loss, 
                 metrics=['accuracy', dice_coefficient])

model.summary()

In [33]:
# Cell 9: Training
class NaNStopper(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        if any(np.isnan(v) for v in logs.values()):
            logger.warning(f"NaN detected in epoch {epoch + 1}. Stopping training.")
            self.model.stop_training = True

class GradientMonitor(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        with tf.GradientTape() as tape:
            for batch in train_dataset.take(1):
                images, masks = batch
                predictions = self.model(images, training=True)
                loss = weighted_combined_loss(masks, predictions)
        gradients = tape.gradient(loss, self.model.trainable_variables)
        grad_norm = tf.sqrt(sum(tf.reduce_sum(tf.square(g)) for g in gradients if g is not None))
        logger.info(f"Epoch {epoch + 1}: Gradient norm = {grad_norm.numpy()}")


# Remove the OptimizerReplacementScheduler and use a built-in LR scheduler instead
with strategy.scope():
    model = unet_model((224, 224, 1))
    
    # Use a LearningRateSchedule object instead of replacing the optimizer
    initial_lr = 1e-4
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries=[5 * steps_per_epoch],  # After 5 epochs
        values=[initial_lr * 0.1, initial_lr]  # Start with lower LR, then go to target LR
    )
    
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, clipvalue=1.0)
    
    model.compile(optimizer=optimizer, 
                 loss=weighted_combined_loss, 
                 metrics=['accuracy', dice_coefficient])

# Keep the other callbacks but remove the OptimizerReplacementScheduler
history = model.fit(
    train_dataset,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_dataset,
    validation_steps=validation_steps,
    epochs=EPOCHS,
    callbacks=[
        tf.keras.callbacks.ModelCheckpoint('best_model.keras', 
                                        save_best_only=True, 
                                        monitor='val_dice_coefficient', 
                                        mode='max'),
        tf.keras.callbacks.EarlyStopping(patience=10, 
                                      restore_best_weights=True, 
                                      monitor='val_dice_coefficient', 
                                      mode='max'),
        tf.keras.callbacks.ReduceLROnPlateau(patience=4, 
                                          monitor='val_dice_coefficient', 
                                          mode='max', 
                                          factor=0.2, 
                                          min_lr=1e-6),
        NaNStopper(),
        GradientMonitor(),
        SafeLrLogger(),
        tensorboard_callback
    ]
)

Epoch 1/33


I0000 00:00:1744779557.016925      94 service.cc:148] XLA service 0x7af2800016a0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1744779557.019664      94 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1744779557.019688      94 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1744779557.087688      94 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1744779557.227446      94 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
  self.gen.throw(typ, value, traceback)
  self._save_model(epoch=epoch, batch=None, logs=logs)
  current = self.get_monitor_value(logs)
  callback.on_epoch_end(epoch, logs)


UnboundLocalError: cannot access local variable 'loss' where it is not associated with a value

In [None]:
# Cell 10: Evaluation
def evaluate_model(dataset, steps):
    metrics = model.evaluate(dataset, steps=steps, return_dict=True)
    logger.info(f"Test Metrics: {metrics}")
    print(f"Test Loss: {metrics['loss']:.4f}, Test Accuracy: {metrics['accuracy']:.4f}, Test Dice: {metrics['dice_coefficient']:.4f}")
    return metrics

test_steps = max(1, len(test_files) // BATCH_SIZE)
test_metrics = evaluate_model(test_dataset, test_steps)

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss Over Epochs')
plt.legend()
plt.subplot(1, 3, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Over Epochs')
plt.legend()
plt.subplot(1, 3, 3)
plt.plot(history.history['dice_coefficient'], label='Training Dice')
plt.plot(history.history['val_dice_coefficient'], label='Validation Dice')
plt.title('Dice Coefficient Over Epochs')
plt.legend()
plt.savefig('training_plots.png')
plt.close()

# Visualize predictions with overlays
def visualize_predictions(dataset, num_samples=4):
    for batch in dataset.take(1):
        images, masks = batch
        preds = model.predict(images, verbose=0)
        preds = (preds > 0.5).astype(np.float32)
        for i in range(min(num_samples, len(images))):
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.imshow(images[i, ..., 0], cmap='gray')
            plt.title('Image')
            plt.axis('off')
            plt.subplot(1, 3, 2)
            plt.imshow(images[i, ..., 0], cmap='gray')
            plt.contour(masks[i, ..., 0], colors='red', levels=[0.5])
            plt.title('True Mask')
            plt.axis('off')
            plt.subplot(1, 3, 3)
            plt.imshow(images[i, ..., 0], cmap='gray')
            plt.contour(preds[i, ..., 0], colors='blue', levels=[0.5])
            plt.title('Predicted Mask')
            plt.axis('off')
            plt.savefig(f'prediction_{i}.png')
            plt.close()

visualize_predictions(test_dataset)

# Save metrics
with open('test_metrics.txt', 'w') as f:
    f.write(str(test_metrics))

In [None]:
# Cell 11: Model Deployment
model.save('unet_left_ventricle_segmentation.keras')
logger.info("Model saved as unet_left_ventricle_segmentation.keras")

# Cell 11: Hyperparameter Tuning Suggestions
"""
Hyperparameter Tuning Recommendations:
- Learning Rate: Test 0.001, 0.0001 for convergence vs. stability.
- Batch Size: Try 16 or 32 if memory allows.
- Augmentation: Increase rotation range to ±30° or add elastic deformations.
- Dropout Rate: Experiment with 0.3 or 0.5 if overfitting occurs.
- Model Complexity: Adjust filter counts or layers based on performance.
"""