# DELE ST1504 CA2
# PART A: GAN

<hr>

**NAME**: Irman Zafyree, Adam Tan

**ADMIN NO**: `2300546`, `2300575`

**CLASS**: DAAA/FT/2B/07

<hr>

**Objective:**

Code a GAN model that is able to generate 260 small black-and-white images of the given dataset in 26 distinct classes.

**Background:**
A Generative Adversarial Network (GAN) is a type of deep learning model consisting of two neural networks: a generator and a discriminator. The primary purpose of a GAN is to generate new data instances that resemble the training data, through a competitive process between a generator and a discriminator. It has revolutionized the field of generative modeling and continues to be a vibrant area of research and application in artificial intelligence.





# Initial Set Up

In [None]:
# Basic imports
from dotenv import load_dotenv
import os

os.environ['TF_CPP_MIN_VLOG_LEVEL'] = '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["KERAS_BACKEND"] = "tensorflow"

import pandas as pd
import numpy as np
import numpy
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from scipy.linalg import sqrtm
from numpy.random import randn, randint,random
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm
from IPython import display

# Tensorflow imports
from tensorflow.keras.callbacks import Callback, EarlyStopping
from tensorflow.keras.metrics import Mean
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.layers import Dense, Reshape, Conv2D, LeakyReLU, Dropout, Flatten, Conv2DTranspose, BatchNormalization, Input, Embedding, Concatenate, Lambda, MaxPooling2D, UpSampling2D, Activation, GlobalAveragePooling2D, AveragePooling2D, GlobalMaxPooling2D
from tensorflow.train import Checkpoint
from tensorflow.keras.utils import to_categorical


# Sklearn imports
from sklearn.model_selection import train_test_split

# Keras
import keras
from keras import backend as K
from keras import metrics
from keras import layers

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

load_dotenv()

dataset_path = os.environ.get("DATASET_PATH")

print(dataset_path)


In [None]:

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    print(tf.config.experimental.get_device_details(gpu))
    tf.config.experimental.set_memory_growth(gpu, True)
print(f"There are {len(gpus)} GPUs")

In [None]:
df = pd.read_csv(dataset_path, header=None)
df

In [None]:
df.isnull().sum()

Background Research

**CSV Dataset:**
- the datatset consists of 99040 datasets
- the first column represents the classs labels -1 to 26.
- Remaining 784 columns are the pixel data

**Classes:**

Total of 27 distinct classes 
1. -1 (all black)
2. 1 (a)
3. 2 (b)
4. 3 (c)
5. 4 (d)
6. 5 (e)
7. 6 (f)
8. 7 (g)
9. 8 (h)
10. 9 (i)
11. 10 (j)
12. 11 (k)
13. 12 (l)
14. 13 (m)
15. 14 (n)
16. 15 (o)
17. 16 (p)
18. 17 (q)
19. 18 (r)
20. 19 (s)
21. 20 (t)
22. 21 (u)
23. 22 (v)
24. 23 (w)
25. 24 (x)
26. 25 (y)
27. 26 (z)


**Images:**

The images are of size 28x28

In [None]:
labels = df[0]

In [None]:
data = df.drop(df.columns[0], axis=1)
data.head()

In [None]:

# Function to display images
def display_images(images, label):
    plt.figure(figsize=(20, 4))
    for index, image in enumerate(images[:20]):
        plt.subplot(1, 20, index + 1)
        image = np.array(image).reshape(28, 28)
        rotated_image = np.rot90(image, k=-1)  
        flipped_horizontal = np.fliplr(rotated_image)
        plt.imshow(flipped_horizontal, cmap='gray')
        plt.axis('off')
        plt.title(f"{label}")

# Order labels from -1 to 26
ordered_labels = sorted(labels.unique(), key=lambda x: (x != -1, x))

# Iterate through each ordered label and display 10 images
for label in ordered_labels:
    label_images = data[labels == label].values
    display_images(label_images, label)
    plt.show()

In [None]:
import numpy as np

# Count occurrences of each label
unique_labels, label_counts = np.unique(labels, return_counts=True)

# Print the distribution
print("Class distribution:")
for label, count in zip(unique_labels, label_counts):
    print(f"Label {label}: {count} instances")

plt.figure(figsize=(8, 6))  # Specify the figure size (optional)
plt.bar(unique_labels, label_counts, color='skyblue')

# Customize the chart
plt.title('Class Distribution')
plt.xlabel('Classes')
plt.ylabel('Counts')
plt.grid(True)  # Add gridlines for better readability (optional)

