In [1]:
import os, sys, json
import tensorflow as tf
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from pathlib import Path
import time
from tensorflow.keras.models import load_model
from PIL import Image

In [2]:
sys.path.append(os.path.abspath(os.path.join('..','data_processing')))
sys.path.append(os.path.abspath(os.path.join('..','models')))

In [3]:
tf.random.set_seed(42)

In [4]:
from gan_preprocessing import dataset 

In [5]:
from gan2 import create_gan2
from gan3 import create_gan3

D accuracy real: 0.6-0.9
D accuracy fake: 0.6-0.9
D loss: 0.3-0.7
G loss: 0.7-2.0

In [6]:
def train_gan(
        generator,
        discriminator,
        gan,
        dataset,
        epochs=100,
        batch_size=32,
        latent_dim=120,
        save_dir='gan_training',
        starting_epoch=0,
        checkpoint_interval=10,  # Save models every 5 epochs
        sample_interval=5      # Generate images every 10 epochs
):
    save_dir = Path(save_dir)
    models_dir = save_dir / 'models'
    samples_dir = save_dir / 'samples'

    for dir_path in [save_dir, models_dir, samples_dir]:
        dir_path.mkdir(parents=True, exist_ok=True)

    history = {
        'training_config': {
            'batch_size': int(batch_size),
            'latent_dim': int(latent_dim),
            'starting_epoch': int(starting_epoch),
            'total_epochs': int(epochs),
            'checkpoint_interval': int(checkpoint_interval),
            'sample_interval': int(sample_interval),
            'start_time': datetime.now().isoformat()
        },
        'epochs': []
    }

    def save_images(epoch):
        noise = tf.random.normal([25, latent_dim])
        generated = generator(noise, training=False)
        generated = (generated + 1) / 2.0

        plt.figure(figsize=(10, 10))
        for i in range(25):
            plt.subplot(5, 5, i + 1)
            plt.imshow(generated[i])
            plt.axis('off')
        plt.savefig(samples_dir / f'epoch_{epoch}.png')
        plt.close()

    def save_checkpoint(epoch, d_losses, g_losses):
        epoch_stats = {
            'epoch_number': int(epoch),
            'epoch_completed': datetime.now().isoformat(),
            'mean_d_loss': float(np.mean(d_losses)),
            'mean_g_loss': float(np.mean(g_losses)),
            'std_d_loss': float(np.std(d_losses)),
            'std_g_loss': float(np.std(g_losses))
        }
        history['epochs'].append(epoch_stats)

        # Save current state
        checkpoint_dir = models_dir / f'checkpoint_epoch_{epoch}'
        checkpoint_dir.mkdir(exist_ok=True)

        # Save full models
        generator.save(checkpoint_dir / 'generator.h5',include_optimizer=True)
        discriminator.save(checkpoint_dir / 'discriminator.h5', include_optimizer=True)

        # Save training history
        with open(save_dir / 'training_history.json', 'w') as f:
            json.dump(history, f, indent=4)

    print("Starting training...")
    start_time = time.time()

    for epoch in range(starting_epoch, starting_epoch + epochs):
        print(f"\nEpoch {epoch + 1}")
        epoch_d_losses = []
        epoch_g_losses = []
        batch_count = 0

        for batch in dataset:
            batch_size = tf.shape(batch)[0]

            # Train discriminator
            noise = tf.random.normal([batch_size, latent_dim])
            generated_images = generator(noise, training=True)

            real_labels = tf.random.uniform([batch_size, 1], 0.8, 1.0)
            fake_labels = tf.random.uniform([batch_size, 1], 0.0, 0.2)

            d_loss_real = discriminator.train_on_batch(batch, real_labels)
            loss_real = d_loss_real[0] 

            # Train discriminator on fake images
            d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
            loss_fake = d_loss_fake[0]

            # Compute the average loss
            d_loss = 0.5 * (loss_real + loss_fake)


            # Train generator
            noise = tf.random.normal([batch_size * 2, latent_dim])
            g_loss = gan.train_on_batch(noise, tf.ones([batch_size * 2, 1]))

            epoch_d_losses.append(float(d_loss))
            epoch_g_losses.append(float(g_loss))

            if batch_count % 50 == 0:
                print(f"Batch {batch_count}: d_loss={d_loss:.4f}, g_loss={g_loss:.4f}")
            batch_count += 1

        # Save samples every sample_interval epochs
        if (epoch + 1) % sample_interval == 0:
            print(f"Generating sample images at epoch {epoch + 1}")
            save_images(epoch + 1)

        # Save checkpoint every checkpoint_interval epochs
        if (epoch + 1) % checkpoint_interval == 0:
            print(f"Saving checkpoint at epoch {epoch + 1}")
            save_checkpoint(epoch + 1, epoch_d_losses, epoch_g_losses)

        # Always update history
        with open(save_dir / 'training_history.json', 'w') as f:
            json.dump(history, f, indent=4)

    # Save final checkpoint regardless of interval
    save_checkpoint(starting_epoch + epochs, epoch_d_losses, epoch_g_losses)
    print(f"\nTraining completed in {time.time() - start_time:.1f} seconds")
    return history

