In [10]:
import wandb
import lightning as L
import torch
import PIL
import os
import torchvision.transforms.v2 as v2
from PIL import Image
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision.datasets import MNIST

In [11]:
"""
Sources:
Lightning GAN implementation
https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

WGAN paper
https://arxiv.org/abs/1701.07875

WGAN with gradient penalty paper
https://arxiv.org/abs/1704.00028

GAN implementations
https://github.com/eriklindernoren/PyTorch-GAN/tree/master

"""

'\nSources:\nLightning GAN implementation\nhttps://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html\n\nWGAN paper\nhttps://arxiv.org/abs/1701.07875\n\nWGAN with gradient penalty paper\nhttps://arxiv.org/abs/1704.00028\n\nGAN implementations\nhttps://github.com/eriklindernoren/PyTorch-GAN/tree/master\n\n'

In [12]:
# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 1000

In [13]:
os.environ["WANDB_API_KEY"] = "47080269e7b1b5a51a89830cb24c495498237e77"
wandb.login()
wandb_logger = WandbLogger(project="JANGAN2")
wandb_logger.experiment.config["batch_size"] = BATCH_SIZE


/home/fil/miniconda3/envs/ML/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.


In [14]:
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")

NUM_WORKERS = int(os.cpu_count())

class MNISTDataModule(L.LightningDataModule):
    def __init__(
        self,
        data_dir: str = PATH_DATASETS,
        batch_size: int = BATCH_SIZE,
        num_workers: int = NUM_WORKERS,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = v2.Compose(
            [
                v2.ToTensor(),
                v2.Normalize((0.1307,), (0.3081,)),
            ]
        )

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

In [15]:
class Generator(nn.Module):
    def __init__(self, img_shape=(1, 28, 28), latent_dim: int = 256):
        super(Generator, self).__init__()

        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            # Start from latent_dim=512, no initial bottleneck
            *block(latent_dim, 1024, normalize=False),  # Expand immediately
            *block(1024, 2048),
            *block(2048, 4096),
            nn.Linear(4096, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.img_shape = img_shape

        self.model = nn.Sequential(
            # More gradual compression for better feature extraction
            nn.Linear(int(np.prod(self.img_shape)), 4096),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(4096, 2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

In [16]:
class GAN(L.LightningModule):
    def __init__(
            self,
            img_size: int = 28,
            latent_dim: int = 256,
            lr: float = 3e-4,
            b1: float = 0,
            b2: float = 0.999,
            n_critic: int = 5  
            ):
        super().__init__()
        # This is partialy for wandb logging
        self.save_hyperparameters()
        # This is important
        self.automatic_optimization = False

        self.latent_dim = latent_dim
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.img_size = img_size
        self.n_critic = n_critic

        self.generator = Generator(img_shape=(1, self.img_size, self.img_size), latent_dim=self.latent_dim)
        self.discriminator = Discriminator(img_shape=(1, self.img_size, self.img_size))

    def forward(self, z):
        return self.generator(z)

    def loss_Discriminator(self, real_img, gen_img):
        return -torch.mean(real_img) + torch.mean(gen_img)

    def loss_Generator(self, gen_img):
        return -torch.mean(gen_img)

    def gradient_penalty(self, critic, real_samples, fake_samples):
        """
        Calculate the gradient penalty for WGAN-GP (Wasserstein GAN with gradient penalty).
        
        Args:
            critic (nn.Module): The critic network
            real_samples (torch.Tensor): Batch of real samples
            fake_samples (torch.Tensor): Batch of generated samples
            
        Returns:
            torch.Tensor: Gradient penalty term (scalar)
        """
        # Random weight for interpolation between real and fake samples
        alpha = torch.rand((real_samples.size(0), 1, 1, 1)).type_as(real_samples)
        
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        
        # Calculate critic scores for interpolated images
        d_interpolates = critic(interpolates)
        
        # Calculate gradients of scores with respect to interpolates
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            # this is equivalent to .to(device) but instead creates the tensor on the device that the images are on
            grad_outputs=torch.ones_like(d_interpolates).type_as(d_interpolates),
            create_graph=True,
            # this is important
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        # Flatten gradients to easily calculate the norm
        gradients = gradients.view(gradients.size(0), -1)
        
        # Calculate gradient penalty
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
        return gradient_penalty

    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        optimizer_g, optimizer_d = self.optimizers()

        # Train discriminator n_critic times
        for _ in range(self.n_critic):
            self.toggle_optimizer(optimizer_d)
            
            # Generate new images for discriminator training
            z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim)
            z_d = z_d.type_as(imgs)
            fake_imgs_d = self(z_d)

            real_score = self.discriminator(imgs)
            fake_score = self.discriminator(fake_imgs_d)
            
            gp = self.gradient_penalty(self.discriminator, imgs, fake_imgs_d)
            loss_D = self.loss_Discriminator(real_score, fake_score) + 10 * gp
            
            # Calculate and log Wasserstein distance
            # The negative of the discriminator loss (before gradient penalty) is an estimate of the Wasserstein distance
            wasserstein_distance = torch.mean(real_score) - torch.mean(fake_score)
            self.log("wasserstein_distance", wasserstein_distance)
            
            self.log("d_loss", loss_D)
            self.manual_backward(loss_D)
            optimizer_d.step()
            optimizer_d.zero_grad()
            self.untoggle_optimizer(optimizer_d)

        # Train generator once every n_critic iterations
        self.toggle_optimizer(optimizer_g)
        
        # Generate images for generator training
        z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_g = z_g.type_as(imgs)
        fake_imgs_g = self(z_g)
        
        # Log images periodically
        wandb.log({"generated_images": [wandb.Image(fake_img) for fake_img in fake_imgs_g[:1]]})

        g_loss = self.loss_Generator(self.discriminator(fake_imgs_g))
        self.log("g_loss", g_loss)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr/3, betas=(b1, b2))
        return [opt_g, opt_d], []

In [17]:
# Trainer settings
model = GAN()

dm = MNISTDataModule()

# could try reduced precision but I had problems with it earlier
trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=EPOCHS,
    accelerator="gpu",
    enable_checkpointing=True   
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [18]:
# Run
trainer.fit(model, train_dataloaders=dm)

Restoring states from the checkpoint path at Weights/WGANGP_MNIST_final/epoch=570-step=2498696.ckpt


RuntimeError: Error(s) in loading state_dict for GAN:
	Missing key(s) in state_dict: "generator.model.0.weight", "generator.model.0.bias", "generator.model.2.weight", "generator.model.2.bias", "generator.model.3.weight", "generator.model.3.bias", "generator.model.3.running_mean", "generator.model.3.running_var", "generator.model.5.weight", "generator.model.5.bias", "generator.model.6.weight", "generator.model.6.bias", "generator.model.6.running_mean", "generator.model.6.running_var", "generator.model.8.weight", "generator.model.8.bias", "discriminator.model.2.weight", "discriminator.model.2.bias", "discriminator.model.4.weight", "discriminator.model.4.bias", "discriminator.model.8.weight", "discriminator.model.8.bias". 
	Unexpected key(s) in state_dict: "generator.l1.0.weight", "generator.l1.0.bias", "generator.conv_blocks.0.weight", "generator.conv_blocks.0.bias", "generator.conv_blocks.0.running_mean", "generator.conv_blocks.0.running_var", "generator.conv_blocks.0.num_batches_tracked", "generator.conv_blocks.2.weight", "generator.conv_blocks.2.bias", "generator.conv_blocks.3.weight", "generator.conv_blocks.3.bias", "generator.conv_blocks.3.running_mean", "generator.conv_blocks.3.running_var", "generator.conv_blocks.3.num_batches_tracked", "generator.conv_blocks.6.weight", "generator.conv_blocks.6.bias", "generator.conv_blocks.7.weight", "generator.conv_blocks.7.bias", "generator.conv_blocks.7.running_mean", "generator.conv_blocks.7.running_var", "generator.conv_blocks.7.num_batches_tracked", "generator.conv_blocks.9.weight", "generator.conv_blocks.9.bias", "discriminator.adv_layer.0.weight", "discriminator.adv_layer.0.bias", "discriminator.model.10.weight", "discriminator.model.10.bias", "discriminator.model.10.running_mean", "discriminator.model.10.running_var", "discriminator.model.10.num_batches_tracked", "discriminator.model.11.weight", "discriminator.model.11.bias", "discriminator.model.14.weight", "discriminator.model.14.bias", "discriminator.model.14.running_mean", "discriminator.model.14.running_var", "discriminator.model.14.num_batches_tracked", "discriminator.model.3.weight", "discriminator.model.3.bias", "discriminator.model.6.running_mean", "discriminator.model.6.running_var", "discriminator.model.6.num_batches_tracked", "discriminator.model.7.weight", "discriminator.model.7.bias". 
	size mismatch for discriminator.model.0.weight: copying a param with shape torch.Size([16, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([4096, 784]).
	size mismatch for discriminator.model.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([4096]).
	size mismatch for discriminator.model.6.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([512, 1024]).
	size mismatch for discriminator.model.6.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([512]).

In [None]:
wandb.finish()