In [7]:
import lightning as L
import torch
import os
import torchvision.transforms.v2 as v2
import datetime
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
from pathlib import Path
from datetime import datetime

In [8]:
"""
Sources:
Lightning GAN implementation
https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

WGAN paper
https://arxiv.org/abs/1701.07875

WGAN with gradient penalty paper
https://arxiv.org/abs/1704.00028

GAN implementations
https://github.com/eriklindernoren/PyTorch-GAN/tree/master

"""

'\nSources:\nLightning GAN implementation\nhttps://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html\n\nWGAN paper\nhttps://arxiv.org/abs/1701.07875\n\nWGAN with gradient penalty paper\nhttps://arxiv.org/abs/1704.00028\n\nGAN implementations\nhttps://github.com/eriklindernoren/PyTorch-GAN/tree/master\n\n'

In [9]:
class Generator(nn.Module):
    def __init__(self, img_shape=(1, 28, 28), latent_dim: int = 256):
        super(Generator, self).__init__()

        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            # Start from latent_dim=512, no initial bottleneck
            *block(latent_dim, 1024, normalize=False),  # Expand immediately
            *block(1024, 2048),
            *block(2048, 4096),
            nn.Linear(4096, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *self.img_shape)
        return img
    
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.img_shape = img_shape

        self.model = nn.Sequential(
            # More gradual compression for better feature extraction
            nn.Linear(int(np.prod(self.img_shape)), 4096),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(4096, 2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

In [10]:
class GAN(L.LightningModule):
    def __init__(
            self,
            img_size: int = 28,
            latent_dim: int = 256,
            lr: float = 3e-4,
            b1: float = 0,
            b2: float = 0.999,
            n_critic: int = 5  
            ):
        super().__init__()
        # This is partialy for wandb logging
        self.save_hyperparameters()
        # This is important
        self.automatic_optimization = False

        self.latent_dim = latent_dim
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.img_size = img_size
        self.n_critic = n_critic

        self.generator = Generator(img_shape=(1, self.img_size, self.img_size), latent_dim=self.latent_dim)
        self.discriminator = Discriminator(img_shape=(1, self.img_size, self.img_size))

    def forward(self, z):
        return self.generator(z)

    def loss_Discriminator(self, real_img, gen_img):
        return -torch.mean(real_img) + torch.mean(gen_img)

    def loss_Generator(self, gen_img):
        return -torch.mean(gen_img)

    def gradient_penalty(self, critic, real_samples, fake_samples):
        """
        Calculate the gradient penalty for WGAN-GP (Wasserstein GAN with gradient penalty).
        
        Args:
            critic (nn.Module): The critic network
            real_samples (torch.Tensor): Batch of real samples
            fake_samples (torch.Tensor): Batch of generated samples
            
        Returns:
            torch.Tensor: Gradient penalty term (scalar)
        """
        # Random weight for interpolation between real and fake samples
        alpha = torch.rand((real_samples.size(0), 1, 1, 1)).type_as(real_samples)
        
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        
        # Calculate critic scores for interpolated images
        d_interpolates = critic(interpolates)
        
        # Calculate gradients of scores with respect to interpolates
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            # this is equivalent to .to(device) but instead creates the tensor on the device that the images are on
            grad_outputs=torch.ones_like(d_interpolates).type_as(d_interpolates),
            create_graph=True,
            # this is important
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        # Flatten gradients to easily calculate the norm
        gradients = gradients.view(gradients.size(0), -1)
        
        # Calculate gradient penalty
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
        return gradient_penalty

    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        optimizer_g, optimizer_d = self.optimizers()

        # Train discriminator n_critic times
        for _ in range(self.n_critic):
            self.toggle_optimizer(optimizer_d)
            
            # Generate new images for discriminator training
            z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim)
            z_d = z_d.type_as(imgs)
            fake_imgs_d = self(z_d)

            real_score = self.discriminator(imgs)
            fake_score = self.discriminator(fake_imgs_d)
            
            gp = self.gradient_penalty(self.discriminator, imgs, fake_imgs_d)
            loss_D = self.loss_Discriminator(real_score, fake_score) + 10 * gp
            
            # Calculate and log Wasserstein distance
            # The negative of the discriminator loss (before gradient penalty) is an estimate of the Wasserstein distance
            wasserstein_distance = torch.mean(real_score) - torch.mean(fake_score)
            self.log("wasserstein_distance", wasserstein_distance)
            
            self.log("d_loss", loss_D)
            self.manual_backward(loss_D)
            optimizer_d.step()
            optimizer_d.zero_grad()
            self.untoggle_optimizer(optimizer_d)

        # Train generator once every n_critic iterations
        self.toggle_optimizer(optimizer_g)
        
        # Generate images for generator training
        z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_g = z_g.type_as(imgs)
        fake_imgs_g = self(z_g)
        
        g_loss = self.loss_Generator(self.discriminator(fake_imgs_g))
        self.log("g_loss", g_loss)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr/3, betas=(b1, b2))
        return [opt_g, opt_d], []

