In [1]:
import wandb
import lightning as L
import torch
import PIL
import os
import torchvision.transforms.v2 as v2
from PIL import Image
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
"""
Sources:
Lightning GAN implementation
https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

WGAN paper
https://arxiv.org/abs/1701.07875

WGAN with gradient penalty paper
https://arxiv.org/abs/1704.00028

GAN implementations
https://github.com/eriklindernoren/PyTorch-GAN/tree/master

"""

'\nSources:\nLightning GAN implementation\nhttps://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html\n\nWGAN paper\nhttps://arxiv.org/abs/1701.07875\n\nWGAN with gradient penalty paper\nhttps://arxiv.org/abs/1704.00028\n\nGAN implementations\nhttps://github.com/eriklindernoren/PyTorch-GAN/tree/master\n\n'

In [3]:
# Hyperparameters
BATCH_SIZE = 16
EPOCHS = 2000



In [4]:
os.environ["WANDB_API_KEY"] = "47080269e7b1b5a51a89830cb24c495498237e77"
wandb.login()
wandb_logger = WandbLogger(project="JAN_WDCGAN_FFHQ")
wandb_logger.experiment.config["batch_size"] = BATCH_SIZE

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
class FFHQDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

# 256 batch size, 128x128 images, 8 cpu cores for batches
class FFHQDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "archive", img_size: int = 128, batch_size: int = 16, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = v2.Compose([            
            # v2.CenterCrop(self.img_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32),
            v2.Resize(size=(64, 64)),
            v2.RandomHorizontalFlip(p=0.5),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def setup(self, stage=None):
        self.dataset = FFHQDataset(img_dir=self.data_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True)

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 64 // 4

        self.latent_dim = 512

        self.l1 = nn.Sequential(nn.Linear(self.latent_dim, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = 64 // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [7]:
class GAN(L.LightningModule):
    def __init__(
            self,
            img_size: int = 64,
            latent_dim: int = 256,
            lr: float = 1e-5,
            b1: float = 0.5,
            b2: float = 0.999,
            n_crtic: int = 5

            ):
        super().__init__()
        # This is partially for wandb logging
        self.save_hyperparameters()
        # Maintain automatic optimization off for manual control
        self.automatic_optimization = False

        self.latent_dim = latent_dim
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.img_size = img_size
        self.n_critic = n_crtic


        self.generator = Generator()
        self.discriminator = Discriminator()
        
        # Add a binary cross-entropy loss for classification
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, z):
        return self.generator(z)

    def loss_Discriminator(self, real_imgs, fake_imgs):
        """
        Compute cross-entropy loss for the discriminator.
        
        Args:
            real_imgs (torch.Tensor): Batch of real images
            fake_imgs (torch.Tensor): Batch of generated images
        
        Returns:
            torch.Tensor: Discriminator loss
        """
        # Create labels (0 for fake, 1 for real)
        real_labels = torch.ones(real_imgs.size(0), 1).type_as(real_imgs)
        fake_labels = torch.zeros(fake_imgs.size(0), 1).type_as(fake_imgs)
        
        # Get discriminator scores
        real_scores = self.discriminator(real_imgs)
        fake_scores = self.discriminator(fake_imgs)
        
        # Compute cross-entropy loss
        real_loss = self.criterion(real_scores, real_labels)
        fake_loss = self.criterion(fake_scores, fake_labels)
        
        # Total discriminator loss
        return real_loss + fake_loss

    def loss_Generator(self, fake_imgs):
        """
        Compute generator loss using cross-entropy.
        
        Args:
            fake_imgs (torch.Tensor): Batch of generated images
        
        Returns:
            torch.Tensor: Generator loss
        """
        # Generator tries to fool the discriminator
        # So we use labels as if fake images were real
        labels = torch.ones(fake_imgs.size(0), 1).type_as(fake_imgs)
        
        # Get discriminator scores for fake images
        fake_scores = self.discriminator(fake_imgs)
        
        # Compute loss (generator wants discriminator to classify fake images as real)
        return self.criterion(fake_scores, labels)

    def training_step(self, batch, batch_idx):
        imgs = batch
        optimizer_g, optimizer_d = self.optimizers()
        


        self.toggle_optimizer(optimizer_d)
        
        # Generate new images for discriminator training
        z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_d = z_d.type_as(imgs)
        fake_imgs_d = self(z_d)

        # Calculate discriminator loss
        loss_D = self.loss_Discriminator(imgs, fake_imgs_d)
        
        self.log("d_loss", loss_D)
        self.manual_backward(loss_D)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)


        for _ in range(self.n_critic):
            self.toggle_optimizer(optimizer_g)
            
            # Generate images for generator training
            z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
            z_g = z_g.type_as(imgs)
            fake_imgs_g = self(z_g)
            # Calculate generator loss
            g_loss = self.loss_Generator(fake_imgs_g)
            
            self.log("g_loss", g_loss)
            self.manual_backward(g_loss)
            optimizer_g.step()
            optimizer_g.zero_grad()
            self.untoggle_optimizer(optimizer_g)

        if batch_idx % 4000 == 0:
            wandb.log({"generated_images": [wandb.Image(fake_img) for fake_img in fake_imgs_g[:20]]})

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))

        # sch_g = torch.optim.lr_scheduler.ExponentialLR(opt_g, gamma=0.99)
        # sch_d = torch.optim.lr_scheduler.ExponentialLR(opt_d, gamma=0.99)
        return [opt_g, opt_d], []


In [8]:
# Trainer settings
model = GAN(latent_dim=512)

dm = FFHQDataModule(batch_size=BATCH_SIZE)

torch.set_float32_matmul_precision("highest")

# could try reduced precision but I had problems with it earlier
trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=EPOCHS,
    accelerator="gpu",
    enable_checkpointing=True,
    precision="32-true"
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
# Run
trainer.fit(model, train_dataloaders=dm)

You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params | Mode 
------------------------------------------------------------
0 | generator     | Generator         | 17.0 M | train
1 | discriminator | Discriminator     | 99.9 K | train
2 | criterion     | BCEWithLogitsLoss | 0      | train
------------------------------------------------------------
17.1 M    Trainable params
0         Non-trainable params
17.1 M    Total params
68.535    Total estimated model params size (MB)
36        Modules in train mode
0         Modules in eval mode


Epoch 15:  23%|██▎       | 1009/4375 [00:56<03:07, 17.98it/s, v_num=64b5]

[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.



Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [None]:
wandb.finish()