In [None]:
!nvidia-smi

## Notebook Setup

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Imports

In [None]:
import os
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.utils import save_image

from cyclegan import CycleGenerator, CycleDiscriminator
from dataset.reside import ResideDataset
from utils.checkpoints import save_checkpoint

## Hyperparameters

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
## Hyperparameters
LR = 2e-4
BATCH_SIZE = 8
NUM_EPOCHS = 12
LAMBDA_CYCLE = 10
LAMBDA_ID = 0.5 * LAMBDA_CYCLE

SAVE_MODEL = True

######
TRAIN_DIR = "indoor-training-set-its-residestandard"
TEST_DIR = "/synthetic-objective-testing-set-sots-reside/indoor"

## Discriminators / Generators / Optimizers / Losses

In [None]:
disc_H = CycleDiscriminator().to(device)
disc_C = CycleDiscriminator().to(device)
gen_H = CycleGenerator(num_residuals=9).to(device)
gen_C = CycleGenerator(num_residuals=9).to(device)
optimizer_D = optim.Adam(
    list(disc_H.parameters()) + list(disc_C.parameters()),
    lr=LR,
    betas=(0.5, 0.999),
)

optimizer_G = optim.Adam(
    list(gen_H.parameters()) + list(gen_C.parameters()),
    lr=LR,
    betas=(0.5, 0.999)
)

L1 = nn.L1Loss()
MSE = nn.MSELoss()

## Dataset / Dataloader

In [None]:
train_transforms = T.Compose([
    T.Resize((128,128)),
    T.RandomHorizontalFlip(0.5),
    T.ToTensor(),
    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

test_transforms = T.Compose([
    T.Resize((128,128)),
    T.ToTensor(),
    T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_set = ResideDataset(TRAIN_DIR, img_transform=train_transforms)
test_set = ResideDataset(TEST_DIR, img_transform=test_transforms)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

## Train Step function

In [None]:
def train(disc_h, disc_c, gen_h, gen_c, loader, opt_disc, opt_gen, l1, mse, perceptual=None):
    loop = tqdm(loader, leave=True)
    
    for idx, (realH, realC) in enumerate(loop):
        realH, realC = realH.to(device), realC.to(device)
        fakeH, fakeC = gen_h(realH), gen_c(realC)
        backH, backC = gen_h(fakeH), gen_c(fakeC)
        sameH, sameC = gen_h(realH), gen_c(realC)
        pred_realH, pred_realC = disc_h(realH), disc_c(realC)
        pred_fakeH_true, pred_fakeC_true = disc_h(fakeH.detach()), disc_c(fakeC.detach())

        # == Discriminator Loss ==
        # Adversarial Loss

        loss_adversarialDH = (mse(pred_realH, torch.ones_like(pred_realH)) +
                              mse(pred_fakeH_true, torch.zeros_like(pred_fakeH_true)))
        loss_adversarialDC = (mse(pred_realC, torch.ones_like(pred_realC)) +
                              mse(pred_fakeC_true, torch.zeros_like(pred_fakeC_true)))

        lossD = (loss_adversarialDH + loss_adversarialDC) / 2

        opt_disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        pred_fakeH_false, pred_fakeC_false = disc_h(fakeH), disc_c(fakeC)

        # == Generator Loss ==
        # Adversarial Loss
        loss_adv_H = mse(pred_fakeH_false, torch.ones_like(pred_fakeH_false))
        loss_adv_C = mse(pred_fakeC_false, torch.ones_like(pred_fakeC_false))
        loss_adv_G = loss_adv_H + loss_adv_C

        # Cycle Loss
        loss_cycleH = l1(backH, sameH)
        loss_cycleC = l1(backC, sameC)
        loss_cycle_G = (loss_cycleH + loss_cycleC)/2

        # Identity Loss
        loss_identity_H = l1(sameH, realH)
        loss_identity_C = l1(sameC, realC)
        loss_identity_G = (loss_identity_H + loss_identity_C)/2

        # Perceptual Loss
        if perceptual:
            loss_perceptual_H = perceptual(sameH, realH)
            loss_perceptual_C = perceptual(sameC, realC)
            loss_perceptual_G = loss_perceptual_H + loss_perceptual_C
        else:
            loss_perceptual_G = torch.tensor(0)

        # Total Loss
        lossG = (loss_adv_G + LAMBDA_CYCLE * loss_cycle_G + LAMBDA_ID * loss_identity_G)

        # Backprop
        opt_gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        loss_total = lossD + lossG

        if idx % 500 == 0:
            print("--------------------------------------------------")
            print(f"Loss/ Total: {loss_total.item():.4f}")
            print(f"Loss/ Adversarial_Disc: {lossD.item():.4f}")
            print(f"Loss/ Adversarial_Gen: {loss_adv_G.item():.4f}")
            print(f"Loss/ Cycle: {loss_cycle_G.item():.4f}")
            print(f"Loss/ Identity: {loss_identity_G.item():.4f}")
            print(f"Loss/ Perceptual: {loss_perceptual_G.item():.4f}")
            print("--------------------------------------------------")
            if not os.path.exists("outputs"):
                os.makedirs("outputs")
            save_image(fakeH * 0.5 + 0.5, f"outputs/haze_{idx}.png")
            save_image(fakeC * 0.5 + 0.5, f"outputs/clear_{idx}.png")

## Train function

In [None]:
def main():
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch: [{epoch}/{NUM_EPOCHS}]")
        train(disc_H, disc_C, gen_H, gen_C, train_loader, optimizer_D, optimizer_G, L1, MSE)
        
        if SAVE_MODEL:
            save_checkpoint(gen_H, optimizer_G, file="genH.pth.tar")
            save_checkpoint(gen_C, optimizer_G, file="genC.pth.tar")
            save_checkpoint(disc_H, optimizer_D, file="criticH.pth.tar")
            save_checkpoint(disc_C, optimizer_D, file="criticC.pth.tar")

## Model Training

In [None]:
main()