In [1]:
import wandb
import lightning as L
import torch
import PIL
import os
import torchvision.transforms.v2 as v2
from PIL import Image
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

torch.set_float32_matmul_precision("highest")

In [2]:
"""
Sources:
Lightning GAN implementation
https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html

WGAN paper
https://arxiv.org/abs/1701.07875

WGAN with gradient penalty paper
https://arxiv.org/abs/1704.00028

GAN implementations
https://github.com/eriklindernoren/PyTorch-GAN/tree/master

"""

'\nSources:\nLightning GAN implementation\nhttps://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html\n\nWGAN paper\nhttps://arxiv.org/abs/1701.07875\n\nWGAN with gradient penalty paper\nhttps://arxiv.org/abs/1704.00028\n\nGAN implementations\nhttps://github.com/eriklindernoren/PyTorch-GAN/tree/master\n\n'

In [3]:
# Hyperparameters
BATCH_SIZE = 32
EPOCHS = 2000

In [4]:
os.environ["WANDB_API_KEY"] = "47080269e7b1b5a51a89830cb24c495498237e77"
wandb.login()
wandb_logger = WandbLogger(project="JANGAN4")
wandb_logger.experiment.config["batch_size"] = BATCH_SIZE

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfilip-szczepanski[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
class FFHQDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.images = [f for f in os.listdir(img_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

# 256 batch size, 128x128 images, 8 cpu cores for batches
class FFHQDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "archive", img_size: int = 128, batch_size: int = 32, num_workers: int = 8):
        super().__init__()
        self.data_dir = data_dir
        self.img_size = img_size
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = v2.Compose([            
            # v2.CenterCrop(self.img_size),
            v2.ToImage(),
            v2.ToDtype(torch.float32),
            v2.Resize(size=(64, 64)),
            v2.RandomHorizontalFlip(p=0.5),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def setup(self, stage=None):
        self.dataset = FFHQDataset(img_dir=self.data_dir, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True)

In [6]:
# class Generator(nn.Module):
#     def __init__(self, img_shape=(3, 64, 64), latent_dim: int = 512):
#         super(Generator, self).__init__()

#         self.img_shape = img_shape

#         def block(in_feat, out_feat, normalize=True):
#             layers = [nn.Linear(in_feat, out_feat)]
#             if normalize:
#                 layers.append(nn.BatchNorm1d(out_feat, 0.8))
#             layers.append(nn.LeakyReLU(0.2, inplace=True))
#             return layers

#         self.model = nn.Sequential(
#             # Start from latent_dim=512, no initial bottleneck
#             *block(latent_dim, 1024, normalize=False),  # Expand immediately
#             *block(1024, 2048),
#             *block(2048, 4096),
#             *block(4096, 8192),
#             nn.Linear(8192, int(np.prod(img_shape))),
#             nn.Tanh()
#         )

#     def forward(self, z):
#         img = self.model(z)
#         img = img.view(img.shape[0], *self.img_shape)
#         return img


# class Discriminator(nn.Module):
#     def __init__(self, img_shape = (3, 64, 64)):
#         super(Discriminator, self).__init__()

#         self.img_shape = img_shape

#         self.model = nn.Sequential(
#             # More gradual compression for better feature extraction
#             nn.Linear(int(np.prod(self.img_shape)), 2048),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(2048, 1024),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(1024, 512),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(512, 1),
#         )

#     def forward(self, img):
#         img_flat = img.view(img.shape[0], -1)
#         validity = self.model(img_flat)
#         return validity

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_size = 128 // 4  # = 32
        self.l1 = nn.Sequential(nn.Linear(512, 128 * self.init_size ** 2))

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 1024, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(1024, 512, 3, stride=1, padding=1),
            nn.BatchNorm2d(512, 0.8),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),  # 32x32 -> 64x64
            nn.Conv2d(512, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256, 0.8),
            nn.ReLU(),
            nn.Conv2d(256, 128, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 3, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(3, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = 64 // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        validity = self.adv_layer(out)

        return validity

In [8]:
# class GAN(L.LightningModule):
#     def __init__(
#             self,
#             img_size: int = 64,
#             latent_dim: int = 256,
#             lr: float = 3e-4,
#             b1: float = 0,
#             b2: float = 0.999,
#             n_critic: int = 1
#             ):
#         super().__init__()
#         # This is partialy for wandb logging
#         self.save_hyperparameters()
#         # This is important
#         self.automatic_optimization = False

#         self.latent_dim = latent_dim
#         self.lr = lr
#         self.b1 = b1
#         self.b2 = b2
#         self.img_size = img_size
#         self.n_critic = n_critic

#         self.generator = Generator()
#         self.discriminator = Discriminator()

#     def forward(self, z):
#         return self.generator(z)

#     def loss_Discriminator(self, real_img, gen_img):
#         return -torch.mean(real_img) + torch.mean(gen_img)

#     def loss_Generator(self, gen_img):
#         return -torch.mean(gen_img)

#     def gradient_penalty(self, critic, real_samples, fake_samples):
#         """
#         Calculate the gradient penalty for WGAN-GP (Wasserstein GAN with gradient penalty).
        
#         Args:
#             critic (nn.Module): The critic network
#             real_samples (torch.Tensor): Batch of real samples
#             fake_samples (torch.Tensor): Batch of generated samples
            
#         Returns:
#             torch.Tensor: Gradient penalty term (scalar)
#         """
#         # Random weight for interpolation between real and fake samples
#         alpha = torch.rand((real_samples.size(0), 1, 1, 1)).type_as(real_samples)
        
#         # Get random interpolation between real and fake samples
#         interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
        
#         # Calculate critic scores for interpolated images
#         d_interpolates = critic(interpolates)
        
#         # Calculate gradients of scores with respect to interpolates
#         gradients = torch.autograd.grad(
#             outputs=d_interpolates,
#             inputs=interpolates,
#             # this is equivalent to .to(device) but instead creates the tensor on the device that the images are on
#             grad_outputs=torch.ones_like(d_interpolates).type_as(d_interpolates),
#             create_graph=True,
#             # this is important
#             retain_graph=True,
#             only_inputs=True,
#         )[0]
        
#         # Flatten gradients to easily calculate the norm
#         gradients = gradients.view(gradients.size(0), -1)
        
#         # Calculate gradient penalty
#         gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        
#         return gradient_penalty

#     def training_step(self, batch, batch_idx):
#         imgs = batch
#         optimizer_g, optimizer_d = self.optimizers()
        

#         # Train discriminator n_critic times

#         self.toggle_optimizer(optimizer_d)
        
#         # Generate new images for discriminator training
#         z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim)
#         #z_d = torch.randn(imgs.shape[0])
#         z_d = z_d.type_as(imgs)
#         fake_imgs_d = self(z_d)

#         real_score = self.discriminator(imgs)
#         fake_score = self.discriminator(fake_imgs_d)
        
#         gp = self.gradient_penalty(self.discriminator, imgs, fake_imgs_d)
#         loss_D = self.loss_Discriminator(real_score, fake_score) + 10 * gp
        
#         # Calculate and log Wasserstein distance
#         # The negative of the discriminator loss (before gradient penalty) is an estimate of the Wasserstein distance
#         #wasserstein_distance = torch.mean(real_score) - torch.mean(fake_score)
#         #self.log("wasserstein_distance", wasserstein_distance)
        
#         self.log("d_loss", loss_D)
#         self.manual_backward(loss_D)
#         optimizer_d.step()
#         optimizer_d.zero_grad()
#         self.untoggle_optimizer(optimizer_d)

#         # for _ in range(self.n_critic):
#         # Train generator once every n_critic iterations
#         self.toggle_optimizer(optimizer_g)
        
#         # Generate images for generator training
#         z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
#         #z_g = torch.randn(imgs.shape[0])
#         z_g = z_g.type_as(imgs)
#         fake_imgs_g = self(z_g)
        
#         g_loss = self.loss_Generator(self.discriminator(fake_imgs_g))
#         self.log("g_loss", g_loss)
#         self.manual_backward(g_loss)
#         optimizer_g.step()
#         optimizer_g.zero_grad()
#         self.untoggle_optimizer(optimizer_g)

#         if batch_idx % 2000 == 0:
#             wandb.log({"generated_images": [wandb.Image(fake_img) for fake_img in fake_imgs_g[:10]]})

#     def configure_optimizers(self):
#         lr = self.hparams.lr
#         b1 = self.hparams.b1
#         b2 = self.hparams.b2

#         opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
#         opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
#         return [opt_g, opt_d], []

In [9]:
class GAN(L.LightningModule):
    def __init__(
            self,
            img_size: int = 64,
            latent_dim: int = 256,
            lr: float = 1e-5,
            b1: float = 0,
            b2: float = 0.9,
            n_critic: int = 1
            ):
        super().__init__()
        # Save hyperparameters for wandb logging
        self.save_hyperparameters()
        # Disable automatic optimization to handle generator and discriminator separately
        self.automatic_optimization = False

        self.latent_dim = latent_dim
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.img_size = img_size
        self.n_critic = n_critic

        self.generator = Generator()
        self.discriminator = Discriminator()
        
        # Define MSE loss function
        self.criterion = torch.nn.MSELoss()

    def forward(self, z):
        return self.generator(z)

    def loss_Discriminator(self, real_img, gen_img):
        # Create labels for real and fake images
        real_labels = torch.ones(real_img.size(0), 1).type_as(real_img)
        fake_labels = torch.zeros(gen_img.size(0), 1).type_as(gen_img)
        
        # Calculate loss for real and fake images
        real_loss = self.criterion(real_img, real_labels)
        fake_loss = self.criterion(gen_img, fake_labels)
        
        # Total discriminator loss is the average of real and fake losses
        return (real_loss + fake_loss) / 2

    def loss_Generator(self, gen_img):
        # For generator, we want the discriminator to predict generated images as real
        target_labels = torch.ones(gen_img.size(0), 1).type_as(gen_img)
        return self.criterion(gen_img, target_labels)

    def training_step(self, batch, batch_idx):
        imgs = batch
        optimizer_g, optimizer_d = self.optimizers()

        # Train discriminator
        self.toggle_optimizer(optimizer_d)
        
        # Generate new images for discriminator training
        z_d = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_d = z_d.type_as(imgs)
        fake_imgs_d = self(z_d)

        real_score = self.discriminator(imgs)
        fake_score = self.discriminator(fake_imgs_d)
        
        loss_D = self.loss_Discriminator(real_score, fake_score)
        
        self.log("d_loss", loss_D)
        self.manual_backward(loss_D)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

        # Train generator

        self.toggle_optimizer(optimizer_g)
        
        # Generate images for generator training
        z_g = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z_g = z_g.type_as(imgs)
        fake_imgs_g = self(z_g)
        
        g_loss = self.loss_Generator(self.discriminator(fake_imgs_g))
        self.log("g_loss", g_loss)
        self.manual_backward(g_loss)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)

        if batch_idx % 2000 == 0:
            wandb.log({"generated_images": [wandb.Image(fake_img) for fake_img in fake_imgs_g[:10]]})

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

In [10]:
# Trainer settings
# at batch size 512 the model colapses
model = GAN(latent_dim=512, lr=3e-4)

dm = FFHQDataModule(batch_size=BATCH_SIZE)

# could try reduced precision but I had problems with it earlier, tried "medium" and bf16, small speedups, not worth the risk
trainer = L.Trainer(
    logger=wandb_logger,
    max_epochs=EPOCHS,
    accelerator="gpu",
    enable_checkpointing=True,
    precision="32-true"
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
# Run
trainer.fit(model, train_dataloaders=dm)

You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Restoring states from the checkpoint path at Weights/WGANGP_MNIST_final/epoch=570-step=2498696.ckpt


RuntimeError: Error(s) in loading state_dict for GAN:
	Missing key(s) in state_dict: "generator.conv_blocks.1.weight", "generator.conv_blocks.1.bias", "generator.conv_blocks.4.weight", "generator.conv_blocks.4.bias", "generator.conv_blocks.4.running_mean", "generator.conv_blocks.4.running_var", "generator.conv_blocks.8.weight", "generator.conv_blocks.8.bias", "generator.conv_blocks.8.running_mean", "generator.conv_blocks.8.running_var", "generator.conv_blocks.10.weight", "generator.conv_blocks.10.bias", "generator.conv_blocks.12.weight", "generator.conv_blocks.12.bias". 
	Unexpected key(s) in state_dict: "generator.conv_blocks.2.weight", "generator.conv_blocks.2.bias", "generator.conv_blocks.3.running_mean", "generator.conv_blocks.3.running_var", "generator.conv_blocks.3.num_batches_tracked", "generator.conv_blocks.6.weight", "generator.conv_blocks.6.bias", "generator.conv_blocks.7.running_mean", "generator.conv_blocks.7.running_var", "generator.conv_blocks.7.num_batches_tracked", "generator.conv_blocks.9.weight", "generator.conv_blocks.9.bias". 
	size mismatch for generator.l1.0.weight: copying a param with shape torch.Size([32768, 512]) from checkpoint, the shape in current model is torch.Size([131072, 512]).
	size mismatch for generator.l1.0.bias: copying a param with shape torch.Size([32768]) from checkpoint, the shape in current model is torch.Size([131072]).
	size mismatch for generator.conv_blocks.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512, 1024, 3, 3]).
	size mismatch for generator.conv_blocks.3.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for generator.conv_blocks.7.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
	size mismatch for generator.conv_blocks.7.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).

In [None]:
wandb.finish()