In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [1]:
import torch
from torch import optim
from torchvision import models
from torchvision.utils import make_grid, save_image

from __datasets__ import ITSDataset, DenseHazeCVPR2019Dataset
from gan import CycleGANConfig, Generator, Discriminator, PerceptualLoss, get_cycle_gan_trainer
from utils.checkpoints import load_checkpoint
from utils.display import display_images
from utils.train_test import train, test
from utils.datasets import DomainDataset

In [2]:
config1 = CycleGANConfig(
    "../../commons/datasets/its/",
    "HazeGan",
    "v1",
    image_shape=(3, 64, 64),
    latent_dim=64,
    dropout=0.3,
    num_epochs=1, batch_size=8,
    lr=2e-4,
    betas=(0.5, 0.999),
    lambdas=(10, 0.5),
    residuals=5,
    blocks=(64, 128, 256, 512),
    writer=True,
)
config2 = CycleGANConfig(
    "../../commons/datasets/dense_haze_cvpr2019/",
    "HazeGan",
    "v1",
    image_shape=(3, 128, 128),
    latent_dim=64,
    dropout=0.3,
    num_epochs=1, batch_size=8,
    lr=2e-4,
    betas=(0.5, 0.999),
    lambdas=(10, 0.5),
    residuals=5,
    blocks=(64, 128, 256, 512),
    writer=True,
)

In [3]:
ds1 = DomainDataset(
    ITSDataset(config1.dataset_path, SET="hazy", download=True, image_transform=config1.transforms, sub_sample=0.2),
    ITSDataset(config1.dataset_path, SET="clear", download=True, image_transform=config1.transforms, sub_sample=1)
)
ds2 = DomainDataset(
    DenseHazeCVPR2019Dataset(config2.dataset_path, SET="hazy", download=True, image_transform=config2.transforms, sub_sample=1),
    DenseHazeCVPR2019Dataset(config2.dataset_path, SET="GT", download=True, image_transform=config2.transforms, sub_sample=1)
)
len(ds1), len(ds2)

(2798, 55)

In [4]:
generatorA = Generator(
    config1.image_shape[0],
    config1.latent_dim,
    config1.residuals,
    p=config1.dropout,
    coder_len=config1.coder_len,
).to(config1.device)
generatorB = Generator(
    config1.image_shape[0],
    config1.latent_dim,
    config1.residuals,
    p=config1.dropout,
    coder_len=config1.coder_len,
).to(config1.device)
discriminatorA = Discriminator(config1.image_shape[0], list(config1.blocks), p=config1.dropout).to(config1.device)
discriminatorB = Discriminator(config1.image_shape[0], list(config1.blocks), p=config1.dropout).to(config1.device)
optimizerG = optim.Adam(
    list(generatorA.parameters()) + list(generatorB.parameters()),
    lr=config1.lr,
    betas=config1.betas
)
optimizerD = optim.Adam(
    list(discriminatorA.parameters()) + list(discriminatorB.parameters()),
    lr=config1.lr,
    betas=config1.betas
)

In [5]:
perceptual_model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:35].to(config1.device)
perceptual_loss = PerceptualLoss(perceptual_model)
fixedA, fixedB = ds1[:4].values()
trainer = get_cycle_gan_trainer(generatorA, generatorB, discriminatorA, discriminatorB, optimizerG, optimizerD,
                                save_path=config1.checkpoint_path,
                                perceptual_loss=perceptual_loss, lambda_cycle=config1.lambdas[0],
                                lambda_identity=config1.lambdas[1],
                                writer=config1.writer, period=100,
                                fixedA=fixedA["image"], fixedB=fixedB["image"])

In [6]:
step = 0

In [None]:
file_path = ".pt"
others = load_checkpoint(
    file_path,
    {"generatorA": generatorA, "generatorB": generatorB, "discriminatorA": discriminatorA, "discriminatorB": discriminatorB},
    {"optimizerG": optimizerG, "optimizerD": optimizerD},
)
step = others["step"]

In [7]:
step = train(
    trainer, ds1,
    ne=config1.num_epochs, bs=config1.batch_size,
    step_offset=step,
)

Epoch: 0/1 | Batch 0/350|          |  0% [00:00<?, ?it/s, loss=?]

KeyboardInterrupt: 

In [8]:
_, fixedB = ds2[:24].values()
generatorA = generatorA.eval()
with torch.inference_mode():
    %timeit generatorA(fixedB["image"])

4.65 s ± 268 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%%time

fixedA, fixedB = ds2[:9].values()
with torch.inference_mode():
    grid_realA = make_grid(fixedA, nrow=1, normalize=True)
    grid_realB = make_grid(fixedB, nrow=1, normalize=True)
    grid_fakeA = make_grid(fakeA := generatorA(fixedB), nrow=1, normalize=True)
    grid_fakeB = make_grid(fakeB := generatorB(fixedA), nrow=1, normalize=True)
    grid_cycleA = make_grid(generatorA(fakeB), nrow=1, normalize=True)
    grid_cycleB = make_grid(generatorB(fakeA), nrow=1, normalize=True)
    grid_identityA = make_grid(identityA := generatorA(fixedA), nrow=1, normalize=True)
    grid_identityB = make_grid(identityB := generatorB(fixedB), nrow=1, normalize=True)
    grid_doubleA = make_grid(generatorA(identityB), nrow=1, normalize=True)
    grid_doubleB = make_grid(generatorB(identityA), nrow=1, normalize=True)
    gridA = make_grid(torch.stack([grid_realA, grid_fakeB, grid_doubleB, grid_cycleA, grid_identityA]), nrow=5, normalize=True)
    gridB = make_grid(torch.stack([grid_realB, grid_fakeA, grid_doubleA, grid_cycleB, grid_identityB]), nrow=5, normalize=True)
display_images(torch.stack([gridA, gridB]).permute(0, 2, 3, 1).cpu())

In [None]:
save_image(gridA, "gridA.png"), save_image(gridB, "gridB.png")