# Cycle GAN (pytorch lightning + Weights and Biases)    

---  

* *Comfort arises from familiarity or precisely knowing where's what!*

* So, why not use the same framework for all your CV/AI projects with near perfect experiment logging while being lightning fast?


> Aryan Garg   


**Note:**
1. You'll need an account on [Weights & Biases](https://wandb.ai/site). Don't worry, it's free! 

**References:**
1. [Tutorial: Cycle GAN from Scratch by Song Seung Won](https://www.kaggle.com/code/songseungwon/cyclegan-tutorial-from-scratch-monet-to-photo)

In [1]:
SUBMIT_NB = True

## Imports

### wandb (WARNING: interactive cell)

In [2]:
if not SUBMIT_NB:
    !pip -qqq install wandb pytorch-lightning torchmetrics

    import wandb
    from pytorch_lightning.loggers import WandbLogger

    wandb.login()

### Lightning

In [3]:
try:
    import lightning.pytorch as pl
except:
    print("[!] Couldn't find pytorch-lightning.\nInstalling it...\n")
    !pip install lightning
    import lightning.pytorch as pl

[!] Couldn't find pytorch-lightning.
Installing it...

Collecting lightning
  Downloading lightning-2.0.2-py3-none-any.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m31.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fastapi<0.89.0,>=0.69.0
  Downloading fastapi-0.88.0-py3-none-any.whl (55 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.5/55.5 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting croniter<1.4.0,>=1.3.0
  Downloading croniter-1.3.14-py2.py3-none-any.whl (18 kB)
Collecting deepdiff<8.0,>=5.7.0
  Downloading deepdiff-6.3.0-py3-none-any.whl (69 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
Collecting dateutils<2.0
  Downloading dateutils-0.6.12-py2.py3-none-any.whl (5.7 kB)
Collecting inquirer<5.0,>=2.10.0
  Downloading inquirer-3.1.3-py3-none-any.whl (18 kB)
Collecting lightning-cloud>=0.5.34
  Dow



In [4]:
from lightning.pytorch.utilities.model_summary import ModelSummary

In [5]:
from pytorch_lightning import seed_everything

### Standard imports

In [6]:
import os
import shutil
import pathlib

import PIL
from PIL import Image
import numpy as np
import cv2 as cv
import random
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim

from torch.utils.data import Dataset, DataLoader, random_split

import torchvision
from torchvision import datasets

In [7]:
import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.utils import make_grid

### Torchmetrics 

*Not needed.*     

But just in case if we decide to use a metric to draw some insights!      

Especially: FID since the competition evaluates on **Memorization-Informed FID**

In [8]:
try:
    import torchmetrics
except:
    print(f"[!] Torchmetrics couldn't be imported.\nInstalling...")
    !pip install torchmetrics

### Custom Utilities (Not many lol)

In [9]:
# Folder Utilities ----------------------------

## Create dir if it doesn't exist
def create_dir(dir_name):
    if not os.path.exists(f'/content/{dir_name}'):
        os.mkdir(f'/content/{dir_name}')

## Delete dir: checkpoints
def delete_dir(dir_name):
    if os.path.isdir(f'/content/{dir_name}'):
        shutil.rmtree(f'/content/{dir_name}')

--- 

## Config File (logged to wandb), Seeds & Devices

In [10]:
# Log this config file to wandb
CONFIG = dict(
    seed=42,
    DATA_ROOT = '/kaggle/input/gan-getting-started/',
    BATCH_SIZE = 32,
    WORKERS = 2,
    IMG_SIZE = (256,256,3),
    NUM_EPOCHS = 20,
    lr = 0.0002,
    b1 = 0.5,
    b2 = 0.999,
    disc_steps=1,
    checkpoint_path='/kaggle/working/',
    )

In [11]:
seed_everything(CONFIG['seed'])

42

In [12]:
# device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

device

device(type='cuda')

---   

## Transforms

In [13]:
train_transform_album = Compose([T.Resize((256,256)),ToTensor()])
#         A.SmallestMaxSize(max_size=160),
#         A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
#         A.RandomCrop(height=128, width=128),
#         A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
#         A.RandomBrightnessContrast(p=0.5),
#         A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

---  

## Dataset   

### !! Use LightningDataModule this time? !!


**Reqs:**
1. Sort images and load aligned dataset (domA_img, domB_img)

In [14]:
class MonetDataset(Dataset):
    def __init__(self, 
                 root_dir: str = None, 
                 cust_transform = Compose([ToTensor()]), 
                 isTrain: bool = True):
        
        super(MonetDataset).__init__()
        self.root = root_dir
        self.files_Y = os.listdir(self.root+"monet_jpg/")
        start = np.random.randint(0,6701)
        end = start+300
        self.files_X = os.listdir(self.root+"photo_jpg/")[start:end]
        self.transform = cust_transform
        self.isTrain = isTrain
        
    def __getitem__(self, idx):
        photo_X = Image.open(self.root+"photo_jpg/"+self.files_X[idx])
        photo_Y = Image.open(self.root+"monet_jpg/"+self.files_Y[idx])
        
        item_X = self.transform(photo_X)
        item_Y = self.transform(photo_Y)
        
        return (item_X, item_Y) # Photo , Monet
    
    def __len__(self):
#         print("Monet:",len(self.files_Y),"Photo:",len(self.files_X))
        return max(len(self.files_Y), len(self.files_X))

In [15]:
if not SUBMIT_NB:
    all_data = MonetDataset(CONFIG['DATA_ROOT'], cust_transform=train_transform_album)
    real, monet = next(iter(all_data))

    image_grid = torch.cat((real, monet), 1)
    plt.imshow(image_grid.cpu().permute(1,2,0))
    plt.title("top: real | bottom: monet")
    plt.axis('off')
    plt.show()

In [16]:
if not SUBMIT_NB:
    print(real.shape, real.dtype)
    print(len(all_data))

In [17]:
class MonetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = None, batch_size: int = 1):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        
    def prepare_data(self):
        pass

    def setup(self, stage: str):
        if stage == 'fit':
            self.train_data = MonetDataset(CONFIG['DATA_ROOT'], isTrain=True)
        elif stage == 'test': 
            # Not implemented anything different from above (training dataset)... 
            # This is just to highlight how the datamodule can be used for general purposes.
            self.test_data = MonetDataset(CONFIG['DATA_ROOT'], isTrain=False)
            
    def train_dataloader(self):
        return DataLoader(self.train_data, shuffle=True, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test_data, shuffle=True, batch_size=self.batch_size)

    def teardown(self, stage: str):
        pass

---   

## Architecture

In [18]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
[0m

In [19]:
from torchvision import models
from torchsummary import summary

### Residual Generator

Note to self:    

nn.conv2d(in_f, out_f, k,s,p)    

k -> kernel_size    
s -> stride (default: 1)    
p -> padding (default: 0)  

In [20]:
class ResBlk(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.model = nn.Sequential(
            # Alpha blk begin --------------
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3, 1, 0),
            nn.InstanceNorm2d(in_features),
            # Alpha blk end   --------------
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3, 1, 0),
            nn.InstanceNorm2d(in_features),
        )
    
    def forward(self, x):
        return self.model(x) + x

In [21]:
class ResGen(nn.Module):
    def __init__(self, channels, out_features=64, num_residual_blocks=9):
        super().__init__()
        self.c = channels
        self.out_f = out_features
        self.num_resBlks = num_residual_blocks
        
        # Alpha blk
        self.model = [nn.ReflectionPad2d(self.c), 
                      nn.Conv2d(self.c, self.out_f, 7, 1, 0), 
                      nn.InstanceNorm2d(self.out_f)]
        self.model += [nn.ReLU(inplace=True)]
        
        # Downsampler: Add 2 Beta blocks
        out_f2 = None
        for i in range(2):
            out_f2 = 2*self.out_f
            self.model += [nn.Conv2d(self.out_f, out_f2, 3,2,1), 
                           nn.InstanceNorm2d(out_f2), 
                           nn.ReLU(inplace=True)]
            self.out_f = out_f2
            
        # Add residual blocks defined in the cell above
        for i in range(self.num_resBlks):
            self.model += [ResBlk(in_features=out_f2)]
            
        # Upsampler: Add 2 Gamma Blocks
        in_up = out_f2
        for i in range(2):
            out_f2 //= 2
            self.model += [nn.Upsample(scale_factor=2),
                           nn.Conv2d(in_up, out_f2, 3,1,1),
                           nn.ReLU(inplace=True)
                          ]
            in_up = out_f2
            
        # Output layer:
        self.model += [nn.ReflectionPad2d(self.c), 
                       nn.Conv2d(out_f2, self.c ,7, 1, 0), 
                       nn.Tanh()]
        
        self.model = nn.Sequential(*self.model)
        
    def forward(self, x):
        return self.model(x)

In [22]:
resgen = ResGen(channels=3).to(device)
summary(resgen, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ReflectionPad2d-1          [-1, 3, 262, 262]               0
            Conv2d-2         [-1, 64, 256, 256]           9,472
    InstanceNorm2d-3         [-1, 64, 256, 256]               0
              ReLU-4         [-1, 64, 256, 256]               0
            Conv2d-5        [-1, 128, 128, 128]          73,856
    InstanceNorm2d-6        [-1, 128, 128, 128]               0
              ReLU-7        [-1, 128, 128, 128]               0
            Conv2d-8          [-1, 256, 64, 64]         295,168
    InstanceNorm2d-9          [-1, 256, 64, 64]               0
             ReLU-10          [-1, 256, 64, 64]               0
  ReflectionPad2d-11          [-1, 256, 66, 66]               0
           Conv2d-12          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-13          [-1, 256, 64, 64]               0
             ReLU-14          [-1, 256,

### PatchGAN Discriminator

In [23]:
class patchDisc(nn.Module):
    def __init__(self, channels):
        super().__init__()
        
        self.c = channels
        def disc_block(in_f, out_f):
            return [nn.Conv2d(in_f, out_f, 4,2,1), nn.InstanceNorm2d(out_f), nn.LeakyReLU(0.2, inplace=True)]
        
        # PatchGAN calculation done by hand
        self.output_shape = (1, 16, 16)
        
        self.model = nn.Sequential(
            *disc_block(self.c, 64),
            *disc_block(64, 128),
            *disc_block(128, 256),
            *disc_block(256, 512),
            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(512, 1, 4,1,1)
        )
    
        
    def forward(self, x):
        return self.model(x)

In [24]:
if not SUBMIT_NB:
    pdisc = patchDisc(3).to(device)
    summary(pdisc, (3,256,256))

---   

## Lightning Recipe

In [25]:
class LIT_cycle(pl.LightningModule):
    
    def __init__(self, Gen_XY=None, Gen_YX=None, D_X=None, D_Y=None, 
                 lr: float = 1e-3, b1: float = 0.5, b2: float = 0.999, channels=3, disc_steps=1):
        super().__init__()
        
        assert Gen_XY is not None, "Pass the generator: X -> Y !"
        assert Gen_YX is not None, "Pass the generator: Y -> X !"
        assert D_X is not None, "Pass the patch-discriminator: X !"
        assert D_Y is not None, "Pass the patch-discriminator: Y !"
        
        self.save_hyperparameters(ignore=[])
        self.automatic_optimization = False
        
        self.g_xy = Gen_XY
        self.g_yx = Gen_YX
        self.dx = D_X
        self.dy = D_Y
        self.channels = channels
        
        self.criterion_GAN = torch.nn.MSELoss().to(device)
        self.cycle_loss = torch.nn.L1Loss().to(device)
        self.identity_loss = torch.nn.L1Loss().to(device)
        
        
    def configure_optimizers(self):
        lr = self.hparams.lr # .hparams is accessed from the args passed to this core Lightning module :)
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        
        # chain the two generator optimizers so that:
        # both generators are updated together (using the same optimizer instance)!
        from itertools import chain
        optim_g = torch.optim.Adam(chain(self.g_xy.parameters(), self.g_yx.parameters()), lr=lr, betas=(b1,b2))
        
        optim_dx = torch.optim.Adam(self.dx.parameters(), lr=lr, betas=(b1,b2))
        optim_dy = torch.optim.Adam(self.dy.parameters(), lr=lr, betas=(b1,b2))
        
        # Second list is for returning any lr-schedulers that you might want to use!
        return [optim_g, optim_dx, optim_dy], []
    
    
    def forward(self, z, toDomain='x'):
        if toDomain == 'x':
            return self.g_yx(z)
        else:
            return self.g_xy(z)
        

    def training_step(self, batch, idx):
        real_X, real_Y = batch
        real_X, real_Y = real_X.to(device), real_Y.to(device)
        
        valid = torch.Tensor(np.ones((real_X.size(0), *self.dx.output_shape))).to(device)
        fake = torch.Tensor(np.zeros((real_X.size(0), *self.dx.output_shape))).to(device)
        
        opt_g, opt_dx, opt_dy = self.optimizers()
        
        # Train G:
        self.toggle_optimizer(opt_g)
        
        # TODO: Compute loss:
        out_x = self.g_yx(real_X)
        out_y = self.g_xy(real_Y)
#         print(out_x.shape, out_y.shape)
        
        loss_id_X = self.identity_loss(out_x, real_X)
        loss_id_Y = self.identity_loss(out_y, real_Y)
        loss_identity = (loss_id_X + loss_id_Y)/2
        
        # GAN Loss
        fake_Y = self.g_xy(real_X)
        loss_GAN_XY = self.criterion_GAN(self.dy(fake_Y), valid) # tricking the 'fake-Y' into 'real-Y'
        fake_X = self.g_yx(real_Y)
        loss_GAN_YX = self.criterion_GAN(self.dx(fake_X), valid) # tricking the 'fake-X' into 'real-X'
        
        loss_GAN = (loss_GAN_XY + loss_GAN_YX)/2
        
        # Cycle Loss
        recov_X = self.g_yx(fake_Y) # recov_X is fake-photo that is generated by fake-monet-drawing 
        loss_cycle_X = self.cycle_loss(recov_X, real_X) # Reduces the difference between the restored image and the real image
        recov_Y = self.g_xy(fake_X)
        loss_cycle_Y = self.cycle_loss(recov_Y, real_Y)
        
        loss_cycle = (loss_cycle_X + loss_cycle_Y)/2.
        
        # Total loss:         
        g_loss = loss_GAN + (10.0*loss_cycle) + (5.0*loss_identity) # multiply by weights suggested by paper-authors
        
        # Gradient step:
        self.manual_backward(g_loss)
        opt_g.step()
        opt_g.zero_grad()
        
        self.untoggle_optimizer(opt_g)
        
        # Train Discriminator
        for i in range(self.hparams.disc_steps): # is kept at 1, usually
            # Toggle dx
            self.toggle_optimizer(opt_dx)
            # Compute losses for dx:
            loss_real = self.criterion_GAN(self.dx(real_X), valid) # Discriminate real images as real
            loss_fake = self.criterion_GAN(self.dx(fake_X.detach()), fake) # Discriminate fake images as fake
        
            dx_loss = (loss_real + loss_fake)/2.
            
            self.manual_backward(dx_loss)
            opt_dx.step()
            opt_dx.zero_grad()
            self.untoggle_optimizer(opt_dx)
            
            # Toggle dy
            self.toggle_optimizer(opt_dy)
            # Compute losses for dy
            loss_real = self.criterion_GAN(self.dy(real_Y), valid) # Discriminate real images as real
            loss_fake = self.criterion_GAN(self.dy(fake_Y.detach()), fake) # Discriminate fake images as fake
        
            dy_loss = (loss_real + loss_fake)/2.
            
            self.manual_backward(dy_loss)
            opt_dy.step()
            opt_dy.zero_grad()
            self.untoggle_optimizer(opt_dy)
            
            tot_d_loss = (dx_loss + dy_loss)/2.
        
        if not SUBMIT_NB:
            self.logger.experiment.log({"Gen. Monet":[wandb.Image(make_grid(fake_Y[0].cpu()), caption="Gen Monet")]})
            self.logger.experiment.log({"Gen. Photo":[wandb.Image(make_grid(fake_X[0].cpu()), caption="Gen Photo")]})
            self.log_dict({"g_loss": g_loss, "dx_loss": dx_loss, "dy_loss": dy_loss, "tot_d_loss": tot_d_loss}, 
                      on_step=True, 
                      on_epoch=True, 
                      prog_bar=True, 
                      logger=True)

In [26]:
# Init Recipe
cycle_gan = LIT_cycle(Gen_XY=ResGen(3), 
                      Gen_YX=ResGen(3), 
                      D_X=patchDisc(3), 
                      D_Y=patchDisc(3), 
                      lr=CONFIG['lr'], 
                      b1=CONFIG['b1'],
                      b2=CONFIG['b2'],
                      channels=CONFIG['IMG_SIZE'][2],
                      disc_steps=CONFIG['disc_steps']
                     )

# summary = ModelSummary(cycle_gan, max_depth=-1)
# print(summary)

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


---   

## Logger: Project Name, Run Name, Config file logging etc.

In [27]:
if not SUBMIT_NB:
    wandb.login()

In [28]:
if not SUBMIT_NB:
    wandb_logger = WandbLogger(project='M8-cycleGAN', 
                           name='exp-2_20eps',
                           config=CONFIG,
                           job_type='train',
                           log_model="all")

---    

## Trainer Callbacks

In [29]:
from pytorch_lightning import Callback
from lightning.pytorch.callbacks import DeviceStatsMonitor, TQDMProgressBar, ModelCheckpoint, EarlyStopping, LearningRateMonitor

# Checkpoint
checkpoint_callback = ModelCheckpoint(dirpath=CONFIG['checkpoint_path'],
                                      filename='{epoch}-{g_loss:.3f}',
                                      monitor='g_loss',
                                      save_top_k=-1,
                                      save_last=True,
                                      save_weights_only=True,
                                      verbose=True,
                                      mode='min')

# Exp2: Learning Rate Monitor
lr_monitor = LearningRateMonitor(logging_interval='step', log_momentum=False)

# Earlystopping
# earlystopping = EarlyStopping(monitor='val_d_acc', patience=3, mode='min')

---   

## Trainer

In [30]:
if SUBMIT_NB:
    logger_name = False
    callbacks_lst = []
else:
    logger_name = wandb_logger
    callbacks_lst = [TQDMProgressBar(refresh_rate=100), checkpoint_callback, lr_monitor]
trainer = pl.Trainer(fast_dev_run=False,    # For debugging purposes
                     log_every_n_steps=1,   # set the logging frequency
                     accelerator='auto',    # Precedence: tpu > gpu >> cpu
                     devices="auto",        # all
                     max_epochs= CONFIG['NUM_EPOCHS'], # CONFIG['NUM_EPOCHS'],
                     callbacks=callbacks_lst,
                     logger=logger_name,    # wandb <3 OR False for NB submission
                    )

---  

## Training

**TODO:** Replace CustomDataset calling with: MonetDataModule Implementation

In [31]:
monet_data = MonetDataset(CONFIG['DATA_ROOT'], isTrain=True)
train_loader = DataLoader(monet_data, batch_size=1, shuffle=True)

In [32]:
# trainer.fit(cycle_gan, train_loader)

## Finish Logging

In [33]:
if not SUBMIT_NB:
    wandb.finish()
    print("Finished W&B session")

---


## Submission (eval steps, logging & zipping results)

In [34]:
cycle_20 = LIT_cycle.load_from_checkpoint("/kaggle/input/cycle-gan-checkpoint-20-epochs/cycle_checkpoint_20.ckpt", 
                                          map_location=device)
# summary = ModelSummary(cycle_20, max_depth=-1)
# print(summary)



In [35]:
from torchvision import models
from torchsummary import summary
summary(cycle_20.g_xy, (3,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ReflectionPad2d-1          [-1, 3, 262, 262]               0
            Conv2d-2         [-1, 64, 256, 256]           9,472
    InstanceNorm2d-3         [-1, 64, 256, 256]               0
              ReLU-4         [-1, 64, 256, 256]               0
            Conv2d-5        [-1, 128, 128, 128]          73,856
    InstanceNorm2d-6        [-1, 128, 128, 128]               0
              ReLU-7        [-1, 128, 128, 128]               0
            Conv2d-8          [-1, 256, 64, 64]         295,168
    InstanceNorm2d-9          [-1, 256, 64, 64]               0
             ReLU-10          [-1, 256, 64, 64]               0
  ReflectionPad2d-11          [-1, 256, 66, 66]               0
           Conv2d-12          [-1, 256, 64, 64]         590,080
   InstanceNorm2d-13          [-1, 256, 64, 64]               0
             ReLU-14          [-1, 256,

In [36]:
! rm -rf ../images
! rm -rf /kaggle/working/images.zip
! mkdir ../images

In [37]:
inf_transform = Compose([T.Resize((256,256)), ToTensor()])

In [38]:
photo_dir = "/kaggle/input/gan-getting-started/photo_jpg"
cycle_gan.eval()
with torch.inference_mode():
    for i, img_file in enumerate(tqdm(os.listdir(photo_dir))):
        img = inf_transform(Image.open(photo_dir+"/"+img_file)).to(device)
        img = img[None, :, :, :]
        prediction = cycle_20.g_xy(img).cpu()
        prediction = prediction[0]
        
        im = T.ToPILImage()(prediction).convert('RGB')
        im.save("../images/" + str(i+1) + ".jpg")   

  0%|          | 0/7038 [00:00<?, ?it/s]

In [39]:
import shutil

In [40]:
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")

'/kaggle/working/images.zip'