# 13 gans with tensorflow keras
**Location: TensorVerseHub/notebooks/05_generative_models/13_gans_with_tensorflow_keras.ipynb**

In [None]:
import tensorflow as tf
import numpy as np
print(f"TensorFlow version: {tf.__version__}")

# GANs with TensorFlow/tf.keras

**File Location:** `notebooks/05_generative_models/13_gans_with_tensorflow_keras.ipynb`

Master Generative Adversarial Networks (GANs) using tf.keras Sequential and Functional APIs. Build, train, and optimize various GAN architectures including DCGAN, conditional GANs, and Wasserstein GANs for image generation and data synthesis.

## Learning Objectives
- Understand GAN architecture and adversarial training principles
- Implement DCGAN using tf.keras Sequential and Functional APIs
- Build conditional GANs for controlled generation
- Master Wasserstein GAN with gradient penalty (WGAN-GP)
- Apply advanced GAN training techniques and stabilization methods
- Generate high-quality synthetic images and data

---

## 1. GAN Fundamentals and Basic Implementation

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow import keras
from tensorflow.keras import layers
import os
from sklearn.preprocessing import LabelEncoder
import time
import warnings
warnings.filterwarnings('ignore')

print(f"TensorFlow version: {tf.__version__}")
tf.random.set_seed(42)
np.random.seed(42)

