In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import os
import datetime
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

# =======================
# GPU Configuration (CRITICAL!)
# =======================
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"‚úÖ GPU Found: {tf.config.list_physical_devices('GPU')}")
    except RuntimeError as e:
        print(e)

# Enable mixed precision for faster training
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
print("‚úÖ Mixed precision enabled")

# Enable XLA for faster execution
tf.config.optimizer.set_jit(True)
print("‚úÖ XLA compilation enabled")

# =======================
# Parameters (OPTIMIZED!)
# =======================
IMG_HEIGHT = 256
IMG_WIDTH = 256
BATCH_SIZE = 1   # Keeping batch size at 1 as requested
EPOCHS = 50
AUTOTUNE = tf.data.AUTOTUNE
OUTPUT_CHANNELS = 1

# =======================
# Dataset Loader (OPTIMIZED!)
# =======================
def load_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_png(image, channels=1)
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
    image = (image / 127.5) - 1.0  # normalize [-1, 1]
    return image

def make_dataset(folder):
    paths = [os.path.join(folder, f) for f in sorted(os.listdir(folder))]
    dataset = tf.data.Dataset.from_tensor_slices(paths)
    dataset = dataset.map(load_image, num_parallel_calls=AUTOTUNE)
    dataset = dataset.cache()  # ‚Üê ADDED: Cache in memory!
    dataset = dataset.shuffle(buffer_size=1000)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.repeat()  # ‚Üê ADDED: Repeat for multiple epochs
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

print("Loading datasets...")
trainA = make_dataset("/content/dataset/dataset/trainA")  # CT
trainB = make_dataset("/content/dataset/dataset/trainB")  # PET

print("‚úÖ Datasets loaded and optimized")

# =======================
# Generator (ResNet-based)
# =======================
def resnet_block(x, filters, size=3):
    initializer = tf.random_normal_initializer(0., 0.02)
    y = layers.Conv2D(filters, size, padding='same', kernel_initializer=initializer)(x)
    y = layers.BatchNormalization()(y)
    y = layers.ReLU()(y)
    y = layers.Conv2D(filters, size, padding='same', kernel_initializer=initializer)(y)
    y = layers.BatchNormalization()(y)
    return layers.Add()([x, y])