# Display the chart
plt.show()

In [None]:
df = df[df[0] != -1]
df.reset_index(inplace=True)

In [None]:
df_labels = df[0]
df_data = df.drop(df.columns[0:2], axis=1)
df_data.head()

In [None]:
augmented_data = df_data.copy()
for i in augmented_data.index:
    pixels = augmented_data.loc[i].values
    image = np.array(pixels).reshape(28, 28)
    rotated_image = np.rot90(image, k=-1)  
    flipped_horizontal = np.fliplr(rotated_image)
    augmented_data.loc[i] = flipped_horizontal.flatten()

In [None]:
print(augmented_data)

In [None]:
for i in range(20):
    plt.subplot(4, 5, i + 1)
    plt.imshow(np.array(augmented_data.loc[i]).reshape(28, 28), cmap='gray')

In [None]:
def average_images_per_class(data, labels):

    avg_image_df = pd.DataFrame()

    # average image
    for class_name, group in data.groupby(labels):
        avg_image = group.mean()
        avg_image_df[class_name] = avg_image

    # Create a col of subplots
    fig, axes = plt.subplots(5, 6, figsize=(20, 15))
    for ax, (class_name, avg_image) in zip(axes.flatten(), avg_image_df.items()):
        # image = avg_image.reshape(28, 28)
        # rotated_image = np.rot90(image, k=-1)  
        # flipped_horizontal = np.fliplr(rotated_image)
        # ax.imshow(flipped_horizontal, cmap='gray')
        # ax.set_title(f"Average image for {class_name}")
        # ax.axis('off')  # Turn off axis

        ax.imshow(np.array(avg_image).reshape(28, 28), cmap='gray')
        ax.set_title(f"Average image for {class_name}")
        ax.axis('off')

    plt.tight_layout()
    plt.show()

average_images_per_class(augmented_data, df_labels)


# Model Research

From researching online, there seem to be 3 GAN Variations, Deep Convolutional GAN (DCGAN), Conditional GAN (CGAN), Wasserstein GAN (WGAN).

### Deep Convolutional GAN (DCGAN)

DCGANs use deep convolutional neural networks to learn the features of the input data, which allows them to generate high-resolution images that are similar to the training data. The generator network in a DCGAN typically consists of transposed convolutional layers, while the discriminator network consists of convolutional layers. The use of convolutional layers allows DCGANs to take advantage of the spatial relationships in the input data, which results in high-quality generated images.

### Conditional GAN (CGAN)

CGANs that allow the user to control the generated data by adding an additional input to the generator network. This input specifies the desired characteristics of the generated data, such as the color or shape of an image. CGANs can be thought of as a combination of a GAN and a conditional generative model, where the generator and discriminator are trained to take into account both the input data and the condition. This allows CGANs to generate data that is more in line with the user’s expectations.

### Wasserstein GAN (WGAN)

WGANs address some of the stability issues that are commonly encountered when training GANs. WGANs use the Wasserstein distance metric to evaluate the quality of the generated data, which has been shown to provide improved stability during training compared to other metrics. The Wasserstein distance metric measures the earth mover’s distance between the generated data and the training data, which provides a robust measure of the quality of the generated data. WGANs have been shown to be particularly effective for generating high-quality images.

We will be using all these 3 GAN and compare them

# Evaluation

Firstly, evaluating GAN model is hard. Unlike classification, we need to compare the generated image to a real image. But how exactly can you compare/quantify the realism of the generated image.

The main 2 evaluation metrics for GAN models is:
- **Fidelity**: Our GAN should generate _high_ quality images
- **Diversity**: Our GAN should generate images that are inherent in the training dataset

There are 2 approaches to compare the images:
- **Pixel Distance**: The naive distance measure where we just subtract the two images' pixel value. However this approach is not reliable
- **Feature Distance**: We use a pre-trained image classfication model and use the activation of an intermediate layer. This vector is a high level representation of our image. Computing the distancee mtric with the representation gives a stable and reliable result.

## Fretchet Inception Distance (FID)
This is one of the popular metrics to measure the feature distance. Frechet Distance is a measure of similarity between curves that takes into account the location and ordering of the points along the curves. 

### Creating the Base VAE Model to train


