In [2]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from PIL import UnidentifiedImageError

In [4]:


# Create output folder for generated images
output_folder = './Data/fits_filtered2/StyleGAN1000epochs'
# output_folder = './Data/fits_filtered2/augmented/StyleGAN_generated_images1000epochs'
os.makedirs(output_folder, exist_ok=True)

# Load and preprocess dataset
def load_images_from_folder(folder, image_size=(64, 64)):
    images = []
    valid_files = 0
    invalid_files = 0
    for filename in os.listdir(folder):
        try:
            img = load_img(os.path.join(folder, filename), target_size=image_size)
            if img is not None:
                images.append(img_to_array(img))
                valid_files += 1
            else:
                invalid_files += 1
        except (UnidentifiedImageError, OSError):
            print(f"Skipping file {filename}, as it is not a valid image.")
            invalid_files += 1
    print(f"Loaded {valid_files} valid images, skipped {invalid_files} invalid images.")
    return np.array(images)

dataset = load_images_from_folder('./Data/fits_filtered2')
if dataset.size == 0:
    raise ValueError("No valid images found in the dataset.")
dataset = (dataset - 127.5) / 127.5  # Normalize to [-1, 1]

# Mapping Network
def mapping_network(latent_dim=100, num_layers=8):
    inputs = layers.Input(shape=(latent_dim,))
    x = inputs
    for _ in range(num_layers):
        x = layers.Dense(latent_dim, activation="relu")(x)
    return tf.keras.Model(inputs, x, name="MappingNetwork")

# Adaptive Instance Normalization (AdaIN)
def adain(x, w):
    mean, var = tf.nn.moments(x, axes=[1, 2], keepdims=True)
    std = tf.sqrt(var + 1e-8)
    gamma = layers.Dense(x.shape[-1])(w)
    beta = layers.Dense(x.shape[-1])(w)
    gamma = tf.reshape(gamma, [-1, 1, 1, x.shape[-1]])
    beta = tf.reshape(beta, [-1, 1, 1, x.shape[-1]])
    return gamma * (x - mean) / std + beta

# StyleGAN Generator
def build_stylegan_generator(latent_dim=100, initial_resolution=8, target_resolution=64):
    w_input = layers.Input(shape=(latent_dim,))
    resolution = initial_resolution
    x = layers.Dense(resolution * resolution * 256, activation="relu")(w_input)
    x = layers.Reshape((resolution, resolution, 256))(x)

    while resolution < target_resolution:
        resolution *= 2
        x = layers.UpSampling2D()(x)
        x = layers.Conv2D(256 // (resolution // 8), kernel_size=3, padding="same")(x)
        x = layers.LeakyReLU(0.2)(x)
        x = adain(x, w_input)

    output = layers.Conv2D(3, kernel_size=1, activation="tanh")(x)
    return tf.keras.Model(w_input, output, name="StyleGANGenerator")

# Discriminator
def build_discriminator(target_resolution=64):
    inputs = layers.Input(shape=(target_resolution, target_resolution, 3))
    x = inputs
    while x.shape[1] > 4:
        x = layers.Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
        x = layers.LeakyReLU(0.2)(x)
    x = layers.Flatten()(x)
    output = layers.Dense(1, activation="sigmoid")(x)
    return tf.keras.Model(inputs, output, name="StyleGANDiscriminator")

# Initialize networks
latent_dim = 100
mapping = mapping_network(latent_dim=latent_dim)
generator = build_stylegan_generator(latent_dim=latent_dim)
discriminator = build_discriminator()

# Compile discriminator
discriminator.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(0.0002, 0.5), metrics=["accuracy"])

# Combined model
z = layers.Input(shape=(latent_dim,))
w = mapping(z)
img = generator(w)
discriminator.trainable = False
valid = discriminator(img)
combined = tf.keras.Model(z, valid)
combined.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(0.0002, 0.5))

# Train StyleGAN
epochs = 1000
batch_size = 64
save_interval = 999
half_batch = batch_size // 2

for epoch in range(epochs):
    # Train discriminator
    idx = np.random.randint(0, dataset.shape[0], half_batch)
    imgs = dataset[idx]

    noise = np.random.normal(0, 1, (half_batch, latent_dim))
    w = mapping.predict(noise)
    gen_imgs = generator.predict(w)

    d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
    d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train generator
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))

    print(f"{epoch}/{epochs} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]:.2f}%] [G loss: {g_loss}]")

    # Save generated images at intervals
    if epoch % save_interval == 0:
        noise = np.random.normal(0, 1, (25, latent_dim))
        w = mapping.predict(noise)
        gen_imgs = generator.predict(w)

        # Rescale images from [-1, 1] to [0, 1]
        gen_imgs = 0.5 * gen_imgs + 0.5

        # Create a 5x5 grid to save the images
        fig, axs = plt.subplots(5, 5)
        cnt = 0
        for i in range(5):
            for j in range(5):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].axis('off')
                cnt += 1
        
        # Save the grid as a PNG file
        fig.savefig(os.path.join(output_folder, f"epoch_{epoch}.png"))
        plt.close()  # Close the figure to free up memory

# Save final single images
noise = np.random.normal(0, 1, (10, latent_dim))
w = mapping.predict(noise)
gen_imgs = generator.predict(w)

# Rescale images from [-1, 1] to [0, 1]
gen_imgs = 0.5 * gen_imgs + 0.5

for i in range(10):
    plt.imshow(gen_imgs[i])
    plt.axis('off')
    plt.savefig(os.path.join(output_folder, f'final_{i}.png'))  # Save final individual images
    plt.close()  # Close the figure to free up memory


Skipping file augmented, as it is not a valid image.
Skipping file ConditionalGAN, as it is not a valid image.
Skipping file ConditionalGAN100, as it is not a valid image.
Skipping file ConditionalGAN1000, as it is not a valid image.
Skipping file dictionary_0.csv, as it is not a valid image.
Skipping file DoubleDCGan, as it is not a valid image.
Skipping file DoubleDCGan1000epochs, as it is not a valid image.
Skipping file DoubleDCGan100epochs, as it is not a valid image.
Skipping file generated_images, as it is not a valid image.
Skipping file generated_images2, as it is not a valid image.
Skipping file generated_images3, as it is not a valid image.
Skipping file generated_images4, as it is not a valid image.
Skipping file generated_images5, as it is not a valid image.
Skipping file generated_images6, as it is not a valid image.
Skipping file generated_imagesPG-GAN, as it is not a valid image.
Skipping file generated_imagesProgressiveDCGAN, as it is not a valid image.
Skipping file P