In [None]:
# train_generator.py

import tensorflow as tf
import os
import time
import matplotlib.pyplot as plt

# --- Configuration ---
EPOCHS = 150
OUTPUT_DIR = 'gan_training_output' # To save sample generated images
CHECKPOINT_DIR = './gan_training_checkpoints'
LAMBDA = 100 # Weight for the L1 loss (reconstruction loss)

# Create output directories
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Loss Functions and Optimizers ---
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss

def generator_loss(disc_generated_output, gen_output, target):
    gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
    # L1 loss (mean absolute error) to make the generated image structurally similar to the target
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    return total_gen_loss, gan_loss, l1_loss

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# --- Checkpoint Manager ---
generator = Generator()
discriminator = Discriminator()

checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# --- Image Generation and Plotting ---
def generate_images(model, test_input, tar, epoch):
    """Generates and saves a plot of the input, real, and predicted images."""
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15, 5))

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Label Map', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        # Denormalize image from [-1, 1] to [0, 1] for display
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    
    # Save the plot
    plt.savefig(os.path.join(OUTPUT_DIR, f'image_at_epoch_{epoch+1:04d}.png'))
    plt.close()

# --- The Core Training Step ---
@tf.function
def train_step(input_image, target, epoch):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)

        disc_real_output = discriminator([input_image, target], training=True)
        disc_generated_output = discriminator([input_image, gen_output], training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

    return disc_loss, gen_total_loss

# --- Main Training Function ---
def train(dataset, epochs):
    # Take one batch from the dataset to use for visualization throughout training
    example_input, example_target = next(iter(dataset.take(1)))

    for epoch in range(epochs):
        start = time.time()
        print(f"Epoch {epoch + 1}/{epochs}")
        
        disc_loss_epoch = []
        gen_loss_epoch = []

        for n, (input_image, target) in tqdm(dataset.enumerate(), desc=f"  Training..."):
            disc_loss, gen_loss = train_step(input_image, target, epoch)
            disc_loss_epoch.append(disc_loss)
            gen_loss_epoch.append(gen_loss)

        # Generate and save a sample image at the end of the epoch
        generate_images(generator, example_input, example_target, epoch)

        # Save a checkpoint every 20 epochs
        if (epoch + 1) % 20 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)

        print (f'Time taken for epoch {epoch + 1} is {time.time()-start:.2f} sec')
        print (f'  -> Discriminator Loss: {tf.reduce_mean(disc_loss_epoch):.4f}, Generator Loss: {tf.reduce_mean(gen_loss_epoch):.4f}')

# --- Run the Training ---
if __name__ == '__main__':
    chipped_data_dir = 'chipped_data_512'
    
    # Build the dataset
    train_dataset = get_gan_dataset(chipped_data_dir, augment=True, shuffle=True)
    
    # Start training
    train(train_dataset, EPOCHS)