In [None]:
class VariationalAutoencoder(keras.Model):
    def __init__(self, latent_dim, input_shape):
        super(VariationalAutoencoder, self).__init__()
        self.latent_dim = latent_dim
        self.encoder = self.build_encoder(input_shape)
        self.decoder = self.build_decoder()
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    def build_encoder(self, input_shape):
        inputs = keras.Input(shape=input_shape)
        x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(inputs)
        x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
        x = layers.Flatten()(x)
        x = layers.Dense(16, activation="relu")(x)
        z_mean = layers.Dense(self.latent_dim, name="z_mean")(x)
        z_log_var = layers.Dense(self.latent_dim, name="z_log_var")(x)
        z = layers.Lambda(self.sampling, output_shape=(self.latent_dim,), name='z')([z_mean, z_log_var])
        model = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
        return model

    def build_decoder(self):
        latent_inputs = keras.Input(shape=(self.latent_dim,))
        x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
        x = layers.Reshape((7, 7, 64))(x)
        x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
        x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
        decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
        model = keras.Model(latent_inputs, decoder_outputs, name="decoder")
        return model

    def sampling(self, args):
        z_mean, z_log_var = args
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction),
                    axis=(1, 2),
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

    def plot_latent_space(self, n=30, figsize=15):
        # display a n*n 2D manifold of digits
        digit_size = 28
        scale = 1.0
        figure = np.zeros((digit_size * n, digit_size * n))
        # linearly spaced coordinates corresponding to the 2D plot
        # of digit classes in the latent space
        grid_x = np.linspace(-scale, scale, n)
        grid_y = np.linspace(-scale, scale, n)[::-1]

        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z_sample = np.array([[xi, yi]])
                x_decoded = self.decoder.predict(z_sample, verbose=0)
                digit = x_decoded[0].reshape(digit_size, digit_size)
                figure[
                    i * digit_size : (i + 1) * digit_size,
                    j * digit_size : (j + 1) * digit_size,
                ] = digit

        plt.figure(figsize=(figsize, figsize))
        start_range = digit_size // 2
        end_range = n * digit_size + start_range
        pixel_range = np.arange(start_range, end_range, digit_size)
        sample_range_x = np.round(grid_x, 1)
        sample_range_y = np.round(grid_y, 1)
        plt.xticks(pixel_range, sample_range_x)
        plt.yticks(pixel_range, sample_range_y)
        plt.xlabel("z[0]")
        plt.ylabel("z[1]")
        plt.imshow(figure, cmap="Greys_r")
        plt.show()
    
    def save_models(self, epoch):
        self.encoder.save(f'vae_encoder_{epoch}.h5')
        self.decoder.save(f'vae_decoder_{epoch}.h5')

# Example usage remains the same
        
    def restore_weights(self, epoch):
        self.encoder = keras.models.load_model("vae_encoder_500.h5", custom_objects={'sampling': self.sampling})
        self.decoder = keras.models.load_model("vae_decoder_500.h5", custom_objects={'sampling': self.sampling})
        
    def generate_images(self, n=10):
        latent_points = np.random.normal(size=(n, self.latent_dim))
        generated_images = self.decoder.predict(latent_points)
        return generated_images

In [None]:

vae_data = np.array(augmented_data)
vae_data = np.expand_dims(vae_data, -1).astype("float32") / 255

early_stopping = keras.callbacks.EarlyStopping(monitor="loss", patience=10)

vaeNew = VariationalAutoencoder(2, (28, 28, 1))
'''
vaeNew.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001))
vaeNew.fit(vae_data, epochs=500, batch_size=64, callbacks=[early_stopping])
vaeNew.plot_latent_space()
'''

In [None]:
vaeNew.restore_weights(500)

### Doing FID calculations

In [None]:
# calculate frechet inception distance
def calculate_fid(encoder, real_data, fake_data):

	z_mean_act1, _, _ = encoder.predict(real_data, verbose=0)
	z_mean_act2, _, _ = encoder.predict(fake_data, verbose=0)
	# calculate mean and covariance statistics
	mu1, sigma1 = z_mean_act1.mean(axis=0), cov(z_mean_act1, rowvar=False)
	mu2, sigma2 = z_mean_act2.mean(axis=0), cov(z_mean_act2, rowvar=False)
	# calculate sum squared difference between means
	ssdiff = numpy.sum((mu1 - mu2)**2.0)
	# calculate sqrt of product between cov
	covmean = sqrtm(sigma1.dot(sigma2))
	# check and correct imaginary numbers from sqrt
	if iscomplexobj(covmean):
		covmean = covmean.real
	# calculate score
	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid

In [None]:
df = pd.read_csv(dataset_path, header=None)

