In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import datetime
import logging
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('finetune.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Constants
IMG_WIDTH = 256
IMG_HEIGHT = 256
INPUT_CHANNELS = 1  # Single-channel grayscale input
OUTPUT_CHANNELS = 6  # One-hot encoded ROI
BATCH_SIZE = 16
BUFFER_SIZE = 1000
EPOCHS = 50  # Reduced for fine-tuning
LAMBDA = 100
CHECKPOINT_DIR = '/home/besanhalwa/samepath/checkpoints'
LOG_DIR = '/home/besanhalwa/samepath/logs'
DATA_DIR = '/home/besanhalwa/samepath'

# Model definitions (from provided code)
def forwardblock(filters, size, apply_batchnorm=True, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(filters, size, strides=1, padding='same',
                               kernel_initializer=initializer, use_bias=False))
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.1))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU())
    return result

def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                               kernel_initializer=initializer, use_bias=False))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    result.add(tf.keras.layers.LeakyReLU())
    return result

def upsample(filters, size, apply_dropout=False, apply_batchnorm=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    result.add(
        tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False))
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.1))
    result.add(tf.keras.layers.ReLU())
    return result

def Generator():
    inputs = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, INPUT_CHANNELS])
    down_stack = [
        downsample(64, 4, apply_batchnorm=True),
        forwardblock(64, 4, apply_batchnorm=True, apply_dropout=True),
        downsample(128, 4, apply_batchnorm=True),
        forwardblock(128, 4, apply_batchnorm=True, apply_dropout=True),
        downsample(256, 4, apply_batchnorm=True),
        forwardblock(256, 4, apply_batchnorm=True, apply_dropout=True),
        downsample(512, 4, apply_batchnorm=True),
        forwardblock(512, 4, apply_batchnorm=True, apply_dropout=True),
        downsample(1024, 4, apply_batchnorm=True),
    ]
    up_stack = [
        upsample(512, 4, apply_batchnorm=True),
        forwardblock(512, 4, apply_batchnorm=True, apply_dropout=True),
        upsample(256, 4, apply_batchnorm=True),
        forwardblock(256, 4, apply_batchnorm=True, apply_dropout=True),
        upsample(128, 4, apply_batchnorm=True),
        forwardblock(128, 4, apply_batchnorm=True, apply_dropout=True),
        upsample(64, 4, apply_batchnorm=True),
        forwardblock(64, 4, apply_batchnorm=True, apply_dropout=True),
    ]
    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                          strides=2,
                                          padding='same',
                                          kernel_initializer=initializer,
                                          activation='tanh')
    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)
    skips = reversed(skips[:-1])
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])
    x = last(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, INPUT_CHANNELS], name='input_image')
    tar = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, OUTPUT_CHANNELS], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar])
    down1 = downsample(64, 4, False)(x)
    down2 = downsample(128, 4)(down1)
    down3 = downsample(256, 4)(down2)
    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                  kernel_initializer=initializer,
                                  use_bias=False)(zero_pad1)
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)
    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                  kernel_initializer=initializer)(zero_pad2)
    return tf.keras.Model(inputs=[inp, tar], outputs=last)

# Load datasets
def load_npy_datasets():
    try:
        train_images = np.load(os.path.join(DATA_DIR, 'train_images.npy'))  # Shape: (2348, 256, 256)
        train_roi = np.load(os.path.join(DATA_DIR, 'train_roi.npy'))  # Shape: (2348, 256, 256, 6)
        val_images = np.load(os.path.join(DATA_DIR, 'val_images.npy'))  # Shape: (416, 256, 256)
        val_roi = np.load(os.path.join(DATA_DIR, 'val_roi.npy'))  # Shape: (416, 256, 256, 6)
        test_images = np.load(os.path.join(DATA_DIR, 'test_images.npy'))  # Shape: (416, 256, 256)
        test_roi = np.load(os.path.join(DATA_DIR, 'test_roi.npy'))  # Shape: (416, 256, 256, 6)
        
        # Ensure images have 1 channel
        if train_images.ndim == 3:
            train_images = np.expand_dims(train_images, axis=-1)
            val_images = np.expand_dims(val_images, axis=-1)
            test_images = np.expand_dims(test_images, axis=-1)
        
        # Scale ROI to [-1, 1] to match tanh output
        train_roi = tf.cast(train_roi, tf.float32) * 2.0 - 1.0
        val_roi = tf.cast(val_roi, tf.float32) * 2.0 - 1.0
        test_roi = tf.cast(test_roi, tf.float32) * 2.0 - 1.0
        
        return (train_images, train_roi), (val_images, val_roi), (test_images, test_roi)
    except Exception as e:
        logger.error(f"Failed to load datasets: {e}")
        raise

