# Generative AI 
Generative AI refers to a subset of artificial intelligence models and algorithms designed to generate new data samples  that closely mirror the statistical properties of a target dataset. Unlike traditional classification and regression models, which predict labels or values based on input features, generative models aim to understand and replicate the complex distributions of input data, allowing them to produce entirely new data points that are similar or even indistinguishable to those in the training set.

## Main Classes of Generative AI Models

### 1. **Generative Adversarial Networks (GANs):**
- **Overview:** Consist of two neural networks, the generator and the discriminator, that are trained simultaneously in an adversarial process.
- **Applications:** Widely used for tasks like image generation, super-resolution, style transfer, and data augmentation.

### 2. **Variational Autoencoders (VAEs):**
- **Overview:** Based on Bayesian inference principles, VAEs encode input data into a distribution in latent space and decode from this space to reconstruct the input.
- **Applications:** Employed in image generation, anomaly detection, and semi-supervised learning tasks.

### 3. **Autoregressive Models:**
- **Overview:** These models generate sequences by predicting each new piece of data conditioned on the previously generated pieces.
- **Applications:** Utilized for text generation, image synthesis, and time series forecasting.

### 4. **Diffusion Models:**
- **Overview:** Generate data by gradually denoising a sample from a simple distribution over a series of steps, guided by a trained neural network.
- **Applications:** Shown impressive results in high-fidelity image generation, audio synthesis, and molecular design.

Each class has its strengths and is chosen based on specific task requirements, such as sample quality, training stability, and computational efficiency. The field continues to evolve rapidly, introducing new models and improving existing ones.

## Differences from discriminative models
In this course, at least so far, we have been focusing on class of models referred to as `discriminative models`. Deep learning classifiers and regressors are two examples of such models. Both classification and regression models focus on mapping input data to known outputs or labels, making them excellent for predictive tasks but not suited for generating new data.

## Potential in Biology
Generative models have significant potential in biology, offering innovative ways to tackle complex problems:

