In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard
import albumentations as A
import time

# Geneartor

In [None]:
model_architecture=[
    ("C1",(64,3,1,1)),
    ("D1",(64,3,2,1)),
    ("C2",(128,3,1,1)),
    ("D2",(128,3,2,1)),
    ("C3",(256,3,1,1)),
    ("D3",(256,3,2,1)),
    ("C4",(512,3,1,1)),
    ("D4",(512,3,2,1)),
    ("B",(1024,3,1,1)), # bottleneck connection
    ("U1",(512,3,2,1)),
    ("C5",(512,3,1,1)), # input_channel= 512+512(from U1 and C4)
    ("U2",(256,3,2,1)),
    ("C6",(256,3,1,1)), # input_channel=256+256(from U2 and C3)
    ("U3",(128,3,2,1)),
    ("C7",(128,3,1,1)), # input_channel=128+128(from U3 and C2)
    ("U4",(64,3,2,1)),
    ("C8",(64,3,1,1)), # input_channel=64+64(from U4 and C1)
    ("C9",(3,1,1,0)), # output_image
]

In [None]:
class CNNBlock(nn.Module):
  def __init__(self,in_channels,out_channels,kernel_size,stride,padding,layer_type=None):
    super(CNNBlock,self).__init__()
    self.layer_type=layer_type
    if not layer_type:
      self.layers = nn.Sequential(
          nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,padding_mode="reflect"),
          nn.BatchNorm2d(out_channels),
          nn.ReLU(),
          nn.Conv2d(out_channels,out_channels,kernel_size,stride,padding,padding_mode="reflect"),
          nn.BatchNorm2d(out_channels),
          nn.ReLU()
      )
    elif layer_type=="last":
      self.layers = nn.Sequential(
          nn.Conv2d(in_channels,out_channels,1,1,0),
          nn.Tanh()
      )
    elif layer_type=="up":
      self.layers = nn.Sequential(
          nn.ConvTranspose2d(in_channels,out_channels,kernel_size,2,padding),
          nn.BatchNorm2d(out_channels),
          nn.ReLU()
      )
    elif layer_type=="down":
      self.layers = nn.Sequential(
          nn.Conv2d(in_channels,out_channels,kernel_size,2,padding),
          nn.BatchNorm2d(out_channels),
          nn.ReLU()
      )

  def forward(self,x):
    return self.layers(x)