# Data preprocessing
def normalize(input_image, real_image):
    input_image = tf.cast(input_image, tf.float32)
    input_image = (input_image / 127.5) - 1
    return input_image, real_image

# Create tf.data datasets
def create_datasets(train_data, val_data, test_data):
    train_images, train_roi = train_data
    val_images, val_roi = val_data
    test_images, test_roi = test_data
    
    train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_roi))
    val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_roi))
    test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_roi))
    
    train_dataset = (train_dataset
                     .cache()
                     .map(normalize, num_parallel_calls=tf.data.AUTOTUNE)
                     .shuffle(BUFFER_SIZE)
                     .batch(BATCH_SIZE)
                     .prefetch(tf.data.AUTOTUNE))
    
    val_dataset = (val_dataset
                   .cache()
                   .map(normalize, num_parallel_calls=tf.data.AUTOTUNE)
                   .batch(BATCH_SIZE)
                   .prefetch(tf.data.AUTOTUNE))
    
    test_dataset = (test_dataset
                    .cache()
                    .map(normalize, num_parallel_calls=tf.data.AUTOTUNE)
                    .batch(BATCH_SIZE)
                    .prefetch(tf.data.AUTOTUNE))
    
    return train_dataset, val_dataset, test_dataset

# Loss functions
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    return total_gen_loss, gan_loss, l1_loss

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss

# Training step for fine-tuning (generator only)
@tf.function
def train_step(input_image, target, generator, discriminator, generator_optimizer):
    with tf.GradientTape() as gen_tape:
        gen_output = generator(input_image, training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=False)  # Discriminator is frozen
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    
    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    
    return gen_total_loss, gen_gan_loss, gen_l1_loss

