<a href="https://colab.research.google.com/github/Redcxx/ucl-master-project/blob/master/pix2pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Settings

## Setup Environment

In [13]:
%pip install pydrive2 > /dev/null
%pip install torchinfo > /dev/null

In [30]:
import os
import sys
import time
import random
import functools
from datetime import datetime
from pathlib import Path
from pprint import pprint

import numpy as np
from PIL import Image

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torchinfo import summary

## Configuration

In [32]:
class SessionConfig(dict):
    def __init__(self, *args, **kwargs):

        # House keeping
        self.run_id = datetime.now().strftime('%Y-%m-%d-%A-%Hh-%Mm-%Ss')
        self.random_seed = 42
        self.working_folder = 'MasterProject'  # will be created on google drive at root
        self.pydrive2_setting_file = 'settings.yaml'

        # Dataset
        self.batch_size = 1
        self.shuffle = False
        self.num_workers = 4
        self.pin_memory = False

        # Training
        self.start_epoch = 1
        self.end_epoch = 100
        self.lr = 0.0002
        self.eval_freq = 1   # eval frequency
        self.log_freq = 1    # log training losses etc interval
        self.save_freq = 10  # save training model interval

        # model
        self.generator_config = None
        self.discriminator_config = None

        # Optimizer
        self.optimizer_beta1 = 0.5
        self.optimizer_beta2 = 0.999
        
        # Loss
        self.l1_lambda = 100.0

        self.update(*args, **kwargs)

    
    def __getitem__(self, key):
        val = dict.__getitem__(self, key)
        return val

    def __setitem__(self, key, val):
        dict.__setitem__(self, key, val)

    def __repr__(self):
        dictrepr = dict.__repr__(self)
        return '%s(%s)' % (type(self).__name__, dictrepr)
        
    def update(self, *args, **kwargs):
        for k, v in dict(*args, **kwargs).items():
            self[k] = v

generator_config = {
    'in_channels': 3,
    'out_channels': 3,
    'blocks': [
    {
        'filters': 64,
        'dropout': False,
        'skip_connection': True
    }, 
    {
        'filters': 128,
        'dropout': False,
        'skip_connection': True
    }, 
    {
        'filters': 256,
        'dropout': False,
        'skip_connection': True
    }, 
    {
        'filters': 512,
        'dropout': False,
        'skip_connection': True
    }, 
    {
        'filters': 512,
        'dropout': True,
        'skip_connection': True
    }, 
    {
        'filters': 512,
        'dropout': True,
        'skip_connection': True
    }, 
    {
        'filters': 512,
        'dropout': True,
        'skip_connection': False
    }]
}

discriminator_config = {
    'in_channels': 6,  # conditionalGAN takes both real and fake image
    'blocks': [
    {
        'filters': 64,
    }, 
    {
        'filters': 128,
    }, 
    {
        'filters': 256,
    }, 
    {
        'filters': 512,
    }]
}


sconfig = SessionConfig(generator_config, discriminator_config)

print(f'RUN_ID: {sconfig.run_id}')
print(f'RANDOM_SEED: {sconfig.random_seed}')
print(f'WORKING_FOLDER: {sconfig.working_folder}')

RUN_ID: 2022-05-27-Friday-13h-45m-22s
RANDOM_SEED: 42
WORKING_FOLDER: MasterProject


## Setup `save_file` and `load_file` for saving checkpoint

In [34]:
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import files, drive

    drive_dir = '/content/drive'
    drive.mount(drive_dir)

    working_dir = os.path.join(drive_dir, 'My Drive', sconfig.working_folder)
    Path(working_dir).mkdir(parents=True, exist_ok=True)  # create directory if not exists on google drive

    def save_file(file_name):
        # save locally
        files.download(file_name)  

        # save on google drive
        with open(file_name, 'rb') as src_file:
            with open(os.path.join(working_dir, file_name), 'wb') as dest_file:
                dest_file.write(src_file.read())
    
    def load_file(file_name):
        if os.path.isfile(file_name):
            print(f'"{file_name}" already exists, not downloading')
            return
        !cp f'{os.path.join(working_dir, file_name)}' file_name


