In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from scipy import linalg
import os
import time

## HyperParameters

In [None]:
latent_dim = 100
epochs = 50  
batch_size = 64
lr = 0.0002
beta1 = 0.5
compute_fid = False 
output_dir = "/kaggle/working/" 

## Load and preprocess CIFAR-10 dataset


In [None]:
def load_cifar10():
    (x_train, _), (_, _) = tf.keras.datasets.cifar10.load_data()
    x_train = (x_train.astype(np.float32) - 127.5) / 127.5  # Normalize to [-1, 1]
    return x_train


## Build Generator


In [None]:
def build_generator():
    model = models.Sequential([
        layers.Input(shape=(latent_dim,)),
        layers.Dense(4 * 4 * 256),
        layers.Reshape((4, 4, 256)),
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(128, 4, strides=2, padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(64, 4, strides=2, padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')
    ])
    return model

##  Build Discriminator


In [None]:
def build_discriminator():
    model = models.Sequential([
        layers.Input(shape=(32, 32, 3)),
        layers.Conv2D(64, 4, strides=2, padding='same'),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(128, 4, strides=2, padding='same'),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Dropout(0.3),
        layers.Conv2D(256, 4, strides=2, padding='same'),
        layers.LeakyReLU(negative_slope=0.2),
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

## Relativistic average GAN loss


In [None]:
def relativistic_loss(real_logits, fake_logits):
    real_diff = real_logits - tf.reduce_mean(fake_logits)
    fake_diff = fake_logits - tf.reduce_mean(real_logits)
    real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.ones_like(real_diff), logits=real_diff))
    fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
        labels=tf.zeros_like(fake_diff), logits=fake_diff))
    return real_loss, fake_loss

## Compute discriminator accuracy


In [None]:
def compute_discriminator_accuracy(real_logits, fake_logits):
    real_preds = tf.cast(tf.sigmoid(real_logits) > 0.5, tf.float32)
    fake_preds = tf.cast(tf.sigmoid(fake_logits) <= 0.5, tf.float32)
    real_correct = tf.reduce_mean(tf.cast(real_preds, tf.float32))
    fake_correct = tf.reduce_mean(tf.cast(fake_preds, tf.float32))
    accuracy = (real_correct + fake_correct) / 2
    return accuracy

## Training 


In [None]:
@tf.function
def train_step(real_images, generator, discriminator, g_optimizer, d_optimizer):
    batch_size = tf.shape(real_images)[0]
    noise = tf.random.normal([batch_size, latent_dim])
    
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
        fake_images = generator(noise, training=True)
        real_logits = discriminator(real_images, training=True)
        fake_logits = discriminator(fake_images, training=True)
        
        d_real_loss, d_fake_loss = relativistic_loss(real_logits, fake_logits)
        d_loss = d_real_loss + d_fake_loss
        g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.ones_like(fake_logits), logits=fake_logits - tf.reduce_mean(real_logits)))
    
    d_gradients = d_tape.gradient(d_loss, discriminator.trainable_variables)
    g_gradients = g_tape.gradient(g_loss, generator.trainable_variables)
    d_optimizer.apply_gradients(zip(d_gradients, discriminator.trainable_variables))
    g_optimizer.apply_gradients(zip(g_gradients, generator.trainable_variables))
    
    accuracy = compute_discriminator_accuracy(real_logits, fake_logits)
    return d_loss, g_loss, accuracy

## Compute FID 


In [None]:
def compute_fid(real_images, generator, num_images=2000):
    try:
        start_time = time.time()
        timeout = 300 
        inception_model = InceptionV3(include_top=False, pooling='avg', input_shape=(299, 299, 3))
        
        
        noise = tf.random.normal([num_images, latent_dim])
        fake_images = generator(noise, training=False)
        
       
        real_images = tf.image.resize(real_images[:num_images], [299, 299])
        fake_images = tf.image.resize(fake_images, [299, 299])
        real_images = preprocess_input(real_images * 127.5 + 127.5)
        fake_images = preprocess_input(fake_images * 127.5 + 127.5)
        
       
        real_activations = inception_model.predict(real_images, batch_size=200)
        fake_activations = inception_model.predict(fake_images, batch_size=200)
        
      
        mu_real = np.mean(real_activations, axis=0)
        mu_fake = np.mean(fake_activations, axis=0)
        sigma_real = np.cov(real_activations, rowvar=False)
        sigma_fake = np.cov(fake_activations, rowvar=False)
        
       
        sigma_real += np.eye(sigma_real.shape[0]) * 1e-6
        sigma_fake += np.eye(sigma_fake.shape[0]) * 1e-6
        
       
        diff = mu_real - mu_fake
        covmean = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)[0]
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        fid_score = diff.dot(diff) + np.trace(sigma_real + sigma_fake - 2 * covmean)
        
        if time.time() - start_time > timeout:
            print("FID computation timed out")
            return None
        
        return max(fid_score, 0.0)
    except Exception as e:
        print(f"FID computation failed: {e}")
        return None

