<a href="https://colab.research.google.com/github/Vijay-K-2003/Deep_Learning_Models/blob/main/CycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CycleGAN



In [61]:
!pip install lightning



In [62]:
import lightning as L

import torch
import torchvision
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from torchvision.utils import make_grid
from PIL import Image

import itertools

from torchvision.datasets import MNIST
from torchvision import transforms

import matplotlib.pyplot as plt

import os

In [63]:
random_seed = 42
# torch.manual_seed(random_seed)
L.seed_everything(random_seed)

INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


42

In [64]:
BATCH_SIZE = 128
NUM_WORKERS = int(os.cpu_count()/2)
AVAILABLE_GPUS = torch.cuda.device_count()

In [65]:
class MNISTDataModule(L.LightningDataModule):

    def __init__(self,
                 data_dir="./data",
                 batch_size=BATCH_SIZE,
                 num_workers=NUM_WORKERS) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.Resize(32),
            transforms.RandomCrop(28),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def prepare_data(self):
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str) -> None:
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

In [66]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        conv_block = [  nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features),
                        nn.ReLU(inplace=True),
                        nn.ReflectionPad2d(1),
                        nn.Conv2d(in_features, in_features, 3),
                        nn.InstanceNorm2d(in_features)  ]

        self.conv_block = nn.Sequential(*conv_block)

    def forward(self, x):
        return x + self.conv_block(x)

In [67]:
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, x):
        x = self.model(x)

        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

