In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
from model import Generator


os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"

BUFFER_SIZE = 548000
BATCH_SIZE = 64
# Each image is 256x256 in size
IMG_WIDTH = 256
IMG_HEIGHT = 256
OUTPUT_CHANNELS = 6
LAMBDA = 100

data_path = '/home/besanhalwa/Eshan/project1_PMRI/Data/npy_tech_pmri_no_aug_leftRightSplit/'
train_data_x = np.load(data_path+'train_images.npy')
train_data_y = np.load(data_path+'train_masks_hot_encoded.npy')

with tf.device('/cpu:0'):
    
    #make dataset
    train_data_x = tf.data.Dataset.from_tensor_slices(train_data_x) 
    train_data_y = tf.data.Dataset.from_tensor_slices(train_data_y)
    train_dataset = tf.data.Dataset.zip((train_data_x, train_data_y))

    train_dataset = train_dataset.shuffle(1000)
    train_dataset = train_dataset.batch(BATCH_SIZE)

## Similarly load val and test data

def dice_coefficient(y_true, y_pred):
    smooth = 1e-6
    intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2, 3))  # Element-wise multiplication
    union = tf.reduce_sum(y_true, axis=(1, 2, 3)) + tf.reduce_sum(y_pred, axis=(1, 2, 3))
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)  # Return the average Dice coefficient
    

def dice_loss_channel_wise(y_true, y_pred, smooth=1e-6):
    """
    Compute channel-wise Dice loss for multi-class segmentation.

    y_true: True segmentation masks (one-hot encoded) of shape (batch_size, height, width, num_classes)
    y_pred: Predicted segmentation masks (probabilities from softmax) of shape (batch_size, height, width, num_classes)
    smooth: Small constant to avoid division by zero
    """
    # Ensure the ground truth and predictions are float32
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    # Flatten the tensors for pixel-wise operations: (batch_size * height * width, num_classes)
    y_true_f = tf.reshape(y_true, (-1, y_true.shape[-1]))  # Flatten the labels
    y_pred_f = tf.reshape(y_pred, (-1, y_pred.shape[-1]))  # Flatten the predictions

    # Calculate intersection and union for each class/channel (axis=0 means per class)
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    union = tf.reduce_sum(y_true_f, axis=0) + tf.reduce_sum(y_pred_f, axis=0)

    # Compute Dice coefficient for each class (channel)
    dice = (2. * intersection + smooth) / (union + smooth)

    # Compute Dice loss for each class (1 - Dice coefficient)
    dice_loss_per_channel = 1 - dice

    # Return the average Dice loss across all classes
    return tf.reduce_mean(dice_loss_per_channel)

def pixel_wise_binary_crossentropy_loss(y_true, y_pred):
    """
    Compute pixel-wise binary cross-entropy loss for multi-class segmentation.

    y_true: True segmentation masks (one-hot encoded) of shape (batch_size, height, width, num_classes)
    y_pred: Predicted segmentation masks (probabilities from softmax) of shape (batch_size, height, width, num_classes)
    """
    # Ensure the ground truth and predictions are float32
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    # Calculate binary cross-entropy loss for each pixel and each class
    bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
    
    # Compute binary cross-entropy loss for each class (pixel-wise)
    loss_per_channel = bce_loss(y_true, y_pred)

    # Return the average loss across all classes and pixels
    return tf.reduce_mean(loss_per_channel)


def combined_loss(y_true, y_pred, alpha=0.5):
    """
    Combined loss function: Binary Cross-Entropy Loss + Channel-wise Dice Loss.

    y_true: True segmentation masks (one-hot encoded) of shape (batch_size, height, width, num_classes)
    y_pred: Predicted segmentation masks (probabilities from softmax) of shape (batch_size, height, width, num_classes)
    alpha: Weighting factor for the binary cross-entropy and Dice loss components
    """
    # Compute Pixel-wise Binary Cross-Entropy Loss
    bce_loss = pixel_wise_binary_crossentropy_loss(y_true, y_pred)

    # Compute Channel-wise Dice Loss
    dice_loss = dice_loss_channel_wise(y_true, y_pred)

    # Combine both losses (weighted sum)
    total_loss = alpha * bce_loss + (1 - alpha) * dice_loss

    return total_loss


weights_path = './training_checkpoints/enhanced_unet_Generator_model_weights_ckpt47.h5'

generator = Generator()
generator.load_weights(weights_path)
print(f"Loaded weights from {weights_path}")


generator.compile(optimizer='adam', loss=combined_loss, metrics=[dice_coefficient])

# Define the checkpoint callback to save weights after each epoch
checkpoint_callback = ModelCheckpoint(
    'checkpoint_epoch_{epoch:02d}.h5',  # File path to save weights after each epoch
    save_weights_only=True,              # Save only the weights, not the entire model
    save_best_only=False,                # Save weights at every epoch (not just the best model)
    verbose=1                            # Display saving status
)

# Train the model with validation data and the checkpoint callback
history = generator.fit(
    train_dataset,  # Your training dataset (TensorFlow Dataset or numpy array)
    epochs=500,      # Number of epochs to train
    validation_data=val_dataset,  # Your validation dataset
    batch_size=256,  # Batch size for training
    callbacks=[checkpoint_callback]  # Include the checkpoint callback
)

Loaded weights from ./training_checkpoints/enhanced_unet_Generator_model_weights_ckpt47.h5
Epoch 1/500


2025-05-28 16:30:21.392071: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inmodel_3/sequential_52/dropout_24/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2025-05-28 16:30:21.732350: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8904
2025-05-28 16:30:21.797038: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2025-05-28 16:30:22.320898: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2025-05-28 16:30:27.539752: I external/local_xla/xla/service/service.cc:168] XLA service 0x701e449e24c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-28 16:30:27.539779: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA 