## Train the GAN


In [None]:
def train_gan(dataset):
    generator = build_generator()
    discriminator = build_discriminator()
    g_optimizer = tf.keras.optimizers.Adam(lr, beta_1=beta1)
    d_optimizer = tf.keras.optimizers.Adam(lr, beta_1=beta1)
    
    dataset_tf = tf.data.Dataset.from_tensor_slices(dataset).shuffle(60000).batch(batch_size)
    
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        epoch_accuracy = 0.0
        steps = 0
        for step, real_images in enumerate(dataset_tf):
            d_loss, g_loss, accuracy = train_step(real_images, generator, discriminator, g_optimizer, d_optimizer)
            if step % 100 == 0:
                print(f"Step {step}, D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}, D Accuracy: {accuracy:.4f}")
            epoch_accuracy += accuracy
            steps += 1
        
        print(f"Epoch {epoch + 1} Average D Accuracy: {epoch_accuracy / steps:.4f}")
        
        noise = tf.random.normal([16, latent_dim])
        fake_images = generator(noise, training=False)
        plt.figure(figsize=(4, 4))
        for i in range(16):
            plt.subplot(4, 4, i + 1)
            plt.imshow((fake_images[i] * 127.5 + 127.5).numpy().astype(np.uint8))
            plt.axis('off')
        plt.savefig(os.path.join(output_dir, f"output_epoch_{epoch}.png"))
        plt.close()
        
        if (epoch + 1) % 10 == 0:
            g_checkpoint_path = os.path.join(output_dir, f"generator_epoch_{epoch + 1}.keras")
            d_checkpoint_path = os.path.join(output_dir, f"discriminator_epoch_{epoch + 1}.keras")
            generator.save(g_checkpoint_path, save_format="keras_v3")
            discriminator.save(d_checkpoint_path, save_format="keras_v3")
            print(f"Saved generator checkpoint: {g_checkpoint_path} (trained for {epoch + 1} epochs)")
            print(f"Saved discriminator checkpoint: {d_checkpoint_path} (trained for {epoch + 1} epochs)")
    
    final_g_model_path = os.path.join(output_dir, "generator.keras")
    final_d_model_path = os.path.join(output_dir, "discriminator.keras")
    generator.save(final_g_model_path, save_format="keras_v3")
    discriminator.save(final_d_model_path, save_format="keras_v3")
    print(f"Saved final generator model: {final_g_model_path} (trained for {epochs} epochs)")
    print(f"Saved final discriminator model: {final_d_model_path} (trained for {epochs} epochs)")
    
    if epochs % 10 != 0:
        g_checkpoint_path = os.path.join(output_dir, f"generator_epoch_{epochs}.keras")
        d_checkpoint_path = os.path.join(output_dir, f"discriminator_epoch_{epochs}.keras")
        generator.save(g_checkpoint_path, save_format="keras_v3")
        discriminator.save(d_checkpoint_path, save_format="keras_v3")
        print(f"Saved final generator checkpoint: {g_checkpoint_path} (trained for {epochs} epochs)")
        print(f"Saved final discriminator checkpoint: {d_checkpoint_path} (trained for {epochs} epochs)")
    
    if compute_fid:
        print("Computing FID...")
        fid_score = compute_fid(dataset, generator)
        if fid_score is not None:
            print(f"FID Score: {fid_score:.2f}")
        else:
            print("FID computation skipped due to error or timeout")
    
    print("\nTo download models and images:")
    print("1. Go to the 'Output' tab (right sidebar).")
    print("2. Navigate to /kaggle/working/.")
    print("3. Download 'generator.keras', 'discriminator.keras', and any 'generator_epoch_X.keras' or 'discriminator_epoch_X.keras' files.")
    print("Alternatively, create a Kaggle dataset:")
    print("Run the following in a new cell after training:")
    print("```bash")
    print("!mkdir -p /kaggle/working/gan-models")
    print("!cp /kaggle/working/*.keras /kaggle/working/gan-models/")
    print("!cp /kaggle/working/*.png /kaggle/working/gan-models/")
    print("!kaggle datasets create -p /kaggle/working/gan-models -u -r zip")
    print("```")
    print("Then download the dataset from your Kaggle profile.")
    
    return generator, discriminator

## Main execution


In [None]:

x_train = load_cifar10()
generator, discriminator = train_gan(x_train)