In [2]:
import albumentations
import torch
from dataset import vissketchDataset
import sys
from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import config
from tqdm import tqdm
from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator

def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (sketch, vis) in enumerate(loop):
        sketch = sketch.to(config.DEVICE)
        vis = vis.to(config.DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_vis = gen_H(sketch)
            D_H_real = disc_H(vis)
            D_H_fake = disc_H(fake_vis.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_sketch = gen_Z(vis)
            D_Z_real = disc_Z(sketch)
            D_Z_fake = disc_Z(fake_sketch.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_Z_loss)/2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_vis)
            D_Z_fake = disc_Z(fake_sketch)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_sketch = gen_Z(fake_vis)
            cycle_vis = gen_H(fake_sketch)
            cycle_sketch_loss = l1(sketch, cycle_sketch)
            cycle_vis_loss = l1(vis, cycle_vis)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_sketch = gen_Z(sketch)
            identity_vis = gen_H(vis)
            identity_sketch_loss = l1(sketch, identity_sketch)
            identity_vis_loss = l1(vis, identity_vis)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_sketch_loss * config.LAMBDA_CYCLE
                + cycle_vis_loss * config.LAMBDA_CYCLE
                + identity_vis_loss * config.LAMBDA_IDENTITY
                + identity_sketch_loss * config.LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 25 == 0:
            save_image(fake_vis*0.5+0.5, f"saved/i/{idx}.png")
            save_image(fake_sketch*0.5+0.5, f"saved/s/{idx}.png")

        loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))



def main():
    disc_H = Discriminator(in_channels=3).to(config.DEVICE)
    disc_Z = Discriminator(in_channels=3).to(config.DEVICE)
    gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
    gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
        )

    dataset = vissketchDataset(
        root_vis=config.TRAIN_DIR+"/train_sketches", root_sketch=config.TRAIN_DIR+"/train_photos", transform=config.transforms
    )
    val_dataset = vissketchDataset(
       root_vis=config.VAL_DIR+"/test_sketches", root_sketch=config.VAL_DIR+"/test_photos", transform=config.transforms
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(config.NUM_EPOCHS):
        print(epoch+1,"/",config.NUM_EPOCHS, "-----batch_size:", config.BATCH_SIZE, )
        
        train_fn(disc_H, disc_Z, gen_Z, gen_H, val_loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)

        if config.SAVE_MODEL:
            save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
            save_checkpoint(gen_Z, opt_gen, filename=config.CHECKPOINT_GEN_Z)
            save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)



In [3]:
torch.cuda.is_available()

True

In [None]:
if __name__ == "__main__":
    main()

1 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [02:53<00:00,  1.04it/s, H_fake=0.284, H_real=0.619]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
2 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:09<00:00,  1.05s/it, H_fake=0.333, H_real=0.609]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
3 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:24<00:00,  1.13s/it, H_fake=0.345, H_real=0.616]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
4 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:28<00:00,  1.16s/it, H_fake=0.298, H_real=0.694]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
5 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:31<00:00,  1.17s/it, H_fake=0.281, H_real=0.707]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
6 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:34<00:00,  1.19s/it, H_fake=0.241, H_real=0.739]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
7 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:35<00:00,  1.20s/it, H_fake=0.223, H_real=0.773]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
8 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:39<00:00,  1.22s/it, H_fake=0.197, H_real=0.802]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
9 / 100 -----batch_size: 1


100%|██████████████████████████████████████████████████████| 180/180 [03:38<00:00,  1.22s/it, H_fake=0.189, H_real=0.8]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
10 / 100 -----batch_size: 1


100%|█████████████████████████████████████████████████████| 180/180 [03:36<00:00,  1.20s/it, H_fake=0.19, H_real=0.803]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
11 / 100 -----batch_size: 1


100%|█████████████████████████████████████████████████████| 180/180 [03:35<00:00,  1.20s/it, H_fake=0.17, H_real=0.829]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
12 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:37<00:00,  1.21s/it, H_fake=0.169, H_real=0.828]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
13 / 100 -----batch_size: 1


100%|█████████████████████████████████████████████████████| 180/180 [03:37<00:00,  1.21s/it, H_fake=0.17, H_real=0.826]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
14 / 100 -----batch_size: 1


100%|████████████████████████████████████████████████████| 180/180 [03:39<00:00,  1.22s/it, H_fake=0.172, H_real=0.825]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
15 / 100 -----batch_size: 1


 43%|██████████████████████▋                              | 77/180 [01:33<02:06,  1.23s/it, H_fake=0.139, H_real=0.861]