# Summary

1. collected and preprocessed the data, Trained the model of generator, discriminator. evaluated with the metrics.

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from PIL import Image
from pathlib import Path  


In [None]:
# Parameters
IMG_SIZE = 64
BATCH_SIZE = 32
LATENT_DIM = 100
EPOCHS = 10000

In [None]:
# Load and preprocess dataset
def load_images(dataset_path):
    images = []
    for file in glob.glob(os.path.join(dataset_path, "*.jpg")):
        img = Image.open(file).convert("RGB")
        img = img.resize((IMG_SIZE, IMG_SIZE))
        img = np.array(img) / 255.0  # Normalize to [0,1]
        images.append(img)
    return np.array(images)

current_directory = Path().resolve()
data_directory = current_directory.parent / 'data'
dataset_path = data_directory / "GAN_data"
images = load_images(dataset_path)
dataset = tf.data.Dataset.from_tensor_slices(images).shuffle(1000).batch(BATCH_SIZE)

In [None]:
# Generator model
def build_generator():
    model = keras.Sequential([
        layers.Dense(8 * 8 * 256, use_bias=False, input_shape=(LATENT_DIM,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((8, 8, 256)),
        layers.Conv2DTranspose(128, (5,5), strides=(2,2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(64, (5,5), strides=(2,2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(3, (5,5), strides=(2,2), padding='same', activation='tanh')
    ])
    return model

# Discriminator model
def build_discriminator():
    model = keras.Sequential([
        layers.Conv2D(64, (5,5), strides=(2,2), padding='same', input_shape=[IMG_SIZE, IMG_SIZE, 3]),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Conv2D(128, (5,5), strides=(2,2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(1e-4), metrics=['accuracy'])

In [None]:
def build_gan(generator, discriminator):
    discriminator.trainable = False
    gan_input = keras.Input(shape=(LATENT_DIM,))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)
    gan = keras.Model(gan_input, gan_output)
    gan.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(1e-4))
    return gan

gan = build_gan(generator, discriminator)

In [None]:
# Training loop
def train(dataset, epochs):
    for epoch in range(epochs):
        for batch in dataset:
            noise = np.random.normal(0, 1, (BATCH_SIZE, LATENT_DIM))
            generated_images = generator.predict(noise)
            real_labels = np.ones((BATCH_SIZE, 1))
            fake_labels = np.zeros((BATCH_SIZE, 1))
            
            # Train discriminator
            d_loss_real = discriminator.train_on_batch(batch, real_labels)
            d_loss_fake = discriminator.train_on_batch(generated_images, fake_labels)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # Train generator
            noise = np.random.normal(0, 1, (BATCH_SIZE, LATENT_DIM))
            misleading_labels = np.ones((BATCH_SIZE, 1))
            g_loss = gan.train_on_batch(noise, misleading_labels)
            
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}/{epochs}, D Loss: {d_loss[0]}, G Loss: {g_loss}")
            generate_and_save_images(generator, epoch)

def generate_and_save_images(model, epoch):
    noise = np.random.normal(0, 1, (16, LATENT_DIM))
    generated_images = model.predict(noise)
    generated_images = (generated_images + 1) / 2.0  # Rescale to [0,1]
    fig, axs = plt.subplots(4, 4, figsize=(6,6))
    for i in range(4):
        for j in range(4):
            axs[i, j].imshow(generated_images[i * 4 + j])
            axs[i, j].axis('off')
    plt.show()

train(dataset, EPOCHS)