<a href="https://colab.research.google.com/github/SarveshD7/Pix2Pix-Pytorch/blob/main/Pix2Pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ***Important Points on the Architecture of Pix2Pix***

Generator is very much inspired by U-Net

Discriminator is just a couple of Convolutional Layers

Discriminator- PatchGAN - The output of the discriminator is not a single value representing whether the image is real or fake.

Rather it is a grid like image (here 30x30 or 70x70 or 26x26) where each value is between [0,1] and represents whether a particular patch of the original image is real or fake.

---




In [6]:
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.optim as optim

In [3]:
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = True


In [4]:
both_transform = A.Compose(
    [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

In [None]:
class MapDataset(Dataset):
  def __init__(self, root_dir):
    self.root_dir = root_dir
    self.list_files = os.listdir(self.root_dir)
    print(self.list_files)

  def __len__(self):
    return len(self.list_files)

  def __getItem__(self, index):
    img_file = self.list_files[index]
    img_path = os.path.join(self.root_dir, img_file)
    img = np.array(Image.open(img_path))
    input_img = img[:, :600, :]  # Since the image comprises of input image and the target joined along the width so taking only till 600 we take the input image only
    target_img = img[:, 600: , :]  # Taking the target image only

    augmentations = both_transform(image=input_img, image0=target_img)
    input_img, target_img = augmentations["image"], augmentations["image0"]

    input_img = transform_only_input(image=input_img)["image"]
    target_img = transform_only_mask(image=input_img)["image"]

    return input_img, target_img

In [2]:
class CNNBlock(nn.Module):
  def __init__(self, in_channels, out_channels, stride=2):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )

  def forward(self, x):
    return self.conv(x)

class Discriminator(nn.Module):
  def __init__(self, in_channels, features=[64, 128, 256, 512]):
    #  features is used to call the same conv once on each of the values
    super().__init__()
    self.initial = nn.Sequential(
        # The input size of Conv2d is in_channels*2 because We pass input image and target output concatenated on the channels
        nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    )

    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(
          CNNBlock(in_channels, feature, stride=1 if feature==features[-1] else 2)
      )
      in_channels = feature
    layers.append(
        nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
    )
    self.model = nn.Sequential(*layers)

  def forward(self,x, y):
    x = torch.cat([x, y], dim=1)
    x = self.initial(x)
    return self.model(x)

def test():
  x = torch.randn((1, 3, 256, 256))
  y = torch.randn((1, 3, 256, 256))
  model = Discriminator(3)
  preds = model(x, y)
  print(preds.shape)
test()

torch.Size([1, 1, 26, 26])


In [7]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4,2,1,bias=False, padding_mode="reflect")
        if down
        else nn.ConvTranspose2d(in_channels, out_channels, 4,2,1,bias=False),
        nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),
    )
    self.use_dropout = use_dropout
    self.dropout = nn.Dropout(0.5)

  def forward(self, x):
    x = self.conv(x)
    return self.dropout(x) if self.use_dropout else x

class Generator(nn.Module):
  def __init__(self, in_channels, features=64):
    super().__init__()
    self.initial_down = nn.Sequential(
        nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
        nn.LeakyReLU(0.2)
    )  # 128 x 128

    self.down1 = Block(features, features*2, down=True, act="leaky", use_dropout=False)  # 64 x 64
    self.down2 = Block(features*2, features*4, down=True, act="leaky", use_dropout=False) # 32 x 32
    self.down3 = Block(features*4, features*8, down=True, act="leaky", use_dropout=False)  # 16 x 16
    self.down4 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)  # 8 x 8
    self.down5 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)  # 4 x 4
    self.down6 = Block(features*8, features*8, down=True, act="leaky", use_dropout=False)  # 2 x 2

    self.bottleneck = nn.Sequential(
        nn.Conv2d(features*8, features*8, 4,2,1, padding_mode="reflect"),
        nn.ReLU()  # 1 x 1
    )
    self.up1 = Block(features*8, features*8, down=False, act="relu", use_dropout=True)
    self.up2 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
    self.up3 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=True)
    self.up4 = Block(features*8*2, features*8, down=False, act="relu", use_dropout=False)
    self.up5 = Block(features*8*2, features*4, down=False, act="relu", use_dropout=False)
    self.up6 = Block(features*4*2, features*2, down=False, act="relu", use_dropout=False)
    self.up7 = Block(features*2*2, features, down=False, act="relu", use_dropout=False)
    self.final_up = nn.Sequential(
        nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
        nn.Tanh()
    )

  def forward(self,x):
    d1 = self.initial_down(x)
    d2 = self.down1(d1)
    d3 = self.down2(d2)
    d4 = self.down3(d3)
    d5 = self.down4(d4)
    d6 = self.down5(d5)
    d7 = self.down6(d6)
    bottleneck = self.bottleneck(d7)
    up1 = self.up1(bottleneck)
    up2 = self.up2(torch.cat([up1, d7], 1))
    up3 = self.up3(torch.cat([up2, d6], 1))
    up4 = self.up4(torch.cat([up3, d5], 1))
    up5 = self.up5(torch.cat([up4, d4], 1))
    up6 = self.up6(torch.cat([up5, d3], 1))
    up7 = self.up7(torch.cat([up6, d2], 1))
    return self.final_up(torch.cat([up7, d1], 1))


In [8]:
def test():
  x = torch.randn((1, 3, 256, 256))
  model = Generator(in_channels=3, features=64)
  preds = model(x)
  print(preds.shape)

In [9]:
test()

torch.Size([1, 3, 256, 256])


In [None]:
def train(disc, gen, loader, opt_disc, opt_gen, l1, bce):
  for idx, (x, y) in loader:

    # Train Discriminator
    y_fake = gen(x)
    D_real = disc(x, y)
    D_fake = disc(x,y_fake.detach())
    D_real_loss = bce(D_real, torch.ones_like(D_real))
    D_fake_loss = bce(D_fake, torch.ones_like(D_fake))
    D_loss = (D_real_loss+D_fake_loss)/2

    disc.zero_grad()
    D_loss.backward()
    disc.step(opt_disc)

    # Train Generator
    D_fake = disc(x, y_fake)
    G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
    L1 = l1(y_fake, y)*L1_LAMBDA
    G_loss = G_fake_loss+l1
    gen.zero_grad()
    G_loss.backward()
    gen.step(opt_gen)



In [None]:
disc = Discriminator(in_channels=3)
gen = Generator(in_channels=3)
opt_disc = optim.Adam(disc.parameters(), lr = LEARNING_RATE, betas = (0.5, 0.999))
opt_gen = optim.Adam(gen.parameters(), lr = LEARNING_RATE, betas = (0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

train_dataset = MapDataset(root_dir="/content/data/maps/train")
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

val_dataset = MapDataset(root_dir="/content/data/maps/val")
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

for epoch in range(NUM_EPOCHS):
  train(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE)