In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
from torch import optim
from torchvision import models
from torchvision.utils import make_grid, save_image

from __datasets__ import ITSDataset, DenseHazeCVPR2019Dataset, DomainDataset
from gan.utils import CycleGANConfig, save_checkpoint, load_checkpoint, display_images
from gan.generator import Generator
from gan.discriminator import Discriminator
from gan.trainer import PerceptualLoss, train, get_cycle_gan_trainer

In [10]:
config1 = CycleGANConfig(
    "../../commons/datasets/its/",
    "HazeGan",
    "v1",
    image_shape=(3, 64, 64),
    latent_dim=64,
    epochs=1, batch_size=8,
    lr=2e-4,
    betas=(0.5, 0.999),
    lambdas=(10, 0.5),
    residuals=5,
    blocks=(64, 128, 256, 512)
)
config2 = CycleGANConfig(
    "../../commons/datasets/dense_haze_cvpr2019/",
    "HazeGan",
    "v1",
    image_shape=(3, 128, 128),
    latent_dim=64,
    epochs=1, batch_size=8,
    lr=2e-4,
    betas=(0.5, 0.999),
    lambdas=(10, 0.5),
    residuals=5,
    blocks=(64, 128, 256, 512)
)

In [11]:
ds1 = DomainDataset(
    ITSDataset(config1.dataset_path, SET="hazy", download=True, img_transform=config1.transforms, sub_sample=0.2),
    ITSDataset(config1.dataset_path, SET="clear", download=True, img_transform=config1.transforms, sub_sample=1)
).to(config1.device)
ds2 = DomainDataset(
    DenseHazeCVPR2019Dataset(config2.dataset_path, SET="hazy", download=True, img_transform=config2.transforms, sub_sample=1),
    DenseHazeCVPR2019Dataset(config2.dataset_path, SET="GT", download=True, img_transform=config2.transforms, sub_sample=1)
).to(config1.device)
len(ds1), len(ds2)

(2798, 55)

In [12]:
generatorA = Generator(config1.image_shape[0], config1.latent_dim, config1.residuals).to(config1.device)
generatorB = Generator(config1.image_shape[0], config1.latent_dim, config1.residuals).to(config1.device)
discriminatorA = Discriminator(config1.image_shape[0], list(config1.blocks)).to(config1.device)
discriminatorB = Discriminator(config1.image_shape[0], list(config1.blocks)).to(config1.device)
optimizerG = optim.Adam(
    list(generatorA.parameters()) + list(generatorB.parameters()),
    lr=config1.lr,
    betas=config1.betas
)
optimizerD = optim.Adam(
    list(discriminatorA.parameters()) + list(discriminatorB.parameters()),
    lr=config1.lr,
    betas=config1.betas
)

In [13]:
perceptual_model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features[:35].to(config1.device)
perceptual_loss = PerceptualLoss(perceptual_model)
fixedA, fixedB = ds1[:4].values()
trainer = get_cycle_gan_trainer(generatorA, generatorB, discriminatorA, discriminatorB, optimizerG, optimizerD,
                                perceptual_loss=perceptual_loss, lambda_cycle=config1.lambdas[0],
                                lambda_identity=config1.lambdas[1],
                                writer=config1.writer, period=100,
                                fixedA=fixedA, fixedB=fixedB)

In [14]:
file_path = ".pt"
others = load_checkpoint(
    file_path,
    {"generator_A": generatorA, "generator_B": generatorB, "discriminator_A": discriminatorA, "discriminator_B": discriminatorB},
    {"optimizer_G": optimizerG, "optimizer_D": optimizerD},
)

In [33]:
train(
    trainer, ds1,
    ne=config1.epochs, bs=config1.batch_size,
)

Epoch: 0/1 | Batch 0/350|          |  0% [00:00<?, ?it/s, loss=?]

KeyboardInterrupt: 

In [None]:
save_checkpoint(
    config1.checkpoint_path,
    {"generatorA": generatorA, "generatorB": generatorB, "discriminatorA": discriminatorA, "discriminatorB": discriminatorB},
    {"optimizerG": optimizerG, "optimizerD": optimizerD},
)

In [37]:
_, fixedB = ds2[:24].values()
generatorA = generatorA.eval()
with torch.inference_mode():
    %timeit generatorA(fixedB)

928 ms ± 45.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
%%time

fixedA, fixedB = ds2[:9].values()
with torch.inference_mode():
    grid_realA = make_grid(fixedA, nrow=1, normalize=True)
    grid_realB = make_grid(fixedB, nrow=1, normalize=True)
    grid_fakeA = make_grid(fakeA := generatorA(fixedB), nrow=1, normalize=True)
    grid_fakeB = make_grid(fakeB := generatorB(fixedA), nrow=1, normalize=True)
    grid_cycleA = make_grid(generatorA(fakeB), nrow=1, normalize=True)
    grid_cycleB = make_grid(generatorB(fakeA), nrow=1, normalize=True)
    grid_identityA = make_grid(identityA := generatorA(fixedA), nrow=1, normalize=True)
    grid_identityB = make_grid(identityB := generatorB(fixedB), nrow=1, normalize=True)
    grid_doubleA = make_grid(generatorA(identityB), nrow=1, normalize=True)
    grid_doubleB = make_grid(generatorB(identityA), nrow=1, normalize=True)
    gridA = make_grid(torch.stack([grid_realA, grid_fakeB, grid_doubleB, grid_cycleA, grid_identityA]), nrow=5, normalize=True)
    gridB = make_grid(torch.stack([grid_realB, grid_fakeA, grid_doubleA, grid_cycleB, grid_identityB]), nrow=5, normalize=True)
display_images(torch.stack([gridA, gridB]).permute(0, 2, 3, 1).cpu())

CPU times: user 1min 50s, sys: 19.6 s, total: 2min 10s
Wall time: 1min 6s


In [48]:
save_image(gridA, "gridA.png"), save_image(gridB, "gridB.png")

(None, None)