<a href="https://colab.research.google.com/github/YashDeepp/GAN-/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys
sys.path.insert(0,'/content/drive/MyDrive/facades')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from dataset import FacadesDataset
from generator import Generator
from discriminator import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image
import os

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )

In [None]:
def main():
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    gen = Generator(in_channels=3, features=64).to(config.DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
        )

    train_dataset = FacadesDataset(root_dir=config.TRAIN_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = FacadesDataset(root_dir=config.VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(config.NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if config.SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

        save_some_examples(gen, val_loader, epoch, folder="/content/drive/MyDrive/facades/evaluation")


if __name__ == "__main__":
    main()

=> Loading checkpoint
=> Loading checkpoint


100%|██████████| 25/25 [00:06<00:00,  3.78it/s, D_fake=0.314, D_real=0.649]
100%|██████████| 25/25 [00:05<00:00,  4.32it/s, D_fake=0.145, D_real=0.83]
100%|██████████| 25/25 [00:06<00:00,  3.79it/s, D_fake=0.104, D_real=0.905]
100%|██████████| 25/25 [00:05<00:00,  4.32it/s, D_fake=0.201, D_real=0.667]
100%|██████████| 25/25 [00:06<00:00,  3.81it/s, D_fake=0.191, D_real=0.771]
100%|██████████| 25/25 [00:05<00:00,  4.42it/s, D_fake=0.17, D_real=0.745]
100%|██████████| 25/25 [00:06<00:00,  3.81it/s, D_fake=0.202, D_real=0.751]
100%|██████████| 25/25 [00:05<00:00,  4.43it/s, D_fake=0.286, D_real=0.768]
100%|██████████| 25/25 [00:06<00:00,  3.92it/s, D_fake=0.137, D_real=0.75]
100%|██████████| 25/25 [00:05<00:00,  4.46it/s, D_fake=0.228, D_real=0.872]
100%|██████████| 25/25 [00:06<00:00,  3.82it/s, D_fake=0.12, D_real=0.825]
100%|██████████| 25/25 [00:05<00:00,  4.42it/s, D_fake=0.0909, D_real=0.978]
100%|██████████| 25/25 [00:06<00:00,  4.07it/s, D_fake=0.213, D_real=0.839]
100%|██████████

In [None]:
os.listdir('/content/drive/MyDrive/facades/data/train')

['342.jpg',
 '35.jpg',
 '356.jpg',
 '346.jpg',
 '345.jpg',
 '357.jpg',
 '326.jpg',
 '338.jpg',
 '89.jpg',
 '88.jpg',
 '54.jpg',
 '51.jpg',
 '77.jpg',
 '42.jpg',
 '74.jpg',
 '397.jpg',
 '398.jpg',
 '396.jpg',
 '399.jpg',
 '371.jpg',
 '370.jpg',
 '362.jpg',
 '361.jpg',
 '4.jpg',
 '376.jpg',
 '395.jpg',
 '394.jpg',
 '384.jpg',
 '383.jpg',
 '39.jpg',
 '340.jpg',
 '62.jpg',
 '140.jpg',
 '246.jpg',
 '203.jpg',
 '53.jpg',
 '294.jpg',
 '363.jpg',
 '120.jpg',
 '273.jpg',
 '59.jpg',
 '60.jpg',
 '8.jpg',
 '34.jpg',
 '225.jpg',
 '307.jpg',
 '95.jpg',
 '170.jpg',
 '237.jpg',
 '116.jpg',
 '167.jpg',
 '259.jpg',
 '274.jpg',
 '310.jpg',
 '280.jpg',
 '366.jpg',
 '196.jpg',
 '252.jpg',
 '318.jpg',
 '21.jpg',
 '322.jpg',
 '118.jpg',
 '114.jpg',
 '330.jpg',
 '195.jpg',
 '55.jpg',
 '302.jpg',
 '205.jpg',
 '32.jpg',
 '47.jpg',
 '156.jpg',
 '5.jpg',
 '380.jpg',
 '229.jpg',
 '61.jpg',
 '343.jpg',
 '105.jpg',
 '261.jpg',
 '241.jpg',
 '87.jpg',
 '72.jpg',
 '76.jpg',
 '235.jpg',
 '279.jpg',
 '142.jpg',
 '308.jpg