# Basic GAN implementation using tf.keras Sequential API
class BasicGAN:
    """Basic GAN implementation demonstrating core concepts"""
    
    def __init__(self, latent_dim=100, img_shape=(28, 28, 1)):
        self.latent_dim = latent_dim
        self.img_shape = img_shape
        self.img_height, self.img_width, self.channels = img_shape
        
        # Build models
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        
        # Compile discriminator
        self.discriminator.compile(
            optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
            loss='binary_crossentropy',
            metrics=['accuracy']
        )
        
        # Build combined model for generator training
        self.discriminator.trainable = False
        z = tf.keras.Input(shape=(self.latent_dim,))
        fake_img = self.generator(z)
        validity = self.discriminator(fake_img)
        
        self.combined = tf.keras.Model(z, validity)
        self.combined.compile(
            optimizer=tf.keras.optimizers.Adam(0.0002, 0.5),
            loss='binary_crossentropy'
        )
    
    def build_generator(self):
        """Build generator using Sequential API"""
        
        model = tf.keras.Sequential([
            # Foundation for 7x7 image
            layers.Dense(7 * 7 * 256, input_shape=(self.latent_dim,)),
            layers.Reshape((7, 7, 256)),
            layers.BatchNormalization(),
            layers.LeakyReLU(alpha=0.2),
            
            # Upsample to 14x14
            layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same'),
            layers.BatchNormalization(),
            layers.LeakyReLU(alpha=0.2),
            
            # Upsample to 28x28
            layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same'),
            layers.BatchNormalization(),
            layers.LeakyReLU(alpha=0.2),
            
            # Final layer - output image
            layers.Conv2DTranspose(self.channels, (5, 5), strides=(1, 1), 
                                 padding='same', activation='tanh'),
        ], name='generator')
        
        return model
    
    def build_discriminator(self):
        """Build discriminator using Sequential API"""
        
        model = tf.keras.Sequential([
            layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                         input_shape=self.img_shape),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),
            
            layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),
            
            layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),
            
            layers.Flatten(),
            layers.Dense(1, activation='sigmoid')
        ], name='discriminator')
        
        return model
    
    def train_step(self, real_images, batch_size):
        """Single training step for GAN"""
        
        # Generate fake images
        noise = tf.random.normal([batch_size, self.latent_dim])
        fake_images = self.generator(noise, training=False)
        
        # Labels
        real_labels = tf.ones((batch_size, 1))
        fake_labels = tf.zeros((batch_size, 1))
        
        # Train discriminator
        d_loss_real = self.discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = self.discriminator.train_on_batch(fake_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train generator
        noise = tf.random.normal([batch_size, self.latent_dim])
        valid_labels = tf.ones((batch_size, 1))
        g_loss = self.combined.train_on_batch(noise, valid_labels)
        
        return d_loss, g_loss
    
    def generate_images(self, n_samples=25):
        """Generate sample images"""
        noise = tf.random.normal([n_samples, self.latent_dim])
        generated_images = self.generator(noise, training=False)
        return 0.5 * generated_images + 0.5  # Rescale to [0,1]

# Load MNIST data
def load_mnist_data():
    """Load and preprocess MNIST dataset"""
    (X_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    
    # Normalize to [-1, 1] for tanh activation
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)
    
    print(f"MNIST loaded: {X_train.shape}, range: [{X_train.min():.2f}, {X_train.max():.2f}]")
    return X_train, y_train

# Test basic GAN
X_train, y_train = load_mnist_data()
basic_gan = BasicGAN(latent_dim=100, img_shape=(28, 28, 1))

print("Generator Architecture:")
basic_gan.generator.summary()
print("\nDiscriminator Architecture:")
basic_gan.discriminator.summary()

# Test generation before training
sample_images = basic_gan.generate_images(9)

plt.figure(figsize=(8, 8))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.imshow(sample_images[i, :, :, 0], cmap='gray')
    plt.axis('off')
plt.suptitle('Random Noise Generated Images (Before Training)')
plt.tight_layout()
plt.show()

## 2. DCGAN Implementation

In [None]:
# Deep Convolutional GAN (DCGAN) - improved architecture
class DCGAN:
    """DCGAN implementation with best practices"""
    
    def __init__(self, latent_dim=100, img_shape=(64, 64, 3)):
        self.latent_dim = latent_dim
        self.img_shape = img_shape
        self.img_height, self.img_width, self.channels = img_shape
        
        # Optimizers with different learning rates
        self.g_optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)
        self.d_optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)
        
        # Build models
        self.generator = self.build_dcgan_generator()
        self.discriminator = self.build_dcgan_discriminator()
        
        # Loss function
        self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        
    def build_dcgan_generator(self):
        """Build DCGAN generator with proper architecture"""
        
        model = tf.keras.Sequential([
            # Project and reshape
            layers.Dense(4 * 4 * 1024, use_bias=False, input_shape=(self.latent_dim,)),
            layers.BatchNormalization(),
            layers.LeakyReLU(),
            layers.Reshape((4, 4, 1024)),
            
            # 4x4 -> 8x8
            layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False),
            layers.BatchNormalization(),
            layers.LeakyReLU(),
            
            # 8x8 -> 16x16
            layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False),
            layers.BatchNormalization(),
            layers.LeakyReLU(),
            
            # 16x16 -> 32x32
            layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False),
            layers.BatchNormalization(),
            layers.LeakyReLU(),
            
            # 32x32 -> 64x64
            layers.Conv2DTranspose(self.channels, (5, 5), strides=(2, 2), 
                                 padding='same', use_bias=False, activation='tanh')
        ], name='dcgan_generator')
        
        return model
    
    def build_dcgan_discriminator(self):
        """Build DCGAN discriminator with proper architecture"""
        
        model = tf.keras.Sequential([
            # 64x64 -> 32x32
            layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', 
                         input_shape=self.img_shape),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),
            
            # 32x32 -> 16x16
            layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
            layers.BatchNormalization(),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),
            
            # 16x16 -> 8x8
            layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'),
            layers.BatchNormalization(),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),
            
            # 8x8 -> 4x4
            layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same'),
            layers.BatchNormalization(),
            layers.LeakyReLU(alpha=0.2),
            layers.Dropout(0.3),
            
            layers.Flatten(),
            layers.Dense(1)  # No sigmoid - using from_logits=True
        ], name='dcgan_discriminator')
        
        return model
    
    def discriminator_loss(self, real_output, fake_output):
        """Calculate discriminator loss"""
        real_loss = self.cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = self.cross_entropy(tf.zeros_like(fake_output), fake_output)
        return real_loss + fake_loss
    
    def generator_loss(self, fake_output):
        """Calculate generator loss"""
        return self.cross_entropy(tf.ones_like(fake_output), fake_output)
    
    @tf.function
    def train_step(self, real_images, batch_size):
        """Optimized training step using tf.function"""
        
        noise = tf.random.normal([batch_size, self.latent_dim])
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = self.generator(noise, training=True)
            
            real_output = self.discriminator(real_images, training=True)
            fake_output = self.discriminator(generated_images, training=True)
            
            gen_loss = self.generator_loss(fake_output)
            disc_loss = self.discriminator_loss(real_output, fake_output)
        
        # Calculate gradients
        gradients_of_generator = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        
        # Apply gradients
        self.g_optimizer.apply_gradients(zip(gradients_of_generator, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(gradients_of_discriminator, self.discriminator.trainable_variables))
        
        return gen_loss, disc_loss

# Prepare CIFAR-10 data for DCGAN
def load_cifar10_data():
    """Load and preprocess CIFAR-10 dataset"""
    (X_train, y_train), _ = tf.keras.datasets.cifar10.load_data()
    
    # Normalize to [-1, 1]
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    
    # Resize to 64x64
    X_train = tf.image.resize(X_train, [64, 64]).numpy()
    
    print(f"CIFAR-10 loaded: {X_train.shape}, range: [{X_train.min():.2f}, {X_train.max():.2f}]")
    return X_train, y_train

# Training function for DCGAN
def train_dcgan(dcgan, dataset, epochs=50, batch_size=128, save_interval=10):
    """Train DCGAN with monitoring"""
    
    # Training history
    history = {'g_loss': [], 'd_loss': []}
    
    # Fixed noise for consistent monitoring
    fixed_noise = tf.random.normal([16, dcgan.latent_dim])
    
    batches_per_epoch = len(dataset) // batch_size
    
    for epoch in range(epochs):
        start_time = time.time()
        
        # Shuffle data
        np.random.shuffle(dataset)
        
        epoch_g_loss = []
        epoch_d_loss = []
        
        for batch_idx in range(batches_per_epoch):
            # Get batch
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            real_images = dataset[start_idx:end_idx]
            
            # Train step
            g_loss, d_loss = dcgan.train_step(real_images, batch_size)
            
            epoch_g_loss.append(g_loss.numpy())
            epoch_d_loss.append(d_loss.numpy())
        
        # Record history
        history['g_loss'].append(np.mean(epoch_g_loss))
        history['d_loss'].append(np.mean(epoch_d_loss))
        
        # Print progress
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs} - {epoch_time:.2f}s - "
              f"G_loss: {history['g_loss'][-1]:.4f}, D_loss: {history['d_loss'][-1]:.4f}")
        
        # Generate sample images
        if (epoch + 1) % save_interval == 0:
            generate_and_save_images(dcgan.generator, epoch + 1, fixed_noise)
    
    return history