else:
    from pydrive2.auth import GoogleAuth
    from pydrive2.drive import GoogleDrive

    def ensure_folder_on_drive(drive, folder_name):
        folders = drive.ListFile({
            # see https://developers.google.com/drive/api/guides/search-files
            'q': "mimeType = 'application/vnd.google-apps.folder'"
        }).GetList()

        folders = list(filter(lambda folder: folder['title'] == folder_name, folders))

        if len(folders) == 1:
            return folders[0]
        
        if len(folders) > 1:
            pprint(folders)
            raise AssertionError('Multiple Folders of the same name detected')

        # folder not found, create a new one at root
        print(f'Folder: {folder_name} not found, creating at root')

        folder = drive.CreateFile({
            'title': folder_name, 
            # "parents": [{
            #     "kind": "drive#fileLink", 
            #     "id": parent_folder_id
            # }],
            "mimeType": "application/vnd.google-apps.folder"
        })
        folder.Upload()
        return folder


    g_auth = GoogleAuth(settings_file=sconfig.pydrive2_setting_file, http_timeout=None)
    g_auth.LocalWebserverAuth(host_name="localhost", port_numbers=None, launch_browser=True)
    drive = GoogleDrive(g_auth)

    folder = ensure_folder_on_drive(drive, sconfig.working_folder)    

    def save_file(file_name):
        file = drive.CreateFile({
            'title': file_name,
            'parents': [{
                'id': folder['id']
            }]
        })
        file.SetContentFile(file_name)
        # save to google drive
        file.Upload()
        # save locally
        file.GetContentFile(file_name)
    
    def load_file(file_name):
        if os.path.isfile(file_name):
            print(f'"{file_name}" already exists, not downloading')
            return
        files = drive.ListFile({
            'q': f"'{folder['id']}' in parents"
        }).GetList()
        downloaded = 1
        for file in files:
            if file['title'] == file_name:
                # download
                drive.CreateFile({'id': file['id']}).GetContentFile(file_name)
                break


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Miscellanuous

