In [None]:
import wandb
import lightning as L
import torch
import PIL
import os
import torchvision.transforms.v2 as v2
from PIL import Image
from lightning.pytorch.loggers import WandbLogger
import datetime
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split



In [None]:
"""
Sources:
Lightning GAN implementation
https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

WGAN paper
https://arxiv.org/abs/1701.07875

StyleGAN2 paper
https://arxiv.org/abs/1912.04958



"""

In [None]:
# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 1000

In [None]:
wandb.login()

In [None]:
# 47080269e7b1b5a51a89830cb24c495498237e77
# wandb.Api(api_key="47080269e7b1b5a51a89830cb24c495498237e77")


wandb_logger = WandbLogger(project="JANGAN3")
wandb_logger.experiment.config["batch_size"] = BATCH_SIZE


In [None]:
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")

NUM_WORKERS = int(os.cpu_count() / 2)

class MNISTDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str = PATH_DATASETS,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [None]:
# pick one of the architectures, comment out the other one

In [None]:
"""
Architecture 1

WGAN - Fully conected GAN with a better loss function
"""

class Generator(nn.Module):
    def __init__(self, img_shape=(1, 28, 28), latent_dim: int = 256):
        super(Generator, self).__init__()

        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            # Start from latent_dim=512, no initial bottleneck
            *block(latent_dim, 1024, normalize=False),  # Expand immediately
            *block(1024, 2048),
            *block(2048, 4096),
            nn.Linear(4096, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape(0), *self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.img_shape = img_shape

        self.model = nn.Sequential(
            # More gradual compression for better feature extraction
            nn.Linear(int(np.prod(self.img_shape)), 4096),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(4096, 2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape(0), -1)
        validity = self.model(img_flat)
        return validity

In [None]:
"""
Architecture 2

StyleGAN2 - Convolution based GAN
"""


In [None]:
class GAN(L.LightningModule):
    def __init__(
            self,
            img_size: int = 28,
            latent_dim: int = 256,
            lr: float = 3e-4,
            b1: float = 0,
            b2: float = 0.999,
            n_critic: int = 5  # Adding n_critic parameter
            ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        self.latent_dim = latent_dim
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.img_size = img_size
        self.n_critic = n_critic  # Store n_critic as class attribute

        self.generator = Generator(img_shape=(1, self.img_size, self.img_size), latent_dim=self.latent_dim)
        self.discriminator = Discriminator(img_shape=(1, self.img_size, self.img_size))

    def forward(self, z):
        return self.generator(z)

    def loss_Discriminator(self, real_img, gen_img):
        return -torch.mean(real_img) + torch.mean(gen_img)

    def loss_Generator(self, gen_img):
        return -torch.mean(gen_img)

    def gradient_penalty(self, critic, real_samples, fake_samples, device="cuda"):
        """
        Calculate the gradient penalty for WGAN-GP (Wasserstein GAN with gradient penalty).
        
        Args:
            critic (nn.Module): The critic network
            real_samples (torch.Tensor): Batch of real samples
            fake_samples (torch.Tensor): Batch of generated samples
            device (str): Device to perform computations on
            
        Returns:
            torch.Tensor: Gradient penalty term (scalar)
        """
        # Random weight for interpolation between real and fake samples
        alpha = torch.rand((real_samples.size(0), 1, 1, 1)).to(device)
        
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        
        # Calculate critic scores for interpolated images
        d_interpolates = critic(interpolates)
        
        # Calculate gradients of scores with respect to interpolates
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates).to(device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        # Flatten gradients to easily calculate the norm
        gradients = gradients.view(gradients.size(0), -1)
        
        # Calculate gradient penalty
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
        return gradient_penalty

    def training_step(self, batch, batch_idx):
        imgs = batch
        optimizer_g, optimizer_d = self.optimizers()

        # Train discriminator n_critic times
        for _ in range(self.n_critic):
            self.toggle_optimizer(optimizer_d)
            
            # Generate new images for discriminator training
            z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim)
            z_d = z_d.type_as(imgs)
            fake_imgs_d = self(z_d)

            real_score = self.discriminator(imgs)
            fake_score = self.discriminator(fake_imgs_d)
            
            gp = self.gradient_penalty(self.discriminator, imgs, fake_imgs_d)
            loss_D = self.loss_Discriminator(real_score, fake_score) + 10 * gp
            
            # Calculate and log Wasserstein distance
            # The negative of the discriminator loss (before gradient penalty) is an estimate of the Wasserstein distance
            wasserstein_distance = torch.mean(real_score) - torch.mean(fake_score)
            self.log("wasserstein_distance", wasserstein_distance)
            
            self.log("d_loss", loss_D)
            self.manual_backward(loss_D)
            optimizer_d.step()
            optimizer_d.zero_grad()
            self.untoggle_optimizer(optimizer_d)

        # Train generator once every n_critic iterations
        self.toggle_optimizer(optimizer_g)
        
        # Generate images for generator training
        z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_g = z_g.type_as(imgs)
        fake_imgs_g = self(z_g)
        
        # Log images periodically
        if self.current_epoch % 5 == 0:
            wandb.log({
                "generated_images": [wandb.Image(fake_img) for fake_img in fake_imgs_g[:25]],
                "real_images": [wandb.Image(img) for img in imgs[:25]]
            })

        g_loss = self.loss_Generator(self.discriminator(fake_imgs_g))
        self.log("g_loss", g_loss)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

In [None]:
# Training loop

In [None]:
# Trainer settings
model = GAN(latent_dim=256)

dm = MNISTDataModule()

trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=EPOCHS,
    accelerator="gpu",
    enable_checkpointing=True
    
)

In [None]:
# Run
# trainer.fit(model, dm)
trainer.fit(model, train_dataloaders=dm)

In [None]:
wandb.finish()