In [1]:
import tensorflow as tf
import numpy as np
import os
import datetime
import logging
from pathlib import Path
from model import Generator  # Import Generator from model.py
import time

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('finetune.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Constants
IMG_WIDTH = 256
IMG_HEIGHT = 256
INPUT_CHANNELS = 1  # Single-channel grayscale input
OUTPUT_CHANNELS = 6  # One-hot encoded ROI (6 classes)
BATCH_SIZE = 128
BUFFER_SIZE = 1000
EPOCHS = 500
WEIGHTS_DIR = './training_checkpoints'
LOG_DIR = './logs'
DATA_DIR = '/home/besanhalwa/Eshan/project1_PMRI/Data/npy_tech_pmri_no_aug_leftRightSplit/'

# Loss functions
def dice_loss_channel_wise(y_true, y_pred, smooth=1e-6):
    """
    Compute channel-wise Dice loss for multi-class segmentation.
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.reshape(y_true, (-1, y_true.shape[-1]))
    y_pred_f = tf.reshape(y_pred, (-1, y_pred.shape[-1]))
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    union = tf.reduce_sum(y_true_f, axis=0) + tf.reduce_sum(y_pred_f, axis=0)
    dice = (2. * intersection + smooth) / (union + smooth)
    dice_loss_per_channel = 1 - dice
    return tf.reduce_mean(dice_loss_per_channel)

def pixel_wise_binary_crossentropy_loss(y_true, y_pred):
    """
    Compute pixel-wise binary cross-entropy loss for multi-class segmentation.
    """
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    bce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=False)
    loss_per_channel = bce_loss(y_true, y_pred)
    return tf.reduce_mean(loss_per_channel)

def combined_loss(y_true, y_pred, alpha=0.5):
    """
    Combined loss function: Binary Cross-Entropy Loss + Channel-wise Dice Loss.
    """
    bce_loss = pixel_wise_binary_crossentropy_loss(y_true, y_pred)
    dice_loss = dice_loss_channel_wise(y_true, y_pred)
    total_loss = alpha * bce_loss + (1 - alpha) * dice_loss
    return total_loss

# Load datasets
def load_npy_datasets():
    try:
        train_images = np.load(os.path.join(DATA_DIR, 'train_images.npy'))  # Shape: (312, 256, 256)
        train_roi = np.load(os.path.join(DATA_DIR, 'train_masks_hot_encoded.npy'))  # Shape: (312, 256, 256)
        val_images = np.load(os.path.join(DATA_DIR, 'val_images.npy'))  # Shape: (416, 256, 256)
        val_roi = np.load(os.path.join(DATA_DIR, 'val_masks_hot_encoded.npy'))  # Shape: (416, 256, 256)
        
        # Ensure images have 1 channel
        if train_images.ndim == 3:
            train_images = np.expand_dims(train_images, axis=-1)
        if val_images.ndim == 3:
            val_images = np.expand_dims(val_images, axis=-1)
        
        # Check for NaN values
        if np.any(np.isnan(train_images)) or np.any(np.isnan(val_images)):
            logger.error("NaN values found in input images")
            raise ValueError("NaN values in input images")
        if np.any(np.isnan(train_roi)) or np.any(np.isnan(val_roi)):
            logger.error("NaN values found in ROI masks")
            raise ValueError("NaN values in ROI masks")
        
        # Log ROI details
        logger.info(f"Train images shape: {train_images.shape}")
        logger.info(f"Train ROI shape: {train_roi.shape}")
        logger.info(f"Val images shape: {val_images.shape}")
        logger.info(f"Val ROI shape: {val_roi.shape}")
        logger.info(f"Train ROI min: {np.min(train_roi)}, max: {np.max(train_roi)}")
        logger.info(f"Val ROI min: {np.min(val_roi)}, max: {np.max(val_roi)}")
        unique_values = np.unique(train_roi)
        logger.info(f"Train ROI unique values: {unique_values[:10]}{'...' if len(unique_values) > 10 else ''}")
        
        # Validate ROI class labels
        if not np.all((train_roi >= 0) & (train_roi <= OUTPUT_CHANNELS)):
            logger.error("Train ROI contains invalid class labels")
            raise ValueError("Invalid class labels in train_roi")
        if not np.all((val_roi >= 0) & (val_roi <= OUTPUT_CHANNELS)):
            logger.error("Val ROI contains invalid class labels")
            raise ValueError("Invalid class labels in val_roi")
        
        # Convert integer labels to one-hot encoded
        #train_roi = tf.one_hot(train_roi, depth=OUTPUT_CHANNELS).numpy().astype(np.float32)
        #val_roi = tf.one_hot(val_roi, depth=OUTPUT_CHANNELS).numpy().astype(np.float32)
        
        return (train_images.astype(np.float32), train_roi), (val_images.astype(np.float32), val_roi)
    except Exception as e:
        logger.error(f"Failed to load datasets: {e}")
        raise

# Data preprocessing
def normalize(input_image, real_image):
    input_image = tf.cast(input_image, tf.float32)
    input_image = (input_image / 127.5) - 1
    return input_image, real_image

# Create tf.data datasets
def create_datasets(train_data, val_data):
    train_images, train_roi = train_data
    val_images, val_roi = val_data
    
    train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_roi))
    val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_roi))
    
    train_dataset = (train_dataset
                     .cache()
                     .map(normalize, num_parallel_calls=tf.data.AUTOTUNE)
                     .shuffle(BUFFER_SIZE)
                     .batch(BATCH_SIZE)
                     .prefetch(tf.data.AUTOTUNE))
    
    val_dataset = (val_dataset
                   .cache()
                   .map(normalize, num_parallel_calls=tf.data.AUTOTUNE)
                   .batch(BATCH_SIZE)
                   .prefetch(tf.data.AUTOTUNE))
    
    return train_dataset, val_dataset

# Training step for fine-tuning (generator only)
@tf.function
def train_step(input_image, target, generator, optimizer):
    with tf.GradientTape() as tape:
        gen_output = generator(input_image, training=True)
        loss = combined_loss(target, gen_output)
    gradients = tape.gradient(loss, generator.trainable_variables)
    optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
    return loss

# Validation function
def validate(val_ds, generator):
    val_loss = 0.0
    num_batches = 0
    for input_image, target in val_ds:
        gen_output = generator(input_image, training=False)
        loss = combined_loss(target, gen_output)
        val_loss += loss.numpy()
        num_batches += 1
    return val_loss / num_batches

# Fine-tuning function
def fit(train_ds, val_ds, generator, optimizer, epochs):
    best_val_loss = float('inf')
    patience = 25
    wait = 0
    
    steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
    checkpoint_prefix = os.path.join(WEIGHTS_DIR, "ckpt")
    checkpoint = tf.train.Checkpoint(generator_optimizer=optimizer, generator=generator)
    
    summary_writer = tf.summary.create_file_writer(
        os.path.join(LOG_DIR, "fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    )
    
    for epoch in range(epochs):
        logger.info(f"Starting epoch {epoch + 1}/{epochs}")
        start = time.time()
        
        # Training
        total_loss = 0.0
        for step, (input_image, target) in enumerate(train_ds):
            loss = train_step(input_image, target, generator, optimizer)
            total_loss += loss.numpy()
            if (step + 1) % 100 == 0:
                logger.info(f"Step {step + 1}/{steps_per_epoch}: Loss = {loss.numpy():.4f}")
        
        avg_train_loss = total_loss / steps_per_epoch
        
        # Validation
        val_loss = validate(val_ds, generator)
        logger.info(f"Epoch {epoch + 1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {val_loss:.4f}, Time = {time.time() - start:.2f} sec")
        
        # Logging to TensorBoard
        with summary_writer.as_default():
            tf.summary.scalar('train_loss', avg_train_loss, step=epoch)
            tf.summary.scalar('val_loss', val_loss, step=epoch)
        
        # Checkpointing: Save best model
        if val_loss < best_val_loss and not np.isnan(val_loss):
            best_val_loss = val_loss
            checkpoint.save(file_prefix='./finetuned_weights' + "fineTuned_best")
            logger.info(f"Saved best checkpoint for epoch {epoch + 1}")
            wait = 0
        else:
            wait += 1
        
        # Checkpointing: Save every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix + f"_epoch_{epoch + 1}")
            logger.info(f"Saved periodic checkpoint for epoch {epoch + 1}")
        
        # Early stopping
        if wait >= patience:
            logger.info(f"Early stopping triggered after {patience} epochs without improvement")
            break

# Main function
def main():
    try:
        # Create directories
        Path(WEIGHTS_DIR).mkdir(parents=True, exist_ok=True)
        Path(LOG_DIR).mkdir(parents=True, exist_ok=True)
        
        # Load model
        generator = Generator()
        logger.info("Initialized generator")
        
        # Check model output for NaN
        train_ds = create_datasets(load_npy_datasets()[0], load_npy_datasets()[1])[0]
        input_image = next(iter(train_ds.take(1)))[0]
        gen_output = generator(input_image, training=False)
        if tf.math.reduce_any(tf.math.is_nan(gen_output)):
            logger.error("NaN found in generator output")
            raise ValueError("NaN found in generator output")
        
        # Load checkpoint
        checkpoint = tf.train.Checkpoint(generator=generator)
        checkpoint_path = os.path.join(WEIGHTS_DIR, 'ckpt-47')
        if not Path(checkpoint_path + '.index').exists():
            raise FileNotFoundError(f"Checkpoint file {checkpoint_path} not found")
        checkpoint.restore(checkpoint_path).expect_partial()
        logger.info(f"Restored generator weights from {checkpoint_path}")
        
        # Load data
        train_data, val_data = load_npy_datasets()
        train_ds, val_ds = create_datasets(train_data, val_data)
        
        # Optimizer
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.5)
        
        # Fine-tune generator
        fit(train_ds, val_ds, generator, optimizer, EPOCHS)
        
        # Save final model
        generator.save(os.path.join(DATA_DIR, 'generator_finetuned.h5'))
        logger.info("Fine-tuning completed and generator saved")
        
    except Exception as e:
        logger.error(f"Process failed: {e}")
        raise

if __name__ == "__main__":
    main()

2025-05-28 15:42:58.267014: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-28 15:42:58.267044: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-28 15:42:58.268044: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-28 15:42:58.273353: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-05-28 15:42:59.456276: I external/local_xla/xla/



  saving_api.save_model(
2025-05-28 15:52:48,454 - INFO - Fine-tuning completed and generator saved
