In [1]:
import torch
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from dataset import SlideDataset
from generator import Generator
from discriminator import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image

torch.backends.cudnn.benchmark = True

In [4]:
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad() #zero_grad() method of the model object, rather than the optimizer object, to ensure that all the model's parameters are cleared of any gradients before the backward pass.
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )


def main():
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    gen = Generator(in_channels=3, features=64).to(config.DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
        )

    train_dataset = SlideDataset(root_dir=config.TRAIN_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = SlideDataset(root_dir=config.VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(config.NUM_EPOCHS):
        if not config.LOAD_MODEL:
            train_fn(
                disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
            )

            if config.SAVE_MODEL and epoch % 5 == 0:
                save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
                save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

        save_some_examples(gen, val_loader, epoch, folder="evaluation")



In [None]:
if __name__ == "__main__":
    main()

100%|█████████████████| 1/1 [00:08<00:00,  8.64s/it, D_fake=0.495, D_real=0.487]
100%|██████████████████| 1/1 [00:09<00:00,  9.25s/it, D_fake=0.453, D_real=0.57]
100%|██████████████████| 1/1 [00:10<00:00, 10.35s/it, D_fake=0.417, D_real=0.59]
100%|█████████████████| 1/1 [00:09<00:00,  9.54s/it, D_fake=0.362, D_real=0.582]
100%|█████████████████| 1/1 [00:09<00:00,  9.63s/it, D_fake=0.366, D_real=0.579]
100%|█████████████████| 1/1 [00:08<00:00,  8.80s/it, D_fake=0.337, D_real=0.635]
100%|█████████████████| 1/1 [00:09<00:00,  9.28s/it, D_fake=0.323, D_real=0.642]
100%|█████████████████| 1/1 [00:09<00:00,  9.58s/it, D_fake=0.268, D_real=0.679]
100%|█████████████████| 1/1 [00:09<00:00,  9.65s/it, D_fake=0.261, D_real=0.691]
100%|█████████████████| 1/1 [00:08<00:00,  8.82s/it, D_fake=0.242, D_real=0.714]
100%|█████████████████| 1/1 [00:08<00:00,  8.80s/it, D_fake=0.221, D_real=0.739]
100%|█████████████████| 1/1 [00:08<00:00,  8.76s/it, D_fake=0.247, D_real=0.732]
100%|█████████████████| 1/1 

100%|█████████████████| 1/1 [00:10<00:00, 10.14s/it, D_fake=0.501, D_real=0.501]
100%|█████████████████| 1/1 [00:10<00:00, 10.19s/it, D_fake=0.497, D_real=0.498]
100%|█████████████████| 1/1 [00:09<00:00,  9.94s/it, D_fake=0.497, D_real=0.509]
100%|█████████████████| 1/1 [00:10<00:00, 10.02s/it, D_fake=0.498, D_real=0.499]
100%|█████████████████| 1/1 [00:11<00:00, 11.41s/it, D_fake=0.492, D_real=0.503]
100%|█████████████████| 1/1 [00:10<00:00, 10.89s/it, D_fake=0.504, D_real=0.495]
100%|█████████████████| 1/1 [00:10<00:00, 10.74s/it, D_fake=0.497, D_real=0.503]
100%|█████████████████| 1/1 [00:10<00:00, 10.13s/it, D_fake=0.499, D_real=0.508]
100%|█████████████████| 1/1 [00:10<00:00, 10.30s/it, D_fake=0.499, D_real=0.498]
100%|███████████████████| 1/1 [00:09<00:00,  9.91s/it, D_fake=0.494, D_real=0.5]
100%|███████████████████| 1/1 [00:10<00:00, 10.63s/it, D_fake=0.502, D_real=0.5]
100%|█████████████████| 1/1 [00:10<00:00, 10.47s/it, D_fake=0.495, D_real=0.514]
100%|██████████████████| 1/1

100%|█████████████████| 1/1 [00:09<00:00,  9.07s/it, D_fake=0.479, D_real=0.536]
100%|██████████████████| 1/1 [00:08<00:00,  8.99s/it, D_fake=0.494, D_real=0.49]
100%|██████████████████| 1/1 [00:09<00:00,  9.03s/it, D_fake=0.514, D_real=0.51]
100%|█████████████████| 1/1 [00:09<00:00,  9.07s/it, D_fake=0.461, D_real=0.528]
100%|█████████████████| 1/1 [00:09<00:00,  9.24s/it, D_fake=0.541, D_real=0.463]
100%|██████████████████| 1/1 [00:09<00:00,  9.09s/it, D_fake=0.473, D_real=0.55]
100%|█████████████████| 1/1 [00:09<00:00,  9.08s/it, D_fake=0.484, D_real=0.487]
100%|█████████████████| 1/1 [00:09<00:00,  9.04s/it, D_fake=0.505, D_real=0.491]
100%|█████████████████| 1/1 [00:09<00:00,  9.06s/it, D_fake=0.502, D_real=0.501]
100%|██████████████████| 1/1 [00:09<00:00,  9.08s/it, D_fake=0.49, D_real=0.517]
100%|█████████████████| 1/1 [00:09<00:00,  9.02s/it, D_fake=0.501, D_real=0.496]
100%|█████████████████| 1/1 [00:09<00:00,  9.04s/it, D_fake=0.465, D_real=0.523]
100%|█████████████████| 1/1 