train_labels = df[0]
train_data = df.drop(0, axis=1)
augmented_data = []
for i in train_data.index:
    pixels = train_data.loc[i].values
    image = np.array(pixels).reshape(28,28)
    rotated_image = np.rot90(image, k=-1)
    flipped_horizontal = np.fliplr(rotated_image)
    augmented_data.append(flipped_horizontal)

train_data = np.array(augmented_data)
train_data = train_data.reshape(train_data.shape[0], 28, 28, 1).astype('float32')
train_data = (train_data - 127.5) / 127.5 # Normalize the images to [-1, 1]

In [None]:
# test the calculation of fid
real_data = train_data[:100]
fake_data = randint(0, 255, 28 * 28 * 100)
fake_data = fake_data.reshape((100, 28, 28, 1)).astype('float32')
calculate_fid(vaeNew.encoder, real_data, fake_data)


## Inception Score
The Inception Score, or IS for short, is an objective metric for evaluating the quality of generated images, specifically synthetic images output by generative adversarial network models. However one issue is that it require a classification model to use. 

### Classifier

In [None]:

# classifier
class Classifier():
    def __init__(self, input_shape, num_classes):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model = self.build_model()
        self.callbacks = [EarlyStopping(monitor='val_loss', patience=5)]

    def build_model(self):
        model = Sequential()
        model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=self.input_shape))
        model.add(Conv2D(64, (3, 3), activation='relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(128, activation='relu'))
        model.add(Dropout(0.5))
        model.add(Dense(self.num_classes, activation='softmax'))
        model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adadelta(), metrics=['accuracy'])
        return model

    def train(self, x_train, y_train, x_test, y_test, epochs=10, batch_size=128):
        self.model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test), callbacks=self.callbacks)

    def evaluate(self, x_test, y_test):
        return self.model.evaluate(x_test, y_test, verbose=0)

    def save_model(self, model_name):
        self.model.save(model_name)

    def load_model(self, model_name):
        self.model = keras.models.load_model(model_name)

In [None]:
def augment_images(data):
    augmented_data = []
    for i in data.index:
        pixels = data.loc[i].values
        image = np.array(pixels).reshape(28, 28)
        rotated_image = np.rot90(image, k=-1)  
        flipped_horizontal = np.fliplr(rotated_image)
        augmented_data.append(flipped_horizontal)
    return augmented_data


In [None]:
classifier = Classifier((28, 28, 1), 26)

df = pd.read_csv(dataset_path, header=None)
df = df[df[0] != -1]
train, test = train_test_split(df, test_size=0.2)

train_labels = train[0]
train_labels = list(map(lambda x: x - 1, train_labels))
train_labels = to_categorical(train_labels, num_classes=26)
train_data = train.drop(train.columns[0], axis=1)
train_data = augment_images(train_data)
train_data = np.expand_dims(train_data, -1).astype("float32") / 255

test_labels = test[0]
test_labels = list(map(lambda x: x - 1, test_labels))
test_labels = to_categorical(test_labels, num_classes=26)
test_data = test.drop(test.columns[0], axis=1)
test_data = augment_images(test_data)
test_data = np.expand_dims(test_data, -1).astype("float32") / 255

classifier.load_model("classifier.h5")
#classifier.train(train_data, train_labels, test_data, test_labels, epochs=400, batch_size=64)


too hard i give up

In [None]:
def calculate_kl (p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))

def estimate_distribution(data, bins=100):
    hist, bin_edges = np.histogram(data, bins=bins, density=True)
    return hist, bin_edges

def kl_divergence(real_data, fake_data, bins=100):
    hist_real, bin_edges_real = estimate_distribution(real_data, bins)
    hist_fake, bin_edges_fake = estimate_distribution(fake_data, bins)
    kl = calculate_kl(hist_real, hist_fake)
    return kl

def calculate_jsd(real_data, fake_data, bins=100):
    hist_real, bin_edges_real = estimate_distribution(real_data, bins)
    hist_fake, bin_edges_fake = estimate_distribution(fake_data, bins)
    hist_avg = 0.5 * (hist_real + hist_fake)
    jsd = 0.5 * (calculate_kl(hist_real, hist_avg) + calculate_kl(hist_fake, hist_avg))
    return jsd

In [None]:
real_data_kl = train_data[:100]
fake_data_kl = randint(0, 255, 28 * 28 * 100)
fake_data_kl = fake_data_kl.reshape((100, 28, 28, 1)).astype('float32')
print(f"KL Real vs Real: {kl_divergence(real_data_kl, real_data_kl)}")
print(f"KL Real vs Fake: {kl_divergence(real_data_kl, fake_data_kl)}")

