In [2]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import os
from PIL import UnidentifiedImageError
from tensorflow.keras import backend as K
from tensorflow.keras.applications import VGG19

# Create a nested folder for the generated images
# output_folder = './Data/fits_filtered9/augmented_images/gan_output2'
output_folder = './Data/fits_filtered4/gan_output1'
os.makedirs(output_folder, exist_ok=True)

# Step 1: Load and Preprocess Dataset
def load_images_from_folder(folder, image_size=(128, 128)):
    images = []
    valid_files, invalid_files = 0, 0
    for filename in os.listdir(folder):
        try:
            img = load_img(os.path.join(folder, filename), target_size=image_size)
            if img is not None:
                images.append(img_to_array(img))
                valid_files += 1
        except (UnidentifiedImageError, OSError) as e:
            print(f"Error loading {filename}: {e}")
            invalid_files += 1
    print(f"Loaded {valid_files} valid images, skipped {invalid_files} invalid images.")
    return np.array(images)

# dataset = load_images_from_folder('./Data/fits_filtered9/augmented_images')
dataset = load_images_from_folder('./Data/fits_filtered4')
print(f"Dataset shape: {dataset.shape}")
if dataset.size == 0:
    raise ValueError("No valid images found in the dataset. Please check the image files.")
dataset = (dataset - 127.5) / 127.5  # Normalize to [-1, 1]

# Step 2: Build the GAN
# Generator
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(512 * 8 * 8, activation="relu", input_dim=100),
        layers.Reshape((8, 8, 512)),
        layers.BatchNormalization(momentum=0.8),
        layers.UpSampling2D(),  # 16x16
        layers.Conv2D(256, kernel_size=3, padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.Activation("relu"),
        layers.UpSampling2D(),  # 32x32
        layers.Conv2D(128, kernel_size=3, padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.Activation("relu"),
        layers.UpSampling2D(),  # 64x64
        layers.Conv2D(64, kernel_size=3, padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.Activation("relu"),
        layers.UpSampling2D(),  # 128x128
        layers.Conv2D(3, kernel_size=3, padding="same"),
        layers.Activation("tanh")
    ])
    return model

# Discriminator
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Conv2D(64, kernel_size=3, strides=2, input_shape=(128, 128, 3), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(128, kernel_size=3, strides=2, padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(256, kernel_size=3, strides=2, padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(512, kernel_size=3, strides=2, padding="same"),
        layers.BatchNormalization(momentum=0.8),
        layers.LeakyReLU(alpha=0.2),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

# Wasserstein loss
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

# Load VGG19 model for perceptual loss
vgg = VGG19(include_top=False, weights='imagenet', input_shape=(128, 128, 3))
vgg.trainable = False
perceptual_model = tf.keras.Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output)

# Perceptual loss
def perceptual_loss(y_true, y_pred):
    y_true_features = perceptual_model(y_true)
    y_pred_features = perceptual_model(y_pred)
    return K.mean(K.square(y_true_features - y_pred_features))

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss=wasserstein_loss, optimizer=tf.keras.optimizers.Adam(0.0001, 0.5), metrics=['accuracy'])

# Build the generator
generator = build_generator()

# The generator takes noise as input and generates images
z = layers.Input(shape=(100,))
img = generator(z)

# For the combined model, only the generator is trained
discriminator.trainable = False

# The discriminator takes generated images as input and determines validity
valid = discriminator(img)

# The combined model (stacked generator and discriminator)
combined = tf.keras.Model(z, valid)
combined.compile(loss=wasserstein_loss, optimizer=tf.keras.optimizers.Adam(0.0001, 0.5))

# Step 3: Train the GAN
epochs = 10
batch_size = 64
save_interval = 10

X_train = dataset
half_batch = max(1, batch_size // 2)

for epoch in range(epochs):
    # Train Discriminator
    idx = np.random.randint(0, X_train.shape[0], half_batch)
    imgs = X_train[idx]

    noise = np.random.normal(0, 1, (half_batch, 100))
    gen_imgs = generator.predict(noise)

    d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
    d_loss_fake = discriminator.train_on_batch(gen_imgs, -np.ones((half_batch, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train Generator
    noise = np.random.normal(0, 1, (batch_size, 100))
    valid_y = np.ones((batch_size, 1))
    g_loss = combined.train_on_batch(noise, valid_y)

    print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")

    if epoch % save_interval == 0:
        noise = np.random.normal(0, 1, (25, 100))
        gen_imgs = generator.predict(noise)
        gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images 0 - 1

        fig, axs = plt.subplots(5, 5)
        cnt = 0
        for i in range(5):
            for j in range(5):
                axs[i, j].imshow(gen_imgs[cnt])
                axs[i, j].axis('off')
                cnt += 1
        plt.suptitle(f'Epoch {epoch}')
        plt.tight_layout()
        plt.savefig(os.path.join(output_folder, f'epoch_{epoch}.png'))
        plt.close()

        generator.save(f'generator_epoch_{epoch}.h5')
        discriminator.save(f'discriminator_epoch_{epoch}.h5')

# Final image generation
noise = np.random.normal(0, 1, (10, 100))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images 0 - 1

for i in range(10):
    plt.imshow(gen_imgs[i])
    plt.axis('off')
    plt.savefig(os.path.join(output_folder, f'final_{i}.png'))
    plt.close()


Error loading augmented: [Errno 13] Permission denied: './Data/fits_filtered4\\augmented'
Error loading data: [Errno 13] Permission denied: './Data/fits_filtered4\\data'
Error loading dictionary_0.csv: cannot identify image file <_io.BytesIO object at 0x0000011A498F9F80>
Error loading gan_output1: [Errno 13] Permission denied: './Data/fits_filtered4\\gan_output1'
Error loading gan_output2: [Errno 13] Permission denied: './Data/fits_filtered4\\gan_output2'
Loaded 47 valid images, skipped 5 invalid images.
Dataset shape: (47, 128, 128, 3)
0 [D loss: 0.07077676057815552, acc.: 31.25] [G loss: 0.5034868717193604]
1 [D loss: -0.199498750269413, acc.: 0.0] [G loss: 0.5362001657485962]
2 [D loss: -0.3538418374955654, acc.: 0.0] [G loss: 0.6114389896392822]
3 [D loss: -0.4283123165369034, acc.: 0.0] [G loss: 0.6722100973129272]
4 [D loss: -0.45474932342767715, acc.: 0.0] [G loss: 0.7048105001449585]
5 [D loss: -0.4708592966198921, acc.: 0.0] [G loss: 0.708672046661377]
6 [D loss: -0.4773478526