<a href="https://colab.research.google.com/github/Aiden-Ross-Dsouza/Generative-Models/blob/main/Generative-Adversarial-Networks/notebooks/Pix2Pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

# Discriminator

In [2]:
class CNNBlock(nn.Module):
  def __init__(self, in_chans, out_chans, stride=2):
    super().__init__()
    self.conv = nn.Sequential(
      nn.Conv2d(in_chans, out_chans, 4, stride, bias=False, padding=1, padding_mode="reflect"),
      nn.InstanceNorm2d(out_chans, affine=True),
      nn.LeakyReLU(0.2),
    )

  def forward(self, x):
    return self.conv(x)

In [3]:
class Discriminator(nn.Module):
  def __init__(self, in_chans=3, features=[64, 128, 256, 512]):
    super().__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_chans*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.LeakyReLU(0.2),
    )
    layers = []
    in_chans = features[0]
    for feature in features[1:]:
      layers.append(
          CNNBlock(in_chans, feature, stride=1 if feature==features[-1] else 2)
      )
      in_chans = feature
    layers.append(
        nn.Conv2d(in_chans, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
    )
    self.model = nn.Sequential(*layers)

  def forward(self, x, y):
    x = torch.cat((x, y), dim=1)
    x = self.initial(x)
    return self.model(x)

In [4]:
def test():
  x = torch.randn((1, 3, 256, 256)) # (batch_dim, chan, img_dim, img_dim)
  y = torch.randn((1, 3, 256, 256))
  model = Discriminator()
  preds = model(x,y)
  print(preds.shape)

test()

torch.Size([1, 1, 30, 30])


# Generator

In [5]:
class Block(nn.Module):
  def __init__(self, in_chans, out_chans, down=True, act="relu", use_dropout=False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_chans, out_chans, 4, 2, 1, bias=False, padding_mode="reflect")
        if down
        else nn.ConvTranspose2d(in_chans, out_chans, 4, 2, 1, bias=False),
        nn.InstanceNorm2d(out_chans, affine=True),
        nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),
    )
    self.use_dropout = use_dropout
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x

In [6]:
class Generator(nn.Module):
  def __init__(self, in_chans=3, features=64):
    super().__init__()
    self.initial_down = nn.Sequential(
        nn.Conv2d(in_chans, features, 4, 2, 1, padding_mode="reflect"), # 256 -> 128
        nn.LeakyReLU(0.2)
    )
    self.down1 = Block(features, features*2, down=True, act="leaky", use_dropout=False) # 128 -> 64
    self.down2 = Block(features*2, features*4, down=True, act="leaky", use_dropout=False) # 64 -> 32
    self.down3 = Block(features*4, features*8, down=True, act="leaky", use_dropout=False) # 32 -> 16
    self.down4 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False) # 16 -> 8
    self.down5 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False) # 8 -> 4
    self.down6 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False) # 4 -> 2
    self.bottleneck = nn.Sequential(
        nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode="reflect"), # 2 -> 1
        nn.ReLU()
    )
    self.up1 = Block(features*8, features*8, down=False, act="relu", use_dropout=True)
    self.up2 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
    self.up3 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
    self.up4 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=False)
    self.up5 = Block(features*8*2, features*4, down=False, act="relu", use_dropout=False)
    self.up6 = Block(features*4*2, features*2, down=False, act="relu", use_dropout=False)
    self.up7 = Block(features*2*2, features, down=False, act="relu", use_dropout=False)
    self.final_up = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_chans, 4, 2, 1),
        nn.Tanh()
    )

  def forward(self, x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)
    bottleneck = self.bottleneck(d7)
    up1 = self.up1(bottleneck)
    up2 = self.up2(torch.cat([up1, d7], 1))
    up3 = self.up3(torch.cat([up2, d6], 1))
    up4 = self.up4(torch.cat([up3, d5], 1))
    up5 = self.up5(torch.cat([up4, d4], 1))
    up6 = self.up6(torch.cat([up5, d3], 1))
    up7 = self.up7(torch.cat([up6, d2], 1))
    return self.final_up(torch.cat([up7, d1], 1))