def generate_and_save_images(generator, epoch, test_input):
    """Generate and display sample images"""
    predictions = generator(test_input, training=False)
    predictions = 0.5 * predictions + 0.5  # Rescale to [0,1]
    
    fig = plt.figure(figsize=(8, 8))
    
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        if predictions.shape[-1] == 1:
            plt.imshow(predictions[i, :, :, 0], cmap='gray')
        else:
            plt.imshow(predictions[i])
        plt.axis('off')
    
    plt.suptitle(f'Generated Images - Epoch {epoch}')
    plt.tight_layout()
    plt.show()

# Test DCGAN
print("=== DCGAN Implementation ===")

# Use MNIST for faster training demonstration
dcgan = DCGAN(latent_dim=100, img_shape=(28, 28, 1))

print("DCGAN Generator Architecture:")
dcgan.generator.summary()
print("\nDCGAN Discriminator Architecture:")
dcgan.discriminator.summary()

# Train for a few epochs (demo)
print("\nTraining DCGAN (demo with 3 epochs)...")
dcgan_history = train_dcgan(dcgan, X_train[:5000], epochs=3, batch_size=64, save_interval=1)

## 3. Conditional GAN (cGAN)

In [None]:
# Conditional GAN implementation
class ConditionalGAN:
    """Conditional GAN for controlled generation"""
    
    def __init__(self, latent_dim=100, num_classes=10, img_shape=(28, 28, 1)):
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.img_height, self.img_width, self.channels = img_shape
        
        # Optimizers
        self.g_optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)
        self.d_optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)
        
        # Build models
        self.generator = self.build_conditional_generator()
        self.discriminator = self.build_conditional_discriminator()
        
        # Loss function
        self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        
    def build_conditional_generator(self):
        """Build conditional generator using Functional API"""
        
        # Noise input
        noise_input = layers.Input(shape=(self.latent_dim,))
        
        # Label input
        label_input = layers.Input(shape=(1,))
        label_embedding = layers.Embedding(self.num_classes, self.latent_dim)(label_input)
        label_embedding = layers.Flatten()(label_embedding)
        
        # Combine noise and label
        combined_input = layers.Multiply()([noise_input, label_embedding])
        
        # Generator architecture
        x = layers.Dense(7 * 7 * 256)(combined_input)
        x = layers.Reshape((7, 7, 256))(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        x = layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        x = layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        output = layers.Conv2DTranspose(self.channels, (5, 5), strides=(1, 1), 
                                      padding='same', activation='tanh')(x)
        
        model = tf.keras.Model([noise_input, label_input], output, name='conditional_generator')
        return model
    
    def build_conditional_discriminator(self):
        """Build conditional discriminator using Functional API"""
        
        # Image input
        img_input = layers.Input(shape=self.img_shape)
        
        # Label input
        label_input = layers.Input(shape=(1,))
        label_embedding = layers.Embedding(self.num_classes, np.prod(self.img_shape))(label_input)
        label_embedding = layers.Flatten()(label_embedding)
        label_embedding = layers.Reshape(self.img_shape)(label_embedding)
        
        # Combine image and label
        combined_input = layers.Concatenate(axis=-1)([img_input, label_embedding])
        
        # Discriminator architecture
        x = layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')(combined_input)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.Dropout(0.3)(x)
        
        x = layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.Dropout(0.3)(x)
        
        x = layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same')(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        x = layers.Dropout(0.3)(x)
        
        x = layers.Flatten()(x)
        output = layers.Dense(1)(x)
        
        model = tf.keras.Model([img_input, label_input], output, name='conditional_discriminator')
        return model
    
    def discriminator_loss(self, real_output, fake_output):
        """Calculate discriminator loss"""
        real_loss = self.cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = self.cross_entropy(tf.zeros_like(fake_output), fake_output)
        return real_loss + fake_loss
    
    def generator_loss(self, fake_output):
        """Calculate generator loss"""
        return self.cross_entropy(tf.ones_like(fake_output), fake_output)
    
    @tf.function
    def train_step(self, real_images, real_labels, batch_size):
        """Training step for conditional GAN"""
        
        noise = tf.random.normal([batch_size, self.latent_dim])
        
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = self.generator([noise, real_labels], training=True)
            
            real_output = self.discriminator([real_images, real_labels], training=True)
            fake_output = self.discriminator([generated_images, real_labels], training=True)
            
            gen_loss = self.generator_loss(fake_output)
            disc_loss = self.discriminator_loss(real_output, fake_output)
        
        # Apply gradients
        gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        disc_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)
        
        self.g_optimizer.apply_gradients(zip(gen_gradients, self.generator.trainable_variables))
        self.d_optimizer.apply_gradients(zip(disc_gradients, self.discriminator.trainable_variables))
        
        return gen_loss, disc_loss
    
    def generate_class_samples(self, class_labels, n_samples_per_class=5):
        """Generate samples for specific classes"""
        
        total_samples = len(class_labels) * n_samples_per_class
        noise = tf.random.normal([total_samples, self.latent_dim])
        
        # Repeat labels for each sample
        expanded_labels = np.repeat(class_labels, n_samples_per_class)
        expanded_labels = expanded_labels.reshape(-1, 1)
        
        generated_images = self.generator([noise, expanded_labels], training=False)
        generated_images = 0.5 * generated_images + 0.5  # Rescale
        
        return generated_images, expanded_labels