In [17]:
# reproducibility
random.seed(sconfig.random_seed)
np.random.seed(sconfig.random_seed)
torch.manual_seed(sconfig.random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(sconfig.random_seed)

# training device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [18]:
!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



# Models

## UnetBlock

In [19]:
class UnetBlock(nn.Module):
    def __init__(
        self, 
        in_filters, out_filters,

        submodule=None, 
        sub_in_filters=None, 
        sub_out_filters=None, 
        sub_skip_connection=False, 

        skip_connection=True, 
        dropout=nn.Dropout, 
        in_norm=nn.BatchNorm2d, out_norm=nn.BatchNorm2d, 
        in_act=nn.LeakyReLU, out_act=nn.ReLU,
    ):
        super().__init__()

        if submodule is None:
            sub_in_filters = in_filters
            sub_out_filters = in_filters
            sub_skip_connection = False
        
        conv_common_args = {
            'kernel_size': 4, 
            'stride': 2, 
            'padding': 1,
            'bias': in_norm.func != nn.BatchNorm2d if type(in_norm) == functools.partial else in_norm != nn.BatchNorm2d  # batch norm has bias
        }

        layers = []

        # encoder
        layers.append(nn.Conv2d(in_channels=in_filters, out_channels=sub_in_filters, **conv_common_args))

        if in_norm:
            layers.append(in_norm(sub_in_filters))

        if in_act:
            layers.append(in_act())

        
        # submodule
        if submodule:
            layers.append(submodule)


        # decoder
        if sub_skip_connection:
            layers.append(nn.ConvTranspose2d(in_channels=sub_out_filters * 2, out_channels=out_filters, **conv_common_args))
        else:
            layers.append(nn.ConvTranspose2d(in_channels=sub_out_filters    , out_channels=out_filters, **conv_common_args))

        if out_norm:
            layers.append(out_norm(out_filters))
        
        if dropout:
            layers.append(dropout())
        
        if out_act:
            layers.append(out_act())
        
        self.model = nn.Sequential(*layers)

        self.skip_connection = skip_connection
    
    def forward(self, x):
        if self.skip_connection:
            return torch.cat([x, self.model(x)], dim=1)
        else:
            return self.model(x)

In [20]:
# summary(
#     UnetBlock(in_filters=64, out_filters=64, submodule=None), 
#     input_size=(16, 64, 16, 16),
#     col_names=['output_size', 'num_params', 'mult_adds']
# )

Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds
UnetBlock                                --                        --                        --
├─Sequential: 1-1                        [16, 64, 16, 16]          --                        --
│    └─Conv2d: 2-1                       [16, 64, 8, 8]            65,536                    67,108,864
│    └─BatchNorm2d: 2-2                  [16, 64, 8, 8]            128                       2,048
│    └─LeakyReLU: 2-3                    [16, 64, 8, 8]            --                        --
│    └─ConvTranspose2d: 2-4              [16, 64, 16, 16]          65,536                    268,435,456
│    └─BatchNorm2d: 2-5                  [16, 64, 16, 16]          128                       2,048
│    └─Dropout: 2-6                      [16, 64, 16, 16]          --                        --
│    └─ReLU: 2-7                         [16, 64, 16, 16]          --                        --
Total para

## Generator

In [21]:
class Generator(nn.Module):

    def __init__(self, config):
        super().__init__()

        # dependency injection
        batch_norm = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
        relu = functools.partial(nn.ReLU, inplace=True)
        leaky_relu = functools.partial(nn.LeakyReLU, inplace=True, negative_slope=0.2)
        dropout = functools.partial(nn.Dropout, p=0.5)
        tahn = nn.Tanh
        
        # build model recursively inside-out
        blocks = config['blocks'][::-1]  

        self.model = None

        # build innermost block
        self.model = UnetBlock(
            in_filters=blocks[0]['filters'], 
            out_filters=blocks[0]['filters'],

            submodule=None, 
            sub_in_filters=None, 
            sub_out_filters=None,
            sub_skip_connection=False,

            skip_connection=False,
            dropout=dropout if blocks[0]['dropout'] else None,
            in_norm=None, out_norm=False,
            in_act=relu, out_act=relu
        )
        
        # build between blocks
        for i, layer in enumerate(blocks[1:], 1):
            self.model = UnetBlock(
                in_filters=layer['filters'], 
                out_filters=layer['filters'],

                submodule=self.model, 
                sub_in_filters=blocks[i-1]['filters'], 
                sub_out_filters=blocks[i-1]['filters'],
                sub_skip_connection=blocks[i-1]['skip_connection'],

                skip_connection=blocks[i]['skip_connection'],
                dropout=dropout if layer['dropout'] else None,
                in_norm=batch_norm, out_norm=batch_norm,
                in_act=leaky_relu, out_act=relu
            )
        
        # build outermost block
        self.model = UnetBlock(
            in_filters=config['in_channels'],
            out_filters=config['out_channels'],

            submodule=self.model,
            sub_in_filters=blocks[-1]['filters'], 
            sub_out_filters=blocks[-1]['filters'],
            sub_skip_connection=blocks[-1]['skip_connection'],

            skip_connection=False, 
            dropout=None,
            in_norm=None, out_norm=None,
            in_act=leaky_relu, out_act=tahn
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# summary(
#     Generator(generator_config),
#     input_size=(16, 3, 256, 256),
#     col_names=['output_size', 'num_params', 'mult_adds'],
#     depth=24
# )

## Discriminator

In [23]:
class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()

        # we do not use bias in conv2d layer if batch norm is used, because batch norm already has bias
        batch_norm = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
        leaky_relu = functools.partial(nn.LeakyReLU, negative_slope=0.2, inplace=True)

        conv_common_args = {
            'kernel_size': 4, 
            'padding': 1,
        }

        blocks = config['blocks']
        layers = []

        # build first block
        layers += [
            nn.Conv2d(config['in_channels'], blocks[0]['filters'], stride=2, **conv_common_args),
            leaky_relu()
        ]

        # build between block
        prev_filters = blocks[0]['filters']
        for i, layer in enumerate(blocks[1:-1], 1):
            curr_filters = min(blocks[i]['filters'], blocks[0]['filters']*8)
            layers += [
                nn.Conv2d(prev_filters, curr_filters, bias=False, stride=2, **conv_common_args),
                batch_norm(curr_filters),
                leaky_relu()
            ]
            prev_filters = curr_filters

        # build last block
        curr_filters = min(blocks[-1]['filters'], blocks[0]['filters'] * 8)
        layers += [
            # stride = 1 for last block
            nn.Conv2d(prev_filters, curr_filters, stride=1, bias=False, **conv_common_args),
            batch_norm(curr_filters),
            leaky_relu(),
            # convert to 1 dimensional output
            nn.Conv2d(curr_filters, 1, stride=1, **conv_common_args)
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [24]:
# summary(
#     Discriminator(discriminator_config),
#     input_size=(16, 3, 256, 256),
#     col_names=['output_size', 'num_params', 'mult_adds'],
#     depth=24
# )

Layer (type:depth-idx)                   Output Shape              Param #                   Mult-Adds
Discriminator                            --                        --                        --
├─Sequential: 1-1                        [16, 1, 30, 30]           --                        --
│    └─Conv2d: 2-1                       [16, 64, 128, 128]        3,136                     822,083,584
│    └─LeakyReLU: 2-2                    [16, 64, 128, 128]        --                        --
│    └─Conv2d: 2-3                       [16, 128, 64, 64]         131,072                   8,589,934,592
│    └─BatchNorm2d: 2-4                  [16, 128, 64, 64]         256                       4,096
│    └─LeakyReLU: 2-5                    [16, 128, 64, 64]         --                        --
│    └─Conv2d: 2-6                       [16, 256, 32, 32]         524,288                   8,589,934,592
│    └─BatchNorm2d: 2-7                  [16, 256, 32, 32]         512                       8,

# Data

In [None]:
%%bash

FILE="facades"

if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then
  echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps"
  exit 1
fi

if [[ $FILE == "cityscapes" ]]; then
    echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
    echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
    exit 1
fi

echo "Specified [$FILE]"

URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz
TAR_FILE=./datasets/$FILE.tar.gz
TARGET_DIR=./datasets/$FILE/
wget -N $URL -O $TAR_FILE > /dev/null
mkdir -p $TARGET_DIR
tar -zxvf $TAR_FILE -C ./datasets/ > /dev/null
rm $TAR_FILE

## Utilities

In [26]:
IMG_EXTENSIONS = [
    '.jpg', '.jpeg',
    '.png', '.ppm', '.bmp',
    '.tif', '.tiff',
]

def is_image_file(filename):
    return any(filename.lower().endswith(ext) for ext in IMG_EXTENSIONS)

def get_all_image_paths(root):
    paths = []
    assert os.path.isdir(root)

    for root, _folders, filenames in sorted(os.walk(root)):
        for filename in filenames:
            if is_image_file(filename):
                paths.append(os.path.join(root, filename))
    
    return paths

## Dataset

In [27]:
class MyDataset(Dataset):

    def __init__(self, root):
        self.paths = sorted(get_all_image_paths(root))
        
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, i):
        A, B = self._split_input_output(self._read_im(self.paths[i]))

        transform = self._generate_transform()
        
        return transform(A), transform(B)

    def _read_im(self, path):
        return Image.open(path).convert('RGB')

    def _split_input_output(AB):
        w, h = AB.size
        w2 = int(w / 2)
        A = AB.crop((0, 0, w2, h))
        B = AB.crop((w2, 0, w, h))

        return A, B

    def _generate_transform(self):
        new_size = 286
        old_size = 256

        rand_x = random.randint(0, new_size - old_size)
        rand_y = random.randint(0, new_size - old_size)
        flip = random.random() > 0.5
        
        return transforms.Compose([
            transforms.Resize((new_size, new_size), interpolation=InterpolationMode.BICUBIC, antialias=True),
            transforms.Lambda(lambda im: self._crop(im, (rand_x, rand_y), (old_size, old_size))),
            transforms.Lambda(lambda im: self._flip(im, flip)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])


    def _flip(self, im, flip):
        if flip:
            return im.transpose(Image.FLIP_LEFT_RIGHT)
        return im

    def _crop(self, im, pos, size):
        return im.crop((pos[0], pos[1], pos[0] + size[0], pos[1] + size[1]))

# Criterion

In [None]:
class GANLoss(nn.Module):

    def __init__(self, real_label=1.0, fake_label=0.0):
        super().__init__()

        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))

        self.loss = nn.MSELoss()
    
    def __call__(self, model_output, target_is_real):

        label = self.real_label if target_is_real else self.fake_label
        label.expand_as(model_output)

        return self.loss(model_output, label)

# Training

## Utilities

In [None]:
def format_time(seconds):
    return time.strftime('%Hh:%Mm:%Ss', time.gmtime(seconds))

def save_checkpoint(net_G, net_D, optimizer_G, optimizer_D, epoch):
    file_name = f'{sconfig.run_id}_epoch_{epoch}.ckpt'
    torch.save({
        'net_G_state_dict': net_G.state_dict(),
        'net_D_state_dict': net_D.state_dict(),
        'net_G_optimizer_state_dict': optimizer_G.state_dict(),
        'net_D_optimizer_state_dict': optimizer_D.state_dict(),
        'session_config': sconfig,
        'epoch': epoch
    }, file_name)

def load_checkpoint(run_id, epoch):
    file_name = f'{run_id}_epoch_{epoch}.ckpt'
    load_file(file_name)  # ensure exists locally
    checkpoint = torch.load(file_name)

    sconfig = checkpoint['session_config']

    net_G = Generator(sconfig.generator_config)
    net_D = Discriminator(sconfig.discriminator_config)
    net_G.load_state_dict(checkpoint['net_G_state_dict'])
    net_D.load_state_dict(checkpoint['net_D_state_dict'])

    optimizer_G = optim.Adam(net_G.parameters(), lr=sconfig.lr, betas=(sconfig.optimizer_beta1, sconfig.optimizer_beta2))
    optimizer_D = optim.Adam(net_D.parameters(), lr=sconfig.lr, betas=(sconfig.optimizer_beta1, sconfig.optimizer_beta2))
    optimizer_G.load_state_dict(checkpoint['net_G_optimizer_state_dict'])
    optimizer_D.load_state_dict(checkpoint['net_D_optimizer_state_dict'])

    return sconfig, net_G, net_D, optimizer_G, optimizer_D, checkpoint['epoch']

def set_requires_grad(net, requires_grad):
    for param in net.parameters():
        param.requires_grad = requires_grad


In [None]:
def train_batch(net_G, net_D, optimizer_G, optimizer_D, real_A, real_B, criterion_gan, criterion_l1):

    ###
    # discrminator
    ###
    set_requires_grad(net_D, True)

    # generate fake image using generator
    fake_B = net_G(real_A)

    # discrminate fake image
    fake_AB = torch.cat((real_A, fake_B), dim=1).detach()  # conditionalGAN takes both real and fake image
    pred_fake = net_D(fake_AB)
    loss_D_fake = criterion_gan(pred_fake, False)

    # discrminate real image
    real_AB = torch.cat((real_A, real_B), 1)
    pred_real = net_D(real_AB)
    loss_D_real = criterion_gan(pred_real, True)

    # update
    optimizer_D.zero_grad()
    loss_D = (loss_D_fake + loss_D_real) * 0.5
    loss_D.backward()
    optimizer_D.step()

    ###
    # generator
    ###
    set_requires_grad(net_D, False)

    # generator should fool the discriminator
    fake_AB = torch.cat((real_A, fake_B), 1)
    pred_fake = net_D(fake_AB)
    loss_G_fake = criterion_gan(pred_fake, True)

    # l1 loss between generated and real image for more accurate output
    loss_G_l1 = criterion_l1(fake_B, real_B) * sconfig.l1_lambda

    # update
    optimizer_G.zero_grad()
    loss_G = loss_G_fake + loss_G_l1
    loss_G.backward()
    optimizer_G.step()

    return loss_G.item(), loss_D.item()

In [29]:
dataset = MyDataset('./datasets/facades/train')
dataloader = DataLoader(dataset, batch_size=sconfig.batch_size, shuffle=sconfig.shuffle, num_workers=sconfig.num_workers, pin_memory=sconfig.pin_memory)

net_G = Generator(generator_config)
net_D = Discriminator(discriminator_config)

criterion_l1 = nn.L1Loss()
criterion_gan = GANLoss()

optimizer_G = optim.Adam(net_G.parameters(), lr=sconfig.lr, betas=(sconfig.optimizer_beta1, sconfig.optimizer_beta2))
optimizer_D = optim.Adam(net_D.parameters(), lr=sconfig.lr, betas=(sconfig.optimizer_beta1, sconfig.optimizer_beta2))

net_G.to(device)
net_D.to(device)

training_start_time = time.time()
for epoch in range(sconfig.start_epoch, sconfig.end_epoch + 1):
    epoch_start_time = time.time()

    epoch_train_G_losses = []
    epoch_train_D_losses = []
    for i, (inp, tar) in enumerate(dataloader):
        net_G.train()
        net_D.train()
        inp = inp.to(device)
        tar = tar.to(device)

        loss_G, loss_D = train_batch(net_G, net_D, inp, tar)

        epoch_train_G_losses.append(loss_G)
        epoch_train_D_losses.append(loss_D)
    
    if epoch % sconfig.eval_freq == 0 or epoch == sconfig.start_epoch:
        net_G.eval()
        net_D.eval()
        epoch_eval_loss = ...
    else:
        epoch_eval_loss = None
    
    if epoch % sconfig.log_freq == 0 or epoch == sconfig.start_epoch:
        print(
            f'[Epoch={epoch}] ' \
            f'[TrainLossG={np.mean(epoch_train_G_losses):.4f}] ' \
            f'[TrainLossD={np.mean(epoch_train_D_losses):.4f}] ' \
            f'[EvalLoss={epoch_eval_loss:.4f}]' if epoch_eval_loss is not None else '' \
            f'[EpochTime={format_time(time.time() - epoch_start_time)}] ' \
            f'[TrainTime={format_time(time.time() - training_start_time)}]'
        )

    if epoch % sconfig.save_freq == 0 or epoch == sconfig.start_epoch:
        save_checkpoint(net_G, net_D, optimizer_G, optimizer_D, epoch)

# Evaluation