In [1]:
import wandb
import lightning as L
import torch
import PIL
import os
import torchvision.transforms.v2 as v2
from PIL import Image
from lightning.pytorch.loggers import WandbLogger
import datetime
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


In [204]:
"""
Sources:
Lightning GAN implementation
https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

WGAN paper
https://arxiv.org/abs/1701.07875

StyleGAN2 paper
https://arxiv.org/abs/1912.04958



"""

'\nSources:\nLightning GAN implementation\nhttps://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html\n\nWGAN paper\nhttps://arxiv.org/abs/1701.07875\n\nStyleGAN2 paper\nhttps://arxiv.org/abs/1912.04958\n\n\n\n'

In [None]:
# Hyperparameters
BATCH_SIZE = 256
EPOCHS = 1000

In [None]:
# 47080269e7b1b5a51a89830cb24c495498237e77
# wandb.Api(api_key="47080269e7b1b5a51a89830cb24c495498237e77")
wandb.login()

wandb_logger = WandbLogger(project="JANGAN3")
wandb_logger.experiment.config["batch_size"] = BATCH_SIZE


In [207]:
class FFHQDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image


# 256 batch size, 128x128 images, 8 cpu cores for batches
class FFHQDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "archive", img_size: int = 128, batch_size: int = 256, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = v2.Compose([            
            # v2.CenterCrop(self.img_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32),
            v2.Resize(size=(64, 64)),
            v2.RandomHorizontalFlip(p=0.5),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def setup(self, stage=None):
        self.dataset = FFHQDataset(img_dir=self.data_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,)

In [208]:
# pick one of the architectures, comment out the other one

In [209]:
"""
Architecture 1

WGAN - Fully conected GAN with a better loss function
"""

class Generator(nn.Module):
    def __init__(self, img_shape = (3, 64, 64), latent_dim: int = 100):
        super(Generator, self).__init__()

        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            *block(1024, 2048),
            *block(2048, 4096),
            nn.Linear(4096, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], * self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.img_shape = img_shape

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.img_shape)), 2048),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

In [210]:
"""
Architecture 2

StyleGAN2 - Convolution based GAN
"""


'\nArchitecture 2\n\nStyleGAN2 - Convolution based GAN\n'

In [None]:
class GAN(L.LightningModule):
    def __init__(
            self,
            img_size: int = 64,
            latent_dim: int = 100,
            lr: float = 3e-3,
            b1: float = 0,
            b2: float = 0.999


            ):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization=False

        self.latent_dim = latent_dim
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.img_size = img_size

        self.generator = Generator(img_shape= (3, self.img_size, self.img_size), latent_dim=self.latent_dim)
        self.discriminator = Discriminator(img_shape= (3, self.img_size, self.img_size))

    def forward(self, z):
        return self.generator(z)
    

    def loss_Discriminator(self, real_img, gen_img):
        return -torch.mean(real_img) + torch.mean(gen_img)
    

    def loss_Generator(self, gen_img):
        return -torch.mean(gen_img)

    
    def gradient_penalty(self, critic, real_samples, fake_samples, device="cuda"):
        """
        Calculate the gradient penalty for WGAN-GP (Wasserstein GAN with gradient penalty).
        
        Args:
            critic (nn.Module): The critic network
            real_samples (torch.Tensor): Batch of real samples
            fake_samples (torch.Tensor): Batch of generated samples
            device (str): Device to perform computations on
            
        Returns:
            torch.Tensor: Gradient penalty term (scalar)
        """
        # Random weight for interpolation between real and fake samples
        alpha = torch.rand((real_samples.size(0), 1, 1, 1)).to(device)
        
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        
        # Calculate critic scores for interpolated images
        d_interpolates = critic(interpolates)
        
        # Calculate gradients of scores with respect to interpolates
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates).to(device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        # Flatten gradients to easily calculate the norm
        gradients = gradients.view(gradients.size(0), -1)
        
        # Calculate gradient penalty
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
        return gradient_penalty

    def training_step(self, batch):
        imgs = batch
        optimizer_g, optimizer_d = self.optimizers()

        # train G
        self.toggle_optimizer(optimizer_g)
        
        # Generate images for generator training
        z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_g = z_g.type_as(imgs)
        fake_imgs_g = self(z_g)
        
        # Log images
        if self.current_epoch % 10 == 0:
            wandb.log({"generated_images": [wandb.Image(fake_imgs_g) for fake_imgs_g in fake_imgs_g[:25]],
                       "real_images": [wandb.Image(imgs) for imgs in imgs[:25]]})


        g_loss = self.loss_Generator(self.discriminator(fake_imgs_g))
        self.log("g_loss", g_loss)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        # train D
        self.toggle_optimizer(optimizer_d)

        # Generate new images for discriminator training
        z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_d = z_d.type_as(imgs)
        fake_imgs_d = self(z_d)

        real_score = self.discriminator(imgs)
        fake_score = self.discriminator(fake_imgs_d)
        
        gp = self.gradient_penalty(self.discriminator, imgs, fake_imgs_d)
        loss_D = self.loss_Discriminator(real_score, fake_score) + 10 * gp
        self.log("d_loss", loss_D)
        self.manual_backward(loss_D)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

        
    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):

        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

In [212]:
# Training loop

In [213]:
# Trainer settings
model = GAN()

dm = FFHQDataModule()

trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=EPOCHS,
    accelerator="gpu",
    enable_checkpointing=True
    
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# Run
trainer.fit(model, dm)

/home/fil/miniconda3/envs/ML/lib/python3.12/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params | Mode 
--------------------------------------------------------
0 | generator     | Generator     | 61.6 M | train
1 | discriminator | Discriminator | 27.9 M | train
--------------------------------------------------------
89.5 M    Trainable params
0         Non-trainable params
89.5 M    Total params
357.908   Total estimated model params size (MB)
32        Modules in train mode
0         Modules in eval mode


Epoch 0:   0%|          | 0/274 [11:33<?, ?it/s]6.01it/s, v_num=43mt]
Epoch 0:   0%|          | 0/274 [11:22<?, ?it/s]
Epoch 0:   0%|          | 0/274 [03:41<?, ?it/s]
Epoch 0:   7%|▋         | 19/274 [00:03<00:46,  5.50it/s, v_num=43mt]

  lambda data: self._console_raw_callback("stdout", data),
  lambda data: self._console_raw_callback("stdout", data),
  lambda data: self._console_raw_callback("stdout", data),


Epoch 20:  47%|████▋     | 130/274 [00:17<00:19,  7.37it/s, v_num=43mt]

In [2]:
wandb.finish()