real_data_jsd = train_data[:100]
fake_data_jsd = randint(0, 255, 28 * 28 * 100)
fake_data_jsd = fake_data_jsd.reshape((100, 28, 28, 1)).astype('float32')
print(f"JSD Real vs Real: {calculate_jsd(real_data_jsd, real_data_jsd)}")
print(f"JSD Real vs Fake: {calculate_jsd(real_data_jsd, fake_data_jsd)}")

# Initial Modelling

## Background on GAN

A generative adversarial network (GAN) has two parts:
- Generator
  - learns to create "real" data
  - generated data becomes the fake training examples
- Discriminator
  - distinguish between real and fake data
  - penalizes the generator if discriminator detects the fake generated data

When the training begins, the generator fake data is easily detectable by the discriminator. However, if the generator training goes well, the fake data becomes indistinguishable to the discriminator.

Below is an example of how the generator improves over time:
1. ![Beginning](https://developers.google.com/static/machine-learning/gan/images/bad_gan.svg)
2. ![Middle](https://developers.google.com/static/machine-learning/gan/images/ok_gan.svg)
3. ![End](https://developers.google.com/static/machine-learning/gan/images/good_gan.svg)

So this is how the entire system will look.
![Entire GAN system](https://developers.google.com/static/machine-learning/gan/images/gan_diagram.svg)

### Discriminator
The main goal of the discriminator is to distinguish real data from the fake generated data. It could use any network architecture appropriate to the type of data it's classifying.

The discriminator training data come from 2 sources.
- Real data comes from the initial dataset that you would like to generate like, in this case, it would be out `emnist-letter.csv`
- Fake data comes from the generator.

The discriminator trains by:
1. Classifying real and fake
2. The loss penalizes the discriminator for misclassifying
3. The discriminator updates it weights through backpropagation

### Generator
The main goal of the generator is to trick the discriminator into thinking it fake data is real.

The generator has a random input. It is usually random noise, and the generator turns the noise into meaningful data. Experiments suggest that the distribution of the noise doesn't matter much, so we can choose something that's easy to sample from, like a uniform distribution.

The generators trains by:
1. Sample noise
2. Produce output
3. Get discriminator to classify
4. Get the loss from the discriminator
5. Backpropagate through the generator network
6. Updates the weight

## Set Up Dataset

In [None]:
df = pd.read_csv(dataset_path, header=None)

train_labels = df[0]
train_data = df.drop(0, axis=1)
augmented_data = []
for i in train_data.index:
    pixels = train_data.loc[i].values
    image = np.array(pixels).reshape(28,28)
    rotated_image = np.rot90(image, k=-1)
    flipped_horizontal = np.fliplr(rotated_image)
    augmented_data.append(flipped_horizontal)

train_data = np.array(augmented_data)
train_data = train_data.reshape(train_data.shape[0], 28, 28, 1).astype('float32')
train_data = (train_data - 127.5) / 127.5 # Normalize the images to [-1, 1]


In [None]:
# from [-1,1 ] to [0, 1]
def denormalize_image(image):
    image = (image + 1) / 2
    return image

## DCGAN

In [None]:
class DCGAN:
    # Define the standalone discriminator model
    def define_discriminator(self,in_shape=(28,28,1)):
        model = Sequential()
        model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Conv2D(128, (3,3), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Conv2D(256, (3,3), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Flatten())
        model.add(Dropout(0.4))
        model.add(Dense(1, activation='sigmoid'))
        opt = Adam(learning_rate=0.0002, beta_1=0.5)
        model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
        return model

    # Define the standalone generator model
    def define_generator(self):
        latent_dim = self.latent_dim
        model = Sequential()
        n_nodes = 128 * 7 * 7
        model.add(Dense(n_nodes, input_dim=latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Reshape((7, 7, 128)))
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Conv2D(1, (3,3), activation='tanh', padding='same'))
        return model

    # Define the combined generator and discriminator model, for updating the generator
    def define_gan(self):
        d_model = self.d_model
        g_model = self.g_model
        d_model.trainable = False
        model = Sequential()
        model.add(g_model)
        model.add(d_model)
        opt = Adam(learning_rate=0.0002, beta_1=0.5)
        model.compile(loss='binary_crossentropy', optimizer=opt)
        return model

    # Select real samples
    def generate_real_samples(self, n_samples):
        dataset = self.dataset
        ix = randint(0, dataset.shape[0], n_samples)
        X = dataset[ix]
        y = np.ones((n_samples, 1))
        return X, y

    # Generate points in latent space as input for the generator
    def generate_latent_points(self, n_samples):
        latent_dim = self.latent_dim
        x_input = randn(latent_dim * n_samples)
        x_input = x_input.reshape(n_samples, latent_dim)
        return x_input

    # Use the generator to generate n fake examples, with class labels
    def generate_fake_samples(self, n_samples, seeded=False):
        g_model = self.g_model
        latent_dim = self.latent_dim
        x_input = self.generate_latent_points(n_samples)
        if seeded:
            x_input = self.seeded_latent_points
        X = g_model.predict(x_input,verbose=0)
        y = np.zeros((n_samples, 1))
        return X, y

    # Train the generator and discriminator
    def train(self,n_epochs=200, n_batch=128, start_epoch=0, threshold=0.5):
        g_model = self.g_model
        d_model = self.d_model
        gan_model = self.gan_model
        dataset = self.dataset
        latent_dim = self.latent_dim
        self.epochs = n_epochs
        bat_per_epo = int(dataset.shape[0] / n_batch)
        half_batch = int(n_batch / 2)
        old_g_loss = 0

        for i in range(start_epoch, n_epochs):
            j = 0
            for j in tqdm(range(bat_per_epo), desc=f'Epoch {i + 1}/{n_epochs}', total=bat_per_epo):
                X_real, y_real = self.generate_real_samples(half_batch)
                #print(1)
                d_loss1, _ = d_model.train_on_batch(X_real, y_real)
                X_fake, y_fake = self.generate_fake_samples(half_batch)
                #print(2)
                d_loss2, _ = d_model.train_on_batch(X_fake, y_fake)
                X_gan = self.generate_latent_points(n_batch)
                y_gan = np.ones((n_batch, 1))
                g_loss = gan_model.train_on_batch(X_gan, y_gan)
                #print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))

            # Stopping condition
            diff = abs(g_loss - old_g_loss)
            if diff < threshold:
                self.on_epoch_end(i, g_loss=g_loss, d_loss1=d_loss1, d_loss2=d_loss2, diff=diff)
                break
            old_g_loss = g_loss
            
            # Call the callback at the end of each epoch
            self.on_epoch_end(i, g_loss=g_loss, d_loss1=d_loss1, d_loss2=d_loss2, diff=diff)
    
    def on_epoch_end(self, epoch, logs=None, g_loss=None, d_loss1=None, d_loss2=None, diff=None):
        #clear output
        display.clear_output(wait=True)
        #print logs
        X_real, y_real = self.generate_real_samples(self.seeded_latent_points.shape[0])
        _, acc_real = self.d_model.evaluate(X_real, y_real, verbose=0)
        x_fake, y_fake = self.generate_fake_samples(self.seeded_latent_points.shape[0], seeded=True)
        _, acc_fake = self.d_model.evaluate(x_fake, y_fake, verbose=0)
        #print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
        if (epoch + 1) % self.save_interval == 0:
            self.save_plot(x_fake, epoch)
            g_filename = os.path.join(self.save_dir, 'generator_model_%03d.h5' % (epoch+1))
            self.g_model.save(g_filename)
            d_filename = os.path.join(self.save_dir, 'discriminator_model_%03d.h5' % (epoch+1))
            self.d_model.save(d_filename)
        
        # logging
        # FID
        fid = calculate_fid(vaeNew.encoder, X_real, x_fake)
        x_fake_denorm = denormalize_image(x_fake)
        x_real_denorm = denormalize_image(X_real)
        kl = kl_divergence(x_real_denorm, x_fake_denorm)
        jsd = calculate_jsd(x_real_denorm, x_fake_denorm)
        self.logging[epoch] = {'acc_real': acc_real, 'acc_fake': acc_fake, 'fid': fid, 'kl': kl, 'jsd': jsd, 'g_loss': g_loss, 'd_loss1': d_loss1, 'd_loss2': d_loss2}

        print(f'FID: {fid}, KL: {kl}, JSD: {jsd} | G Loss: {g_loss}, D Loss Real: {d_loss1}, D Loss Fake: {d_loss2}, Diff: {diff} | Epoch {epoch + 1}/{self.epochs}')
        

    def save_plot(self, examples, epoch, n=7):
        examples = (examples + 1) / 2.0 # scale from [-1,1] to [0,1]
        for i in range(min(n * n, examples.shape[0])):
            plt.subplot(n, n, 1 + i)
            plt.axis('off')
            plt.imshow(examples[i], cmap='gray')
        filename = os.path.join(self.save_dir, 'generated_plot_e%03d.png' % (epoch+1))
        plt.savefig(filename)
        plt.close()
    
    def restore_start_training(self, epoch):
        #dcgan_gen_img\generator_model_076.h5
        self.g_model = tf.keras.models.load_model(os.path.join(self.save_dir, f'generator_model_{epoch:03d}.h5'))
        self.d_model = tf.keras.models.load_model(os.path.join(self.save_dir, f'discriminator_model_{epoch:03d}.h5'))
        self.gan_model = self.define_gan()
    
    # Init the class
    def __init__(self, latent_dim=100, dataset=None, save_interval=1, save_dir='generated_images', clear_dir=False):
        self.latent_dim = latent_dim
        self.d_model = self.define_discriminator()
        self.g_model = self.define_generator()
        self.gan_model = self.define_gan()
        self.dataset = dataset
        self.save_interval = save_interval
        self.save_dir = save_dir
        self.clear_dir = clear_dir
        self.seeded_latent_points = self.generate_latent_points(1000)
        self.logging = {}
        print(self.seeded_latent_points.shape)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        if clear_dir:
            for f in os.listdir(save_dir):
                os.remove(os.path.join(save_dir, f))
        

In [None]:
dcgan = DCGAN(latent_dim=100, dataset=train_data, save_interval=1, save_dir='dcgan_gen_img_v2', clear_dir=True)


In [None]:
dcgan.train(n_epochs=250, n_batch=128, start_epoch=0, threshold=0.01)

## CGAN

In [None]:
class CDCGAN:
    # Define the standalone discriminator model
    def define_discriminator(self,in_shape=(28,28,1)):
        # label input
        n_classes = self.n_classes
        in_label = Input(shape=(1,))
        li = Embedding(n_classes, 50)(in_label)
        n_nodes = in_shape[0] * in_shape[1]
        li = Dense(n_nodes)(li)
        li = Reshape((in_shape[0], in_shape[1], 1))(li)

        # image input
        in_image = Input(shape=in_shape)

        # concat label as a channel
        merge = Concatenate()([in_image, li])

        model = Conv2D(64, (3,3), padding='same')(merge)
        model = LeakyReLU(alpha=0.2)(model)
        model = Conv2D(128, (3,3), strides=(2,2), padding='same')(model)
        model = LeakyReLU(alpha=0.2)(model)
        model = Conv2D(128, (3,3), strides=(2,2), padding='same')(model)
        model = LeakyReLU(alpha=0.2)(model)
        model = Conv2D(256, (3,3), strides=(2,2), padding='same')(model)
        model = LeakyReLU(alpha=0.2)(model)
        model = Flatten()(model)
        model = Dropout(0.4)(model)
        out_layer = Dense(1, activation='sigmoid')(model)
        model = Model([in_image, in_label], out_layer)
        opt = Adam(learning_rate=0.0002, beta_1=0.5)
        model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
        return model

    # Define the standalone generator model
    def define_generator(self):
        latent_dim = self.latent_dim

        # label input'
        in_label = Input(shape=(1,))
        li = Embedding(self.n_classes, 50)(in_label)
        n_nodes = 7 * 7
        li = Dense(n_nodes)(li)
        li = Reshape((7, 7, 1))(li)

        # image generator input
        in_lat = Input(shape=(latent_dim,))
        n_nodes = 128 * 7 * 7
        gen = Dense(n_nodes)(in_lat)
        gen = LeakyReLU(alpha=0.2)(gen)
        gen = Reshape((7, 7, 128))(gen)

        # merge image gen and label input
        merge = Concatenate()([gen, li])

        model = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)
        model = LeakyReLU(alpha=0.2)(model)
        model = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(model)
        model = LeakyReLU(alpha=0.2)(model)
        model = Conv2D(1, (3,3), activation='tanh', padding='same')(model)
        model = Model([in_lat, in_label], model)
        return model

    # Define the combined generator and discriminator model, for updating the generator
    def define_gan(self):
        d_model = self.d_model
        g_model = self.g_model
        d_model.trainable = False
        g_noise, g_label = g_model.input
        g_output = g_model.output
        gan_output = d_model([g_output, g_label])
        model = Model([g_noise, g_label], gan_output)
        opt = Adam(learning_rate=0.0002, beta_1=0.5)
        model.compile(loss='binary_crossentropy', optimizer=opt)
        return model

    # Select real samples
    def generate_real_samples(self, n_samples):
        dataset = self.dataset
        images, labels = dataset
        ix = randint(0, images.shape[0], n_samples)
        X, labels = images[ix], labels[ix]
        y = np.ones((n_samples, 1))
        return [X, labels], y

    # Generate points in latent space as input for the generator
    def generate_latent_points(self, n_samples):
        latent_dim = self.latent_dim
        n_classes = self.n_classes

        x_input = randn(latent_dim * n_samples)
        z_input = x_input.reshape(n_samples, latent_dim)
        labels = randint(0, n_classes, n_samples)
        return [z_input, labels]

    # Use the generator to generate n fake examples, with class labels
    def generate_fake_samples(self, n_samples, seeded=False):
        g_model = self.g_model
        latent_dim = self.latent_dim
        z_input, labels_input = self.generate_latent_points(n_samples)
        if seeded:
            z_input, labels_input = self.seeded_latent_points
        images = g_model.predict([z_input, labels_input], verbose=0)
        y = np.zeros((n_samples, 1))
        return [images, labels_input], y

    def save_plot(self, examples, epoch, n=7):
        examples = (examples + 1) / 2.0
        for i in range(n * n):
            plt.subplot(n, n, 1 + i)
            plt.axis('off')
            plt.imshow(examples[i], cmap='gray')
        filename = os.path.join(self.save_dir, 'generated_plot_e%03d.png' % (epoch+1))
        plt.savefig(filename)
        plt.close()

    def on_epoch_end(self, epoch, logs=None):
        #clear output
        display.clear_output(wait=True)
        #print logs
        [X_real, labels_real], y_real = self.generate_real_samples(100)
        _, acc_real = self.d_model.evaluate([X_real, labels_real], verbose=0)
        [x_fake, labels], y_fake = self.generate_fake_samples(100, seeded=True)
        _, acc_fake = self.d_model.evaluate([x_fake, labels], y_fake, verbose=0)
        print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
        if (epoch + 1) % self.save_interval == 0:
            self.save_plot(x_fake, epoch)
            g_filename = os.path.join(self.save_dir, 'generator_model_%03d.h5' % (epoch+1))
            self.g_model.save(g_filename)
            d_filename = os.path.join(self.save_dir, 'discriminator_model_%03d.h5' % (epoch+1))
            self.d_model.save(d_filename)


    # Train the generator and discriminator
    def train(self,n_epochs=200, n_batch=128):
        g_model = self.g_model
        d_model = self.d_model
        gan_model = self.gan_model
        dataset = self.dataset
        latent_dim = self.latent_dim
        bat_per_epo = int(dataset[0].shape[0] / n_batch)
        half_batch = int(n_batch / 2)

        for i in range(n_epochs):
            j = 0
            for j in tqdm(range(bat_per_epo), desc=f'Epoch {i + 1}/{n_epochs}', total=bat_per_epo):
                [X_real, labels_real], y_real = self.generate_real_samples(half_batch)
                d_loss1, _ = d_model.train_on_batch([X_real, labels_real], y_real)
                [X_fake, labels], y_fake = self.generate_fake_samples(half_batch)
                d_loss2, _ = d_model.train_on_batch([X_fake, labels], y_fake)
                [z_input, labels_input] = self.generate_latent_points(n_batch)
                y_gan = np.ones((n_batch, 1))
                g_loss = gan_model.train_on_batch([z_input, labels_input], y_gan)
                #print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
            
            # Call the callback at the end of each epoch
            self.on_epoch_end(i)
            
    # Init the class
    def __init__(self, latent_dim=100, n_classes=10, dataset=None, save_interval=1, save_dir='generated_images', clear_dir=False):
        self.latent_dim = latent_dim
        self.dataset = dataset
        self.n_classes = n_classes
        self.save_interval = save_interval
        self.save_dir = save_dir
        self.clear_dir = clear_dir
        self.d_model = self.define_discriminator()
        self.g_model = self.define_generator()
        self.gan_model = self.define_gan()
        self.seeded_latent_points = self.generate_latent_points(100)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        if clear_dir:
            for f in os.listdir(save_dir):
                os.remove(os.path.join(save_dir, f))
    

In [None]:
#train_dataset = (train_data, train_labels)
#cdcgan = CDCGAN(latent_dim=100, n_classes=26, dataset=train_dataset, save_interval=3, save_dir='cdcgan_gen_img', clear_dir=True)
#cdcgan.train(n_epochs=1000, n_batch=64)