# Train conditional GAN
def train_conditional_gan(cgan, images, labels, epochs=20, batch_size=128):
    """Train conditional GAN"""
    
    history = {'g_loss': [], 'd_loss': []}
    batches_per_epoch = len(images) // batch_size
    
    for epoch in range(epochs):
        start_time = time.time()
        
        # Shuffle data
        indices = np.random.permutation(len(images))
        images_shuffled = images[indices]
        labels_shuffled = labels[indices]
        
        epoch_g_loss = []
        epoch_d_loss = []
        
        for batch_idx in range(batches_per_epoch):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            
            real_images = images_shuffled[start_idx:end_idx]
            real_labels = labels_shuffled[start_idx:end_idx]
            
            g_loss, d_loss = cgan.train_step(real_images, real_labels, batch_size)
            
            epoch_g_loss.append(g_loss.numpy())
            epoch_d_loss.append(d_loss.numpy())
        
        history['g_loss'].append(np.mean(epoch_g_loss))
        history['d_loss'].append(np.mean(epoch_d_loss))
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs} - {epoch_time:.2f}s - "
              f"G_loss: {history['g_loss'][-1]:.4f}, D_loss: {history['d_loss'][-1]:.4f}")
        
        # Generate class-specific samples
        if (epoch + 1) % 5 == 0:
            generate_class_specific_images(cgan, epoch + 1)
    
    return history

def generate_class_specific_images(cgan, epoch):
    """Generate images for each class"""
    
    class_labels = list(range(min(10, cgan.num_classes)))  # Show first 10 classes
    generated_images, labels = cgan.generate_class_samples(class_labels, n_samples_per_class=1)
    
    fig, axes = plt.subplots(2, 5, figsize=(12, 6))
    axes = axes.flatten()
    
    for i, (img, label) in enumerate(zip(generated_images[:10], labels[:10])):
        if generated_images.shape[-1] == 1:
            axes[i].imshow(img[:, :, 0], cmap='gray')
        else:
            axes[i].imshow(img)
        axes[i].set_title(f'Class {label[0]}')
        axes[i].axis('off')
    
    plt.suptitle(f'Class-Conditional Generation - Epoch {epoch}')
    plt.tight_layout()
    plt.show()

# Test Conditional GAN
print("=== Conditional GAN Implementation ===")

# Initialize conditional GAN
cgan = ConditionalGAN(latent_dim=100, num_classes=10, img_shape=(28, 28, 1))

print("Conditional Generator Architecture:")
cgan.generator.summary()
print("\nConditional Discriminator Architecture:")
cgan.discriminator.summary()

# Train for a few epochs (demo)
print("\nTraining Conditional GAN (demo with 3 epochs)...")
cgan_history = train_conditional_gan(cgan, X_train[:5000], y_train[:5000], epochs=3, batch_size=64)

# Test class-specific generation
print("\nTesting class-specific generation:")
test_classes = [0, 1, 2, 3, 4]  # Generate digits 0-4
generated_samples, sample_labels = cgan.generate_class_samples(test_classes, n_samples_per_class=2)

plt.figure(figsize=(10, 4))
for i in range(10):
    plt.subplot(2, 5, i+1)
    plt.imshow(generated_samples[i, :, :, 0], cmap='gray')
    plt.title(f'Class {sample_labels[i][0]}')
    plt.axis('off')
plt.suptitle('Conditional Generation Results')
plt.tight_layout()
plt.show()

## 4. Wasserstein GAN with Gradient Penalty (WGAN-GP)