* __Drug Discovery__: Generative models can propose novel molecular structures with desired properties, accelerating the identification of potential new drugs. See [examples](https://blogs.nvidia.com/blog/generative-ai-proteins-evozyne/)
* __Synthetic Biology__: These models can design new genetic sequences or synthetic organisms with specific functions, supporting advances in bioengineering.
* __Data Augmentation__: Generative AI can create additional training samples for rare conditions or species, enhancing the performance of classification models in biology.
* __Understanding Complex Systems__: By generating data under different simulated conditions, generative models help in understanding complex biological systems and interactions.

Today, we will concentrate in more mundane tasks, such as increasing the resolution of bioimages, or generating synthetic images. To do that, we will use a class of generative models called '_Generative Adversarial Networks_'.

## Generative Adversarial Networks (GANs)
GANs are a powerful class of generative models consisting of two neural networks, the generator and the discriminator, which are trained simultaneously in a competitive setting:

* Generator: Learns to produce data that mimics the training dataset. In biology, this could involve generating synthetic genomic sequences or realistic cell images.
* Discriminator: Learns to distinguish between real data from the training set and fake data produced by the generator. Its feedback helps improve the generator.
The adversarial process leads to the generator creating highly realistic data, making GANs particularly effective for tasks requiring high-quality synthetic data generation. Please see the diagram below that was produced by [Google](https://developers.google.com/machine-learning/gan/gan_structure).

![GAN](images/gan_diagram.svg)

## Generating random digits

Today, our first task will be to create a simple GAN implementation and train it to generate digit images that resemble those from the MNIST dataset, a collection of handwritten digits widely used in machine learning. The code we'll explore consists of two main components: the Generator, which will learn to produce images akin to MNIST digits, and the Discriminator, whose job is to distinguish between real MNIST images and the fakes produced by our Generator. As we dive into the code, keep in mind that the Generator and Discriminator are essentially in a continuous game of cat and mouse, improving through each iteration of our training loop. This dynamic interaction is what makes GANs both challenging and incredibly fascinating. 


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, TensorDataset

# Set keras backend to PyTorch
import os
os.environ['KERAS_BACKEND'] = 'torch'

import keras
from keras import layers

To start, let's load the MNIST dataset from torchvision

In [None]:
# Load MNIST dataset
# Define the transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

# Download and load the MNIST dataset
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Define batch size
BATCH_SIZE = 256

# Create the DataLoader
train_loader = DataLoader(mnist_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

We can then define a simple generator model using Keras with PyTorch backend

In [None]:
def make_generator_model():
    model = keras.Sequential([
        layers.Dense(7*7*128, use_bias=False, input_shape=(200,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((7, 7, 128)),
        # Upsample to 28x28x1
        layers.Conv2DTranspose(1, (4, 4), strides=(4, 4), padding='same', use_bias=False, activation='tanh')
    ])
    return model

generator = make_generator_model()

In [None]:
generator.summary()

We follow the generator by defining a discriminator


In [None]:
def make_discriminator_model():
    model = keras.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(),
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

discriminator = make_discriminator_model()

In [None]:
discriminator.summary()

We can now define our losses (i.e., error metric). We will use binary cross entropy because the task at hand is quite simple. Either the model will consider a sample "real" or "fake".

In [None]:
# Define loss functions
def discriminator_loss(real_output, fake_output):
    # Use BCEWithLogitsLoss which has better numerical stability
    bce = nn.BCEWithLogitsLoss()
    real_loss = bce(real_output, torch.ones_like(real_output))
    fake_loss = bce(fake_output, torch.zeros_like(fake_output))
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    bce = nn.BCEWithLogitsLoss()
    return bce(fake_output, torch.ones_like(fake_output))

# Setup optimizers
generator_optimizer = optim.Adam(generator.parameters(), lr=5e-4)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=5e-4)



We can also create some code to display images during training.

In [None]:
def generate_and_display_images(model, test_input):
    # Notice training=False so BatchNorm runs in inference mode
    predictions = model(test_input, training=False)
    predictions = predictions.detach().cpu().numpy()
    
    fig = plt.figure(figsize=(4, 4))
    
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
        
    plt.show()

Finally, we can define our approach for training the model and storing the losses/data as needed.

In [None]:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

# Lists to store losses
generator_losses = []
discriminator_losses = []

def train(dataloader, epochs, seed, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Move seed to device
    seed = seed.to(device)
    
    # Training loop
    for epoch in tqdm(range(epochs), desc='Epochs'):
        epoch_gen_loss = []
        epoch_disc_loss = []
        
        for batch in dataloader:
            images, _ = batch
            
            # Ensure correct shape for MNIST images and move to device
            images = images.reshape(-1, 28, 28, 1).to(device)
            batch_size = images.shape[0]
            
            # Generate noise on device
            noise = torch.randn(batch_size, 200, device=device)
            
            # ------- Combined forward pass (like TensorFlow version) -------
            with autocast():
                # Generate fake images once
                fake_images = generator(noise, training=True)
                
                # Get discriminator outputs for both real and fake
                real_output = discriminator(images, training=True)
                fake_output = discriminator(fake_images.detach(), training=True)
                
                # Calculate discriminator loss
                real_loss = nn.BCEWithLogitsLoss()(real_output, torch.ones_like(real_output))
                fake_loss = nn.BCEWithLogitsLoss()(fake_output, torch.zeros_like(fake_output))
                disc_loss = real_loss + fake_loss
            
            # ------- Update discriminator -------
            discriminator_optimizer.zero_grad()
            scaler.scale(disc_loss).backward()
            scaler.step(discriminator_optimizer)
            
            # ------- Train generator with the same fake images -------
            with autocast():
                # Recompute discriminator output with the same fake images
                # but allow gradients to flow to generator
                fake_output = discriminator(fake_images, training=True)
                
                # Generator wants discriminator to think its images are real
                gen_loss = nn.BCEWithLogitsLoss()(fake_output, torch.ones_like(fake_output))
            
            generator_optimizer.zero_grad()
            scaler.scale(gen_loss).backward()
            scaler.step(generator_optimizer)
            
            # Update scaler for next iteration
            scaler.update()
            
            # Store losses
            epoch_gen_loss.append(gen_loss.item())
            epoch_disc_loss.append(disc_loss.item())
        
        # Average loss for the epoch
        generator_losses.append(np.mean(epoch_gen_loss))
        discriminator_losses.append(np.mean(epoch_disc_loss))
        
        # Display progress
        if (epoch + 1) % 10 == 0:
            print(f'\nEpoch {epoch + 1} completed')
            print(f'Generator loss: {generator_losses[-1]}, Discriminator loss: {discriminator_losses[-1]}')
            
            # Use autocast here too for consistency
            with torch.no_grad(), autocast():
                generate_and_display_images(generator, seed)
    
    # Final display
    if epochs % 10 != 0:
        with torch.no_grad(), autocast():
            generate_and_display_images(generator, seed)

Let's train!!

In [None]:
# Create fixed noise for image generation
seed = torch.randn(16, 200)  # 16 examples for a 4x4 grid
EPOCHS = 100  # Set the number of epochs

train(train_loader, EPOCHS, seed)


## Applications: Super resolution!


This part of the notebook demonstrates super-resolution using a GAN. We'll use a PyTorch implementation of ESRGAN (Enhanced Super-Resolution Generative Adversarial Networks) to recover high-resolution (HR) images from their low-resolution counterparts.

The original model used in the TensorFlow version upsamples a 50x50 low resolution image to a 200x200 high resolution image (scale factor=4). For this PyTorch version, we'll implement a compatible interface.

Let's start by defining some convenience functions and importing the dependencies:

In [None]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
import requests
import matplotlib.pyplot as plt
from io import BytesIO
import torch.nn.functional as F

class ResidualDenseBlock(nn.Module):
    def __init__(self, nf=64, gc=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=True)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=True)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=True)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=True)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=True)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDB(nn.Module):
    def __init__(self, nf=64):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(nf)
        self.rdb2 = ResidualDenseBlock(nf)
        self.rdb3 = ResidualDenseBlock(nf)
        
    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out * 0.2 + x

class RRDBNet(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23):
        super(RRDBNet, self).__init__()
        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        
        # RRDB blocks
        self.body = nn.ModuleList()
        for _ in range(nb):
            self.body.append(RRDB(nf))
            
        self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        
        # Upsampling
        self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
        
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
    def forward(self, x):
        feat = self.conv_first(x)
        body_feat = feat.clone()
        
        for block in self.body:
            body_feat = block(body_feat)
            
        body_feat = self.conv_body(body_feat)
        feat = feat + body_feat
        
        # Upsampling
        feat = self.lrelu(F.interpolate(self.conv_up1(feat), scale_factor=2, mode='nearest'))
        feat = self.lrelu(F.interpolate(self.conv_up2(feat), scale_factor=2, mode='nearest'))
        
        feat = self.lrelu(self.conv_hr(feat))
        out = self.conv_last(feat)
        
        return out



Let's define our super-resolution wrapper class:

In [None]:
class ESRGANUpscaler:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")
        
        # Initialize the model with correct architecture
        self.model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23).to(self.device)
        
        # Download weights from GitHub if not present
        model_path = 'RealESRGAN_x4plus.pth'
        if not os.path.exists(model_path):
            print("Downloading pre-trained ESRGAN model...")
            url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
            r = requests.get(url)
            with open(model_path, 'wb') as f:
                f.write(r.content)
            print(f"Downloaded model to {model_path}")
        
        try:
            # Load pre-trained weights
            weights = torch.load(model_path, map_location=self.device)
            
            # Extract the appropriate part of the state dict
            if 'params_ema' in weights:
                state_dict = weights['params_ema']
            elif 'params' in weights:
                state_dict = weights['params']
            else:
                state_dict = weights
                
            self.model.load_state_dict(state_dict, strict=True)
            self.model.eval()
            print("Successfully loaded pre-trained ESRGAN model!")
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Using model without pre-trained weights (results will be poor)")
        
    def upscale(self, img):
        """
        Upscale an image using the ESRGAN model
        Args:
            img: PIL Image or numpy array
        Returns:
            super-resolution image as numpy array
        """
        # Convert input to proper format
        if isinstance(img, np.ndarray):
            img = Image.fromarray(img.astype(np.uint8))
        
        # Ensure it's RGB
        img = img.convert('RGB')
        
        # Convert to tensor and normalize to [0, 1]
        img_tensor = torch.from_numpy(np.array(img)).float().div(255.0)
        # Change from HWC to CHW format
        img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(self.device)
        
        # Process with the model
        with torch.no_grad():
            if self.device.type == 'cuda':
                with torch.amp.autocast('cuda'):
                    output = self.model(img_tensor)
            else:
                output = self.model(img_tensor)
        
        # Convert output tensor to numpy array
        output = output.squeeze().float().cpu().clamp_(0, 1).permute(1, 2, 0).numpy()
        
        # Scale to [0, 255] and convert to uint8
        output = (output * 255.0).round().astype(np.uint8)
        
        return output