In [7]:
def test():
  x = torch.randn((1, 3, 256, 256))
  model = Generator(in_chans=3, features=64)
  preds = model(x)
  print(preds.shape)

test()

torch.Size([1, 3, 256, 256])


# Config

In [8]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

  check_for_updates()


In [39]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
lr = 2e-4
batch_size = 16
num_workers = 2
img_size = 256
chans_img = 3
l1_lambda = 100
num_epochs = 100
load_model = False
save_model = True
checkpoint_disc = "/content/drive/MyDrive/Pix2PixGAN_results/disc.pth.tar"
checkpoint_gen = "/content/drive/MyDrive/Pix2PixGAN_results/gen.pth.tar"

In [10]:
both_transform = A.Compose(
    [A.Resize(width=256, height=256)], additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.ColorJitter(p=0.1),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2()
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
        ToTensorV2()
    ]
)

# Dataset

In [11]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

In [12]:
class MapDataset(Dataset):
  def __init__(self, root_dir):
    super().__init__()
    self.root_dir = root_dir
    self.list_files = os.listdir(self.root_dir)

  def __len__(self):
    return len(self.list_files)

  def __getitem__(self, idx):
    img_file = self.list_files[idx]
    img_path = os.path.join(self.root_dir, img_file)
    image = np.array(Image.open(img_path))
    input_image = image[:, : image.shape[1]//2, :]
    target_image = image[:, image.shape[1]//2: , :]

    augmentations = both_transform(image=input_image, image0=target_image)
    input_image, target_image = augmentations["image"], augmentations["image0"]

    input_image = transform_only_input(image=input_image)["image"]
    target_image = transform_only_mask(image=target_image)["image"]

    return input_image, target_image

# Utils

In [13]:
import torchaudio
from torchvision.utils import save_image

In [14]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(device), y.to(device)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()

In [15]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [16]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

# Train

In [17]:
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.optim as optim

In [18]:
def train(disc, gen, train_loader, opt_disc, opt_gen, L1_loss, BCE, g_scaler, d_scaler):

  for idx, (x, y) in enumerate(tqdm(train_loader, leave=True)):
    x, y = x.to(device), y.to(device)

    # train discriminator
    with torch.cuda.amp.autocast():
      y_fake = gen(x)
      D_real = disc(x, y)
      D_fake = disc(x, y_fake.detach())
      D_real_loss = BCE(D_real, torch.ones_like(D_real))
      D_fake_loss = BCE(D_fake, torch.zeros_like(D_fake))
      D_loss = (D_real_loss + D_fake_loss) / 2

    opt_disc.zero_grad()
    d_scaler.scale(D_loss).backward()
    d_scaler.step(opt_disc)
    d_scaler.update()

    #train generator
    with torch.cuda.amp.autocast():
      D_fake = disc(x, y_fake)
      G_fake_loss = BCE(D_fake, torch.ones_like(D_fake))
      L1 = L1_loss(y_fake, y) * l1_lambda
      G_loss = G_fake_loss + L1

    opt_gen.zero_grad()
    g_scaler.scale(G_loss).backward()
    g_scaler.step(opt_gen)
    g_scaler.update()

# Download Dataset

In [19]:
import shutil
import zipfile

In [20]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [21]:
# Create Kaggle directory
os.makedirs('/root/.kaggle', exist_ok=True)

# Define source and destination paths
source = "/content/drive/MyDrive/Kaggle_API/kaggle.json"
destination = "/root/.kaggle/kaggle.json"

# Copy the file
shutil.copy(source, destination)

# Set correct permissions
!chmod 600 /root/.kaggle/kaggle.json

In [22]:
!kaggle datasets download -d alincijov/pix2pix-maps

Dataset URL: https://www.kaggle.com/datasets/alincijov/pix2pix-maps
License(s): CC0-1.0
Downloading pix2pix-maps.zip to /content
 80% 192M/239M [00:00<00:00, 673MB/s]  
100% 239M/239M [00:00<00:00, 657MB/s]


In [23]:
# Unzip the dataset
with zipfile.ZipFile("pix2pix-maps.zip", 'r') as zip_ref:
    zip_ref.extractall("pix2pix-maps")

# List extracted files
!ls pix2pix-maps

train  val


# Main

In [37]:
def main():
  disc = Discriminator(in_chans=3).to(device)
  gen = Generator(in_chans=3).to(device)
  opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))
  opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
  BCE = nn.BCEWithLogitsLoss()
  l1_loss = nn.L1Loss()

  if load_model:
    load_checkpoint(checkpoint_gen, gen, opt_gen, lr)
    load_checkpoint(checkpoint_disc, disc, opt_disc, lr)

  train_dataset = MapDataset(root_dir="pix2pix-maps/train")
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()
  val_dataset = MapDataset(root_dir='pix2pix-maps/val')
  val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

  for epoch in range(num_epochs):
    train(disc, gen, train_loader, opt_disc, opt_gen, l1_loss, BCE, g_scaler, d_scaler)
    if save_model and epoch % 5 == 0:
      save_checkpoint(gen, opt_gen, filename=checkpoint_gen)
      save_checkpoint(disc, opt_disc, filename=checkpoint_disc)

    save_some_examples(gen, val_loader, epoch, folder='/content/drive/MyDrive/Pix2PixGAN_results')