class Generator(nn.Module):
  def __init__(self,img_channels,model_architecture,feature_g): # latent_dim=(-1,100,1,1),feature_g=64
    super(Generator,self).__init__()
    self.model_architecture = model_architecture
    self.img_channels= img_channels
    self.feature_g = feature_g
    self.layers = nn.ModuleList()
    self.create_network()

  def create_network(self):
    img_channels = self.img_channels
    for layer in self.model_architecture:
      if layer[0][0]=="C" and int(layer[0][1])<=4:
        if int(layer[0][1])==1: # Initial Layer
          self.layers.append(CNNBlock(self.img_channels,self.feature_g,3,1,1))
          img_channels = self.feature_g
        else:
          self.layers.append(CNNBlock(img_channels,img_channels*2,3,1,1))
          img_channels = img_channels*2
      elif layer[0][0]=="D":
        self.layers.append(CNNBlock(img_channels,img_channels,3,2,1,layer_type="down"))
      elif layer[0][0]=="B":
        self.layers.append(CNNBlock(img_channels,img_channels*2,3,1,1))
        img_channels = img_channels*2
      elif layer[0][0]=="U":
        self.layers.append(CNNBlock(img_channels,img_channels//2,4,2,1,layer_type="up"))
        img_channels = img_channels//2
      elif layer[0][0]=="C" and int(layer[0][1])>=5:
        if int(layer[0][1])==9: # Last Layer
          self.layers.append(CNNBlock(img_channels,self.img_channels,1,1,0,layer_type="last"))
        else:
          self.layers.append(CNNBlock(img_channels*2,img_channels,3,1,1))


  def forward(self,x): # input_image and noise
    skip_connections=[]
    for idx,layer in enumerate(self.model_architecture):
      if layer[0] in ["C1","C2","C3","C4"]:
        x = self.layers[idx](x)
        skip_connections.append(x)
      elif layer[0]=="C9": # last layer
        x = self.layers[idx](x)
      elif layer[0] in ["C5","C6","C7","C8"]:
          x = torch.cat([skip_connections[-1],x],dim=1)
          x = self.layers[idx](x)
          skip_connections.pop()
      else:
        x = self.layers[idx](x)
    return x

In [None]:
# import torch
# import torch.nn as nn


# class Block(nn.Module):
#     def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
#         super(Block, self).__init__()
#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
#             if down
#             else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
#             nn.BatchNorm2d(out_channels),
#             nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
#         )

#         self.use_dropout = use_dropout
#         self.dropout = nn.Dropout(0.5)
#         self.down = down

#     def forward(self, x):
#         x = self.conv(x)
#         return self.dropout(x) if self.use_dropout else x


# class Generator(nn.Module):
#     def __init__(self, in_channels=3, features=64):
#         super().__init__()
#         self.initial_down = nn.Sequential(
#             nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
#             nn.LeakyReLU(0.2),
#         )
#         self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
#         self.down2 = Block(
#             features * 2, features * 4, down=True, act="leaky", use_dropout=False
#         )
#         self.down3 = Block(
#             features * 4, features * 8, down=True, act="leaky", use_dropout=False
#         )
#         self.down4 = Block(
#             features * 8, features * 8, down=True, act="leaky", use_dropout=False
#         )
#         self.down5 = Block(
#             features * 8, features * 8, down=True, act="leaky", use_dropout=False
#         )
#         self.down6 = Block(
#             features * 8, features * 8, down=True, act="leaky", use_dropout=False
#         )
#         self.bottleneck = nn.Sequential(
#             nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
#         )

#         self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
#         self.up2 = Block(
#             features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
#         )
#         self.up3 = Block(
#             features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
#         )
#         self.up4 = Block(
#             features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
#         )
#         self.up5 = Block(
#             features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
#         )
#         self.up6 = Block(
#             features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
#         )
#         self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
#         self.final_up = nn.Sequential(
#             nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
#             nn.Tanh(),
#         )

#     def forward(self, x):
#         d1 = self.initial_down(x)
#         d2 = self.down1(d1)
#         d3 = self.down2(d2)
#         d4 = self.down3(d3)
#         d5 = self.down4(d4)
#         d6 = self.down5(d5)
#         d7 = self.down6(d6)
#         bottleneck = self.bottleneck(d7)
#         up1 = self.up1(bottleneck)
#         up2 = self.up2(torch.cat([up1, d7], 1))
#         up3 = self.up3(torch.cat([up2, d6], 1))
#         up4 = self.up4(torch.cat([up3, d5], 1))
#         up5 = self.up5(torch.cat([up4, d4], 1))
#         up6 = self.up6(torch.cat([up5, d3], 1))
#         up7 = self.up7(torch.cat([up6, d2], 1))
#         return self.final_up(torch.cat([up7, d1], 1))

In [None]:
#Test Generator
x = torch.rand((2,3,128,128))
gen = Generator(3,model_architecture,64)
gen(x).shape

torch.Size([2, 3, 128, 128])

# Discriminator

In [None]:
class CNN_Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(CNN_Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, 4, stride, 1, bias=False, padding_mode="reflect"
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels * 2,features[0],kernel_size=4,stride=2,padding=1,padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(CNN_Block(in_channels, feature, stride=1 if feature == features[-1] else 2))
            in_channels = feature

        #last layer
        layers.append(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"),
        )

        self.model = nn.Sequential(*layers)

    def forward(self, x, y):
        x = torch.cat([x, y], dim=1)
        x = self.initial(x)
        x = self.model(x)
        return x

# Dataloader

In [None]:
import tensorflow as tf
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from albumentations.pytorch import ToTensorV2
import pathlib

dataset_name = "facades"

_URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz'

path_to_zip = tf.keras.utils.get_file(
    fname=f"{dataset_name}.tar.gz",
    origin=_URL,
    extract=True)

path_to_zip  = pathlib.Path(path_to_zip)

PATH = path_to_zip.parent/dataset_name

# _URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/maps.tar.gz'
# path_to_zip = tf.keras.utils.get_file('maps.tar.gz',
#                                       origin=_URL,
#                                       extract=True)

# PATH = os.path.join(os.path.dirname(path_to_zip), 'maps/')

Downloading data from http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz


In [None]:
!ls /root/.keras/datasets/facades/

test  train  val


In [None]:
np.array(Image.open("/root/.keras/datasets/facades/test/11.jpg")).shape

(256, 512, 3)

In [None]:
class MapDataset(Dataset):
    def __init__(self, root_dir,resize):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
        self.both_transform = A.Compose(
            [A.Resize(width=resize, height=resize),], additional_targets={"image0": "image"},
        )

        self.transform_only_input = A.Compose(
            [
                A.HorizontalFlip(p=0.5),
                A.ColorJitter(p=0.2),
                A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
                ToTensorV2(),
            ]
        )

        self.transform_only_mask = A.Compose(
            [
                A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
                ToTensorV2(),
            ]
        )

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :256, :]
        target_image = image[:, 256:, :]

        # augmentations = config.both_transform(image=input_image, image0=target_image)
        # input_image = augmentations["image"]
        # target_image = augmentations["image0"]

        input_image = self.transform_only_input(image=input_image)["image"]
        target_image = self.transform_only_mask(image=target_image)["image"]

        return input_image, target_image

# Utils and Config

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image

torch.backends.cudnn.benchmark = True
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "/root/.keras/datasets/facades/train"
VAL_DIR = "/root/.keras/datasets/facades/val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 50
SAVE_IMG_DIR = "/content/evaluation_1"
if not os.path.exists(SAVE_IMG_DIR):
  os.makedirs(SAVE_IMG_DIR)
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"

def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

# Training Loop

In [None]:
def train_fn(
    disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_loss=D_loss.item(),
                G_loss = G_fake_loss.item(),
                G_L1_loss=L1.item(),
            )


def main():
    disc = Discriminator(in_channels=3).to(DEVICE)
    gen = Generator(in_channels=3, features=64).to(DEVICE)
    # gen = Generator(img_channels=3,model_architecture=model_architecture,feature_g=64).to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999),)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
        )

    train_dataset = MapDataset(TRAIN_DIR,IMAGE_SIZE)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    val_dataset = MapDataset(VAL_DIR,IMAGE_SIZE)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    for epoch in range(NUM_EPOCHS):
        train_fn(
            disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
        )

        if SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

        save_some_examples(gen, val_loader, epoch, folder=SAVE_IMG_DIR)

In [None]:
main()

100%|██████████| 25/25 [00:06<00:00,  3.95it/s, D_loss=0.455, G_L1_loss=54.9, G_loss=1.26]
100%|██████████| 25/25 [00:04<00:00,  5.55it/s, D_loss=0.336, G_L1_loss=42, G_loss=1.88]
100%|██████████| 25/25 [00:04<00:00,  5.42it/s, D_loss=0.175, G_L1_loss=40.5, G_loss=2.29]
100%|██████████| 25/25 [00:04<00:00,  5.28it/s, D_loss=0.139, G_L1_loss=42.2, G_loss=2.62]
100%|██████████| 25/25 [00:04<00:00,  5.51it/s, D_loss=0.452, G_L1_loss=41, G_loss=1.3]
100%|██████████| 25/25 [00:05<00:00,  4.85it/s, D_loss=0.544, G_L1_loss=41.6, G_loss=2.17]
100%|██████████| 25/25 [00:04<00:00,  5.44it/s, D_loss=0.179, G_L1_loss=35.8, G_loss=2.02]
100%|██████████| 25/25 [00:04<00:00,  5.47it/s, D_loss=0.576, G_L1_loss=33.3, G_loss=0.734]
100%|██████████| 25/25 [00:04<00:00,  5.16it/s, D_loss=0.0927, G_L1_loss=39.2, G_loss=3.03]
100%|██████████| 25/25 [00:04<00:00,  5.40it/s, D_loss=0.11, G_L1_loss=39.7, G_loss=3.45]
100%|██████████| 25/25 [00:04<00:00,  5.05it/s, D_loss=0.201, G_L1_loss=38.1, G_loss=2.77]
100

In [None]:
main()

100%|██████████| 25/25 [00:15<00:00,  1.62it/s, D_loss=0.39, G_loss=60]
100%|██████████| 25/25 [00:15<00:00,  1.61it/s, D_loss=0.282, G_loss=53.6]
100%|██████████| 25/25 [00:15<00:00,  1.58it/s, D_loss=0.285, G_loss=48.4]
100%|██████████| 25/25 [00:16<00:00,  1.55it/s, D_loss=0.167, G_loss=46.6]
100%|██████████| 25/25 [00:16<00:00,  1.53it/s, D_loss=0.474, G_loss=46.7]
100%|██████████| 25/25 [00:16<00:00,  1.53it/s, D_loss=0.56, G_loss=42.8]
100%|██████████| 25/25 [00:16<00:00,  1.51it/s, D_loss=0.352, G_loss=38.4]
100%|██████████| 25/25 [00:16<00:00,  1.48it/s, D_loss=0.138, G_loss=42.8]
100%|██████████| 25/25 [00:16<00:00,  1.48it/s, D_loss=0.312, G_loss=37.4]
100%|██████████| 25/25 [00:16<00:00,  1.47it/s, D_loss=0.107, G_loss=38.8]
100%|██████████| 25/25 [00:17<00:00,  1.46it/s, D_loss=0.545, G_loss=33.6]
100%|██████████| 25/25 [00:17<00:00,  1.45it/s, D_loss=0.412, G_loss=38.3]
100%|██████████| 25/25 [00:17<00:00,  1.45it/s, D_loss=0.139, G_loss=38.4]
100%|██████████| 25/25 [00:17

In [None]:
# !rm *.png