In [None]:
# Wasserstein GAN with Gradient Penalty
class WGANGP:
    """Wasserstein GAN with Gradient Penalty for stable training"""
    
    def __init__(self, latent_dim=100, img_shape=(28, 28, 1), critic_iterations=5, lambda_gp=10):
        self.latent_dim = latent_dim
        self.img_shape = img_shape
        self.critic_iterations = critic_iterations
        self.lambda_gp = lambda_gp
        
        # Build models
        self.generator = self.build_wgan_generator()
        self.critic = self.build_wgan_critic()
        
        # Optimizers
        self.g_optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.0, beta_2=0.9)
        self.c_optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.0, beta_2=0.9)
    
    def build_wgan_generator(self):
        """Build Wasserstein GAN generator"""
        
        model = tf.keras.Sequential([
            layers.Dense(7 * 7 * 256, input_shape=(self.latent_dim,)),
            layers.Reshape((7, 7, 256)),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same'),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same'),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2DTranspose(self.img_shape[-1], (5, 5), strides=(1, 1), 
                                 padding='same', activation='tanh'),
        ], name='wgan_generator')
        
        return model
    
    def build_wgan_critic(self):
        """Build Wasserstein GAN critic (no sigmoid activation)"""
        
        model = tf.keras.Sequential([
            layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                         input_shape=self.img_shape),
            layers.LeakyReLU(alpha=0.2),
            
            layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
            layers.LeakyReLU(alpha=0.2),
            
            layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'),
            layers.LeakyReLU(alpha=0.2),
            
            layers.Flatten(),
            layers.Dense(1)  # No activation - output raw score
        ], name='wgan_critic')
        
        return model
    
    def gradient_penalty(self, real_images, fake_images, batch_size):
        """Calculate gradient penalty for WGAN-GP"""
        
        # Random interpolation between real and fake images
        alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
        interpolated = alpha * real_images + (1 - alpha) * fake_images
        
        with tf.GradientTape() as tape:
            tape.watch(interpolated)
            critic_interpolated = self.critic(interpolated, training=True)
        
        # Calculate gradients
        gradients = tape.gradient(critic_interpolated, interpolated)
        gradients_norm = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
        gradient_penalty = tf.reduce_mean((gradients_norm - 1.0) ** 2)
        
        return gradient_penalty
    
    def generator_loss(self, fake_output):
        """Generator loss for WGAN"""
        return -tf.reduce_mean(fake_output)
    
    def critic_loss(self, real_output, fake_output):
        """Critic loss for WGAN"""
        return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
    
    @tf.function
    def train_critic_step(self, real_images, batch_size):
        """Train critic for one step"""
        
        noise = tf.random.normal([batch_size, self.latent_dim])
        
        with tf.GradientTape() as critic_tape:
            fake_images = self.generator(noise, training=False)
            
            real_output = self.critic(real_images, training=True)
            fake_output = self.critic(fake_images, training=True)
            
            critic_loss = self.critic_loss(real_output, fake_output)
            gp = self.gradient_penalty(real_images, fake_images, batch_size)
            critic_loss += self.lambda_gp * gp
        
        critic_gradients = critic_tape.gradient(critic_loss, self.critic.trainable_variables)
        self.c_optimizer.apply_gradients(zip(critic_gradients, self.critic.trainable_variables))
        
        return critic_loss, gp
    
    @tf.function
    def train_generator_step(self, batch_size):
        """Train generator for one step"""
        
        noise = tf.random.normal([batch_size, self.latent_dim])
        
        with tf.GradientTape() as gen_tape:
            fake_images = self.generator(noise, training=True)
            fake_output = self.critic(fake_images, training=False)
            
            gen_loss = self.generator_loss(fake_output)
        
        gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(gen_gradients, self.generator.trainable_variables))
        
        return gen_loss
    
    def train_step(self, real_images, batch_size):
        """Full training step for WGAN-GP"""
        
        # Train critic multiple times
        critic_losses = []
        gradient_penalties = []
        
        for _ in range(self.critic_iterations):
            c_loss, gp = self.train_critic_step(real_images, batch_size)
            critic_losses.append(c_loss.numpy())
            gradient_penalties.append(gp.numpy())
        
        # Train generator once
        g_loss = self.train_generator_step(batch_size)
        
        return np.mean(critic_losses), g_loss.numpy(), np.mean(gradient_penalties)

# Training function for WGAN-GP
def train_wgan_gp(wgan, dataset, epochs=20, batch_size=128):
    """Train WGAN-GP with monitoring"""
    
    history = {'c_loss': [], 'g_loss': [], 'gp': []}
    fixed_noise = tf.random.normal([16, wgan.latent_dim])
    
    batches_per_epoch = len(dataset) // batch_size
    
    for epoch in range(epochs):
        start_time = time.time()
        
        np.random.shuffle(dataset)
        
        epoch_c_loss = []
        epoch_g_loss = []
        epoch_gp = []
        
        for batch_idx in range(batches_per_epoch):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            real_images = dataset[start_idx:end_idx]
            
            c_loss, g_loss, gp = wgan.train_step(real_images, batch_size)
            
            epoch_c_loss.append(c_loss)
            epoch_g_loss.append(g_loss)
            epoch_gp.append(gp)
        
        history['c_loss'].append(np.mean(epoch_c_loss))
        history['g_loss'].append(np.mean(epoch_g_loss))
        history['gp'].append(np.mean(epoch_gp))
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs} - {epoch_time:.2f}s - "
              f"C_loss: {history['c_loss'][-1]:.4f}, G_loss: {history['g_loss'][-1]:.4f}, "
              f"GP: {history['gp'][-1]:.4f}")
        
        if (epoch + 1) % 5 == 0:
            generate_and_save_images(wgan.generator, epoch + 1, fixed_noise)
    
    return history

