In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from google.colab import drive
from PIL import Image
import time

# ========== SETUP ========== #
# Mount Google Drive
drive.flush_and_unmount()
drive.mount('/content/drive', force_remount=True)

# Configuration
dataset_path = "/content/drive/MyDrive/dfu /PartB_DFU_Dataset/Infection/Aug-Negative"  # UPDATE THIS
save_path = "/content/drive/MyDrive/ViTGAN_outputs"         # UPDATE THIS
os.makedirs(save_path, exist_ok=True)

# Hyperparameters
img_size = 64
channels = 3
latent_dim = 128
batch_size = 64
epochs = 500
save_interval = 25

# ========== DATA PIPELINE ========== #
def load_dataset(dataset_path, img_size):
    """Optimized data loading pipeline"""
    image_paths = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path)
                  if f.endswith(('.jpg', '.jpeg', '.png'))]

    def preprocess_image(img_path):
        img = tf.io.read_file(img_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, [img_size, img_size])
        img = (img - 127.5) / 127.5  # Normalize to [-1, 1]
        return img

    dataset = tf.data.Dataset.from_tensor_slices(image_paths)
    dataset = dataset.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.shuffle(2000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

dataset = load_dataset(dataset_path, img_size)

# ========== MODEL ARCHITECTURE ========== #
def build_generator(latent_dim):
    inputs = layers.Input(shape=(latent_dim,))

    # Foundation
    x = layers.Dense(8*8*256)(inputs)
    x = layers.Reshape((8, 8, 256))(x)

    # Upsampling blocks
    x = layers.Conv2DTranspose(128, 4, 2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2DTranspose(64, 4, 2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    # Output (ensure exact dimensions)
    x = layers.Conv2DTranspose(channels, 4, padding='same', activation='tanh')(x)
    x = layers.Cropping2D(((0,1),(0,1)))(x) if img_size % 2 != 0 else x
    x = layers.Resizing(img_size, img_size)(x)

    return keras.Model(inputs, x, name="generator")

def build_discriminator():
    inputs = layers.Input(shape=(img_size, img_size, channels))

    # Downsampling blocks
    x = layers.Conv2D(64, 4, 2, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, 2, padding='same')(x)
    x = layers.LeakyReLU(0.2)(x)

    # Classification
    x = layers.Flatten()(x)
    x = layers.Dense(1)(x)

    return keras.Model(inputs, x, name="discriminator")

# ========== TRAINING SETUP ========== #
# Initialize models
generator = build_generator(latent_dim)
discriminator = build_discriminator()

# Optimizers
g_optimizer = Adam(0.0002, beta_1=0.5)
d_optimizer = Adam(0.0002, beta_1=0.5)

# Loss function
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

# ========== TRAINING LOOP ========== #
@tf.function
def train_step(real_images):
    # Train discriminator
    noise = tf.random.normal([batch_size, latent_dim])
    with tf.GradientTape() as d_tape:
        generated_images = generator(noise, training=True)
        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(generated_images, training=True)
        d_loss = (cross_entropy(tf.ones_like(real_output), real_output) +
                 cross_entropy(tf.zeros_like(fake_output), fake_output)) / 2

    d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)
    d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))

    # Train generator
    noise = tf.random.normal([batch_size, latent_dim])
    with tf.GradientTape() as g_tape:
        generated_images = generator(noise, training=True)
        fake_output = discriminator(generated_images, training=True)
        g_loss = cross_entropy(tf.ones_like(fake_output), fake_output)

    g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)
    g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))

    return d_loss, g_loss

def save_images(epoch):
    """Save grid and individual images to Drive"""
    os.makedirs(save_path, exist_ok=True)
    noise = tf.random.normal([25, latent_dim])
    images = generator(noise, training=False)
    images = (images * 127.5 + 127.5).numpy().astype('uint8')

    # Save grid
    plt.figure(figsize=(10,10))
    for i in range(25):
        plt.subplot(5, 5, i+1)
        plt.imshow(images[i])
        plt.axis('off')
    plt.savefig(f"{save_path}/grid_{epoch:04d}.png")
    plt.close()

    # Save individual images
    for i, img in enumerate(images):
        Image.fromarray(img).save(f"{save_path}/img_{epoch:04d}_{i:02d}.png")

    print(f"✅ Saved outputs to {save_path}")

def train(dataset, epochs, save_interval):
    for epoch in range(epochs):
        start = time.time()

        total_d_loss = 0
        total_g_loss = 0
        num_batches = 0

        for batch in dataset:
            d_loss, g_loss = train_step(batch)
            total_d_loss += d_loss
            total_g_loss += g_loss
            num_batches += 1

        # Save progress
        if epoch % save_interval == 0 or epoch == epochs - 1:
            save_images(epoch)

        print(f"⏱️ Epoch {epoch+1}/{epochs} | "
              f"Disc Loss: {total_d_loss/num_batches:.4f} | "
              f"Gen Loss: {total_g_loss/num_batches:.4f} | "
              f"Time: {time.time()-start:.2f}s")

# ========== START TRAINING ========== #
print("🚀 Training started - monitoring progress...")
train(dataset, epochs, save_interval)

# Final save
generator.save(f"{save_path}/final_generator.h5")
print(f"🎉 Training complete! Models saved to {save_path}")

Mounted at /content/drive
🚀 Training started - monitoring progress...
✅ Saved outputs to /content/drive/MyDrive/ViTGAN_outputs
⏱️ Epoch 1/500 | Disc Loss: 0.5914 | Gen Loss: 1.3614 | Time: 95.43s
⏱️ Epoch 2/500 | Disc Loss: 0.3521 | Gen Loss: 2.2827 | Time: 10.56s
⏱️ Epoch 3/500 | Disc Loss: 0.6712 | Gen Loss: 0.7224 | Time: 8.23s
⏱️ Epoch 4/500 | Disc Loss: 0.6246 | Gen Loss: 0.8602 | Time: 9.22s
⏱️ Epoch 5/500 | Disc Loss: 0.7118 | Gen Loss: 0.7818 | Time: 9.42s
⏱️ Epoch 6/500 | Disc Loss: 0.7035 | Gen Loss: 0.8148 | Time: 8.53s
⏱️ Epoch 7/500 | Disc Loss: 0.7060 | Gen Loss: 0.7119 | Time: 9.22s
⏱️ Epoch 8/500 | Disc Loss: 0.6978 | Gen Loss: 0.7704 | Time: 9.42s
⏱️ Epoch 9/500 | Disc Loss: 0.6950 | Gen Loss: 0.7112 | Time: 8.70s
⏱️ Epoch 10/500 | Disc Loss: 0.6931 | Gen Loss: 0.7573 | Time: 8.92s
⏱️ Epoch 11/500 | Disc Loss: 0.6961 | Gen Loss: 0.6995 | Time: 9.59s
⏱️ Epoch 12/500 | Disc Loss: 0.6928 | Gen Loss: 0.7218 | Time: 9.17s
⏱️ Epoch 13/500 | Disc Loss: 0.6953 | Gen Loss: 0.71