In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
import os
from PIL import Image
from tqdm import tqdm
from torchvision.utils import save_image
# local module imports
import config
from utils import save_checkpoint, load_checkpoint, save_some_examples
from Pix_Generator import Generator
from Pix_Discriminator import Discriminator

torch.backends.cudnn.benchmark = True

In [4]:
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )

In [1]:
!unzip -q Sample_dataset.zip

In [7]:
class PairedImageDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.input_images = sorted(os.listdir(input_dir))
        self.target_images = sorted(os.listdir(target_dir))
        self.transform = transform

        assert len(self.input_images) == len(self.target_images), "Mismatch between input and target images!"

    def __len__(self):
        return len(self.input_images)

    def __getitem__(self, idx):
        input_image_path = os.path.join(self.input_dir, self.input_images[idx])
        target_image_path = os.path.join(self.target_dir, self.target_images[idx])

        input_image = Image.open(input_image_path).convert("RGB")
        target_image = Image.open(target_image_path).convert("RGB")

        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)

        return input_image, target_image


input_dir = '/content/Sample_dataset/input'
target_dir = '/content/Sample_dataset/target'

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = PairedImageDataset(input_dir=input_dir, target_dir=target_dir, transform=transform)

# Split dataset into training and validation (80-20 split)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoader for training and validation
batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

# Example to iterate over the data loader
for batch_idx, (input_images, target_images) in enumerate(train_loader):
    print(f"Batch {batch_idx + 1}")
    print(f"Input batch size: {input_images.shape}")
    print(f"Target batch size: {target_images.shape}")


Batch 1
Input batch size: torch.Size([16, 3, 256, 256])
Target batch size: torch.Size([16, 3, 256, 256])
Batch 2
Input batch size: torch.Size([4, 3, 256, 256])
Target batch size: torch.Size([4, 3, 256, 256])


In [9]:
def main():
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    gen = Generator(in_channels=3, features=64).to(config.DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
        )

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(config.NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if config.SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)

        save_some_examples(gen, val_loader, epoch, folder="evaluation")


if __name__ == "__main__":
    main()

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 2/2 [00:00<00:00,  4.78it/s, D_fake=0.567, D_real=0.461]
100%|██████████| 2/2 [00:00<00:00,  7.22it/s, D_fake=0.506, D_real=0.589]
100%|██████████| 2/2 [00:00<00:00,  7.15it/s, D_fake=0.416, D_real=0.517]
100%|██████████| 2/2 [00:00<00:00,  7.17it/s, D_fake=0.47, D_real=0.542]
100%|██████████| 2/2 [00:00<00:00,  7.02it/s, D_fake=0.452, D_real=0.535]
100%|██████████| 2/2 [00:00<00:00,  7.21it/s, D_fake=0.439, D_real=0.567]
100%|██████████| 2/2 [00:00<00:00,  7.22it/s, D_fake=0.406, D_real=0.576]
100%|██████████| 2/2 [00:00<00:00,  7.03it/s, D_fake=0.395, D_real=0.589]
100%|██████████| 2/2 [00:00<00:00,  7.16it/s, D_fake=0.376, D_real=0.603]
100%|██████████| 2/2 [00:00<00:00,  7.15it/s, D_fake=0.385, D_real=0.616]
100%|██████████| 2/2 [00:00<00:00,  7.17it/s, D_fake=0.344, D_real=0.597]
100%|██████████| 2/2 [00:00<00:00,  7