In [42]:
import wandb
import lightning as L
import torch
import PIL
import os
import torchvision.transforms.v2 as v2
from PIL import Image
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [43]:
"""
Sources:
Lightning GAN implementation
https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

WGAN paper
https://arxiv.org/abs/1701.07875

WGAN with gradient penalty paper
https://arxiv.org/abs/1704.00028

GAN implementations
https://github.com/eriklindernoren/PyTorch-GAN/tree/master

"""

'\nSources:\nLightning GAN implementation\nhttps://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html\n\nWGAN paper\nhttps://arxiv.org/abs/1701.07875\n\nWGAN with gradient penalty paper\nhttps://arxiv.org/abs/1704.00028\n\nGAN implementations\nhttps://github.com/eriklindernoren/PyTorch-GAN/tree/master\n\n'

In [44]:
# Hyperparameters
BATCH_SIZE = 32
EPOCHS = 2000

In [45]:
os.environ["WANDB_API_KEY"] = "47080269e7b1b5a51a89830cb24c495498237e77"
wandb.login()
wandb_logger = WandbLogger(project="GAN2")
wandb_logger.experiment.config["batch_size"] = BATCH_SIZE

In [46]:
class FFHQDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

# 256 batch size, 128x128 images, 8 cpu cores for batches
class FFHQDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "cats/Data", img_size: int = 128, batch_size: int = 16, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = v2.Compose([            
            # v2.CenterCrop(self.img_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32),
            v2.Resize(size=(64, 64)),
            v2.RandomHorizontalFlip(p=0.5),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def setup(self, stage=None):
        self.dataset = FFHQDataset(img_dir=self.data_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True)

In [47]:
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 64 // 4
        self.latent_dim = 1024

        self.l1 = torch.nn.Sequential(torch.nn.Linear(self.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = torch.nn.Sequential(
            torch.nn.BatchNorm2d(128),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(128, 0.8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Upsample(scale_factor=2),
            torch.nn.Conv2d(128, 64, 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64, 0.8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Conv2d(64, 3, 3, stride=1, padding=1),
            torch.nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [torch.nn.Conv2d(in_filters, out_filters, 3, 2, 1), 
                    torch.nn.LeakyReLU(0.2, inplace=True), 
                    torch.nn.Dropout2d(0.25)]
            if bn:
                block.append(torch.nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = torch.nn.Sequential(
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        ds_size = 64 // 2 ** 4
        self.adv_layer = torch.nn.Sequential(
            torch.nn.Linear(128 * ds_size ** 2, 1), 
            torch.nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)
        return validity

In [48]:
# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()
        
#         # Increased initial size remains same to maintain 64x64 output
#         self.init_size = 64 // 4  # Starting with 16x16
        
#         # Increased latent dimension for more expressive generation
#         self.latent_dim = 1024
        
#         # Increased number of features in initial linear layer
#         self.l1 = nn.Sequential(
#             nn.Linear(self.latent_dim, 256 * self.init_size ** 2),
#             nn.LeakyReLU(0.2, inplace=True)
#         )

#         # Enhanced convolution blocks with more features and additional refinement layers
#         self.conv_blocks = nn.Sequential(
#             nn.BatchNorm2d(256),
            
#             # First upsampling block (16x16 -> 32x32)
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(256, 256, 3, stride=1, padding=1),
#             nn.BatchNorm2d(256, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
#             # Additional refinement layer
#             nn.Conv2d(256, 256, 3, stride=1, padding=1),
#             nn.BatchNorm2d(256, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             # Second upsampling block (32x32 -> 64x64)
#             nn.Upsample(scale_factor=2),
#             nn.Conv2d(256, 128, 3, stride=1, padding=1),
#             nn.BatchNorm2d(128, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
#             # Additional refinement layer
#             nn.Conv2d(128, 128, 3, stride=1, padding=1),
#             nn.BatchNorm2d(128, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             # Final refinement and output layers
#             nn.Conv2d(128, 64, 3, stride=1, padding=1),
#             nn.BatchNorm2d(64, 0.8),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Conv2d(64, 3, 3, stride=1, padding=1),
#             nn.Tanh()
#         )

#     def forward(self, z):
#         out = self.l1(z)
#         out = out.view(out.shape[0], 256, self.init_size, self.init_size)
#         img = self.conv_blocks(out)
#         return img


# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()

#         def discriminator_block(in_filters, out_filters, bn=True, kernel_size=3):
#             block = [
#                 nn.Conv2d(in_filters, out_filters, kernel_size, 2, kernel_size//2),
#                 nn.LeakyReLU(0.2, inplace=True),
#                 nn.Dropout2d(0.25)
#             ]
#             if bn:
#                 block.append(nn.BatchNorm2d(out_filters, 0.8))
#             return block

#         # Enhanced discriminator with more features and additional layers
#         self.model = nn.Sequential(
#             # Initial feature extraction (64x64 -> 32x32)
#             *discriminator_block(3, 32, bn=False),
#             nn.Conv2d(32, 32, 3, 1, 1),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             # Second block (32x32 -> 16x16)
#             *discriminator_block(32, 64),
#             nn.Conv2d(64, 64, 3, 1, 1),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             # Third block (16x16 -> 8x8)
#             *discriminator_block(64, 128),
#             nn.Conv2d(128, 128, 3, 1, 1),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             # Fourth block (8x8 -> 4x4)
#             *discriminator_block(128, 256),
#             nn.Conv2d(256, 256, 3, 1, 1),
#             nn.LeakyReLU(0.2, inplace=True),
#         )

#         # The height and width of downsampled image
#         ds_size = 64 // 2 ** 4  # Still 4x4 as before
        
#         # Enhanced final classification layers
#         self.adv_layer = nn.Sequential(
#             nn.Linear(256 * ds_size ** 2, 512),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(512, 1),
#             nn.Sigmoid()
#         )

#     def forward(self, img):
#         out = self.model(img)
#         out = out.view(out.shape[0], -1)
#         validity = self.adv_layer(out)
#         return validity

In [49]:
class GAN(L.LightningModule):
    def __init__(
            self,
            img_size: int = 64,
            latent_dim: int = 1024,
            lr: float = 5e-5,
            b1: float = 0.5,
            b2: float = 0.999,
            gamma: float = 0.999995,
            n_crtic: int = 5

            ):
        super().__init__()
        # This is partially for wandb logging
        self.save_hyperparameters()
        # Maintain automatic optimization off for manual control
        self.automatic_optimization = False

        self.latent_dim = latent_dim
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.img_size = img_size
        self.gamma = gamma
        self.n_critic = n_crtic



        self.generator = Generator()
        self.discriminator = Discriminator()
        
        # Add a binary cross-entropy loss for classification
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, z):
        return self.generator(z)

    def loss_Discriminator(self, real_imgs, fake_imgs):
        """
        Compute cross-entropy loss for the discriminator.
        
        Args:
            real_imgs (torch.Tensor): Batch of real images
            fake_imgs (torch.Tensor): Batch of generated images
        
        Returns:
            torch.Tensor: Discriminator loss
        """
        # Create labels (0 for fake, 1 for real)
        real_labels = torch.ones(real_imgs.size(0), 1).type_as(real_imgs)
        fake_labels = torch.zeros(fake_imgs.size(0), 1).type_as(fake_imgs)
        
        # Get discriminator scores
        real_scores = self.discriminator(real_imgs)
        fake_scores = self.discriminator(fake_imgs)
        
        # Compute cross-entropy loss
        real_loss = self.criterion(real_scores, real_labels)
        fake_loss = self.criterion(fake_scores, fake_labels)
        
        # Total discriminator loss
        return real_loss + fake_loss

    def loss_Generator(self, fake_imgs):
        """
        Compute generator loss using cross-entropy.
        
        Args:
            fake_imgs (torch.Tensor): Batch of generated images
        
        Returns:
            torch.Tensor: Generator loss
        """
        # Generator tries to fool the discriminator
        # So we use labels as if fake images were real
        labels = torch.ones(fake_imgs.size(0), 1).type_as(fake_imgs)
        
        # Get discriminator scores for fake images
        fake_scores = self.discriminator(fake_imgs)
        
        # Compute loss (generator wants discriminator to classify fake images as real)
        return self.criterion(fake_scores, labels)

    def training_step(self, batch, batch_idx):
        imgs = batch
        optimizer_g, optimizer_d = self.optimizers()
        scheduler_g, scheduler_d = self.configure_schedulers(opt_g=optimizer_g, opt_d=optimizer_d)


        current_lr_g = optimizer_g.param_groups[0]['lr']
        current_lr_d = optimizer_d.param_groups[0]['lr']
        
        # Log learning rates using both wandb and Lightning's logger
        self.log('lr_generator', current_lr_g)
        self.log('lr_discriminator', current_lr_d)

        # self.toggle_optimizer(optimizer_g)
        # for _ in range(self.n_critic):
           
        #     # Generate images for generator training
        #     # for _ in range(self.n_critic):
        #     z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        #     z_g = z_g.type_as(imgs)
        #     fake_imgs_g = self(z_g)
        #     # Calculate generator loss
        #     g_loss = self.loss_Generator(fake_imgs_g)
            
        #     self.log("g_loss", g_loss)
        #     self.manual_backward(g_loss)
        #     optimizer_g.step()

        #     optimizer_g.zero_grad()

        # scheduler_g.step()
        # self.untoggle_optimizer(optimizer_g)
        for _ in range(self.n_critic):
            self.toggle_optimizer(optimizer_d)
            
            # Generate new images for discriminator training
            z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim)
            z_d = z_d.type_as(imgs)
            fake_imgs_d = self(z_d)

            # Calculate discriminator loss
            loss_D = self.loss_Discriminator(imgs, fake_imgs_d)
            
            self.log("d_loss", loss_D)
            self.manual_backward(loss_D)
            optimizer_d.step()
            scheduler_d.step()
            optimizer_d.zero_grad()
            self.untoggle_optimizer(optimizer_d)


        # for _ in range(self.n_critic):
        self.toggle_optimizer(optimizer_g)
        # for _ in range(self.n_critic):        

        # Generate images for generator training
        z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_g = z_g.type_as(imgs)
        fake_imgs_g = self(z_g)
        # Calculate generator loss
        g_loss = self.loss_Generator(fake_imgs_g)
        
        self.log("g_loss", g_loss)
        self.manual_backward(g_loss)
        optimizer_g.step()
        scheduler_g.step()
        optimizer_g.zero_grad()
        scheduler_g.step()    
        self.untoggle_optimizer(optimizer_g)

        if batch_idx % 900 == 0:
            wandb.log({"generated_images": [wandb.Image(fake_img) for fake_img in fake_imgs_g[:20]]})

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr/10, betas=(b1, b2))

        return [opt_g, opt_d], []
    
    def configure_schedulers(self, opt_g, opt_d):
        gamma = self.hparams.gamma


        sch_g = torch.optim.lr_scheduler.ExponentialLR(opt_g, gamma=gamma)
        sch_d = torch.optim.lr_scheduler.ExponentialLR(opt_d, gamma=gamma)

        return [sch_g, sch_d]


In [50]:
# Trainer settings
model = GAN()

# model.eval()

dm = FFHQDataModule(batch_size=BATCH_SIZE)

torch.set_float32_matmul_precision("highest")

# could try reduced precision but I had problems with it earlier
trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=EPOCHS,
    accelerator="gpu",
    enable_checkpointing=True,
    precision="32-true"
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [51]:
# # Run
trainer.fit(model, train_dataloaders=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params | Mode 
------------------------------------------------------------
0 | generator     | Generator         | 33.8 M | train
1 | discriminator | Discriminator     | 99.9 K | train
2 | criterion     | BCEWithLogitsLoss | 0      | train
------------------------------------------------------------
33.9 M    Trainable params
0         Non-trainable params
33.9 M    Total params
135.644   Total estimated model params size (MB)
36        Modules in train mode
0         Modules in eval mode


Epoch 0:   3%|▎         | 29/933 [17:04<8:52:02,  0.03it/s, v_num=493x]
Epoch 10:  97%|█████████▋| 903/933 [00:59<00:01, 15.28it/s, v_num=jlc5]

  lambda data: self._console_raw_callback("stdout", data),


Epoch 21:  32%|███▏      | 294/933 [00:21<00:46, 13.71it/s, v_num=jlc5]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [52]:
wandb.finish()

0,1
d_loss,█▇▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
g_loss,▁▁██████████████████████████████████████
lr_discriminator,█████▇▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▅▄▃▃▃▃▃▃▂▂▂▂▂▁▁▁
lr_generator,████▇▇▇▇▇▆▆▆▆▆▅▅▅▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▁▁▁▁
trainer/global_step,▁▁▁▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████

0,1
d_loss,1.00641
epoch,21.0
g_loss,0.69315
lr_discriminator,0.0
lr_generator,4e-05
trainer/global_step,19849.0
