In [None]:
import os, sys, json
import tensorflow as tf
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from pathlib import Path
import time
from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model
from PIL import Image

In [None]:
sys.path.append(os.path.abspath(os.path.join('..','data_processing')))
sys.path.append(os.path.abspath(os.path.join('..','models')))

In [None]:
from cgan_preprocessing import dataset


In [None]:
from cgan import create_cgan
from cgan2 import create_cgan2

In [None]:
from cgan3 import create_cgan3

## cGan (64x64)

In [None]:
def train_cgan(
        generator,
        discriminator,
        cgan,
        dataset,
        epochs=100,
        batch_size=32,
        latent_dim=100,
        save_dir='cgan_training',
        starting_epoch=0,
        checkpoint_interval=10,
        sample_interval=5
):
    save_dir = Path(save_dir)
    models_dir = save_dir / 'models'
    samples_dir = save_dir / 'samples'

    # Create directories if they dont exist
    for dir_path in [save_dir, models_dir, samples_dir]:
        dir_path.mkdir(parents=True, exist_ok=True)

    # Load existing history if resuming training
    history_file = save_dir / 'training_history.json'
    if starting_epoch > 0 and history_file.exists():
        with open(history_file, 'r') as f:
            history = json.load(f)
        print(f"Loaded existing training history from epoch {starting_epoch}")
    else:
        history = {
            'training_config': {
                'batch_size': int(batch_size),
                'latent_dim': int(latent_dim),
                'starting_epoch': int(starting_epoch),
                'total_epochs': int(epochs),
                'checkpoint_interval': int(checkpoint_interval),
                'sample_interval': int(sample_interval),
                'start_time': datetime.now().isoformat()
            },
            'epochs': []
        }

    def save_images(epoch):
        try:
            print(f"Attempting to generate images for epoch {epoch}")
            print(f"Saving to directory: {samples_dir}")

            # Generate images for different conditions
            conditions = [
                [0, 0],  # Female, Not Smiling
                [0, 1],  # Female, Smiling
                [1, 0],  # Male, Not Smiling
                [1, 1]   # Male, Smiling
            ]
            rows = 4  # One row for each condition
            cols = 4  # Number of samples per condition

            plt.figure(figsize=(12, 12))

            for condition_idx, condition in enumerate(conditions):
                noise = tf.random.normal([cols, latent_dim])
                condition_batch = tf.tile(tf.constant([condition]), [cols, 1])

                generated = generator([noise, condition_batch], training=False)
                generated = (generated + 1) / 2.0  # Normalize to [0,1]

                condition_label = f"{'Male' if condition[0] else 'Female'}, {'Smiling' if condition[1] else 'Not Smiling'}"

                for col in range(cols):
                    plt.subplot(rows, cols, condition_idx * cols + col + 1)
                    if col == 0:
                        plt.ylabel(condition_label, fontsize=8)
                    plt.imshow(generated[col])
                    plt.axis('off')

            save_path = samples_dir / f'epoch_{epoch}.png'
            print(f"Attempting to save figure to: {save_path}")
            plt.savefig(save_path)
            print(f"Successfully saved figure to: {save_path}")

            plt.close()

            if save_path.exists():
                print(f"Verified: File exists at {save_path}")
                print(f"File size: {save_path.stat().st_size} bytes")
            else:
                print(f"Warning: File was not found at {save_path} after saving")

        except Exception as e:
            print(f"Error in save_images: {str(e)}")
            print(f"Error type: {type(e)}")
            import traceback
            print(f"Traceback: {traceback.format_exc()}")

    def save_checkpoint(epoch, d_losses, g_losses):
        epoch_stats = {
            'epoch_number': int(epoch),
            'epoch_completed': datetime.now().isoformat(),
            'mean_d_loss': float(np.mean(d_losses)),
            'mean_g_loss': float(np.mean(g_losses)),
            'std_d_loss': float(np.std(d_losses)),
            'std_g_loss': float(np.std(g_losses))
        }
        history['epochs'].append(epoch_stats)

        checkpoint_dir = models_dir / f'checkpoint_epoch_{epoch}'
        checkpoint_dir.mkdir(exist_ok=True)

        generator.save(checkpoint_dir / 'generator.h5', include_optimizer=True)
        discriminator.save(checkpoint_dir / 'discriminator.h5', include_optimizer=True)

        with open(history_file, 'w') as f:
            json.dump(history, f, indent=4)

    print(f"Starting/Resuming training from epoch {starting_epoch + 1}...")
    start_time = time.time()

    for epoch in range(starting_epoch, starting_epoch + epochs):
        print(f"\nEpoch {epoch + 1}")
        epoch_d_losses = []
        epoch_g_losses = []
        batch_count = 0

        for batch_images, batch_labels in dataset:
            batch_size = tf.shape(batch_images)[0]

            # Train discriminator
            noise = tf.random.normal([batch_size, latent_dim])
            generated_images = generator([noise, batch_labels], training=True)

            # Add label noise for better training stability
            real_labels = tf.random.uniform([batch_size, 1], 0.8, 1.0)
            fake_labels = tf.random.uniform([batch_size, 1], 0.0, 0.2)

            # Train discriminator on real images
            d_loss_real = discriminator.train_on_batch(
                [batch_images, batch_labels],
                real_labels
            )
            loss_real = d_loss_real[0]

            # Train discriminator on fake images
            d_loss_fake = discriminator.train_on_batch(
                [generated_images, batch_labels],
                fake_labels
            )
            loss_fake = d_loss_fake[0]

            d_loss = 0.5 * (loss_real + loss_fake)

            # Train generator
            noise = tf.random.normal([batch_size * 2, latent_dim])
            # Double the labels batch for generator training
            g_labels = tf.tile(batch_labels, [2, 1])
            g_loss = cgan.train_on_batch(
                [noise, g_labels],
                tf.ones([batch_size * 2, 1])
            )

            epoch_d_losses.append(float(d_loss))
            epoch_g_losses.append(float(g_loss))

            if batch_count % 50 == 0:
                print(f"Batch {batch_count}: d_loss={d_loss:.4f}, g_loss={g_loss:.4f}")
            batch_count += 1

        absolute_epoch = epoch + 1

        # Save samples every sample_interval epochs
        if absolute_epoch % sample_interval == 0:
            print(f"Generating sample images at epoch {absolute_epoch}")
            save_images(absolute_epoch)

        # Save checkpoint every checkpoint_interval epochs
        if absolute_epoch % checkpoint_interval == 0:
            print(f"Saving checkpoint at epoch {absolute_epoch}")
            save_checkpoint(absolute_epoch, epoch_d_losses, epoch_g_losses)

        # Always update history
        with open(history_file, 'w') as f:
            json.dump(history, f, indent=4)

    # Save final checkpoint
    save_checkpoint(starting_epoch + epochs, epoch_d_losses, epoch_g_losses)
    print(f"\nTraining completed in {time.time() - start_time:.1f} seconds")
    return history

