In [5]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import os
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from PIL import UnidentifiedImageError
import matplotlib.pyplot as plt

# Create a nested folder for the generated images
output_folder = './Data/fits_filtered2/ConditionalGAN'
os.makedirs(output_folder, exist_ok=True)

# Step 1: Load and preprocess the dataset with labels
def load_images_from_csv(csv_file, folder, image_size=(64, 64)):
    df = pd.read_csv(csv_file)
    images = []
    labels = []
    for _, row in df.iterrows():
        image_path = os.path.join(folder, row['output'])
        try:
            img = load_img(image_path, target_size=image_size)
            images.append(img_to_array(img))
            labels.append(row['label'])
        except (UnidentifiedImageError, OSError):
            print(f"Skipping file {image_path}, as it is not a valid image.")
    images = np.array(images)
    labels = np.array(labels)
    images = (images - 127.5) / 127.5  # Normalize to [-1, 1]
    return images, labels

# Example usage
dataset, labels = load_images_from_csv('./Data/fits_filtered2/dictionary_0.csv', './Data/fits_filtered2')
print(f"Dataset shape: {dataset.shape}, Labels shape: {labels.shape}")

# Step 2: Build the Generator model
def build_generator(latent_dim, num_classes):
    label_input = layers.Input(shape=(1,))
    label_embedding = layers.Embedding(num_classes, latent_dim)(label_input)
    label_embedding = layers.Flatten()(label_embedding)

    noise_input = layers.Input(shape=(latent_dim,))
    model_input = layers.multiply([noise_input, label_embedding])  # Conditional input

    model = tf.keras.Sequential()
    model.add(layers.Dense(256 * 16 * 16, activation="relu", input_dim=latent_dim))
    model.add(layers.Reshape((16, 16, 256)))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.UpSampling2D())
    model.add(layers.Conv2D(128, kernel_size=4, padding="same"))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Activation("relu"))
    model.add(layers.UpSampling2D())
    model.add(layers.Conv2D(64, kernel_size=4, padding="same"))
    model.add(layers.BatchNormalization(momentum=0.8))
    model.add(layers.Activation("relu"))
    model.add(layers.Conv2D(3, kernel_size=4, padding="same"))
    model.add(layers.Activation("tanh"))

    img = model(model_input)
    return tf.keras.Model([noise_input, label_input], img)

# Step 3: Build the Discriminator model
def build_discriminator(img_shape, num_classes):
    img_input = layers.Input(shape=img_shape)
    label_input = layers.Input(shape=(1,))
    label_embedding = layers.Embedding(num_classes, np.prod(img_shape))(label_input)
    label_embedding = layers.Reshape(img_shape)(label_embedding)

    merged_input = layers.multiply([img_input, label_embedding])  # Conditional input
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, kernel_size=4, strides=2, input_shape=img_shape, padding="same"))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))
    model.add(layers.Conv2D(128, kernel_size=4, strides=2, padding="same"))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.25))
    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation='sigmoid'))

    validity = model(merged_input)
    return tf.keras.Model([img_input, label_input], validity)

# Step 4: Set training parameters
epochs = 10
batch_size = 64
latent_dim = 100
num_classes = len(np.unique(labels))  # Number of unique labels
half_batch = int(batch_size / 2)

# Step 5: Build and compile the models
discriminator = build_discriminator((64, 64, 3), num_classes)
discriminator.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])

generator = build_generator(latent_dim, num_classes)

# The generator takes noise and labels as input and generates images
z = layers.Input(shape=(latent_dim,))
label = layers.Input(shape=(1,))
img = generator([z, label])

# For the combined model, only the generator is trained
discriminator.trainable = False

# The discriminator takes generated images as input and determines validity
valid = discriminator([img, label])

# The combined model (stacked generator and discriminator)
combined = tf.keras.Model([z, label], valid)
combined.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.0002, 0.5))

# Step 6: Train the model
for epoch in range(epochs):
    # Train Discriminator
    idx = np.random.randint(0, dataset.shape[0], half_batch)
    imgs = dataset[idx]
    labels_batch = labels[idx]

    noise = np.random.normal(0, 1, (half_batch, latent_dim))
    gen_imgs = generator.predict([noise, labels_batch])

    d_loss_real = discriminator.train_on_batch([imgs, labels_batch], np.ones((half_batch, 1)))
    d_loss_fake = discriminator.train_on_batch([gen_imgs, labels_batch], np.zeros((half_batch, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train Generator
    noise = np.random.normal(0, 1, (batch_size, latent_dim))

    # Sample batch of labels corresponding to the batch of images
    idx = np.random.randint(0, dataset.shape[0], batch_size)
    batch_labels = labels[idx]

    # Prepare the valid labels for training the generator (1 for real)
    valid_y = np.ones((batch_size, 1))

    # Train the generator with noise and sampled labels
    g_loss = combined.train_on_batch([noise, batch_labels], valid_y)

    print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")

    # If at save interval => save generated image samples
    if epoch % 10 == 0:
        noise = np.random.normal(0, 1, (25, latent_dim))
        gen_imgs = generator.predict([noise, labels[:25]])

        gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images 0 - 1

        fig, axs = plt.subplots(5, 5)
        cnt = 0
        for i in range(5):
            for j in range(5):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].axis('off')
                cnt += 1
        plt.savefig(os.path.join(output_folder, f'epoch_{epoch}.png'))  # Save the figure
        plt.close()  # Close the figure to free up memory

# Step 7: Generate New Data
noise = np.random.normal(0, 1, (10, latent_dim))
gen_imgs = generator.predict([noise, labels[:10]])
gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images 0 - 1

for i in range(10):
    plt.imshow(gen_imgs[i])
    plt.axis('off')
    plt.savefig(os.path.join(output_folder, f'final_{i}.png'))  # Save the figure
    plt.close()  # Close the figure to free up memory


Skipping file ./Data/fits_filtered2\tic12.fit, as it is not a valid image.
Skipping file ./Data/fits_filtered2\tic13.fit, as it is not a valid image.
Skipping file ./Data/fits_filtered2\tic14.fit, as it is not a valid image.
Dataset shape: (140, 64, 64, 3), Labels shape: (140,)
0 [D loss: 0.6945118010044098, acc.: 3.125] [G loss: 0.6928286552429199]
1 [D loss: 0.6899002194404602, acc.: 50.0] [G loss: 0.6798086166381836]
2 [D loss: 0.6863475441932678, acc.: 50.0] [G loss: 0.6683561205863953]
3 [D loss: 0.6826144456863403, acc.: 50.0] [G loss: 0.650938868522644]
4 [D loss: 0.6786059439182281, acc.: 50.0] [G loss: 0.6336326599121094]
5 [D loss: 0.673251748085022, acc.: 50.0] [G loss: 0.6142532825469971]
6 [D loss: 0.6668217182159424, acc.: 50.0] [G loss: 0.5897880792617798]
7 [D loss: 0.6620357036590576, acc.: 50.0] [G loss: 0.5669037103652954]
8 [D loss: 0.6509718596935272, acc.: 50.0] [G loss: 0.5269441604614258]
9 [D loss: 0.6444802284240723, acc.: 50.0] [G loss: 0.503311812877655]