In [8]:
def load_gan_checkpoint(checkpoint_dir, latent_dim=120):

    checkpoint_dir = Path(checkpoint_dir)

# Load full models
    generator = tf.keras.models.load_model(checkpoint_dir / 'generator.h5', compile=False)
    discriminator = tf.keras.models.load_model(checkpoint_dir / 'discriminator.h5', compile=False)

# Recreate GAN
    discriminator.trainable = False
    gan_input = tf.keras.Input(shape=(latent_dim,))
    gan_output = discriminator(generator(gan_input))
    gan = tf.keras.models.Model(gan_input, gan_output)

# Compile with same settings
    discriminator.compile(
        optimizer=tf.keras.optimizers.Adam(2e-8, beta_1=0.5, clipvalue=1.0),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    gan.compile(
    optimizer=tf.keras.optimizers.Adam(2e-8, beta_1=0.5, clipvalue=1.0),
    loss='binary_crossentropy'
    )

    return generator, discriminator, gan

In [10]:
def generate_and_save_images(generator_path, save_dir, num_images=50, latent_dim=100):
    # Load the trained generator model
    generator = load_model(generator_path)
    print(f"Loaded generator model from {generator_path}")

    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"Images will be saved to {save_dir.resolve()}")

    batch_size = 32
    num_full_batches = num_images // batch_size
    remaining_images = num_images % batch_size

    image_count = 0 

    def save_image(tensor, sample_number):
        # Scale from [-1, 1] to [0, 255]
        img = (tensor * 127.5 + 127.5).numpy().astype('uint8')
        img = Image.fromarray(img)

        # Get the current date 
        current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]  # Up to milliseconds

        # Create a unique filename
        img_filename = save_dir / f"image_{sample_number:05d}_{current_datetime}.png"

        # Save the image
        img.save(img_filename)

    # Generate and save full batches
    for batch in range(num_full_batches):
        noise = tf.random.normal([batch_size, latent_dim])
        generated_images = generator(noise, training=False)

        for i in range(batch_size):
            img_tensor = generated_images[i]
            image_count += 1
            save_image(img_tensor, image_count)

            if image_count % 100 == 0:
                print(f"{image_count} images generated and saved.")

    if remaining_images > 0:
        noise = tf.random.normal([remaining_images, latent_dim])
        generated_images = generator(noise, training=False)

        for i in range(remaining_images):
            img_tensor = generated_images[i]
            image_count += 1
            save_image(img_tensor, image_count)

            if image_count % 100 == 0:
                print(f"{image_count} images generated and saved.")

    print(f"Image generation complete. {image_count} images saved to {save_dir.resolve()}")