In [68]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]
        in_features = 64
        out_features = in_features*2

        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features*2

        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        out_features = in_features//2

        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features//2

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, 7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [69]:
class GAN(L.LightningModule):
    def __init__(self, input_nc, output_nc, lr=0.0002, decay_epoch=100, batch_size=1, img_size=256, buffer_size=50):
        super().__init__()
        self.save_hyperparameters()
        self.automatic_optimization = False

        # Models
        self.netG_A2B = Generator(input_nc, output_nc)
        self.netG_B2A = Generator(output_nc, input_nc)
        self.netD_A = Discriminator(input_nc)
        self.netD_B = Discriminator(output_nc)

        # Losses
        self.criterion_GAN = torch.nn.MSELoss()
        self.criterion_cycle = torch.nn.L1Loss()
        self.criterion_identity = torch.nn.L1Loss()

        # Replay buffers
        self.fake_A_buffer = []
        self.fake_B_buffer = []

        # Input tensor size
        self.img_size = img_size

    def forward(self, z):
        pass

    def adversarial_loss(self, predicted_labels, actual_labels):
        return self.criterion_GAN(predicted_labels, actual_labels)

    def training_step(self, batch, batch_idx):
        optimizer_G, optimizer_D_A, optimizer_D_B = self.optimizers()

        real_A, real_B = batch

        # Generators A2B and B2A
        optimizer_G.zero_grad()

        # Identity loss
        same_B = self.netG_A2B(real_B)
        loss_identity_B = self.criterion_identity(same_B, real_B) * 5.0
        same_A = self.netG_B2A(real_A)
        loss_identity_A = self.criterion_identity(same_A, real_A) * 5.0

        # GAN loss
        fake_B = self.netG_A2B(real_A)
        pred_fake_B = self.netD_B(fake_B)
        loss_GAN_A2B = self.adversarial_loss(pred_fake_B, torch.ones_like(pred_fake_B))

        fake_A = self.netG_B2A(real_B)
        pred_fake_A = self.netD_A(fake_A)
        loss_GAN_B2A = self.adversarial_loss(pred_fake_A, torch.ones_like(pred_fake_A))

        # Cycle loss
        recovered_A = self.netG_B2A(fake_B)
        loss_cycle_ABA = self.criterion_cycle(recovered_A, real_A) * 10.0

        recovered_B = self.netG_A2B(fake_A)
        loss_cycle_BAB = self.criterion_cycle(recovered_B, real_B) * 10.0

        # Total generator loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        self.manual_backward(loss_G)
        optimizer_G.step()

        # Discriminator A
        optimizer_D_A.zero_grad()
        pred_real_A = self.netD_A(real_A)
        loss_D_real_A = self.adversarial_loss(pred_real_A, torch.ones_like(pred_real_A))

        if len(self.fake_A_buffer) >= self.hparams.buffer_size:
            self.fake_A_buffer.pop(0)
        self.fake_A_buffer.append(fake_A.detach())

        pred_fake_A = self.netD_A(fake_A.detach())
        loss_D_fake_A = self.adversarial_loss(pred_fake_A, torch.zeros_like(pred_fake_A))

        loss_D_A = (loss_D_real_A + loss_D_fake_A) * 0.5
        self.manual_backward(loss_D_A)
        optimizer_D_A.step()

        # Discriminator B
        optimizer_D_B.zero_grad()
        pred_real_B = self.netD_B(real_B)
        loss_D_real_B = self.adversarial_loss(pred_real_B, torch.ones_like(pred_real_B))

        if len(self.fake_B_buffer) >= self.hparams.buffer_size:
            self.fake_B_buffer.pop(0)
        self.fake_B_buffer.append(fake_B.detach())

        pred_fake_B = self.netD_B(fake_B.detach())
        loss_D_fake_B = self.adversarial_loss(pred_fake_B, torch.zeros_like(pred_fake_B))

        loss_D_B = (loss_D_real_B + loss_D_fake_B) * 0.5
        self.manual_backward(loss_D_B)
        optimizer_D_B.step()

        # Logging losses
        self.log('loss_G', loss_G, prog_bar=True, logger=True)
        self.log('loss_D_A', loss_D_A, prog_bar=True, logger=True)
        self.log('loss_D_B', loss_D_B, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer_G = torch.optim.Adam(
            itertools.chain(self.netG_A2B.parameters(), self.netG_B2A.parameters()),
            lr=self.hparams.lr, betas=(0.5, 0.999))

        optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=self.hparams.lr, betas=(0.5, 0.999))
        optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=self.hparams.lr, betas=(0.5, 0.999))

        lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda epoch: 1 - max(0, epoch - self.hparams.decay_epoch) / float(self.hparams.decay_epoch))
        lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lambda epoch: 1 - max(0, epoch - self.hparams.decay_epoch) / float(self.hparams.decay_epoch))
        lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lambda epoch: 1 - max(0, epoch - self.hparams.decay_epoch) / float(self.hparams.decay_epoch))

        return [optimizer_G, optimizer_D_A, optimizer_D_B], [lr_scheduler_G, lr_scheduler_D_A, lr_scheduler_D_B]

    def plot_losses(self):
        # Implement logic to plot or log losses during training
        pass

    def plot_images(self):
        # Implement logic to visualize images during training
        pass

    def on_train_epoch_end(self):
        self.plot_images()
        self.plot_losses()


In [70]:
dm = MNISTDataModule()
model = GAN(input_nc=1, output_nc=1, lr=0.0002, decay_epoch=100, batch_size=1, img_size=28)

In [71]:
model.plot_images()

In [72]:
trainer = L.Trainer(max_epochs=50, accelerator='gpu', devices=AVAILABLE_GPUS)
trainer.fit(model, dm)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name               | Type          | Params | Mode 
-------------------------------------------------------------
0 | netG_A2B           | Generator     | 11.4 M | train
1 | netG_B2A           | Generator     | 11.4 M | train
2 | netD_A             | Discriminator | 2.8 M  | train
3 | netD_B             | Discriminator | 2.8 M  | train
4 | criterion_GAN      | MSELoss       | 0      | train
5 | criterion_cycle    | L1Loss        | 0      | train
6 | criterion_identity | L1Lo

Training: |          | 0/? [00:00<?, ?it/s]

RuntimeError: Padding length should be less than or equal to two times the input dimension but got padding length 4 and input of dimension 1

In [None]:
model.plot_images()