In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
import os
from PIL import Image
from tqdm import tqdm
from torchvision.utils import save_image
# local module imports
import config
from utils import save_checkpoint, load_checkpoint, save_some_examples
from UvU_Net_Generator import OuterUNet as generator
from Pix_Discriminator import Discriminator

torch.backends.cudnn.benchmark = True

  check_for_updates()


torch.Size([1, 3, 512, 512])


In [2]:
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )

In [3]:
!unzip -q Sample_dataset.zip

In [4]:
class PairedImageDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.input_images = sorted(os.listdir(input_dir))
        self.target_images = sorted(os.listdir(target_dir))
        self.transform = transform

        assert len(self.input_images) == len(self.target_images), "Mismatch between input and target images!"

    def __len__(self):
        return len(self.input_images)

    def __getitem__(self, idx):
        input_image_path = os.path.join(self.input_dir, self.input_images[idx])
        target_image_path = os.path.join(self.target_dir, self.target_images[idx])

        input_image = Image.open(input_image_path).convert("RGB")
        target_image = Image.open(target_image_path).convert("RGB")

        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)

        return input_image, target_image


input_dir = '/content/Sample_dataset/sobel_images'
target_dir = '/content/Sample_dataset/input_images'

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = PairedImageDataset(input_dir=input_dir, target_dir=target_dir, transform=transform)

# Split dataset into training and validation (80-20 split)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoader for training and validation
batch_size = 16

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

# Example to iterate over the data loader
for batch_idx, (input_images, target_images) in enumerate(train_loader):
    print(f"Batch {batch_idx + 1}")
    print(f"Input batch size: {input_images.shape}")
    print(f"Target batch size: {target_images.shape}")


Batch 1
Input batch size: torch.Size([16, 3, 256, 256])
Target batch size: torch.Size([16, 3, 256, 256])
Batch 2
Input batch size: torch.Size([4, 3, 256, 256])
Target batch size: torch.Size([4, 3, 256, 256])


In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [8]:
# Assuming the directory to save checkpoints
save_dir = '/content/drive/MyDrive/model_checkpoints/'  # Change this to your desired directory

