# Pix2Pix Paper Implementation 


## Discriminator Architecture

In [3]:
import torch
import torch.nn as nn

- We will use CNNBlock 4 times only as we want our features scaled from 64 to 512{64->128->256->512}

In [4]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride = 2):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding_mode="reflect", bias=False),  #"reflect":- pads with relfections(mirroring) of the input; Usefult to reduce edge artifacts
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self, x):
        return self.conv(x)
        
        

- We are giving (Image + Target) as the input to the Discriminator hence we will multiply in_channels by 2 as the total number of channels are 6 i.e 3 RGB channels from Image + 3 RGB channels from Target_image

In [5]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, features = [64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )
        
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels=in_channels, out_channels=feature, stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature
            
        #Add a final_convolution layer to output a single channel instead of 512
        layers.append(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
        )
        
        #We have packed the CNNBlocks into a list; Now we will unpack it 
        self.model = nn.Sequential(*layers)
        
    def forward(self, x, target):
        x = torch.cat([x, target], dim=1)
        return self.model(self.initial(x))
        
        

In [6]:
def test():
    x = torch.randn((1,3,256,256))
    y = torch.randn((1,3,256,256))
    model = Discriminator()
    preds = model(x,y)
    print(preds.shape)
    
test()

torch.Size([1, 1, 26, 26])


## Generator Architecture

In [7]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_dropout = False, act="relu"):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels, kernel_size=4, stride=2, padding=1, bias=False ,padding_mode="reflect")
            if down #It means Conv2d() will act when we are in the Down phase of our UNET like Block
            else nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), #It means ConvTranspose2d() will act when we are in the Up phase of our UNET like Block
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),   
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x
    

In [8]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3, features = 64):
        super().__init__()
        #We will not use BatchNormalization in the first layer
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1,padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        ) #256----->128
        self.down1 = Block(features, features*2, act="leaky", use_dropout=False)#128----->64
        self.down2 = Block(features*2, features*4, act="leaky", use_dropout=False)#64----->32
        self.down3 = Block(features*4, features*8, act="leaky", use_dropout=False)#32----->16
        self.down4 = Block(features*8, features*8, act="leaky", use_dropout=False)#16----->8
        self.down5 = Block(features*8, features*8, act="leaky", use_dropout=False)#8----->4
        self.down6 = Block(features*8, features*8, act="leaky", use_dropout=False)#4----->2
        
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode="reflect"),
            nn.ReLU(),
        )
        
        self.up1 = Block(features*8, features*8, down=False ,act="relu", use_dropout=True)
        self.up2 = Block(features*8*2, features*8, down=False ,act="relu", use_dropout=True)
        self.up3 = Block(features*8*2, features*8, down=False ,act="relu", use_dropout=True)
        self.up4 = Block(features*8*2, features*8, down=False ,act="relu", use_dropout=False)
        self.up5 = Block(features*8*2, features*4, down=False ,act="relu", use_dropout=False)
        self.up6 = Block(features*4*2, features*2, down=False ,act="relu", use_dropout=False)
        self.up7 = Block(features*2*2, features, down=False ,act="relu", use_dropout=False)
        
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(in_channels=features*2, out_channels=in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(), #As we want the pixels values to be between -1 and 1
        )
        
    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        
        bottleneck = self.bottleneck(d7)
        
        #For the Up Phase we need to concatenate the up_convolution with the respective down_convolution
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], dim=1))
        up3 = self.up3(torch.cat([up2, d6], dim=1))
        up4 = self.up4(torch.cat([up3, d5], dim=1))
        up5 = self.up5(torch.cat([up4, d4], dim=1))
        up6 = self.up6(torch.cat([up5, d3], dim=1))
        up7 = self.up7(torch.cat([up6, d2], dim=1))
        
        return self.final_up(torch.cat([up7, d1], dim=1))

In [10]:
def test():
    x = torch.randn((1,3,256,256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)
    
test()

torch.Size([1, 3, 256, 256])


## Dataset Loading 

#### Albumenations For Image Pre-Processing

In [None]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

both_transform = A.Compose([
    A.Resize(width=256,height=256)],
    additional_targets={"image0":"image"},
)

transform_only_input = A.Compose(
    [
        A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], max_pixel_value=255.0),
        ToTensorV2(),
    ]
)
transform_only_mask = A.Compose([
    A.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], max_pixel_value=255.0),
    ToTensorV2(),
])



In [None]:
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset

In [None]:
class SatDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
        print(self.list_files)
        
    def __len__(self):
        return len(self.list_files)
    
    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        
        #Since our image contains both the Input_Image and the Target_Image so we will slice the imamge from middle
        input_image = image[:, :600, :]  
        target_image = image[:, 600:, :]
        
        # augmentations = config.both_transform(image = input_image, image0 = target_image)
        # input_image , target_image = augmentations["image"], augmentations["image0"]
        # input_image = config.transform_only_input(image = input_image)["image"]
        # target_image = config.transform_only_mask(image= target_image)["image"]
        
        augmentations = both_transform(image = input_image, image0 = target_image)
        input_image , target_image = augmentations["image"], augmentations["image0"]
        input_image = transform_only_input(image = input_image)["iamge"]
        target_image = transform_only_mask(image = target_image)["image"]
        
        return input_image, target_image
        
        
        

## Train Loop

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from utils import save_checkpoint,save_image,save_some_examples,load_checkpoint
import config
from dataset import SatDataset
from genenrator import Generator
from discriminator import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm


In [None]:
def train_fun():
    pass

In [None]:
import torch.amp


def main():
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    gene = Generator(in_channels=3).to(config.DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    opt_gene = optim.Adam(gene.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
    BCE_Loss = nn.BCEWithLogitsLoss()
    L1_Loss = nn.L1Loss()
    
    if config.LOAD_MODEL:
        load_checkpoint(config.CHECKPOINT_GEN, gene, opt_gene, lr=config.LEARNING_RATE)
        load_checkpoint(config.CHECKPOINT_DISC, disc, opt_disc, lr=config.LEARNING_RATE)
    
    train_dataset = SatDataset(root_dir="data/maps/train")   
    train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS)
    
    #We will do float16 training which requires less compute
    g_scaler = torch.amp.GradScaler()
    d_scaler = torch.amp.GradScaler()
    
    val_dataset = SatDataset(root_dir="data/maps/val")
    val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
    
    for epoch in tqdm(range(config.NUM_EPOCHS)):
        train_fun(disc, gene, train_loader, val_loader,opt_disc, opt_gene, L1_Loss, BCE_Loss, g_scaler, d_scaler)
        
        if config.SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint()
        
        
    