In [None]:
import wandb
import lightning as L
import torch
import PIL
import os
import torchvision.transforms.v2 as v2
from PIL import Image
from lightning.pytorch.loggers import WandbLogger
import datetime
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.utils import spectral_norm


In [None]:
"""
Sources:
Lightning GAN implementation
https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

WGAN paper
https://arxiv.org/abs/1701.07875

StyleGAN2 paper
https://arxiv.org/abs/1912.04958



"""

In [None]:
# Hyperparameters
BATCH_SIZE = 64
EPOCHS = 1000

In [None]:
# 47080269e7b1b5a51a89830cb24c495498237e77
# wandb.Api(api_key="47080269e7b1b5a51a89830cb24c495498237e77")
wandb.login()

wandb_logger = WandbLogger(project="JANGAN3")
wandb_logger.experiment.config["batch_size"] = BATCH_SIZE

In [None]:
class FFHQDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image


# 256 batch size, 128x128 images, 8 cpu cores for batches
class FFHQDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "archive", img_size: int = 128, batch_size: int = 256, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = v2.Compose([            
            # v2.CenterCrop(self.img_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32),
            v2.Resize(size=(64, 64)),
            v2.RandomHorizontalFlip(p=0.5),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def setup(self, stage=None):
        self.dataset = FFHQDataset(img_dir=self.data_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        
        # Initial dense layer to reshape latent vector
        # 4x4 is a good starting point for 64x64 images
        self.initial = nn.Linear(latent_dim, 512 * 4 * 4)
        
        # Main convolutional structure
        self.conv_blocks = nn.Sequential(
            # 4x4 -> 8x8
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # 8x8 -> 16x16
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # 16x16 -> 32x32
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # 32x32 -> 64x64
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

class Discriminator(nn.Module):
    def __init__(self, channels=3, features_d=64, num_classes=1):
        super(Discriminator, self).__init__()
        
        # Initial layer doesn't use batch norm to prevent artifacts from normalized inputs
        # Input: 64x64x3
        self.initial = spectral_norm(
            nn.Conv2d(channels, features_d, kernel_size=4, stride=2, padding=1)
        )
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        
        # Main convolutional blocks with increasing channel depth
        # Each block halves spatial dimensions
        self.conv_blocks = nn.ModuleList([
            # Block 1: 32x32x64 -> 16x16x128
            self._make_block(features_d, features_d * 2),
            
            # Block 2: 16x16x128 -> 8x8x256
            self._make_block(features_d * 2, features_d * 4),
            
            # Block 3: 8x8x256 -> 4x4x512
            self._make_block(features_d * 4, features_d * 8),
        ])
        
        # Self-attention layer after reaching 32x32 resolution
        # This helps capture global structure in faces
        self.attention = SelfAttention(features_d * 2)  # 128 channels
        
        # MiniBatchStandardDeviation layer for improved sample diversity
        self.minibatch_std = MiniBatchStdDev()
        
        # Final classification
        # Input: 4x4x512, Output: 1x1x1
        self.final = spectral_norm(
            nn.Conv2d(features_d * 8, num_classes, kernel_size=4, stride=1, padding=0)
        )
        
    def _make_block(self, in_channels, out_channels):
        """Helper function to create a convolutional block with spectral norm"""
        return nn.Sequential(
            spectral_norm(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=4,
                    stride=2,
                    padding=1,
                    bias=False
                )
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25)  # Adding dropout for regularization
        )
        
    def forward(self, x):
        # Initial convolution
        out = self.activation(self.initial(x))
        
        # Process through first conv block
        out = self.conv_blocks[0](out)
        
        # Apply self-attention after second conv block
        out = self.attention(out)
        
        # Process through remaining blocks
        for block in self.conv_blocks[1:]:
            out = block(out)
            
        # Apply minibatch discrimination
        out = self.minibatch_std(out)
        
        # Final classification
        out = self.final(out)
        
        return out.view(-1, 1).squeeze(1)


class SelfAttention(nn.Module):
    """Self-attention module for capturing global structure"""
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        # Reduced channel dimension for efficiency
        self.query = spectral_norm(nn.Conv2d(in_channels, in_channels // 8, 1))
        self.key = spectral_norm(nn.Conv2d(in_channels, in_channels // 8, 1))
        self.value = spectral_norm(nn.Conv2d(in_channels, in_channels, 1))
        self.gamma = nn.Parameter(torch.zeros(1))  # Learnable scaling parameter
        
    def forward(self, x):
        batch_size, channels, width, height = x.size()
        
        # Create query, key, value projections
        query = self.query(x).view(batch_size, -1, width * height)
        key = self.key(x).view(batch_size, -1, width * height)
        value = self.value(x).view(batch_size, -1, width * height)
        
        # Calculate attention scores
        attention = torch.bmm(query.permute(0, 2, 1), key)
        attention = torch.softmax(attention, dim=-1)
        
        # Apply attention to values
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, width, height)
        
        # Residual connection with learnable scaling
        return self.gamma * out + x


class MiniBatchStdDev(nn.Module):
    """Adds minibatch standard deviation feature map"""
    def __init__(self, group_size=4):
        super(MiniBatchStdDev, self).__init__()
        self.group_size = group_size
        
    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        # Calculate standard deviation over minibatch
        std = torch.std(x, dim=0, keepdim=True)
        mean_std = torch.mean(std)
        
        # Create new feature map filled with std value
        std_channel = torch.ones(batch_size, 1, height, width).to(x.device) * mean_std
        
        # Concatenate with input
        return torch.cat([x, std_channel], dim=1)

In [None]:
class GANv2(L.LightningModule):
    def __init__(self, 
                 latent_dim: int = 100,
                 img_size: int = 64,
                 lr: float = 3e-3,
                 b1: float = 0,
                 b2: float = 0.999):
        super().__init__()

        self.latent_dim = latent_dim
        self.img_size = img_size
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.save_hyperparameters()
        self.automatic_optimization=False

        self.generator = Generator(latent_dim=latent_dim)
        self.discriminator = Discriminator()

    def forward(self, z):
        return self.generator(z)
    

    def loss_Discriminator(self, real_img, gen_img):
        return -torch.mean(real_img) + torch.mean(gen_img)
    

    def loss_Generator(self, gen_img):
        return -torch.mean(gen_img)

    
    def gradient_penalty(self, critic, real_samples, fake_samples, device="cuda"):
        """
        Calculate the gradient penalty for WGAN-GP (Wasserstein GAN with gradient penalty).
        
        Args:
            critic (nn.Module): The critic network
            real_samples (torch.Tensor): Batch of real samples
            fake_samples (torch.Tensor): Batch of generated samples
            device (str): Device to perform computations on
            
        Returns:
            torch.Tensor: Gradient penalty term (scalar)
        """
        # Random weight for interpolation between real and fake samples
        alpha = torch.rand((real_samples.size(0), 1, 1, 1)).to(device)
        
        # Get random interpolation between real and fake samples
        interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        
        # Calculate critic scores for interpolated images
        d_interpolates = critic(interpolates)
        
        # Calculate gradients of scores with respect to interpolates
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates).to(device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        # Flatten gradients to easily calculate the norm
        gradients = gradients.view(gradients.size(0), -1)
        
        # Calculate gradient penalty
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
        return gradient_penalty
    

    def training_step(self, batch):
        imgs = batch
        optimizer_g, optimizer_d = self.optimizers()
        
        # Number of critic iterations per generator iteration
        n_critic = 5
        
        # Train discriminator
        self.toggle_optimizer(optimizer_d)
        
        total_d_loss = 0
        for _ in range(n_critic):
            optimizer_d.zero_grad()
            
            # Generate images for discriminator training
            z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim, device=self.device)
            fake_imgs_d = self(z_d)

            # Calculate discriminator outputs
            real_score = self.discriminator(imgs)
            fake_score = self.discriminator(fake_imgs_d.detach())  # Important: detach here
            
            # Calculate gradient penalty
            gp = self.gradient_penalty(self.discriminator, imgs, fake_imgs_d)
            
            # Discriminator loss
            d_loss = self.loss_Discriminator(real_score, fake_score) + 10 * gp
            
            self.manual_backward(d_loss)
            optimizer_d.step()
            
            # Log metrics
            self.log("d_loss", d_loss)
            self.log("gp", gp)
            self.log("real_score_mean", real_score.mean())
            self.log("fake_score_mean", fake_score.mean())
            wasserstein_distance = torch.mean(real_score) - torch.mean(fake_score)
            self.log("wasserstein_distance", wasserstein_distance.item())
            
            total_d_loss += d_loss.item()
        
        self.untoggle_optimizer(optimizer_d)
        
        # Train generator
        self.toggle_optimizer(optimizer_g)
        optimizer_g.zero_grad()
        
        # Generate new images for generator training
        z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim, device=self.device)
        fake_imgs_g = self(z_g)
        
        # Generator loss
        g_loss = self.loss_Generator(self.discriminator(fake_imgs_g))
        
        self.manual_backward(g_loss)
        optimizer_g.step()
        
        # Log generator metrics
        self.log("g_loss", g_loss)
        
        # Log images periodically
        if self.current_epoch % 10 == 0:
            self.log_images(fake_imgs_g[:25], imgs[:25])
        
        self.untoggle_optimizer(optimizer_g)

    def log_images(self, fake_imgs, real_imgs):
        """Separate method for image logging to keep code clean"""
        wandb.log({
            "generated_images": [wandb.Image(img) for img in fake_imgs],
            "real_images": [wandb.Image(img) for img in real_imgs]
        })

        
    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):

        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr/3, betas=(b1, b2))
        return [opt_g, opt_d], []



In [None]:
# Trainer settings
model = GANv2(latent_dim=512)

dm = FFHQDataModule()

trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=EPOCHS,
    accelerator="gpu",
    enable_checkpointing=True
    
)

In [None]:
# Run
trainer.fit(model, dm)

In [None]:
wandb.finish()