def Generator():
    inputs = layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 1])
    initializer = tf.random_normal_initializer(0., 0.02)

    # c7s1-64
    x = layers.Conv2D(64, 7, strides=1, padding='same', kernel_initializer=initializer)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # d128, d256
    x = layers.Conv2D(128, 3, strides=2, padding='same', kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(256, 3, strides=2, padding='same', kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # 6 resnet blocks
    for _ in range(6):
        x = resnet_block(x, 256)

    # u128, u64
    x = layers.Conv2DTranspose(128, 3, strides=2, padding='same', kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2DTranspose(64, 3, strides=2, padding='same', kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # c7s1-1
    x = layers.Conv2D(OUTPUT_CHANNELS, 7, strides=1, padding='same', kernel_initializer=initializer, activation='tanh')(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

# =======================
# Discriminator (PatchGAN)
# =======================
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 1])
    x = layers.Conv2D(64, 4, strides=2, padding='same', kernel_initializer=initializer)(inp)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, strides=2, padding='same', kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 4, strides=2, padding='same', kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(512, 4, strides=1, padding='same', kernel_initializer=initializer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(1, 4, strides=1, padding='same', kernel_initializer=initializer)(x)
    return tf.keras.Model(inputs=inp, outputs=x)

# =======================
# Instantiate models
# =======================
print("Building models...")
G = Generator()      # CT ‚Üí PET
F = Generator()      # PET ‚Üí CT
D_CT = Discriminator()
D_PET = Discriminator()
print("‚úÖ Models built")

# =======================
# Loss Functions (FIXED for Mixed Precision)
# =======================
loss_obj = tf.keras.losses.MeanSquaredError()

def discriminator_loss(real, generated):
    # Inputs (real, generated) are Discriminator outputs, which should be float32 due to mixed precision policy
    # if the last layer is correctly set, but casting doesn't hurt.
    real_loss = loss_obj(tf.ones_like(real), real)
    generated_loss = loss_obj(tf.zeros_like(generated), generated)
    total_loss = (real_loss + generated_loss) * 0.5
    return total_loss

def generator_loss(generated):
    # Input 'generated' is Discriminator output (which should ideally be float32)
    return loss_obj(tf.ones_like(generated), generated)

LAMBDA = 10

def cycle_loss(real_image, cycled_image):
    # CRITICAL FIX: Cast both tensors to float32 before subtraction
    real_image_f32 = tf.cast(real_image, tf.float32)
    cycled_image_f32 = tf.cast(cycled_image, tf.float32)

    return LAMBDA * tf.reduce_mean(tf.abs(real_image_f32 - cycled_image_f32))

def identity_loss(real_image, same_image):
    # CRITICAL FIX: Cast both tensors to float32 before subtraction
    real_image_f32 = tf.cast(real_image, tf.float32)
    same_image_f32 = tf.cast(same_image, tf.float32)

    return LAMBDA * 0.5 * tf.reduce_mean(tf.abs(real_image_f32 - same_image_f32))

# =======================
# Optimizers (with mixed precision)
# =======================
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# Wrap optimizers for mixed precision
generator_g_optimizer = mixed_precision.LossScaleOptimizer(generator_g_optimizer)
generator_f_optimizer = mixed_precision.LossScaleOptimizer(generator_f_optimizer)
discriminator_x_optimizer = mixed_precision.LossScaleOptimizer(discriminator_x_optimizer)
discriminator_y_optimizer = mixed_precision.LossScaleOptimizer(discriminator_y_optimizer)

# =======================
# Checkpoints
# =======================
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(G=G,
                           F=F,
                           D_CT=D_CT,
                           D_PET=D_PET,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# =======================
# Training Step (FIXED for MODERN KERAS API)
# =======================
@tf.function
def train_step(real_CT, real_PET):
    with tf.GradientTape(persistent=True) as tape:

        # --- Forward Pass ---
        # Generator forward (G: CT -> PET, F: PET -> CT)
        fake_PET = G(real_CT, training=True)
        cycled_CT = F(fake_PET, training=True)

        fake_CT = F(real_PET, training=True)
        cycled_PET = G(fake_CT, training=True)

        # Identity mapping
        same_CT = F(real_CT, training=True)
        same_PET = G(real_PET, training=True)

        # Discriminator outputs
        disc_real_CT = D_CT(real_CT, training=True)
        disc_real_PET = D_PET(real_PET, training=True)
        disc_fake_CT = D_CT(fake_CT, training=True)
        disc_fake_PET = D_PET(fake_PET, training=True)

        # --- Loss Calculations (These are the UNSEALED losses) ---
        gen_g_loss = generator_loss(disc_fake_PET) + cycle_loss(real_CT, cycled_CT) + identity_loss(real_PET, same_PET)
        gen_f_loss = generator_loss(disc_fake_CT) + cycle_loss(real_PET, cycled_PET) + identity_loss(real_CT, same_CT)

        # Discriminator losses
        disc_CT_loss = discriminator_loss(disc_real_CT, disc_fake_CT)
        disc_PET_loss = discriminator_loss(disc_real_PET, disc_fake_PET)

        # ----------------------------------------------------------------------
        # REMOVED: All manual loss scaling (e.g., generator_g_optimizer.get_scaled_loss)
        # ----------------------------------------------------------------------

    # Gradients (Calculate on the UNSEALED loss)
    # The GradientTape will use float32 for the loss calculation due to policy/casting,
    # but the subsequent apply_gradients call handles the mixed precision logic.
    gradients_g = tape.gradient(gen_g_loss, G.trainable_variables)
    gradients_f = tape.gradient(gen_f_loss, F.trainable_variables)
    gradients_d_CT = tape.gradient(disc_CT_loss, D_CT.trainable_variables)
    gradients_d_PET = tape.gradient(disc_PET_loss, D_PET.trainable_variables)

    # ----------------------------------------------------------------------
    # REMOVED: All manual gradient unscaling (e.g., generator_g_optimizer.get_unscaled_gradients)
    # ----------------------------------------------------------------------

    # Apply gradients
    # The LossScaleOptimizer instance (e.g., generator_g_optimizer) handles
    # scaling, unscaling, and applying updates internally.
    generator_g_optimizer.apply_gradients(zip(gradients_g, G.trainable_variables))
    generator_f_optimizer.apply_gradients(zip(gradients_f, F.trainable_variables))
    discriminator_x_optimizer.apply_gradients(zip(gradients_d_CT, D_CT.trainable_variables))
    discriminator_y_optimizer.apply_gradients(zip(gradients_d_PET, D_PET.trainable_variables))

    return gen_g_loss, gen_f_loss, disc_CT_loss, disc_PET_loss

def generate_images(G, F, real_CT, real_PET):
    """
    Visualizes CT ‚Üí PET ‚Üí CT and PET ‚Üí CT ‚Üí PET
    """
    # Generate images
    fake_PET = G(real_CT, training=False)
    cycled_CT = F(fake_PET, training=False)

    fake_CT = F(real_PET, training=False)
    cycled_PET = G(fake_CT, training=False)

    # Convert [-1,1] to [0,1] for visualization
    def denorm(img):
        return (img + 1) / 2

    real_CT = denorm(real_CT)
    real_PET = denorm(real_PET)
    fake_PET = denorm(fake_PET)
    cycled_CT = denorm(cycled_CT)
    fake_CT = denorm(fake_CT)
    cycled_PET = denorm(cycled_PET)

    plt.figure(figsize=(12, 8))

    # Display first batch image
    plt.subplot(2,3,1)
    plt.title("Real CT")
    plt.imshow(real_CT[0,:,:,0], cmap='gray')
    plt.axis('off')

    plt.subplot(2,3,2)
    plt.title("Fake PET (CT ‚Üí PET)")
    plt.imshow(fake_PET[0,:,:,0], cmap='gray')
    plt.axis('off')

    plt.subplot(2,3,3)
    plt.title("Cycled CT (CT ‚Üí PET ‚Üí CT)")
    plt.imshow(cycled_CT[0,:,:,0], cmap='gray')
    plt.axis('off')

    plt.subplot(2,3,4)
    plt.title("Real PET")
    plt.imshow(real_PET[0,:,:,0], cmap='gray')
    plt.axis('off')

    plt.subplot(2,3,5)
    plt.title("Fake CT (PET ‚Üí CT)")
    plt.imshow(fake_CT[0,:,:,0], cmap='gray')
    plt.axis('off')

    plt.subplot(2,3,6)
    plt.title("Cycled PET (PET ‚Üí CT ‚Üí PET)")
    plt.imshow(cycled_PET[0,:,:,0], cmap='gray')
    plt.axis('off')

    plt.tight_layout()
    plt.savefig(f'output_epoch_{epoch+1}.png')
    # plt.close() # Removed plt.close()
    plt.show()

# =======================
# Training Loop (OPTIMIZED with progress bar)
# =======================
print(f"\n{'='*50}")
print(f"Starting Training: {EPOCHS} epochs, batch size {BATCH_SIZE}")
print(f"{'='*50}\n")

# Calculate steps per epoch
STEPS_PER_EPOCH = 1400 // BATCH_SIZE  # 1400 images

dataset = tf.data.Dataset.zip((trainA, trainB))

for epoch in range(EPOCHS):
    start = time.time()

    print(f"\nEpoch {epoch+1}/{EPOCHS}")

    # Progress bar for each epoch
    with tqdm(total=STEPS_PER_EPOCH, desc=f"Epoch {epoch+1}") as pbar:
        for step, (real_CT, real_PET) in enumerate(dataset.take(STEPS_PER_EPOCH)):
            gen_g_loss, gen_f_loss, disc_CT_loss, disc_PET_loss = train_step(real_CT, real_PET)

            pbar.update(1)

            # Update progress bar with losses every 50 steps
            if step % 50 == 0:
                pbar.set_postfix({
                    'G_loss': f'{gen_g_loss:.2f}',
                    'F_loss': f'{gen_f_loss:.2f}',
                    'D_CT': f'{disc_CT_loss:.2f}',
                    'D_PET': f'{disc_PET_loss:.2f}'
                })

    epoch_time = time.time() - start
    print(f"‚úÖ Epoch {epoch+1} completed in {epoch_time/60:.2f} minutes")

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        ckpt_manager.save()
        print(f"üíæ Checkpoint saved at epoch {epoch+1}")

    # Visualize after each epoch (moved outside the if condition)
    for real_CT_batch, real_PET_batch in dataset.take(1):
        generate_images(G, F, real_CT_batch, real_PET_batch)
        print(f"üñºÔ∏è  Sample images saved")


print("\n" + "="*50)
print("üéâ Training Complete!")
print("="*50)