# Test WGAN-GP
print("=== WGAN-GP Implementation ===")

wgan_gp = WGANGP(latent_dim=100, img_shape=(28, 28, 1), critic_iterations=5, lambda_gp=10)

print("WGAN-GP Generator Architecture:")
wgan_gp.generator.summary()
print("\nWGAN-GP Critic Architecture:")
wgan_gp.critic.summary()

# Train for a few epochs (demo)
print("\nTraining WGAN-GP (demo with 3 epochs)...")
wgan_history = train_wgan_gp(wgan_gp, X_train[:5000], epochs=3, batch_size=64)

## 5. Advanced Training Techniques and Stabilization

In [None]:
# Advanced GAN training techniques
class StableGANTrainer:
    """Advanced trainer with stabilization techniques"""
    
    def __init__(self, gan_type='dcgan'):
        self.gan_type = gan_type
        self.training_history = {
            'g_loss': [], 'd_loss': [], 'fid_scores': [], 'is_scores': []
        }
    
    def spectral_normalization(self, layer):
        """Apply spectral normalization to layer"""
        return tf.keras.utils.get_custom_objects().get('SpectralNormalization', layer)
    
    def progressive_growing_schedule(self, current_epoch, total_epochs):
        """Progressive growing schedule for resolution"""
        phases = [4, 8, 16, 32]  # Resolution phases
        phase_length = total_epochs // len(phases)
        current_phase = min(current_epoch // phase_length, len(phases) - 1)
        return phases[current_phase]
    
    def adaptive_learning_rate(self, g_loss, d_loss, base_lr=0.0002):
        """Adaptive learning rate based on loss balance"""
        
        # If discriminator is too strong, reduce its learning rate
        if d_loss < 0.1 and g_loss > 2.0:
            d_lr = base_lr * 0.5
            g_lr = base_lr * 1.5
        # If generator is too strong, reduce its learning rate  
        elif g_loss < 0.1 and d_loss > 2.0:
            g_lr = base_lr * 0.5
            d_lr = base_lr * 1.5
        else:
            g_lr = base_lr
            d_lr = base_lr
        
        return g_lr, d_lr
    
    def label_smoothing(self, labels, smoothing=0.1):
        """Apply label smoothing to reduce discriminator overconfidence"""
        
        if smoothing > 0:
            # Smooth positive labels
            labels = labels * (1 - smoothing) + 0.5 * smoothing
        
        return labels
    
    def feature_matching_loss(self, real_features, fake_features):
        """Feature matching loss for stable training"""
        
        return tf.reduce_mean(tf.abs(tf.reduce_mean(real_features, axis=0) - 
                                   tf.reduce_mean(fake_features, axis=0)))
    
    def diversity_loss(self, generated_images, batch_size):
        """Encourage diversity in generated images"""
        
        # Compute pairwise distances
        flattened = tf.reshape(generated_images, [batch_size, -1])
        
        # L2 distance matrix
        distances = tf.norm(flattened[:, None] - flattened[None, :], axis=2)
        
        # Encourage larger minimum distance
        min_distances = tf.reduce_min(distances + tf.eye(batch_size) * 1e6, axis=1)
        diversity_loss = -tf.reduce_mean(min_distances)
        
        return diversity_loss

# Model evaluation metrics
class GANEvaluator:
    """Comprehensive GAN evaluation"""
    
    def __init__(self):
        pass
    
    def calculate_fid_score(self, real_images, generated_images, batch_size=50):
        """Calculate Fr√©chet Inception Distance (simplified)"""
        
        # In practice, use pre-trained InceptionV3
        # This is a simplified version for demonstration
        
        def get_activations(images):
            # Simple feature extractor (replace with InceptionV3 in practice)
            model = tf.keras.Sequential([
                layers.Conv2D(32, 3, activation='relu'),
                layers.GlobalAveragePooling2D(),
                layers.Dense(128, activation='relu')
            ])
            
            return model(images)
        
        real_features = get_activations(real_images[:batch_size])
        fake_features = get_activations(generated_images[:batch_size])
        
        # Calculate statistics
        mu_real = tf.reduce_mean(real_features, axis=0)
        mu_fake = tf.reduce_mean(fake_features, axis=0)
        
        sigma_real = tfp.stats.covariance(real_features)
        sigma_fake = tfp.stats.covariance(fake_features)
        
        # FID calculation (simplified)
        diff = mu_real - mu_fake
        fid = tf.reduce_sum(diff ** 2) + tf.linalg.trace(sigma_real + sigma_fake - 2 * tf.linalg.sqrtm(sigma_real @ sigma_fake))
        
        return fid.numpy()
    
    def calculate_inception_score(self, generated_images, batch_size=50):
        """Calculate Inception Score (simplified)"""
        
        # Simplified version - in practice use InceptionV3
        def get_predictions(images):
            # Simple classifier (replace with InceptionV3 in practice)
            model = tf.keras.Sequential([
                layers.Conv2D(32, 3, activation='relu'),
                layers.GlobalAveragePooling2D(),
                layers.Dense(10, activation='softmax')  # 10 classes for demo
            ])
            
            return model(images)
        
        predictions = get_predictions(generated_images[:batch_size])
        
        # Calculate IS
        py = tf.reduce_mean(predictions, axis=0)
        kl_div = predictions * (tf.math.log(predictions) - tf.math.log(py))
        is_score = tf.exp(tf.reduce_mean(tf.reduce_sum(kl_div, axis=1)))
        
        return is_score.numpy()
    
    def visual_quality_assessment(self, generated_images, grid_size=(5, 5)):
        """Visual quality assessment"""
        
        n_samples = grid_size[0] * grid_size[1]
        samples = generated_images[:n_samples]
        
        plt.figure(figsize=(12, 12))
        for i in range(n_samples):
            plt.subplot(grid_size[0], grid_size[1], i + 1)
            
            if samples.shape[-1] == 1:
                plt.imshow(samples[i, :, :, 0], cmap='gray')
            else:
                plt.imshow(samples[i])
            
            plt.axis('off')
        
        plt.suptitle('Generated Images Quality Assessment')
        plt.tight_layout()
        plt.show()

# Training with advanced techniques
def train_with_stabilization(gan, dataset, epochs=50, batch_size=128):
    """Train GAN with advanced stabilization techniques"""
    
    trainer = StableGANTrainer('dcgan')
    evaluator = GANEvaluator()
    
    # Training parameters
    base_lr = 0.0002
    label_smoothing = 0.1
    
    history = {'g_loss': [], 'd_loss': [], 'lr_g': [], 'lr_d': []}
    
    for epoch in range(epochs):
        start_time = time.time()
        
        np.random.shuffle(dataset)
        batches_per_epoch = len(dataset) // batch_size
        
        epoch_g_loss = []
        epoch_d_loss = []
        
        for batch_idx in range(batches_per_epoch):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            real_images = dataset[start_idx:end_idx]
            
            # Generate fake images
            noise = tf.random.normal([batch_size, gan.latent_dim])
            fake_images = gan.generator(noise, training=False)
            
            # Apply label smoothing
            real_labels = trainer.label_smoothing(tf.ones((batch_size, 1)), label_smoothing)
            fake_labels = tf.zeros((batch_size, 1))
            
            # Train discriminator
            with tf.GradientTape() as disc_tape:
                real_output = gan.discriminator(real_images, training=True)
                fake_output = gan.discriminator(fake_images, training=True)
                
                d_loss_real = tf.keras.losses.binary_crossentropy(real_labels, real_output, from_logits=True)
                d_loss_fake = tf.keras.losses.binary_crossentropy(fake_labels, fake_output, from_logits=True)
                d_loss = tf.reduce_mean(d_loss_real) + tf.reduce_mean(d_loss_fake)
            
            # Train generator
            with tf.GradientTape() as gen_tape:
                fake_images = gan.generator(noise, training=True)
                fake_output = gan.discriminator(fake_images, training=False)
                
                g_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(
                    tf.ones_like(fake_output), fake_output, from_logits=True))
                
                # Add diversity loss
                diversity_loss = trainer.diversity_loss(fake_images, batch_size)
                g_loss += 0.1 * diversity_loss
            
            # Adaptive learning rate
            g_lr, d_lr = trainer.adaptive_learning_rate(g_loss.numpy(), d_loss.numpy(), base_lr)
            
            # Apply gradients
            d_gradients = disc_tape.gradient(d_loss, gan.discriminator.trainable_variables)
            g_gradients = gen_tape.gradient(g_loss, gan.generator.trainable_variables)
            
            # Update with adaptive learning rates
            gan.d_optimizer.learning_rate = d_lr
            gan.g_optimizer.learning_rate = g_lr
            
            gan.d_optimizer.apply_gradients(zip(d_gradients, gan.discriminator.trainable_variables))
            gan.g_optimizer.apply_gradients(zip(g_gradients, gan.generator.trainable_variables))
            
            epoch_g_loss.append(g_loss.numpy())
            epoch_d_loss.append(d_loss.numpy())
        
        # Record history
        history['g_loss'].append(np.mean(epoch_g_loss))
        history['d_loss'].append(np.mean(epoch_d_loss))
        history['lr_g'].append(g_lr)
        history['lr_d'].append(d_lr)
        
        epoch_time = time.time() - start_time
        print(f"Epoch {epoch+1}/{epochs} - {epoch_time:.2f}s - "
              f"G_loss: {history['g_loss'][-1]:.4f}, D_loss: {history['d_loss'][-1]:.4f}, "
              f"G_LR: {g_lr:.6f}, D_LR: {d_lr:.6f}")
        
        # Periodic evaluation
        if (epoch + 1) % 10 == 0:
            # Generate samples for evaluation
            test_noise = tf.random.normal([50, gan.latent_dim])
            generated_samples = gan.generator(test_noise, training=False)
            generated_samples = 0.5 * generated_samples + 0.5
            
            # Visual assessment
            evaluator.visual_quality_assessment(generated_samples, (5, 5))
    
    return history