Now let's download and display a test image:

In [None]:
# Load the image
url = 'https://raw.githubusercontent.com/tensorflow/examples/master/lite/examples/super_resolution/android/app/src/main/assets/lr-1.jpg'
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert('RGB')

# Display the image
plt.figure(figsize=(1,1))
plt.imshow(img)
plt.axis('off')  # Hide axes ticks
plt.show()

Let's first try to upsample this image using interpolation techniques

In [None]:
w, h = img.size
bicubic = np.array(img.resize((w*4, h*4), Image.BICUBIC))

# Plot
plt.subplot(1, 2, 2)
plt.title('Bicubic')
plt.imshow(bicubic)  # Convert CHW -> HWC for matplotlib
plt.show()

Doesn't look that great. Finally, we can run the model and see if we obtain a better image

In [None]:
# Create the ESRGAN upscaler
print("Initializing ESRGAN model...")
esrgan_model = ESRGANUpscaler()

# Upscale with ESRGAN
print("Upscaling with ESRGAN...")
esrgan_result = esrgan_model.upscale(img)

# Create a bicubic upscaled version for comparison
w, h = img.size
bicubic = np.array(img.resize((w*4, h*4), Image.BICUBIC))

# Display results
plt.figure(figsize=(15, 8))

plt.subplot(1, 3, 1)
plt.title('Original')
plt.imshow(img)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Bicubic (4x)')
plt.imshow(bicubic)
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('ESRGAN (4x)')
plt.imshow(esrgan_result)
plt.axis('off')