def main():
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    gen = generator(in_channels=3, features=64).to(config.DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
        )

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(config.NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        # Save models every 50 epochs to Google Drive
        if epoch % 50 == 49:
            print(f"Saving model at epoch {epoch}")
            torch.save({
                'epoch': epoch,
                'generator_state_dict': gen.state_dict(),
                'optimizer_gen_state_dict': opt_gen.state_dict(),
            }, f"{save_dir}generator_epoch_{epoch}.pth")

            torch.save({
                'epoch': epoch,
                'discriminator_state_dict': disc.state_dict(),
                'optimizer_disc_state_dict': opt_disc.state_dict(),
            }, f"{save_dir}discriminator_epoch_{epoch}.pth")

            save_some_examples(gen, val_loader, epoch, folder="evaluation")


if __name__ == "__main__":
    main()


  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 2/2 [00:00<00:00,  4.94it/s, D_fake=0.537, D_real=0.478]
100%|██████████| 2/2 [00:00<00:00,  6.08it/s, D_fake=0.478, D_real=0.59]
100%|██████████| 2/2 [00:00<00:00,  6.14it/s, D_fake=0.478, D_real=0.496]
100%|██████████| 2/2 [00:00<00:00,  6.15it/s, D_fake=0.474, D_real=0.546]
100%|██████████| 2/2 [00:00<00:00,  5.98it/s, D_fake=0.455, D_real=0.558]
100%|██████████| 2/2 [00:00<00:00,  6.06it/s, D_fake=0.433, D_real=0.568]
100%|██████████| 2/2 [00:00<00:00,  6.04it/s, D_fake=0.42, D_real=0.572]
100%|██████████| 2/2 [00:00<00:00,  5.80it/s, D_fake=0.412, D_real=0.568]
100%|██████████| 2/2 [00:00<00:00,  6.04it/s, D_fake=0.396, D_real=0.603]
100%|██████████| 2/2 [00:00<00:00,  6.08it/s, D_fake=0.396, D_real=0.592]
100%|██████████| 2/2 [00:00<00:00,  6.08it/s, D_fake=0.388, D_real=0.596]
100%|██████████| 2/2 [00:00<00:00,  5.

Saving model at epoch 49


100%|██████████| 2/2 [00:00<00:00,  5.09it/s, D_fake=0.0698, D_real=0.919]
100%|██████████| 2/2 [00:00<00:00,  5.54it/s, D_fake=0.0696, D_real=0.926]
100%|██████████| 2/2 [00:00<00:00,  5.78it/s, D_fake=0.0556, D_real=0.937]
100%|██████████| 2/2 [00:00<00:00,  5.34it/s, D_fake=0.182, D_real=0.908]
100%|██████████| 2/2 [00:00<00:00,  5.83it/s, D_fake=0.0546, D_real=0.873]
100%|██████████| 2/2 [00:00<00:00,  5.68it/s, D_fake=0.0551, D_real=0.86]
100%|██████████| 2/2 [00:00<00:00,  5.46it/s, D_fake=0.0631, D_real=0.919]
100%|██████████| 2/2 [00:00<00:00,  5.91it/s, D_fake=0.0605, D_real=0.951]
100%|██████████| 2/2 [00:00<00:00,  5.63it/s, D_fake=0.0651, D_real=0.937]
100%|██████████| 2/2 [00:00<00:00,  5.84it/s, D_fake=0.0354, D_real=0.932]
100%|██████████| 2/2 [00:00<00:00,  5.48it/s, D_fake=0.0499, D_real=0.936]
100%|██████████| 2/2 [00:00<00:00,  5.78it/s, D_fake=0.0467, D_real=0.945]
100%|██████████| 2/2 [00:00<00:00,  5.71it/s, D_fake=0.119, D_real=0.939]
100%|██████████| 2/2 [00:00<

Saving model at epoch 99


100%|██████████| 2/2 [00:00<00:00,  5.00it/s, D_fake=0.0374, D_real=0.954]
100%|██████████| 2/2 [00:00<00:00,  5.93it/s, D_fake=0.0269, D_real=0.965]
100%|██████████| 2/2 [00:00<00:00,  5.59it/s, D_fake=0.0411, D_real=0.929]
100%|██████████| 2/2 [00:00<00:00,  5.43it/s, D_fake=0.0407, D_real=0.952]
100%|██████████| 2/2 [00:00<00:00,  5.49it/s, D_fake=0.0408, D_real=0.965]
100%|██████████| 2/2 [00:00<00:00,  5.57it/s, D_fake=0.0392, D_real=0.965]
100%|██████████| 2/2 [00:00<00:00,  5.25it/s, D_fake=0.0331, D_real=0.962]
100%|██████████| 2/2 [00:00<00:00,  5.33it/s, D_fake=0.0323, D_real=0.967]
100%|██████████| 2/2 [00:00<00:00,  5.55it/s, D_fake=0.0282, D_real=0.972]
100%|██████████| 2/2 [00:00<00:00,  5.53it/s, D_fake=0.0277, D_real=0.97]
100%|██████████| 2/2 [00:00<00:00,  5.51it/s, D_fake=0.0251, D_real=0.969]
100%|██████████| 2/2 [00:00<00:00,  5.61it/s, D_fake=0.0229, D_real=0.968]
100%|██████████| 2/2 [00:00<00:00,  5.54it/s, D_fake=0.0205, D_real=0.974]
100%|██████████| 2/2 [00:0

Saving model at epoch 149


100%|██████████| 2/2 [00:00<00:00,  5.02it/s, D_fake=0.0174, D_real=0.974]
100%|██████████| 2/2 [00:00<00:00,  5.82it/s, D_fake=0.0199, D_real=0.978]
100%|██████████| 2/2 [00:00<00:00,  5.74it/s, D_fake=0.0254, D_real=0.979]
100%|██████████| 2/2 [00:00<00:00,  5.88it/s, D_fake=0.0201, D_real=0.975]
100%|██████████| 2/2 [00:00<00:00,  5.85it/s, D_fake=0.0254, D_real=0.979]
100%|██████████| 2/2 [00:00<00:00,  5.91it/s, D_fake=0.0208, D_real=0.978]
100%|██████████| 2/2 [00:00<00:00,  5.87it/s, D_fake=0.0239, D_real=0.974]
100%|██████████| 2/2 [00:00<00:00,  5.84it/s, D_fake=0.0239, D_real=0.98]
100%|██████████| 2/2 [00:00<00:00,  5.85it/s, D_fake=0.0241, D_real=0.979]
100%|██████████| 2/2 [00:00<00:00,  5.89it/s, D_fake=0.0161, D_real=0.98]
100%|██████████| 2/2 [00:00<00:00,  5.98it/s, D_fake=0.0168, D_real=0.978]
100%|██████████| 2/2 [00:00<00:00,  5.39it/s, D_fake=0.0211, D_real=0.979]
100%|██████████| 2/2 [00:00<00:00,  5.46it/s, D_fake=0.0154, D_real=0.982]
100%|██████████| 2/2 [00:00

Saving model at epoch 199


100%|██████████| 2/2 [00:00<00:00,  5.04it/s, D_fake=0.0162, D_real=0.986]
100%|██████████| 2/2 [00:00<00:00,  5.80it/s, D_fake=0.0105, D_real=0.988]
100%|██████████| 2/2 [00:00<00:00,  5.22it/s, D_fake=0.0107, D_real=0.987]
100%|██████████| 2/2 [00:00<00:00,  5.43it/s, D_fake=0.0107, D_real=0.988]
100%|██████████| 2/2 [00:00<00:00,  5.82it/s, D_fake=0.0125, D_real=0.99]
100%|██████████| 2/2 [00:00<00:00,  5.73it/s, D_fake=0.0112, D_real=0.99]
100%|██████████| 2/2 [00:00<00:00,  5.72it/s, D_fake=0.0163, D_real=0.989]
100%|██████████| 2/2 [00:00<00:00,  5.74it/s, D_fake=0.0138, D_real=0.988]
100%|██████████| 2/2 [00:00<00:00,  5.75it/s, D_fake=0.00971, D_real=0.988]
100%|██████████| 2/2 [00:00<00:00,  5.81it/s, D_fake=0.00826, D_real=0.989]
100%|██████████| 2/2 [00:00<00:00,  5.79it/s, D_fake=0.0111, D_real=0.987]
100%|██████████| 2/2 [00:00<00:00,  5.61it/s, D_fake=0.0107, D_real=0.991]
100%|██████████| 2/2 [00:00<00:00,  5.46it/s, D_fake=0.0194, D_real=0.991]
100%|██████████| 2/2 [00:

Saving model at epoch 249


100%|██████████| 2/2 [00:00<00:00,  4.93it/s, D_fake=0.0263, D_real=0.981]
100%|██████████| 2/2 [00:00<00:00,  5.75it/s, D_fake=0.0271, D_real=0.976]
100%|██████████| 2/2 [00:00<00:00,  5.83it/s, D_fake=0.0143, D_real=0.982]
100%|██████████| 2/2 [00:00<00:00,  5.66it/s, D_fake=0.0202, D_real=0.977]
100%|██████████| 2/2 [00:00<00:00,  5.34it/s, D_fake=0.0172, D_real=0.975]
100%|██████████| 2/2 [00:00<00:00,  5.48it/s, D_fake=0.0169, D_real=0.977]
100%|██████████| 2/2 [00:00<00:00,  5.37it/s, D_fake=0.0241, D_real=0.978]
100%|██████████| 2/2 [00:00<00:00,  5.41it/s, D_fake=0.0275, D_real=0.977]
100%|██████████| 2/2 [00:00<00:00,  5.43it/s, D_fake=0.0285, D_real=0.966]
100%|██████████| 2/2 [00:00<00:00,  5.13it/s, D_fake=0.025, D_real=0.977]
100%|██████████| 2/2 [00:00<00:00,  5.01it/s, D_fake=0.035, D_real=0.983]
100%|██████████| 2/2 [00:00<00:00,  5.46it/s, D_fake=0.0189, D_real=0.981]
100%|██████████| 2/2 [00:00<00:00,  5.34it/s, D_fake=0.0162, D_real=0.978]
100%|██████████| 2/2 [00:00

Saving model at epoch 299


100%|██████████| 2/2 [00:00<00:00,  4.87it/s, D_fake=0.0112, D_real=0.989]
100%|██████████| 2/2 [00:00<00:00,  5.92it/s, D_fake=0.00653, D_real=0.987]
100%|██████████| 2/2 [00:00<00:00,  5.64it/s, D_fake=0.0134, D_real=0.985]
100%|██████████| 2/2 [00:00<00:00,  5.88it/s, D_fake=0.00887, D_real=0.99]
100%|██████████| 2/2 [00:00<00:00,  5.78it/s, D_fake=0.0111, D_real=0.989]
100%|██████████| 2/2 [00:00<00:00,  5.90it/s, D_fake=0.0126, D_real=0.99]
100%|██████████| 2/2 [00:00<00:00,  5.69it/s, D_fake=0.0114, D_real=0.985]
100%|██████████| 2/2 [00:00<00:00,  5.75it/s, D_fake=0.00746, D_real=0.991]
100%|██████████| 2/2 [00:00<00:00,  5.85it/s, D_fake=0.0112, D_real=0.988]
100%|██████████| 2/2 [00:00<00:00,  5.83it/s, D_fake=0.0112, D_real=0.992]
100%|██████████| 2/2 [00:00<00:00,  5.93it/s, D_fake=0.0192, D_real=0.989]
100%|██████████| 2/2 [00:00<00:00,  5.90it/s, D_fake=0.00708, D_real=0.992]
100%|██████████| 2/2 [00:00<00:00,  5.92it/s, D_fake=0.00988, D_real=0.989]
100%|██████████| 2/2 [

Saving model at epoch 349


100%|██████████| 2/2 [00:00<00:00,  4.50it/s, D_fake=0.0262, D_real=0.967]
100%|██████████| 2/2 [00:00<00:00,  5.40it/s, D_fake=0.0151, D_real=0.973]
100%|██████████| 2/2 [00:00<00:00,  5.48it/s, D_fake=0.0154, D_real=0.972]
100%|██████████| 2/2 [00:00<00:00,  5.21it/s, D_fake=0.0169, D_real=0.971]
100%|██████████| 2/2 [00:00<00:00,  5.73it/s, D_fake=0.0148, D_real=0.975]
100%|██████████| 2/2 [00:00<00:00,  5.85it/s, D_fake=0.0134, D_real=0.977]
100%|██████████| 2/2 [00:00<00:00,  5.79it/s, D_fake=0.0171, D_real=0.977]
100%|██████████| 2/2 [00:00<00:00,  5.90it/s, D_fake=0.0184, D_real=0.978]
100%|██████████| 2/2 [00:00<00:00,  5.86it/s, D_fake=0.0172, D_real=0.98]
100%|██████████| 2/2 [00:00<00:00,  5.88it/s, D_fake=0.0228, D_real=0.978]
100%|██████████| 2/2 [00:00<00:00,  5.79it/s, D_fake=0.0146, D_real=0.987]
100%|██████████| 2/2 [00:00<00:00,  5.88it/s, D_fake=0.031, D_real=0.974]
100%|██████████| 2/2 [00:00<00:00,  5.91it/s, D_fake=0.0137, D_real=0.976]
100%|██████████| 2/2 [00:00

Saving model at epoch 399


100%|██████████| 2/2 [00:00<00:00,  4.93it/s, D_fake=0.013, D_real=0.988]
100%|██████████| 2/2 [00:00<00:00,  5.81it/s, D_fake=0.013, D_real=0.992]
100%|██████████| 2/2 [00:00<00:00,  5.84it/s, D_fake=0.0157, D_real=0.985]
100%|██████████| 2/2 [00:00<00:00,  5.80it/s, D_fake=0.0117, D_real=0.991]
100%|██████████| 2/2 [00:00<00:00,  5.84it/s, D_fake=0.0103, D_real=0.99]
100%|██████████| 2/2 [00:00<00:00,  5.79it/s, D_fake=0.0126, D_real=0.99]
100%|██████████| 2/2 [00:00<00:00,  5.89it/s, D_fake=0.00977, D_real=0.99]
100%|██████████| 2/2 [00:00<00:00,  5.85it/s, D_fake=0.00834, D_real=0.992]
100%|██████████| 2/2 [00:00<00:00,  5.79it/s, D_fake=0.00937, D_real=0.992]
100%|██████████| 2/2 [00:00<00:00,  5.77it/s, D_fake=0.00993, D_real=0.991]
100%|██████████| 2/2 [00:00<00:00,  5.47it/s, D_fake=0.0127, D_real=0.988]
100%|██████████| 2/2 [00:00<00:00,  5.36it/s, D_fake=0.0145, D_real=0.989]
100%|██████████| 2/2 [00:00<00:00,  5.49it/s, D_fake=0.0198, D_real=0.989]
100%|██████████| 2/2 [00:0

Saving model at epoch 449


100%|██████████| 2/2 [00:00<00:00,  4.81it/s, D_fake=0.0157, D_real=0.985]
100%|██████████| 2/2 [00:00<00:00,  5.82it/s, D_fake=0.0106, D_real=0.972]
100%|██████████| 2/2 [00:00<00:00,  5.96it/s, D_fake=0.012, D_real=0.989]
100%|██████████| 2/2 [00:00<00:00,  5.86it/s, D_fake=0.0135, D_real=0.99]
100%|██████████| 2/2 [00:00<00:00,  5.92it/s, D_fake=0.0135, D_real=0.992]
100%|██████████| 2/2 [00:00<00:00,  5.79it/s, D_fake=0.0172, D_real=0.991]
100%|██████████| 2/2 [00:00<00:00,  5.75it/s, D_fake=0.0183, D_real=0.977]
100%|██████████| 2/2 [00:00<00:00,  5.83it/s, D_fake=0.364, D_real=0.977]
100%|██████████| 2/2 [00:00<00:00,  5.79it/s, D_fake=0.109, D_real=0.77]
100%|██████████| 2/2 [00:00<00:00,  5.69it/s, D_fake=0.269, D_real=0.538]
100%|██████████| 2/2 [00:00<00:00,  5.80it/s, D_fake=0.118, D_real=0.807]
100%|██████████| 2/2 [00:00<00:00,  5.88it/s, D_fake=0.115, D_real=0.799]
100%|██████████| 2/2 [00:00<00:00,  5.89it/s, D_fake=0.11, D_real=0.601]
100%|██████████| 2/2 [00:00<00:00, 

Saving model at epoch 499


In [5]:
def main():
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    gen = generator(in_channels=3, features=64).to(config.DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
        )

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(config.NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if config.SAVE_MODEL and epoch % 50 == 0:
            save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
            save_some_examples(gen, val_loader, epoch, folder="evaluation")


if __name__ == "__main__":
    main()

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 2/2 [01:12<00:00, 36.01s/it, D_fake=0.448, D_real=0.512]


FileNotFoundError: [Errno 2] No such file or directory: '/content/evaluation/y_gen_0.png'