# Plot comprehensive training history
def plot_training_analysis(histories, model_names):
    """Plot comprehensive training analysis"""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Generator losses
    axes[0, 0].set_title('Generator Losses')
    for history, name in zip(histories, model_names):
        axes[0, 0].plot(history['g_loss'], label=f'{name} Generator')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Discriminator losses
    axes[0, 1].set_title('Discriminator Losses')
    for history, name in zip(histories, model_names):
        axes[0, 1].plot(history['d_loss'], label=f'{name} Discriminator')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Loss ratio (stability indicator)
    axes[0, 2].set_title('Loss Ratio (G_loss / D_loss)')
    for history, name in zip(histories, model_names):
        if 'd_loss' in history:
            ratio = np.array(history['g_loss']) / (np.array(history['d_loss']) + 1e-8)
            axes[0, 2].plot(ratio, label=name)
    axes[0, 2].axhline(y=1.0, color='red', linestyle='--', alpha=0.5, label='Ideal Ratio')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Ratio')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Learning rates (if available)
    if 'lr_g' in histories[0]:
        axes[1, 0].set_title('Adaptive Learning Rates')
        axes[1, 0].plot(histories[0]['lr_g'], label='Generator LR')
        axes[1, 0].plot(histories[0]['lr_d'], label='Discriminator LR')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    
    # Training stability
    axes[1, 1].set_title('Training Stability')
    for history, name in zip(histories, model_names):
        stability = np.abs(np.diff(history['g_loss']))
        axes[1, 1].plot(stability, label=f'{name} G_loss variance')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss Variance')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Model comparison
    axes[1, 2].set_title('Final Performance Comparison')
    final_g_losses = [history['g_loss'][-1] for history in histories]
    final_d_losses = [history['d_loss'][-1] for history in histories]
    
    x = np.arange(len(model_names))
    width = 0.35
    
    axes[1, 2].bar(x - width/2, final_g_losses, width, label='Generator Loss', alpha=0.8)
    axes[1, 2].bar(x + width/2, final_d_losses, width, label='Discriminator Loss', alpha=0.8)
    axes[1, 2].set_xlabel('Models')
    axes[1, 2].set_ylabel('Final Loss')
    axes[1, 2].set_xticks(x)
    axes[1, 2].set_xticklabels(model_names)
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Compare all trained models
print("=== Training Analysis and Comparison ===")