In [None]:
def load_gan_checkpoint(checkpoint_dir, latent_dim=100):

    checkpoint_dir = Path(checkpoint_dir)

    # Load full models
    generator = tf.keras.models.load_model(checkpoint_dir / 'generator.h5', compile=False)
    discriminator = tf.keras.models.load_model(checkpoint_dir / 'discriminator.h5', compile=False)

    optimizer_g = tf.keras.optimizers.Adam(learning_rate=2e-9, beta_1=0.5, beta_2=0.999)
    optimizer_d = tf.keras.optimizers.Adam(learning_rate=2e-9, beta_1=0.5, beta_2=0.999)

    # Recreate GAN
    discriminator.trainable = False

    noise_input = layers.Input(shape=(latent_dim,))
    label_input = layers.Input(shape=(2,))

    generated_images = generator([noise_input, label_input])
    validity = discriminator([generated_images, label_input])

    cgan = models.Model([noise_input, label_input], validity)

    # Compile with same settings
    discriminator.compile(
        optimizer=optimizer_d,
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=False, label_smoothing=0.1),
        metrics=['accuracy']
    )

    cgan.compile(
        optimizer=optimizer_g,
        loss=tf.keras.losses.BinaryCrossentropy(from_logits=False)
    )

    return generator, discriminator, cgan

In [None]:
    # To continue training:
# Load the last checkpoint
last_epoch = 150  # or whatever your last epoch was
checkpoint_path = f'../results/cgan2/models/checkpoint_epoch_150'
generator, discriminator, cgan = load_gan_checkpoint(checkpoint_path)

# Continue training
his = train_cgan(
    generator=generator,
    discriminator=discriminator,
    cgan=cgan,
    dataset=dataset,
    epochs=50,  # additional epochs
    save_dir='../results/cgan2',
    starting_epoch=last_epoch  # continue from last epoch
)

In [None]:
time.sleep(60*45)

In [None]:
    # To continue training:
# Load the last checkpoint
last_epoch = 200  # or whatever your last epoch was
checkpoint_path = f'../results/cgan2/models/checkpoint_epoch_200'
generator, discriminator, cgan = load_gan_checkpoint(checkpoint_path)

