In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys
sys.path.insert(0,'/content/drive/MyDrive/GAN/PixelGAN')

In [None]:
import os
os.chdir('/content/drive/MyDrive/GAN/PixelGAN')
print(os.getcwd())

/content/drive/MyDrive/GAN/PixelGAN


In [None]:
import torch
from dataset import PixelSceneryDataset
import sys
from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import config
from tqdm import tqdm
from torchvision.utils import save_image
from discriminator_model import Discriminator
from generator_model import Generator


In [None]:
def train_fn(
    disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (pixel, scenery) in enumerate(loop):
        pixel = pixel.to(config.DEVICE)
        scenery = scenery.to(config.DEVICE)

        # Train Discriminators S and P
        with torch.cuda.amp.autocast():
            fake_scenery = gen_H(pixel)
            D_H_real = disc_H(scenery)
            D_H_fake = disc_H(fake_scenery.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_pixel = gen_Z(scenery)
            D_Z_real = disc_Z(pixel)
            D_Z_fake = disc_Z(fake_pixel.detach())
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_Z_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators S and P
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_scenery)
            D_Z_fake = disc_Z(fake_pixel)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_pixel = gen_Z(fake_scenery)
            cycle_scenery = gen_H(fake_pixel)
            cycle_pixel_loss = l1(pixel, cycle_pixel)
            cycle_scenery_loss = l1(scenery, cycle_scenery)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_pixel = gen_Z(pixel)
            identity_scenery = gen_H(scenery)
            identity_pixel_loss = l1(pixel, identity_pixel)
            identity_scenery_loss = l1(scenery, identity_scenery)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_pixel_loss * config.LAMBDA_CYCLE
                + cycle_scenery_loss * config.LAMBDA_CYCLE
                + identity_scenery_loss * config.LAMBDA_IDENTITY
                + identity_pixel_loss * config.LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(fake_scenery * 0.5 + 0.5, f"saved_images/scenery_{idx}.png")
            save_image(fake_pixel * 0.5 + 0.5, f"saved_images/pixel_{idx}.png")

        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))


In [None]:
def main():
    disc_H = Discriminator(in_channels=3).to(config.DEVICE)
    disc_Z = Discriminator(in_channels=3).to(config.DEVICE)
    gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
    gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN_H,
            gen_H,
            opt_gen,
            config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_GEN_Z,
            gen_Z,
            opt_gen,
            config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_CRITIC_H,
            disc_H,
            opt_disc,
            config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_CRITIC_Z,
            disc_Z,
            opt_disc,
            config.LEARNING_RATE,
        )

    dataset = PixelSceneryDataset(
        root_scenery=config.TRAIN_DIR + "/scenery",
        root_pixel=config.TRAIN_DIR + "/pixel",
        transform=config.transforms,
    )
    val_dataset = PixelSceneryDataset(
        root_scenery="dataset/val/scenery",
        root_pixel="dataset/val/pixel",
        transform=config.transforms,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(config.NUM_EPOCHS):
        train_fn(
            disc_H,
            disc_Z,
            gen_Z,
            gen_H,
            loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler,
        )

        if config.SAVE_MODEL:
            save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
            save_checkpoint(gen_Z, opt_gen, filename=config.CHECKPOINT_GEN_Z)
            save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)

In [None]:
if __name__ == "__main__":
    main()

100%|██████████| 880/880 [03:48<00:00,  3.85it/s, H_fake=0.455, H_real=0.544]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:44<00:00,  3.93it/s, H_fake=0.434, H_real=0.565]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:44<00:00,  3.93it/s, H_fake=0.439, H_real=0.562]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.434, H_real=0.568]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.431, H_real=0.57]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.428, H_real=0.572]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.423, H_real=0.577]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.416, H_real=0.583]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:43<00:00,  3.94it/s, H_fake=0.41, H_real=0.589]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.406, H_real=0.593]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:40<00:00,  3.98it/s, H_fake=0.405, H_real=0.592]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.405, H_real=0.593]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.399, H_real=0.598]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:40<00:00,  3.98it/s, H_fake=0.399, H_real=0.598]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.398, H_real=0.6]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.398, H_real=0.601]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.98it/s, H_fake=0.397, H_real=0.601]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.395, H_real=0.604]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.391, H_real=0.607]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.388, H_real=0.61]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.385, H_real=0.613]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.384, H_real=0.613]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.98it/s, H_fake=0.389, H_real=0.608]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.98it/s, H_fake=0.381, H_real=0.616]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.373, H_real=0.626]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.375, H_real=0.624]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.378, H_real=0.616]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:43<00:00,  3.93it/s, H_fake=0.381, H_real=0.614]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:43<00:00,  3.93it/s, H_fake=0.381, H_real=0.617]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:44<00:00,  3.93it/s, H_fake=0.373, H_real=0.621]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:44<00:00,  3.93it/s, H_fake=0.374, H_real=0.624]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.373, H_real=0.624]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.373, H_real=0.624]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.98it/s, H_fake=0.378, H_real=0.619]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.375, H_real=0.619]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.37, H_real=0.627]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.373, H_real=0.626]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.98it/s, H_fake=0.372, H_real=0.623]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.373, H_real=0.622]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.366, H_real=0.63]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.372, H_real=0.626]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.375, H_real=0.622]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.373, H_real=0.621]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.379, H_real=0.62]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.374, H_real=0.622]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.373, H_real=0.622]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.98it/s, H_fake=0.377, H_real=0.618]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.388, H_real=0.607]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.383, H_real=0.616]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.369, H_real=0.626]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.378, H_real=0.622]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.385, H_real=0.61]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.389, H_real=0.606]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:41<00:00,  3.97it/s, H_fake=0.388, H_real=0.611]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.96it/s, H_fake=0.388, H_real=0.608]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.384, H_real=0.613]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.39, H_real=0.604]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 880/880 [03:42<00:00,  3.95it/s, H_fake=0.389, H_real=0.61]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


 64%|██████▎   | 560/880 [02:22<01:24,  3.81it/s, H_fake=0.386, H_real=0.612]