In [1]:
import os
from typing import Dict, List, Optional, OrderedDict, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.seed import seed_everything

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
from torchvision.utils import save_image


In [2]:
ALL_IMAGES = []

img_size = 64
batch_size = 128
normalize = [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]
latent_size = 128
data_dir = "./data/train"

train_dataset = datasets.ImageFolder(data_dir, transforms.Compose(
    [
        transforms.Resize(img_size), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize(*normalize)
    ]
))
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True
)

In [3]:
def denormalize(input_image_tensors: torch.Tensor) -> torch.Tensor:
    """
    Denormalizes the input image tensors.

    Parameters
    ----------
    input_image_tensors : torch.Tensor
        The input image tensors.
    
    Returns
    -------
    torch.Tensor
        The denormalized image tensors.
    """
    return input_image_tensors.mul(normalize[1][0]).add(normalize[0][0])


def save_samples(index: int, sample_images: torch.Tensor) -> None:
    """
    Saves the generated samples.

    Parameters
    ----------
    index : int
        The index of the sample.
    sample_images : torch.Tensor
        The generated sample images.
    """
    fake_name = f"generated-images-{index}.png"
    save_image(denormalize(sample_images[-64:]), os.path.join("generated", fake_name), nrow=8)

In [4]:
class BirdDiscriminator(nn.Module):
    def __init__(
        self,
        input_size: int,
        channel: Optional[int] = 3,
        kernel_size: Optional[int] = 4,
        stride: Optional[int] = 2,
        padding: Optional[int] = 1,
        negative_slope: Optional[float] = 0.2,
        bias: Optional[bool] = False,
    ):
        """
        Initializes the discriminator.

        Parameters
        ----------
        input_size : int
            The input image size.
        channel : int, optional
            The number of channels. (default: 3)
        kernel_size : int, optional
            The kernal size. (default: 4)
        stride : int, optional
            The stride. (default: 2)
        padding : int, optional
            The padding. (default: 1)
        negative_slope : float, optional
            The negative slope. (default: 0.2)
        bias : bool, optional
            Whether to use bias. (default: False)
        """
        super().__init__()
        self.input_size = input_size
        self.channel = channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.negative_slope = negative_slope
        self.bias = bias

        self.model = nn.Sequential(
            # input size: (3, 64, 64)
            nn.Conv2d(
                self.channel, self.input_size, 
                kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            # input size: (64, 32, 32)
            nn.Conv2d(64, 128, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            # input size: (128, 16, 16)
            nn.Conv2d(128, 256, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            # input size: (256, 8, 8)
            nn.Conv2d(256, 512, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(self.negative_slope, inplace=True),

            # input size: (512, 4, 4)
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=self.bias), # output size: (1, 1, 1)
            nn.Flatten(),
            nn.Sigmoid()
        )

    
    def forward(self, input_img: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        input_img : torch.Tensor
            The input image.

        Returns
        -------
        torch.Tensor
            The output.
        """
        return self.model(input_img)

In [5]:
class BirdGenerator(nn.Module):
    def __init__(
        self,
        latent_size: Optional[int] = 128,
        channel: Optional[int] = 3,
        kernel_size: Optional[int] = 4,
        stride: Optional[int] = 2,
        padding: Optional[int] = 1,
        bias: Optional[bool] = False,
    ):
        """
        Initializes the generator.

        Parameters
        ----------
        latent_size : int, optional
            The latent size. (default: 128)
        channel : int, optional
            The number of channels. (default: 3)
        kernel_size : int, optional
            The kernel size. (default: 4)
        stride : int, optional
            The stride. (default: 2)
        padding : int, optional
            The padding. (default: 1)
        bias : bool, optional
            Whether to use bias. (default: False)
        """
        super().__init__()
        self.latent_size = latent_size
        self.channel = channel
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.bias = bias

        self.model = nn.Sequential(
            # input size: (latent_size=128, 1, 1)
            nn.ConvTranspose2d(
                self.latent_size, 512, kernel_size=self.kernel_size, stride=1, padding=0, bias=self.bias
            ),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),

            # input size: (512, 4, 4)
            nn.ConvTranspose2d(
                512, 256, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # input size: (256, 8, 8)
            nn.ConvTranspose2d(
                256, 128, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            # input size: (128, 16, 16)
            nn.ConvTranspose2d(
                128, 64, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            # input size: (64, 32, 32)
            nn.ConvTranspose2d(
                64, self.channel, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
            ),
            nn.Tanh() # output size: (3, 64, 64)
        )

    
    def forward(self, input_img: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        input_img : torch.Tensor
            The input image.

        Returns
        -------
        torch.Tensor
            The output.
        """
        return self.model(input_img)

In [6]:
class BirdLightningGan(pl.LightningModule):
    def __init__(
        self, 
        latent_size: Optional[int] = 128,
        learning_rate: Optional[float] = 0.0002,
        batch_size: Optional[int] = 128,
        bias1: Optional[float] = 0.5,
        bias2: Optional[float] = 0.999,
    ):
        """
        Initializes the LightningGan.

        Parameters
        ----------
        latent_size : int, optional
            The latent size. (default: 128)
        learning_rate : float, optional
            The learning rate. (default: 0.0002)
        batch_size : int, optional
            The batch size. (default: 128)
        bias1 : float, optional
            The bias1. (default: 0.5)
        bias2 : float, optional
            The bias2. (default: 0.999)
        """
        super().__init__()
        self.latent_size = latent_size
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.bias1 = bias1
        self.bias2 = bias2
        self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1)
        self.save_hyperparameters()

        self.generator = BirdGenerator(latent_size=self.latent_size)
        self.discriminator = BirdDiscriminator(input_size=64)


    def adversarial_loss(self, preds: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculates the adversarial loss.

        Parameters
        ----------
        preds : torch.Tensor
            The predictions.
        labels : torch.Tensor
            The labels.

        Returns
        -------
        torch.Tensor
            The adversarial loss.
        """
        return F.binary_cross_entropy(preds, labels)

    
    def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List]:
        """
        Configures the optimizers.

        Returns
        -------
        Tuple[List[torch.optim.Optimizer], List]
            The optimizers and the LR schedulers.
        """
        opt_generator = torch.optim.Adam(
            self.generator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
        )
        opt_discriminator = torch.optim.Adam(
            self.discriminator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
        )
        return [opt_generator, opt_discriminator], []

    
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Forward propagation.

        Parameters
        ----------
        z : torch.Tensor
            The latent vector.

        Returns
        -------
        torch.Tensor
            The output.
        """
        return self.generator(z)

    
    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int
    ) -> Dict:
        """
        Training step.

        Parameters
        ----------
        batch : Tuple[torch.Tensor, torch.Tensor]
            The batch.
        batch_idx : int
            The batch index.
        optimizer_idx : int
            The optimizer index.

        Returns
        -------
        Dict
            The training loss.
        """
        real_images, _ = batch

        if optimizer_idx == 0: # Only train the generator
            fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
            fake_random_noise = fake_random_noise.type_as(real_images)
            fake_images = self(fake_random_noise)

            # Try to fool the discriminator
            preds = self.discriminator(fake_images)
            targets = torch.ones(self.batch_size, 1)
            targets = targets.type_as(real_images)
            loss = self.adversarial_loss(preds, targets)
            self.log("g_loss", loss, on_step=False, on_epoch=True)

            tqdm_dict = {"g_loss": loss}
            output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
            return output

        elif optimizer_idx == 1: # Only train the discriminator
            real_preds = self.discriminator(real_images)
            real_targets = torch.ones(real_images.size(0), 1)
            real_targets = real_targets.type_as(real_images)
            real_loss = self.adversarial_loss(real_preds, real_targets)

            # Generate fake images
            real_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
            real_random_noise = real_random_noise.type_as(real_images)
            fake_images = self(real_random_noise)

            # Pass fake images though discriminator
            fake_targets = torch.zeros(fake_images.size(0), 1)
            fake_targets = fake_targets.type_as(real_images)
            fake_preds = self.discriminator(fake_images)
            fake_loss = self.adversarial_loss(fake_preds, fake_targets)

            # Update discriminator weights
            loss = real_loss + fake_loss
            self.log("d_loss", loss, on_step=False, on_epoch=True)

            tqdm_dict = {"d_loss": loss}
            output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
            return output

    
    def on_epoch_end(self):
        """
        Called at the end of an epoch.
        """
        z = self.validation.type_as(self.generator.model[0].weight)
        sample_images = self(z)
        ALL_IMAGES.append(sample_images.detach().cpu())
        save_samples(self.current_epoch, sample_images)

In [7]:
seed_everything(42)
gpus = 1 if torch.cuda.is_available() else 0

logger = TensorBoardLogger("logs", name="bird_lightning_gan")

model = BirdLightningGan()

trainer = pl.Trainer(
    gpus=gpus,
    max_epochs=500,
    progress_bar_refresh_rate=25,
    # callbacks=[early_stopping, checkpointer],
    logger=logger,
)
trainer.fit(model, train_dataloader)

Global seed set to 42
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | generator     | BirdGenerator     | 3.8 M 
1 | discriminator | BirdDiscriminator | 2.8 M 
----------------------------------------------------
6.6 M     Trainable params
0         Non-trainable params
6.6 M     Total params
26.287    Total estimated model params size (MB)


Epoch 499: 100%|██████████| 457/457 [3:22:44<00:00, 26.62s/it, loss=4.34, v_num=3]    
