In [7]:
import torch
from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
import config
from dataset import MapDataset
from generator_model import Generator
from discriminator_model import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image
torch.backends.cudnn.benchmark = True
from torch.utils.data import Subset

In [11]:
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)
    
    D_loss_total = 0
    G_loss_total = 0
    num_batches = len(loader)

    for idx, (x, y) in enumerate(loop):
        x = x.cuda()
        y = y.cuda()

        # Train Discriminator
        with torch.autocast("cuda"):
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.autocast("cuda"):
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        # Accumulate the losses
        D_loss_total += D_loss.item()
        G_loss_total += G_loss.item()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )

    # Return average losses for the epoch
    avg_D_loss = D_loss_total / num_batches
    avg_G_loss = G_loss_total / num_batches
    return avg_D_loss, avg_G_loss


In [12]:
def main():
    disc = Discriminator(in_channels=3).cuda()
    gen = Generator(in_channels=3, features=64).cuda()
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if config.LOAD_MODEL:
        load_checkpoint(config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE)
        load_checkpoint(config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE)

    train_dataset = MapDataset(config.TRAIN_DIR)
    train_indices = list(range(2000))  # Create a list of indices for 2000 images
    train_dataset = Subset(train_dataset, train_indices)
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS)
    g_scaler = torch.GradScaler('cuda')  # Specify device explicitly
    d_scaler = torch.GradScaler('cuda')  # Specify device explicitly

    val_dataset = MapDataset(config.VAL_DIR)
    val_indices = list(range(400))  # Create a list of indices for 400 images
    val_dataset = Subset(val_dataset, val_indices)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(config.NUM_EPOCHS):
        avg_D_loss, avg_G_loss = train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler
        )

        # Print the losses at the end of each epoch
        print(f"Epoch {epoch}/{config.NUM_EPOCHS} - D Loss: {avg_D_loss:.4f}, G Loss: {avg_G_loss:.4f}")

        if config.SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

        save_some_examples(gen, val_loader, epoch, folder="generated_images")


In [13]:
main()

100%|██████████| 32/32 [07:37<00:00, 14.30s/it, D_fake=0.302, D_real=0.655]


Epoch 0/50 - D Loss: 0.5866, G Loss: 34.2244
High image path: D:\Notebooks\gan\pixtopix-gan-pytorch\LoLI-Street Dataset\Val\high\dense_30001.jpg
Low image path: D:\Notebooks\gan\pixtopix-gan-pytorch\LoLI-Street Dataset\Val\low\dense_30001.jpg


100%|██████████| 32/32 [10:57<00:00, 20.55s/it, D_fake=0.301, D_real=0.69] 


Epoch 1/50 - D Loss: 0.5043, G Loss: 24.4358
High image path: D:\Notebooks\gan\pixtopix-gan-pytorch\LoLI-Street Dataset\Val\high\dense_30001.jpg
Low image path: D:\Notebooks\gan\pixtopix-gan-pytorch\LoLI-Street Dataset\Val\low\dense_30001.jpg


100%|██████████| 32/32 [11:21<00:00, 21.30s/it, D_fake=0.265, D_real=0.712]


Epoch 2/50 - D Loss: 0.4552, G Loss: 22.3860
High image path: D:\Notebooks\gan\pixtopix-gan-pytorch\LoLI-Street Dataset\Val\high\dense_30001.jpg
Low image path: D:\Notebooks\gan\pixtopix-gan-pytorch\LoLI-Street Dataset\Val\low\dense_30001.jpg


100%|██████████| 32/32 [10:45<00:00, 20.19s/it, D_fake=0.147, D_real=0.852]


Epoch 3/50 - D Loss: 0.3140, G Loss: 22.7485
High image path: D:\Notebooks\gan\pixtopix-gan-pytorch\LoLI-Street Dataset\Val\high\dense_30001.jpg
Low image path: D:\Notebooks\gan\pixtopix-gan-pytorch\LoLI-Street Dataset\Val\low\dense_30001.jpg


  0%|          | 0/32 [05:04<?, ?it/s]


KeyboardInterrupt: 