In [40]:
if __name__ == "__main__":
  main()

  g_scaler = torch.cuda.amp.GradScaler()
  d_scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
100%|██████████| 69/69 [00:14<00:00,  4.74it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.63it/s]
100%|██████████| 69/69 [00:14<00:00,  4.63it/s]
100%|██████████| 69/69 [00:14<00:00,  4.79it/s]
100%|██████████| 69/69 [00:13<00:00,  4.96it/s]
100%|██████████| 69/69 [00:13<00:00,  4.96it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.65it/s]
100%|██████████| 69/69 [00:14<00:00,  4.67it/s]
100%|██████████| 69/69 [00:15<00:00,  4.59it/s]
100%|██████████| 69/69 [00:14<00:00,  4.87it/s]
100%|██████████| 69/69 [00:14<00:00,  4.85it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.69it/s]
100%|██████████| 69/69 [00:14<00:00,  4.70it/s]
100%|██████████| 69/69 [00:14<00:00,  4.70it/s]
100%|██████████| 69/69 [00:14<00:00,  4.68it/s]
100%|██████████| 69/69 [00:14<00:00,  4.85it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.68it/s]
100%|██████████| 69/69 [00:14<00:00,  4.72it/s]
100%|██████████| 69/69 [00:14<00:00,  4.76it/s]
100%|██████████| 69/69 [00:14<00:00,  4.75it/s]
100%|██████████| 69/69 [00:14<00:00,  4.82it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.74it/s]
100%|██████████| 69/69 [00:14<00:00,  4.70it/s]
100%|██████████| 69/69 [00:14<00:00,  4.73it/s]
100%|██████████| 69/69 [00:14<00:00,  4.75it/s]
100%|██████████| 69/69 [00:14<00:00,  4.73it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:15<00:00,  4.60it/s]
100%|██████████| 69/69 [00:14<00:00,  4.71it/s]
100%|██████████| 69/69 [00:14<00:00,  4.65it/s]
100%|██████████| 69/69 [00:14<00:00,  4.78it/s]
100%|██████████| 69/69 [00:14<00:00,  4.84it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.68it/s]
100%|██████████| 69/69 [00:14<00:00,  4.70it/s]
100%|██████████| 69/69 [00:15<00:00,  4.59it/s]
100%|██████████| 69/69 [00:14<00:00,  4.67it/s]
100%|██████████| 69/69 [00:14<00:00,  4.73it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.60it/s]
100%|██████████| 69/69 [00:14<00:00,  4.65it/s]
100%|██████████| 69/69 [00:15<00:00,  4.60it/s]
100%|██████████| 69/69 [00:14<00:00,  4.75it/s]
100%|██████████| 69/69 [00:14<00:00,  4.78it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:15<00:00,  4.52it/s]
100%|██████████| 69/69 [00:14<00:00,  4.78it/s]
100%|██████████| 69/69 [00:14<00:00,  4.61it/s]
100%|██████████| 69/69 [00:14<00:00,  4.65it/s]
100%|██████████| 69/69 [00:14<00:00,  4.79it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.65it/s]
100%|██████████| 69/69 [00:14<00:00,  4.68it/s]
100%|██████████| 69/69 [00:14<00:00,  4.67it/s]
100%|██████████| 69/69 [00:14<00:00,  4.71it/s]
100%|██████████| 69/69 [00:14<00:00,  4.78it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:15<00:00,  4.59it/s]
100%|██████████| 69/69 [00:14<00:00,  4.60it/s]
100%|██████████| 69/69 [00:15<00:00,  4.54it/s]
100%|██████████| 69/69 [00:14<00:00,  4.74it/s]
100%|██████████| 69/69 [00:14<00:00,  4.85it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.77it/s]
100%|██████████| 69/69 [00:14<00:00,  4.74it/s]
100%|██████████| 69/69 [00:14<00:00,  4.67it/s]
100%|██████████| 69/69 [00:14<00:00,  4.85it/s]
100%|██████████| 69/69 [00:14<00:00,  4.84it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.62it/s]
100%|██████████| 69/69 [00:14<00:00,  4.60it/s]
100%|██████████| 69/69 [00:14<00:00,  4.73it/s]
100%|██████████| 69/69 [00:14<00:00,  4.85it/s]
100%|██████████| 69/69 [00:14<00:00,  4.84it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:15<00:00,  4.53it/s]
100%|██████████| 69/69 [00:15<00:00,  4.47it/s]
100%|██████████| 69/69 [00:15<00:00,  4.56it/s]
100%|██████████| 69/69 [00:14<00:00,  4.78it/s]
100%|██████████| 69/69 [00:14<00:00,  4.74it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.65it/s]
100%|██████████| 69/69 [00:15<00:00,  4.56it/s]
100%|██████████| 69/69 [00:15<00:00,  4.51it/s]
100%|██████████| 69/69 [00:14<00:00,  4.68it/s]
100%|██████████| 69/69 [00:14<00:00,  4.85it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.65it/s]
100%|██████████| 69/69 [00:14<00:00,  4.64it/s]
100%|██████████| 69/69 [00:14<00:00,  4.64it/s]
100%|██████████| 69/69 [00:14<00:00,  4.85it/s]
100%|██████████| 69/69 [00:14<00:00,  4.85it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:14<00:00,  4.65it/s]
100%|██████████| 69/69 [00:14<00:00,  4.71it/s]
100%|██████████| 69/69 [00:14<00:00,  4.71it/s]
100%|██████████| 69/69 [00:15<00:00,  4.57it/s]
100%|██████████| 69/69 [00:14<00:00,  4.75it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:15<00:00,  4.52it/s]
100%|██████████| 69/69 [00:14<00:00,  4.63it/s]
100%|██████████| 69/69 [00:14<00:00,  4.63it/s]
100%|██████████| 69/69 [00:14<00:00,  4.68it/s]
100%|██████████| 69/69 [00:14<00:00,  4.77it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:15<00:00,  4.55it/s]
100%|██████████| 69/69 [00:14<00:00,  4.66it/s]
100%|██████████| 69/69 [00:15<00:00,  4.55it/s]
100%|██████████| 69/69 [00:14<00:00,  4.74it/s]
100%|██████████| 69/69 [00:14<00:00,  4.74it/s]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 69/69 [00:15<00:00,  4.55it/s]
100%|██████████| 69/69 [00:14<00:00,  4.68it/s]
100%|██████████| 69/69 [00:14<00:00,  4.70it/s]
100%|██████████| 69/69 [00:14<00:00,  4.78it/s]
