In [None]:
# train_generator.py (New Robust Version)

import tensorflow as tf
import os
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
from google.colab import drive # NEW: Import for Google Drive

# --- Configuration ---
EPOCHS = 80
LAMBDA = 100

# --- NEW: Define Google Drive paths ---
DRIVE_MOUNT_PATH = '/content/drive'
# This is the folder you specified for all outputs
DRIVE_OUTPUT_PATH = '/content/drive/MyDrive/Aerial Segmentation Machine Learning/data_gen'
# Create specific subdirectories for organization
DRIVE_CHECKPOINT_DIR = os.path.join(DRIVE_OUTPUT_PATH, 'checkpoints')
DRIVE_IMAGE_DIR = os.path.join(DRIVE_OUTPUT_PATH, 'image_samples')

# --- Main Training Function ---
def train(dataset, epochs):
    # --- NEW: Mount Drive and Create Directories ---
    print("💽 Mounting Google Drive...")
    if not os.path.ismount(DRIVE_MOUNT_PATH):
        drive.mount(DRIVE_MOUNT_PATH)
    else:
        print("✅ Drive already mounted.")

    os.makedirs(DRIVE_CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(DRIVE_IMAGE_DIR, exist_ok=True)
    print(f"✅ Outputs will be saved to: {DRIVE_OUTPUT_PATH}")
    # ---

    # Initialize models and optimizers (assumes these functions are in the global scope)
    generator = Generator()
    discriminator = Discriminator()
    generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    # Loss function
    loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    def discriminator_loss(disc_real_output, disc_generated_output):
        # ... (loss logic is the same)
        real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
        generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
        return real_loss + generated_loss

    def generator_loss(disc_generated_output, gen_output, target):
        # ... (loss logic is the same)
        gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
        l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
        return gan_loss + (LAMBDA * l1_loss)

    # The core training step function
    @tf.function
    def train_step(input_image, target):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            gen_output = generator(input_image, training=True)
            disc_real_output = discriminator([input_image, target], training=True)
            disc_generated_output = discriminator([input_image, gen_output], training=True)
            gen_loss = generator_loss(disc_generated_output, gen_output, target)
            disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

        generator_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
        discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
        return disc_loss, gen_loss

    # Image generation/plotting function
    def generate_and_save_images(model, test_input, tar, epoch):
        prediction = model(test_input, training=True)
        plt.figure(figsize=(15, 5))
        display_list = [test_input[0], tar[0], prediction[0]]
        title = ['Input Label Map', 'Ground Truth', 'Generated Image']
        for i in range(3):
            plt.subplot(1, 3, i+1)
            plt.title(title[i])
            plt.imshow(display_list[i] * 0.5 + 0.5) # Denormalize for viewing
            plt.axis('off')
        
        # NEW: Save plot directly to Google Drive
        save_path = os.path.join(DRIVE_IMAGE_DIR, f'image_at_epoch_{epoch+1:04d}.png')
        plt.savefig(save_path)
        plt.close()

    # --- The Main Training Loop ---
    example_input, example_target = next(iter(dataset.take(1)))
    start_time_total = time.time()

    for epoch in range(epochs):
        start_time_epoch = time.time()
        print(f"--- Starting Epoch {epoch + 1}/{epochs} ---")
        
        disc_loss_epoch, gen_loss_epoch = [], []

        for input_image, target in tqdm(dataset, desc=f"  Training..."):
            disc_loss, gen_loss = train_step(input_image, target)
            disc_loss_epoch.append(disc_loss)
            gen_loss_epoch.append(gen_loss)

        # Generate and save a sample image at the end of the epoch
        generate_and_save_images(generator, example_input, example_target, epoch)
        print(f"✅ Sample image for epoch {epoch+1} saved to Drive.")

        # --- NEW: Save models periodically to Google Drive ---
        if (epoch + 1) % 20 == 0:
            timestamp = time.strftime("%Y%m%d-%H%M%S")
            gen_save_path = os.path.join(DRIVE_CHECKPOINT_DIR, f'gen_{timestamp}_epoch{epoch+1}.keras')
            disc_save_path = os.path.join(DRIVE_CHECKPOINT_DIR, f'disc_{timestamp}_epoch{epoch+1}.keras')
            
            generator.save(gen_save_path)
            discriminator.save(disc_save_path)
            print(f"✅ Saved models to Drive for epoch {epoch+1}.")
        # ---
        
        print(f'Time for epoch {epoch + 1} is {time.time()-start_time_epoch:.2f} sec')
        print(f'  -> Avg Discriminator Loss: {tf.reduce_mean(disc_loss_epoch):.4f}, Avg Generator Loss: {tf.reduce_mean(gen_loss_epoch):.4f}')

    # --- Final Save ---
    print("--- Training finished. Saving final models. ---")
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    generator.save(os.path.join(DRIVE_CHECKPOINT_DIR, f'gen_final_{timestamp}_epoch{epochs}.keras'))
    discriminator.save(os.path.join(DRIVE_CHECKPOINT_DIR, f'disc_final_{timestamp}_epoch{epochs}.keras'))
    print(f"✅ Final models saved to: {DRIVE_CHECKPOINT_DIR}")