# Week 9, Day 2: Advanced GANs and VAEs

## Learning Objectives
- Understand advanced GAN architectures
- Learn conditional generation
- Master advanced VAE techniques
- Practice implementing advanced models

## Topics Covered
1. Conditional GANs
2. CycleGAN
3. Conditional VAEs
4. Beta-VAE

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers

## 1. Conditional GAN

In [None]:
class ConditionalGAN:
    def __init__(self, latent_dim, num_classes):
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Build generator and discriminator
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
        
        # Optimizers
        self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    
    def build_generator(self):
        # Noise input
        noise = layers.Input(shape=(self.latent_dim,))
        
        # Label input
        label = layers.Input(shape=(1,))
        label_embedding = layers.Embedding(self.num_classes, 50)(label)
        label_embedding = layers.Flatten()(label_embedding)
        
        # Combine noise and label
        combined = layers.Concatenate()([noise, label_embedding])
        
        # Generator network
        x = layers.Dense(7*7*256)(combined)
        x = layers.Reshape((7, 7, 256))(x)
        
        x = layers.Conv2DTranspose(128, 4, strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        x = layers.Conv2DTranspose(64, 4, strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        output = layers.Conv2D(1, 4, padding='same', activation='tanh')(x)
        
        return tf.keras.Model([noise, label], output)
    
    def build_discriminator(self):
        # Image input
        image = layers.Input(shape=(28, 28, 1))
        
        # Label input
        label = layers.Input(shape=(1,))
        label_embedding = layers.Embedding(self.num_classes, 50)(label)
        label_embedding = layers.Flatten()(label_embedding)
        label_embedding = layers.Dense(28*28)(label_embedding)
        label_embedding = layers.Reshape((28, 28, 1))(label_embedding)
        
        # Combine image and label
        combined = layers.Concatenate()([image, label_embedding])
        
        # Discriminator network
        x = layers.Conv2D(64, 4, strides=2, padding='same')(combined)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        x = layers.Flatten()(x)
        output = layers.Dense(1, activation='sigmoid')(x)
        
        return tf.keras.Model([image, label], output)

## 2. CycleGAN

In [None]:
class CycleGAN:
    def __init__(self):
        # Build generators and discriminators
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        
        # Optimizers
        self.g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
        self.d_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    
    def build_generator(self):
        # U-Net Generator
        inputs = layers.Input(shape=(256, 256, 3))
        
        # Encoder
        e1 = self.encoder_block(inputs, 64, batchnorm=False)
        e2 = self.encoder_block(e1, 128)
        e3 = self.encoder_block(e2, 256)
        e4 = self.encoder_block(e3, 512)
        
        # Decoder
        d1 = self.decoder_block(e4, e3, 256)
        d2 = self.decoder_block(d1, e2, 128)
        d3 = self.decoder_block(d2, e1, 64)
        
        outputs = layers.Conv2DTranspose(3, 4, strides=2, padding='same',
                                        activation='tanh')(d3)
        
        return tf.keras.Model(inputs=inputs, outputs=outputs)
    
    def build_discriminator(self):
        # PatchGAN discriminator
        inputs = layers.Input(shape=(256, 256, 3))
        
        x = layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        
        outputs = layers.Conv2D(1, 4, strides=1, padding='same')(x)
        
        return tf.keras.Model(inputs=inputs, outputs=outputs)
    
    def encoder_block(self, x, filters, batchnorm=True):
        x = layers.Conv2D(filters, 4, strides=2, padding='same')(x)
        if batchnorm:
            x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.2)(x)
        return x
    
    def decoder_block(self, x, skip, filters):
        x = layers.Conv2DTranspose(filters, 4, strides=2, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Concatenate()([x, skip])
        x = layers.ReLU()(x)
        return x

## 3. Conditional VAE

In [None]:
class ConditionalVAE(tf.keras.Model):
    def __init__(self, latent_dim, num_classes):
        super(ConditionalVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        
        # Encoder
        self.encoder = tf.keras.Sequential([
            layers.Input(shape=(28, 28, 1)),
            layers.Conv2D(32, 3, activation='relu', strides=2, padding='same'),
            layers.Conv2D(64, 3, activation='relu', strides=2, padding='same'),
            layers.Flatten(),
            layers.Dense(latent_dim + latent_dim)
        ])
        
        # Decoder
        self.decoder = tf.keras.Sequential([
            layers.Input(shape=(latent_dim + num_classes,)),
            layers.Dense(7*7*32, activation='relu'),
            layers.Reshape((7, 7, 32)),
            layers.Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same'),
            layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same'),
            layers.Conv2DTranspose(1, 3, activation='sigmoid', padding='same')
        ])
    
    def encode(self, x, c):
        x = self.encoder(x)
        mean, logvar = tf.split(x, num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z, c):
        z_c = tf.concat([z, c], axis=1)
        return self.decoder(z_c)
    
    def call(self, inputs):
        x, c = inputs
        mean, logvar = self.encode(x, c)
        z = self.reparameterize(mean, logvar)
        x_logit = self.decode(z, c)
        return x_logit, mean, logvar

## 4. Beta-VAE

In [None]:
class BetaVAE(tf.keras.Model):
    def __init__(self, latent_dim, beta=4.0):
        super(BetaVAE, self).__init__()
        self.latent_dim = latent_dim
        self.beta = beta
        
        # Encoder
        self.encoder = tf.keras.Sequential([
            layers.Input(shape=(28, 28, 1)),
            layers.Conv2D(32, 3, activation='relu', strides=2, padding='same'),
            layers.Conv2D(64, 3, activation='relu', strides=2, padding='same'),
            layers.Flatten(),
            layers.Dense(256, activation='relu'),
            layers.Dense(latent_dim + latent_dim)
        ])
        
        # Decoder
        self.decoder = tf.keras.Sequential([
            layers.Input(shape=(latent_dim,)),
            layers.Dense(256, activation='relu'),
            layers.Dense(7*7*64, activation='relu'),
            layers.Reshape((7, 7, 64)),
            layers.Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same'),
            layers.Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same'),
            layers.Conv2DTranspose(1, 3, activation='sigmoid', padding='same')
        ])
    
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean
    
    def decode(self, z):
        return self.decoder(z)
    
    def call(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_logit = self.decode(z)
        return x_logit, mean, logvar
    
    def compute_loss(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_logit = self.decode(z)
        
        # Reconstruction loss
        cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
        logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
        
        # KL divergence
        kl_div = -0.5 * (1 + logvar - tf.square(mean) - tf.exp(logvar))
        kl_div = tf.reduce_sum(kl_div, axis=1)
        
        # Beta-VAE loss
        return -tf.reduce_mean(logpx_z - self.beta * kl_div)

## Practical Exercises

In [None]:
# Exercise 1: Conditional Generation

def conditional_generation_exercise():
    print("Task: Implement conditional generation")
    print("1. Create conditional model")
    print("2. Process condition input")
    print("3. Train model")
    print("4. Generate samples")
    
    # Your code here

conditional_generation_exercise()

In [None]:
# Exercise 2: Style Transfer

def style_transfer_exercise():
    print("Task: Implement style transfer")
    print("1. Create transfer model")
    print("2. Process style and content")
    print("3. Train model")
    print("4. Generate transfers")
    
    # Your code here

style_transfer_exercise()

## MCQ Quiz

1. What is conditional generation?
   - a) Random generation
   - b) Controlled generation
   - c) Data processing
   - d) Model training

2. What is CycleGAN?
   - a) Basic GAN
   - b) Unpaired translation
   - c) Classification model
   - d) Regression model

3. What is a conditional VAE?
   - a) Basic VAE
   - b) Conditional generation
   - c) Classification model
   - d) Regression model

4. What is Beta-VAE?
   - a) Basic VAE
   - b) Disentanglement model
   - c) Classification model
   - d) Regression model

5. What is cycle consistency?
   - a) Data processing
   - b) Translation constraint
   - c) Model architecture
   - d) Loss function

6. What is disentanglement?
   - a) Data processing
   - b) Feature separation
   - c) Model architecture
   - d) Loss function

7. What is style transfer?
   - a) Data processing
   - b) Style application
   - c) Model architecture
   - d) Loss function

8. What is the purpose of skip connections?
   - a) Model training
   - b) Detail preservation
   - c) Loss calculation
   - d) Data processing

9. What is perceptual loss?
   - a) Basic loss
   - b) Feature-based loss
   - c) Classification loss
   - d) Regression loss

10. What is domain adaptation?
    - a) Data processing
    - b) Domain transfer
    - c) Model architecture
    - d) Loss function

Answers: 1-b, 2-b, 3-b, 4-b, 5-b, 6-b, 7-b, 8-b, 9-b, 10-b