# Continue training
his2 = train_cgan(
    generator=generator,
    discriminator=discriminator,
    cgan=cgan,
    dataset=dataset,
    epochs=50,  # additional epochs
    save_dir='../results/cgan2',
    starting_epoch=last_epoch  # continue from last epoch
)

In [None]:
time.sleep(60*45)

In [None]:
    # To continue training:
# Load the last checkpoint
last_epoch = 250  # or whatever your last epoch was
checkpoint_path = f'../results/cgan2/models/checkpoint_epoch_250'
generator, discriminator, cgan = load_gan_checkpoint(checkpoint_path)

# Continue training
his7v2 = train_cgan(
    generator=generator,
    discriminator=discriminator,
    cgan=cgan,
    dataset=dataset,
    epochs=50,  # additional epochs
    save_dir='../results/cgan2',
    starting_epoch=last_epoch  # continue from last epoch
)

# Cgan 2 

In [None]:
# Create your models
generator, discriminator, cgan = create_cgan2()

# Train the model
history3 = train_cgan(
    generator,
    discriminator,
    cgan,
    dataset,
    epochs=100,
    batch_size=32,
    save_dir='../results/cgan2_updated'
)

# Cgan 3

In [None]:
# Create your models
generator3, discriminator3, cgan3 = create_cgan3()

# Train the model
history3v2 = train_cgan(
    generator3,
    discriminator3,
    cgan3,
    dataset,
    epochs=100,
    batch_size=32,
    save_dir='../results/cgan3'
)

In [None]:
def generate_and_save_condition_images(generator_path, save_dir, num_images_per_condition=100, latent_dim=120):
    # Load the trained generator model
    generator = load_model(generator_path)
    print(f"Loaded generator model from {generator_path}")

    # Define the conditions
    conditions = [
        {'label': 'Female_Not_Smiling', 'attributes': [0, 0]},
        {'label': 'Female_Smiling', 'attributes': [0, 1]},
        {'label': 'Male_Not_Smiling', 'attributes': [1, 0]},
        {'label': 'Male_Smiling', 'attributes': [1, 1]}
    ]

    # Ensure the base save directory exists
    save_dir = Path(save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"Images will be saved to {save_dir.resolve()}")

    def save_image(tensor, condition_label, sample_number):
        img = (tensor + 1.0) * 127.5  
        img = img.numpy().astype('uint8') 
        img = Image.fromarray(img)

        # Get the current date 
        current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] 

        # Create a unique filename
        img_filename = save_dir / condition_label / f"image_{sample_number:05d}_{current_datetime}.png"

        # Save the image
        img.save(img_filename)

    # Iterate over each condition and generate images
    for condition in conditions:
        condition_label = condition['label']
        condition_attributes = condition['attributes']

        condition_dir = save_dir / condition_label
        condition_dir.mkdir(parents=True, exist_ok=True)
        print(f"Saving images for condition '{condition_label}' to directory: {condition_dir.resolve()}")

        image_count = 0 
        batch_size = 32
        num_full_batches = num_images_per_condition // batch_size
        remaining_images = num_images_per_condition % batch_size

        # Generate and save full batches
        for batch in range(num_full_batches):
            # Generate noise
            noise = tf.random.normal([batch_size, latent_dim])

            # Create condition batch
            condition_batch = tf.tile(tf.constant([condition_attributes], dtype=tf.float32), [batch_size, 1])

            # Generate images
            generated_images = generator([noise, condition_batch], training=False)

            for i in range(batch_size):
                img_tensor = generated_images[i]
                image_count += 1
                save_image(img_tensor, condition_label, image_count)

                if image_count % 100 == 0:
                    print(f"{image_count} images generated and saved for condition '{condition_label}'.")
        if remaining_images > 0:
            noise = tf.random.normal([remaining_images, latent_dim])
            condition_batch = tf.tile(tf.constant([condition_attributes], dtype=tf.float32), [remaining_images, 1])
            generated_images = generator([noise, condition_batch], training=False)
            
            for i in range(remaining_images):
                img_tensor = generated_images[i]
                image_count += 1
                save_image(img_tensor, condition_label, image_count)

                if image_count % 100 == 0:
                    print(f"{image_count} images generated and saved for condition '{condition_label}'.")

        print(f"Completed generating {image_count} images for condition '{condition_label}'.")

    print(f"\nImage generation complete. All images saved to {save_dir.resolve()}")


In [None]:
generate_and_save_condition_images('../results/cgan3/models/checkpoint_epoch_100/generator.h5', '../../data/cgan2_images/', num_images_per_condition=94,latent_dim=100)