plt.tight_layout()
plt.show()

## Applications: Image Synthesis
Generative Adversarial Networks (GANs) are powerful models for synthesizing new, realistic images from random noise inputs. In this section, we demonstrate how an old-school (2015) pretrained GAN generator can be used to produce synthetic images, illustrating the generative capability central to GAN-based synthesis tasks.

Introduced by Radford et al. in 2015, DCGAN demonstrated that GANs could generate high-quality, coherent images by leveraging convolutional neural networks (CNNs) in both the generator and discriminator components. This marked a significant advancement over earlier fully connected GAN architectures, which struggled to capture spatial hierarchies in image data.

### Why DCGAN?

DCGAN introduced key architectural innovations that made GANs more stable and efficient to train:

* Convolutional layers without pooling: Both the generator and discriminator use strided convolutions, removing the need for pooling layers and allowing the network to learn its own spatial downsampling or upsampling.

* Batch normalization: Applied in both networks (except in the output layers), it stabilizes training by normalizing feature distributions.

* ReLU and LeakyReLU activations: The generator uses ReLU activations, encouraging diverse outputs, while the discriminator uses LeakyReLU to avoid dead neurons.

* Tanh output: The generator outputs images in a normalized range (-1, 1), improving convergence.





In [None]:
import torch
import torch.nn as nn
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import urllib.request

# Define your Generator 
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=3):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),   # main.0
            nn.BatchNorm2d(ngf * 8),                               # main.1
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), # main.3
            nn.BatchNorm2d(ngf * 4),                               # main.4
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), # main.6
            nn.BatchNorm2d(ngf * 2),                               # main.7
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),     # main.9
            nn.BatchNorm2d(ngf),                                   # main.10
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),          # main.12
            nn.Tanh()
        )
    def forward(self, input):
        return self.main(input)

# Download the model
model_url = 'https://github.com/Natsu6767/DCGAN-PyTorch/raw/master/model/model_final.pth'
model_filename = 'model_final.pth'
urllib.request.urlretrieve(model_url, model_filename)

# Load the checkpoint
checkpoint = torch.load(model_filename, map_location='cpu')
gen_weights = checkpoint['generator']

# Remap keys
new_state_dict = {}
key_map = {
    'tconv1.weight': 'main.0.weight',
    'bn1.weight': 'main.1.weight',
    'bn1.bias': 'main.1.bias',
    'bn1.running_mean': 'main.1.running_mean',
    'bn1.running_var': 'main.1.running_var',
    'tconv2.weight': 'main.3.weight',
    'bn2.weight': 'main.4.weight',
    'bn2.bias': 'main.4.bias',
    'bn2.running_mean': 'main.4.running_mean',
    'bn2.running_var': 'main.4.running_var',
    'tconv3.weight': 'main.6.weight',
    'bn3.weight': 'main.7.weight',
    'bn3.bias': 'main.7.bias',
    'bn3.running_mean': 'main.7.running_mean',
    'bn3.running_var': 'main.7.running_var',
    'tconv4.weight': 'main.9.weight',
    'bn4.weight': 'main.10.weight',
    'bn4.bias': 'main.10.bias',
    'bn4.running_mean': 'main.10.running_mean',
    'bn4.running_var': 'main.10.running_var',
    'tconv5.weight': 'main.12.weight'
}

for k, v in gen_weights.items():
    if k in key_map:
        new_state_dict[key_map[k]] = v
    else:
        print(f"Skipping key: {k}")

# Initialize generator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG = Generator().to(device)
netG.load_state_dict(new_state_dict)
netG.eval()



Now let's plot it!

In [None]:
# Generate images
nz = 100
num_images = 4
noise = torch.randn(num_images, nz, 1, 1, device=device)
with torch.no_grad():
    fake_images = netG(noise).cpu()

# Plot
grid = vutils.make_grid(fake_images, nrow=4, normalize=True)
plt.figure(figsize=(16,4))
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.axis('off')
plt.show()