In [11]:
class NoiseDataset(Dataset):
    """Dataset that generates random noise vectors for GAN inference"""
    def __init__(self, num_samples, latent_dim=256):  # Changed to match Generator
        self.num_samples = num_samples
        self.latent_dim = latent_dim
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return torch.randn(self.latent_dim)  # Return a 1D vector of size latent_dim
    
class GANPredictor(L.LightningModule):
    def __init__(self, generator, num_samples=256, batch_size=32, latent_dim=256):  # Changed default latent_dim
        super().__init__()
        self.generator = generator
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.latent_dim = latent_dim
        
        self.generator.eval()
        for param in self.generator.parameters():
            param.requires_grad = False
            
    def forward(self, z):
        return self.generator(z)
    
    def predict_step(self, batch, batch_idx):
        # Generate images
        fake_images = self(batch)          
        return fake_images
    
    def predict_dataloader(self):
        return DataLoader(
            NoiseDataset(
                num_samples=self.num_samples,
                latent_dim=self.latent_dim
            ),
            batch_size=self.batch_size,
            num_workers=os.cpu_count(),
            pin_memory=True,
            shuffle=False
        )


def save_and_display_images(samples, save_dir="generated_images"):
    """
    Save and display generated images
    
    Args:
        samples (torch.Tensor): Generated samples tensor
        save_dir (str): Directory to save images
    """
    # Create save directory with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = Path(save_dir) / timestamp
    save_path.mkdir(parents=True, exist_ok=True)
    
    #Save individual images
    # for i, sample in enumerate(samples):
    #     # Normalize to [0, 1] range if needed
    #     sample = (sample - sample.min()) / (sample.max() - sample.min())
        
    #     # Save individual image
    #     torchvision.utils.save_image(
    #         sample,
    #         save_path / f"sample_{i:03d}.png"
    #     )
    
    # Create and save grid of images
    grid = torchvision.utils.make_grid(samples, nrow=16, normalize=True)
    torchvision.utils.save_image(grid, save_path / "WGANGP_MNIST_grid.png")
       
    print(f"Images saved to {save_path}")
    return save_path

def generate_samples(checkpoint_path, num_samples=256, batch_size=32):
    # Load trained model
    model = GAN.load_from_checkpoint(checkpoint_path)
    
    # Create predictor with correct latent_dim
    predictor = GANPredictor(
        generator=model.generator,
        num_samples=num_samples,
        batch_size=batch_size,
        latent_dim=256  # Match your GAN's latent_dim
    )
    
    trainer = L.Trainer(
        accelerator='auto',
        logger=False,
        enable_checkpointing=False
    )
    
    predictions = trainer.predict(predictor)
    all_samples = torch.cat(predictions, dim=0)
    
    # Save and display images
    save_path = save_and_display_images(all_samples, "generated_images")
    
    return all_samples, save_path

In [12]:
samples, save_path = generate_samples(
        checkpoint_path="Weights/WGANGP_MNIST_final/WGANGP_final.ckpt",
        num_samples=256,
        batch_size=64
    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 4/4 [00:00<00:00, 49.38it/s]
Images saved to generated_images/20250119_145517