# Fine-tuning function
def fit(train_ds, val_ds, generator, discriminator, generator_optimizer, epochs):
    best_val_loss = float('inf')
    patience = 5
    wait = 0
    
    steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
    checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "ckpt")
    checkpoint = tf.train.Checkpoint(
        generator_optimizer=generator_optimizer,
        generator=generator
    )
    
    summary_writer = tf.summary.create_file_writer(
        os.path.join(LOG_DIR, "fit", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
    )
    
    for epoch in range(epochs):
        logger.info(f"Starting epoch {epoch + 1}/{epochs}")
        start = time.time()
        
        # Training
        total_gen_loss = 0.0
        for step, (input_image, target) in enumerate(train_ds):
            gen_total_loss, gen_gan_loss, gen_l1_loss = train_step(
                input_image, target, generator, discriminator, generator_optimizer
            )
            total_gen_loss += gen_total_loss
            
            if (step + 1) % 100 == 0:
                logger.info(f"Step {step + 1}/{steps_per_epoch}: Gen Loss = {gen_total_loss.numpy():.4f}")
        
        # Validation
        val_gen_loss = 0.0
        val_steps = 0
        for input_image, target in val_ds:
            gen_output = generator(input_image, training=False)
            disc_generated_output = discriminator([input_image, gen_output], training=False)
            gen_total_loss, _, _ = generator_loss(disc_generated_output, gen_output, target)
            val_gen_loss += gen_total_loss
            val_steps += 1
        
        val_gen_loss /= val_steps
        logger.info(f"Epoch {epoch + 1}: Val Gen Loss = {val_gen_loss:.4f}, Time = {time.time() - start:.2f} sec")
        
        # Logging to TensorBoard
        with summary_writer.as_default():
            tf.summary.scalar('gen_total_loss', total_gen_loss / steps_per_epoch, step=epoch)
            tf.summary.scalar('val_gen_loss', val_gen_loss, step=epoch)
        
        # Checkpointing: Save best model
        if val_gen_loss < best_val_loss:
            best_val_loss = val_gen_loss
            checkpoint.save(file_prefix=checkpoint_prefix + "_best")
            logger.info(f"Saved best checkpoint for epoch {epoch + 1}")
            wait = 0
        else:
            wait += 1
        
        # Checkpointing: Save every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix + f"_epoch_{epoch + 1}")
            logger.info(f"Saved periodic checkpoint for epoch {epoch + 1}")
        
        # Early stopping
        if wait >= patience:
            logger.info(f"Early stopping triggered after {patience} epochs without improvement")
            break

# Convert one-hot encoded mask to single-channel integer mask
def one_hot_to_integer_mask(one_hot_mask):
    """
    Convert a one-hot encoded mask (..., 256, 256, 6) to a single-channel integer mask (..., 256, 256)
    with values from 0 to 5, handling tanh output by thresholding.
    """
    # Convert tanh output ([-1, 1]) to probabilities [0, 1]
    one_hot_mask = (one_hot_mask + 1) / 2.0
    return tf.argmax(one_hot_mask, axis=-1, output_type=tf.int32)

# Display input image with overlaid segmentation mask
def display_overlay(input_image, predicted_mask, ground_truth_mask=None, num_samples=5):
    """
    Display input images with predicted segmentation masks overlaid.
    Optionally display ground truth masks if provided.
    """
    # Convert to numpy and remove batch dimension if necessary
    input_image = input_image.numpy()
    predicted_mask = predicted_mask.numpy()
    if ground_truth_mask is not None:
        ground_truth_mask = ground_truth_mask.numpy()
    
    # Ensure we process only up to num_samples
    num_samples = min(num_samples, input_image.shape[0])
    
    # Create a colormap for the segmentation mask
    cmap = plt.get_cmap('jet')
    norm = plt.Normalize(vmin=0, vmax=OUTPUT_CHANNELS-1)
    
    # Set up figure
    cols = 3 if ground_truth_mask is not None else 2
    fig, axes = plt.subplots(num_samples, cols, figsize=(cols * 5, num_samples * 5))
    if num_samples == 1:
        axes = np.array([axes])
    
    for i in range(num_samples):
        # Denormalize input image to [0, 1] for display
        img = (input_image[i, ..., 0] + 1) * 0.5  # Grayscale, remove channel dimension
        
        # Predicted mask
        pred_mask = predicted_mask[i]
        
        # Plot input image
        axes[i, 0].imshow(img, cmap='gray')
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')
        
        # Plot input with overlaid predicted mask
        axes[i, 1].imshow(img, cmap='gray')
        axes[i, 1].imshow(pred_mask, cmap=cmap, norm=norm, alpha=0.5)
        axes[i, 1].set_title('Input + Predicted Mask')
        axes[i, 1].axis('off')
        
        # Plot ground truth mask if provided
        if ground_truth_mask is not None:
            gt_mask = ground_truth_mask[i]
            axes[i, 2].imshow(img, cmap='gray')
            axes[i, 2].imshow(gt_mask, cmap=cmap, norm=norm, alpha=0.5)
            axes[i, 2].set_title('Input + Ground Truth Mask')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(DATA_DIR, 'test_results.png'))
    plt.show()
    logger.info("Saved test results visualization to test_results.png")

def main():
    try:
        # Create directories
        Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
        Path(LOG_DIR).mkdir(parents=True, exist_ok=True)
        
        # Load models
        generator = Generator()
        discriminator = Discriminator()
        discriminator.trainable = False  # Freeze discriminator
        logger.info("Initialized generator and frozen discriminator")
        
        # Load checkpoint
        checkpoint = tf.train.Checkpoint(
            generator=generator,
            discriminator=discriminator
        )
        checkpoint_path = os.path.join(DATA_DIR, 'ckpt-47')
        if not Path(checkpoint_path + '.index').exists():
            raise FileNotFoundError(f"Checkpoint file {checkpoint_path} not found")
        checkpoint.restore(checkpoint_path).expect_partial()
        logger.info(f"Restored weights from {checkpoint_path}")
        
        # Load data
        train_data, val_data, test_data = load_npy_datasets()
        train_ds, val_ds, test_ds = create_datasets(train_data, val_data, test_data)
        
        # Fine-tune generator
        generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        fit(train_ds, val_ds, generator, discriminator, generator_optimizer, EPOCHS)
        
        # Save final model
        generator.save(os.path.join(DATA_DIR, 'generator_finetuned.h5'))
        logger.info("Fine-tuning completed and generator saved")
        
        # Test the model
        for input_images, ground_truth_roi in test_ds.take(1):  # Take one batch
            predictions = generator(input_images, training=False)
            
            # Convert one-hot predictions and ground truth to integer masks
            predicted_masks = one_hot_to_integer_mask(predictions)
            ground_truth_masks = one_hot_to_integer_mask(ground_truth_roi)
            
            # Display results
            display_overlay(input_images, predicted_masks, ground_truth_masks, num_samples=5)
            break
        
    except Exception as e:
        logger.error(f"Process failed: {e}")
        raise

if __name__ == "__main__":
    main()