# Collect all training histories
all_histories = [dcgan_history, cgan_history, wgan_history]
model_names = ['DCGAN', 'Conditional GAN', 'WGAN-GP']

# Plot comprehensive analysis
plot_training_analysis(all_histories, model_names)

## Summary

This comprehensive notebook demonstrated the complete spectrum of GAN implementations using tf.keras:

### Key Implementations

**1. Basic GAN Architecture:**
- Fundamental generator and discriminator concepts
- tf.keras Sequential API implementation
- Basic adversarial training loop

**2. Deep Convolutional GAN (DCGAN):**
- Industry-standard architecture with best practices
- Optimized training with tf.function decorators
- Proper normalization and activation choices

**3. Conditional GAN (cGAN):**
- tf.keras Functional API for complex inputs
- Class-conditional generation capabilities
- Label embedding and conditioning techniques

**4. Wasserstein GAN-GP (WGAN-GP):**
- Gradient penalty for training stability
- Critic instead of discriminator (no sigmoid)
- Earth Mover's distance optimization

**5. Advanced Training Techniques:**
- Label smoothing and adaptive learning rates
- Feature matching and diversity losses
- Spectral normalization and stabilization methods

### Technical Achievements

- **Stable Training**: WGAN-GP provides most stable training dynamics
- **Controlled Generation**: Conditional GANs enable precise control over outputs
- **Quality Metrics**: FID and IS scores for objective evaluation
- **Production Ready**: Optimized implementations with tf.function

### Performance Insights

- **DCGAN**: Fast convergence, good for standard image generation
- **cGAN**: Excellent controllability, slight complexity overhead
- **WGAN-GP**: Most stable, higher computational cost due to gradient penalty
- **Advanced Techniques**: Significant improvements in training stability

### Applications Demonstrated

- Digit generation with MNIST
- Class-conditional synthesis
- High-resolution image creation
- Comprehensive model evaluation and comparison

### Next Steps

Continue to notebook 14 (VAEs and Advanced GANs) to explore Variational Autoencoders and cutting-edge GAN architectures like StyleGAN and Progressive GANs, building upon the foundational techniques mastered here.

The GAN implementations provide a solid foundation for generating high-quality synthetic data across various domains and applications.