In [3]:
import config
from models.generator import Generator
from models.discriminator import Discriminator
from utils import save_checkpoint, load_checkpoint, save_some_examples
from dataset import MapDataset

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

import numpy as np 
from tqdm import tqdm
import matplotlib.pyplot as plt 

In [None]:
status = torch.cuda.is_available()
print(status)

In [4]:
disc = Discriminator(in_channels=3).to(config.DEVICE)
gen = Generator(in_channels=3).to(config.DEVICE)

opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(config.BETA_1, config.BETA_2))
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(config.BETA_1, config.BETA_2))

BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

In [5]:
if config.LOAD_MODEL:
    try:
        
        load_checkpoint(config.GENERATOR_CHECKPOINTS, gen, opt_gen, config.LEARNING_RATE)
        load_checkpoint(config.DISCRIMINATOR_CHECKPOINTS, disc, opt_disc, config.LEARNING_RATE)
    except:
        print('weights not found')

=> Loading checkpoint
weights not found


In [6]:
root_dir = config.ROOT_DIR
dataset = MapDataset(root_dir)

val_size = 100
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()


In [7]:
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)

        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        for _ in range(2): 
            with torch.cuda.amp.autocast():
                y_fake = gen(x)  
                D_fake = disc(x, y_fake)
                G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
                L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
                G_loss = G_fake_loss + L1

            opt_gen.zero_grad()
            g_scaler.scale(G_loss).backward()
            g_scaler.step(opt_gen)
            g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )

In [8]:

for epoch in range(config.EPOCHS):
    train_fn(
        disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
    )
    
    if config.SAVE_MODEL and epoch % 10 == 0 :
        save_checkpoint(gen, opt_gen, filename=config.GENERATOR_CHECKPOINTS)
        save_checkpoint(disc, opt_disc, filename=config.DISCRIMINATOR_CHECKPOINTS)
    
    save_some_examples(gen, val_loader, epoch, folder='generatedImages')

100%|██████████| 44/44 [00:33<00:00,  1.31it/s, D_fake=0.518, D_real=0.581]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.514, D_real=0.582]
100%|██████████| 44/44 [00:31<00:00,  1.40it/s, D_fake=0.523, D_real=0.564]
100%|██████████| 44/44 [00:30<00:00,  1.45it/s, D_fake=0.519, D_real=0.566]
100%|██████████| 44/44 [00:30<00:00,  1.45it/s, D_fake=0.525, D_real=0.531]
100%|██████████| 44/44 [00:31<00:00,  1.40it/s, D_fake=0.507, D_real=0.575]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.522, D_real=0.576]
100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.537, D_real=0.578]
100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.505, D_real=0.578]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.511, D_real=0.579]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.503, D_real=0.58] 


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.501, D_real=0.579]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.582]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.583]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.582]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.586]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.586]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.587]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.583]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.588]
100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.504, D_real=0.589]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.51, D_real=0.588] 
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.5, D_real=0.591]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.592]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.591]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.509, D_real=0.593]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.51, D_real=0.573] 
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.594]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.511, D_real=0.595]
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.5, D_real=0.597]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.504, D_real=0.598]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.502, D_real=0.598]
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.499, D_real=0.599]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.6]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.599]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.601]
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.499, D_real=0.601]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.6]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.602]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.501, D_real=0.601]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.603]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.502, D_real=0.605]
100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.5, D_real=0.607]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.607]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.6]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.518, D_real=0.604]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.606]
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.604]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.603]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.61] 
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.61] 


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.609] 
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.499, D_real=0.612]
100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.5, D_real=0.613]  
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.5, D_real=0.612]  
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.499, D_real=0.614]
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.518, D_real=0.59] 
100%|██████████| 44/44 [00:32<00:00,  1.36it/s, D_fake=0.5, D_real=0.613]  
100%|██████████| 44/44 [00:32<00:00,  1.36it/s, D_fake=0.499, D_real=0.614]
100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.499, D_real=0.616]
100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.499, D_real=0.616]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.499, D_real=0.618]
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.499, D_real=0.618]
100%|██████████| 44/44 [00:32<00:00,  1.37it/s, D_fake=0.5, D_real=0.619]  
100%|██████████| 44/44 [00:32<00:00,  1.36it/s, D_fake=0.499, D_real=0.619]
100%|██████████| 44/44 [00:31<00:00,  1.39it/s, D_fake=0.5, D_real=0.618]  
100%|██████████| 44/44 [00:31<00:00,  1.38it/s, D_fake=0.5, D_real=0.619]  
100%|██████████| 44/44 [00:30<00:00,  1.46it/s, D_fake=0.5, D_real=0.621]  
100%|██████████| 44/44 [00:30<00:00,  1.46it/s, D_fake=0.5, D_real=0.622]  
 43%|████▎     | 19/44 [00:16<00:13,  1.81it/s, D_fake=0.5, D_real=0.621]