In [13]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from tensorflow.keras.optimizers import Adam
from scipy.linalg import sqrtm

In [None]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


In [None]:
def load_cifar10_from_folder(folder_path):
    train_data = []
    for i in range(1, 6):
        batch = unpickle(os.path.join(folder_path, f'data_batch_{i}'))
        train_data.append(batch[b'data'])
    train_images = np.vstack(train_data).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    train_images = (train_images.astype('float32') - 127.5) / 127.5  # Normalize to [-1, 1]
    train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(50000).batch(128)
    return train_dataset

In [None]:
def build_generator():
    model = tf.keras.Sequential([
        layers.Dense(8 * 8 * 256, use_bias=False, input_shape=(100,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((8, 8, 256)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(3, (3, 3), activation='tanh', padding='same')
    ])
    return model

In [None]:
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=[32, 32, 3]),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

In [None]:
def generator_loss(fake_output):
    return tf.keras.losses.BinaryCrossentropy()(tf.ones_like(fake_output), fake_output)

def discriminator_loss(real_output, fake_output):
    real_loss = tf.keras.losses.BinaryCrossentropy()(tf.ones_like(real_output), real_output)
    fake_loss = tf.keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

In [None]:
def train_gan(generator, discriminator, dataset, epochs=50):
    generator_optimizer = Adam(1e-4)
    discriminator_optimizer = Adam(1e-4)
    @tf.function
    def train_step(images):
        noise = tf.random.normal([128, 100])
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = generator(noise, training=True)
            real_output = discriminator(images, training=True)
            fake_output = discriminator(generated_images, training=True)
            gen_loss = generator_loss(fake_output)
            disc_loss = discriminator_loss(real_output, fake_output)
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)
        print(f'Epoch {epoch + 1} completed')

In [None]:
def generate_and_show_images(generator):
    noise = tf.random.normal([16, 100])
    generated_images = generator(noise, training=False)
    fig, axes = plt.subplots(4, 4, figsize=(6, 6))
    for i, ax in enumerate(axes.flatten()):
        ax.imshow((generated_images[i] + 1) / 2)
        ax.axis('off')
    plt.show()

In [None]:
def calculate_fid(model, real_images, fake_images):
    act1 = model.predict(real_images)
    act2 = model.predict(fake_images)
    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2.0)
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [None]:
# Main Execution
folder_path = "./cifar-10"
train_dataset = load_cifar10_from_folder(folder_path)
generator = build_generator()
discriminator = build_discriminator()
train_gan(generator, discriminator, train_dataset, epochs=50)
generate_and_show_images(generator)