Just press 'Runtime' -> 'Run all' and the program will start.

It's set to use pretrained models, but if training is needed, change the last block of code according to the instructions.

In [None]:
import numpy as np
import cv2
import pickle
import warnings
warnings.simplefilter("ignore", UserWarning)
from skimage.color import rgb2lab, rgb2gray, lab2rgb
from skimage.io import imread, imshow
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
from PIL import Image, ImageCms
from torchsummary import summary
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import make_grid, save_image
from torchvision.transforms import InterpolationMode
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg16, vgg19
import torch.optim as optim
import glob
import sys
import tqdm
import os
import shutil


GDRIVE_MOUNT = False
models_path = '/content/saved_models/'
saved_images_path = '/content/saved_images/'
saved_images_test_path = '/content/saved_images_test/'
gdrive_saved_models = "/content/gdrive/MyDrive/Computer_Science/Colab/project/unet_only/"

data_path = 'landscape'

# dataset of 328K images
if data_path == 'test_256':
    if not os.path.isfile('test_256.tar'):
        !wget 'http://data.csail.mit.edu/places/places365/test_256.tar'
        shutil.unpack_archive("test_256.tar", "/content/data/")

# dataset of 20K images
elif data_path == 'small_test_256':
    if not os.path.isfile('small_test_256.zip'):
      !gdown --id '1_u1S2C6zuBgeT__QukSUP41sjs8uOg_M'
      shutil.unpack_archive("small_test_256.zip", "/content/data/")

# dataset of landscapes
elif data_path == 'landscape':
    if not os.path.isfile('landscape_data.zip'):
        !gdown --id '1gCMCD1CrTu1QK_KCoSROy6rEFiAfC9QB'
        shutil.unpack_archive("landscape_data.zip", "/content/data/")
    if not os.path.isfile('landscape_test.zip'):
        !gdown --id '188faJSg2spTlt25ndD9jddR5DbXTsqQu'
        shutil.unpack_archive("landscape_test.zip", "/content/data/")

if not os.path.exists(saved_images_path):
    os.makedirs(saved_images_path)

if not os.path.exists(saved_images_test_path):
    os.makedirs(saved_images_test_path)

if not os.path.exists(models_path):
    os.makedirs(models_path)

Downloading...
From: https://drive.google.com/uc?id=1gCMCD1CrTu1QK_KCoSROy6rEFiAfC9QB
To: /content/landscape_data.zip
651MB [00:05, 120MB/s] 
Downloading...
From: https://drive.google.com/uc?id=188faJSg2spTlt25ndD9jddR5DbXTsqQu
To: /content/landscape_test.zip
7.66MB [00:00, 24.4MB/s]


In [None]:
## Uncomment to also use GDrive, for debugging

# from google.colab import drive
# drive.mount('/content/gdrive')
# from google.colab import files

# GDRIVE_MOUNT = True

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1
LEARNING_RATE = 2e-5
NUM_WORKERS = 4
NUM_EPOCHS = 50
ROOT_PATH = "/content/data/"
IM_SIZE = (512, 512)
# IM_SIZE = (256, 256)
LOAD_MODEL = False
SAVE_MODEL = True

In [None]:
# Utils

# Data transforms
train_transform = transforms.Compose([
    transforms.Resize(IM_SIZE),
    transforms.RandomHorizontalFlip(p=0.5)
])

default_transform = transforms.Compose([
    transforms.Resize(IM_SIZE)
])

# custom dataset for colorization task
class ColorizationDataset(Dataset):
    def __init__(self, path, transform=None):
        self.transform = transform
        self.root_path = path
        first = 'Z'
        for file in os.listdir(self.root_path):
            if file < first:
                first = file
        self.first_idx = int(first[-12:-4]) - 1

    def __len__(self):
        return len(os.listdir(self.root_path))

    def __getitem__(self, index):
        index += self.first_idx
        if torch.is_tensor(index):
            index = index.to_list()
        ending = '00000000' + str(index + 1)
        path_ending = ending[-8:] + '.jpg'

        im_rgb = Image.open(f'{self.root_path}/{path_ending}').convert("RGB")
        if self.transform:
            im_rgb = self.transform(im_rgb)
        # im_rgb = im_rgb.resize(IM_SIZE)
        im_rgb = np.array(im_rgb)
        im_lab = rgb2lab(im_rgb).astype("float32") # Converting RGB to L*a*b
        im_lab = transforms.ToTensor()(im_lab)

        # LAB colorspace normalization
        l = im_lab[[0], ...] / 50. - 1. # Between -1 and 1
        ab = im_lab[[1, 2], ...] / 110. # Between -1 and 1
        Lab = torch.cat((l, ab), dim=0)

        # l = grey, ab = color
        return l, ab


# Load dataset
def load_dataset(path, test = False, shuffle=True):
    if test is True:
        dataset = ColorizationDataset(path, default_transform)
    else:
        dataset = ColorizationDataset(path, train_transform)
    idxs = list(range(len(dataset)))
    if shuffle:
        np.random.shuffle(idxs)
        sampler = torch.utils.data.sampler.SubsetRandomSampler(idxs)
    else:
        sampler = torch.utils.data.sampler.SequentialSampler(idxs)
    loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    return loader


# Denormalize images
def denormalize(images, means, stds):
    means = torch.tensor(means).reshape(1, 3, 1, 1)
    stds = torch.tensor(stds).reshape(1, 3, 1, 1)
    return images * stds + means


# Convert lab image to rgb
def my_lab2rgb(L, AB, denormalize = True):
        L = torch.squeeze(L, dim=0)
        AB = torch.squeeze(AB, dim=0)
        if denormalize is True:
            AB = AB * 110.0
            L = (L + 1.0) * 50.0
        Lab = torch.cat((L, AB), dim=0)
        Lab = Lab.data.cpu().float().numpy()
        Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
        rgb = lab2rgb(Lab)
        rgb = np.transpose(rgb.astype(np.float64), (2, 0, 1))
        return rgb


# Delete contents of provided path
def delete_images(path):
    files = glob.glob(path + '*')
    # print(len(files))
    # for f in files:
        # os.remove(f)
    shutil.rmtree(path, ignore_errors=True)
    os.makedirs(path)


# Find normalizing parameters
def find_mean_std(loader):
    mean = 0.
    std = 0.
    with tqdm.tqdm(total=(len(loader)), file=sys.stdout) as pbar:
      for images in loader:
        #   l = l.squeeze(dim=0)
        #   ab = ab.squeeze(dim=0)
        #   Lab = torch.cat((l, ab), dim=0)
          batch_samples = images.size(0)
          images = images.view(batch_samples, images.size(1), -1)
          mean += images.mean(2).sum(0)
          std += images.std(2).sum(0)
          pbar.update(); 
    mean /= len(loader.dataset)
    std /= len(loader.dataset)
    print(mean, std)


# Save checkpoint
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


# Load checkpoint
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


# Set whether model requires grad
def set_requires_grad(model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad


# Class to keep track of loss
class LossClassPix2pix:
    def __init__(self):
        self.loss_D_fake = 0
        self.loss_D_real = 0
        self.loss_D = 0
        self.loss_G_GAN = 0
        self.loss_G_L1 = 0
        self.loss_G = 0
        self.loss_VGG = 0


class LossClassCycle:
    def __init__(self):
        self.loss_D_C_fake = 0
        self.loss_D_G_fake = 0
        self.loss_D_C_real = 0
        self.loss_D_G_real = 0
        self.loss_D_C = 0
        self.loss_D_G = 0
        self.loss_G_C_GAN = 0
        self.loss_G_G_GAN = 0
        self.loss_G_C_L1 = 0
        self.loss_G_G_L1 = 0
        self.loss_G_C_Cycle = 0
        self.loss_G_G_Cycle = 0
        self.loss_G_C = 0
        self.loss_G_G = 0
        self.loss_VGG = 0


class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
        self.avg_list = []
        self.clean_list = []
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count
        self.avg_list.append(self.avg)
        self.clean_list.append(val)


def create_loss_meters_pix2pix():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    loss_VGG = AverageMeter()

    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G,
            'loss_VGG': loss_VGG}


def create_loss_meters_cycle():
    loss_D_C_fake = AverageMeter()
    loss_D_G_fake = AverageMeter()
    loss_D_C_real = AverageMeter()
    loss_D_G_real = AverageMeter()
    loss_D_C = AverageMeter()
    loss_D_G = AverageMeter()
    loss_G_C_GAN = AverageMeter()
    loss_G_G_GAN = AverageMeter()
    loss_G_C_L1 = AverageMeter()
    loss_G_G_L1 = AverageMeter()
    loss_G_C_Cycle = AverageMeter()
    loss_G_G_Cycle = AverageMeter()
    loss_G_C = AverageMeter()
    loss_G_G = AverageMeter()
    loss_VGG = AverageMeter()
    
    return {'loss_D_C_fake': loss_D_C_fake,
            'loss_D_G_fake': loss_D_G_fake,
            'loss_D_C_real': loss_D_C_real,
            'loss_D_G_real': loss_D_G_real,
            'loss_D_C': loss_D_C,
            'loss_D_G': loss_D_G,
            'loss_G_C_GAN': loss_G_C_GAN,
            'loss_G_G_GAN': loss_G_G_GAN,
            'loss_G_C_L1': loss_G_C_L1,
            'loss_G_G_L1': loss_G_G_L1,
            'loss_G_C_Cycle': loss_G_C_Cycle,
            'loss_G_G_Cycle': loss_G_G_Cycle,
            'loss_G_C': loss_G_C,
            'loss_G_G': loss_G_G,
            'loss_VGG': loss_VGG}


def update_losses(losses, losses_dict, count):
    for loss_name, loss_meter in losses_dict.items():
        loss = getattr(losses, loss_name)
        if hasattr(loss, 'item'):
            loss_meter.update(loss.item(), count=count)
        else:
            loss_meter.update(loss, count=count)


def log_results(loss_meter_dict):
    print('\n')
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f"{loss_name}: {loss_meter.avg:.5f}")

In [None]:
# Generator model - old

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_channels, output_channels, num_features = 64, num_residuals = 1):
        super().__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*4, num_features*8, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*8, num_features*16, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*16, num_features*32, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*32) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*32, num_features*16, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*16, num_features*8, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*8, num_features*4, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )
        self.last = nn.Conv2d(num_features*1, output_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x, L = None):
        if L is not None:
            return L
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x


class GeneratorPix2pix(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(input_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 16, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 16, features * 16, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 16, features * 16, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 16, features * 16, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 16 * 2, features * 16, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 16 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, output_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1, d7], 1))
        up3 = self.up3(torch.cat([up2, d6], 1))
        up4 = self.up4(torch.cat([up3, d5], 1))
        up5 = self.up5(torch.cat([up4, d4], 1))
        up6 = self.up6(torch.cat([up5, d3], 1))
        up7 = self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))

In [None]:
# Discriminator model

class DisBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(DisBlock(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [None]:
# VGG perceptual loss

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate
        self.resize = resize
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        loss = 0.0
        x = input
        y = target
        for i, block in enumerate(self.blocks):
            x = block(x)
            y = block(y)
            if i in feature_layers:
                loss += torch.nn.functional.l1_loss(x, y)
            if i in style_layers:
                act_x = x.reshape(x.shape[0], x.shape[1], -1)
                act_y = y.reshape(y.shape[0], y.shape[1], -1)
                gram_x = act_x @ act_x.permute(0, 2, 1)
                gram_y = act_y @ act_y.permute(0, 2, 1)
                loss += torch.nn.functional.l1_loss(gram_x, gram_y)
        return loss

In [None]:
# Train function for U-Net only

def train_fn_unet(gen_C, opt_G, loader, l1_crit, gan_crit, epoch):
    
    losses_dict = create_loss_meters_pix2pix()
    vgg_crit = VGGPerceptualLoss().to(DEVICE)

    with tqdm.tqdm(total=(len(loader)), file=sys.stdout, leave=True, position=0) as pbar:
        for idx, (l, ab) in enumerate(loader, start=1):
            losses = LossClassPix2pix()
            l = l.to(DEVICE)
            ab = ab.to(DEVICE)
            
            fake_color = gen_C(l)
            fake_color_im = torch.cat((l, fake_color), 1)
            real_color_im = torch.cat((l, ab), 1)

            
            # real_im = np.transpose(my_lab2rgb(l, ab), (1,2,0))
            # fake_im = np.transpose(my_lab2rgb(l, fake_color), (1,2,0))
            # fake = torch.from_numpy(np.transpose(fake_im, (2,0,1))).unsqueeze(dim=0).to(DEVICE)
            # real = torch.from_numpy(np.transpose(real_im, (2,0,1))).unsqueeze(dim=0).to(DEVICE)
            # losses.loss_VGG = vgg_crit(fake.float(), real.float())

            fake = fake_color_im.to(DEVICE)
            real = real_color_im.to(DEVICE)
            losses.loss_VGG = vgg_crit(fake.float(), real.float())

            # Train Generator
            gen_C.train()
            opt_G.zero_grad()
            losses.loss_G_L1 = l1_crit(fake_color, ab)
            losses.loss_G = losses.loss_G_L1
            losses.loss_G.backward(retain_graph=False)
            opt_G.step()
            gen_C.eval()

            # Update losses
            update_losses(losses, losses_dict, count=l.size(0))

            if idx % 100 == 0:
                log_results(losses_dict)
                im_rgb_np = my_lab2rgb(l, fake_color)
                im_rgb = torch.from_numpy(im_rgb_np)
                if not os.path.exists(saved_images_path + str(epoch)):
                    os.makedirs(saved_images_path + str(epoch))
                save_image(im_rgb, saved_images_path + str(epoch) + f"/{idx}_rgb_color.png")
                save_image(torch.from_numpy(my_lab2rgb(l, ab)), saved_images_path + str(epoch) + f"/{idx}_rgb_real.png")
                
            pbar.update();
    return losses_dict

In [None]:
# Train function for Pix2Pix

def train_fn_pix2pix(disc_C, gen_C, loader, opt_D, opt_G, l1_crit, gan_crit, epoch):
    C_reals = 0
    C_fakes = 0
    lambda_l1 = 5
    
    losses_dict = create_loss_meters_pix2pix()
    vgg_crit = VGGPerceptualLoss().to(DEVICE)

    with tqdm.tqdm(total=(len(loader)), file=sys.stdout, leave=True, position=0) as pbar:
        for idx, (l, ab) in enumerate(loader, start=1):
            losses = LossClassPix2pix()
            l = l.to(DEVICE)
            ab = ab.to(DEVICE)
            
            fake_color = gen_C(l)
            fake_color_im = torch.cat((l, fake_color), 1)
            real_color_im = torch.cat((l, ab), 1)

            
            # real_im = np.transpose(my_lab2rgb(l, ab), (1,2,0))
            # fake_im = np.transpose(my_lab2rgb(l, fake_color), (1,2,0))
            # fake = torch.from_numpy(np.transpose(fake_im, (2,0,1))).unsqueeze(dim=0).to(DEVICE)
            # real = torch.from_numpy(np.transpose(real_im, (2,0,1))).unsqueeze(dim=0).to(DEVICE)
            # losses.loss_VGG = vgg_crit(fake.float(), real.float())

            fake = fake_color_im.to(DEVICE)
            real = real_color_im.to(DEVICE)
            losses.loss_VGG = vgg_crit(fake.float(), real.float())

            # Train Discriminator
            disc_C.train()
            set_requires_grad(disc_C, True)
            opt_D.zero_grad()
            real_pred = disc_C(real_color_im)
            fake_pred = disc_C(fake_color_im.detach())
            losses.loss_D_real = gan_crit(real_pred, torch.ones_like(real_pred))
            losses.loss_D_fake = gan_crit(fake_pred, torch.zeros_like(fake_pred))
            losses.loss_D = (losses.loss_D_fake + losses.loss_D_real) / 2
            losses.loss_D.backward(retain_graph=False)
            opt_D.step()
            disc_C.eval()

            # Train Generator
            gen_C.train()
            set_requires_grad(disc_C, False)
            opt_G.zero_grad()
            fake_pred = disc_C(fake_color_im.detach())
            losses.loss_G_GAN = gan_crit(fake_pred, torch.ones_like(fake_pred))
            losses.loss_G_L1 = l1_crit(fake_color, ab) * lambda_l1
            losses.loss_G = losses.loss_G_GAN + losses.loss_G_L1
            losses.loss_G.backward(retain_graph=False)
            opt_G.step()
            gen_C.eval()

            # Update losses
            C_reals += real_pred.mean().item()
            C_fakes += fake_pred.mean().item()
            update_losses(losses, losses_dict, count=l.size(0))

            if idx % 100 == 0:
                log_results(losses_dict)
                im_rgb_np = my_lab2rgb(l, fake_color)
                im_rgb = torch.from_numpy(im_rgb_np)
                if not os.path.exists(saved_images_path + str(epoch)):
                    os.makedirs(saved_images_path + str(epoch))
                save_image(im_rgb, saved_images_path + str(epoch) + f"/{idx}_rgb_color.png")
                save_image(torch.from_numpy(my_lab2rgb(l, ab)), saved_images_path + str(epoch) + f"/{idx}_rgb_real.png")
                
            pbar.update(); pbar.set_postfix(C_real=C_reals/(idx+1), C_fake=C_fakes/(idx+1))
    return losses_dict

In [None]:
# Train function for CycleGAN

def train_fn_cycle(disc_C, disc_G, gen_C, gen_G, loader, opt_D, opt_G, l1_crit, gan_crit, d_scaler, g_scaler, epoch):
    C_reals = 0
    C_fakes = 0
    lambda_gan = 5
    lambda_l1 = 1
    lambda_cycle = 10
    
    losses_dict = create_loss_meters_cycle()
    vgg_crit = VGGPerceptualLoss().to(DEVICE)

    with tqdm.tqdm(total=(len(loader)), file=sys.stdout, leave=True, position=0) as pbar:
        for idx, (l, ab) in enumerate(loader, start=1):
            losses = LossClassCycle()
            l = l.to(DEVICE)
            ab = ab.to(DEVICE)
            
            fake_color = gen_C(l)
            fake_color_im = torch.cat((l, fake_color), 1)
            real_color_im = torch.cat((l, ab), 1)

            fake = fake_color_im.to(DEVICE)
            real = real_color_im.to(DEVICE)
            losses.loss_VGG = vgg_crit(fake.float(), real.float())
        
            fake_grey = gen_G(ab)
            fake_grey_im = torch.cat((fake_grey.detach(), ab), 1)

            # Train Discriminators
            opt_D.zero_grad()
            # Color
            disc_C.train()
            set_requires_grad(disc_C, True)
            real_pred = disc_C(real_color_im)
            fake_pred = disc_C(fake_color_im.detach())
            losses.loss_D_C_real = gan_crit(real_pred, torch.ones_like(real_pred))
            losses.loss_D_C_fake = gan_crit(fake_pred, torch.zeros_like(fake_pred))
            losses.loss_D_C = (losses.loss_D_C_fake + losses.loss_D_C_real) / 2
            losses.loss_D_C.backward(retain_graph=True)

            # Grey
            disc_G.train()
            set_requires_grad(disc_G, True)
            real_pred = disc_G(real_color_im)
            fake_pred = disc_G(fake_grey_im.detach())
            losses.loss_D_G_real = gan_crit(real_pred, torch.ones_like(real_pred))
            losses.loss_D_G_fake = gan_crit(fake_pred, torch.zeros_like(fake_pred))
            losses.loss_D_G = (losses.loss_D_G_fake + losses.loss_D_G_real) / 2
            losses.loss_D_G.backward(retain_graph=False)
            opt_D.step()
            disc_C.eval()
            disc_G.eval()

            # Train Generators
            opt_G.zero_grad()
            # Color
            gen_C.train()
            set_requires_grad(disc_C, False)
            fake_pred = disc_C(fake_color_im.detach())
            losses.loss_G_C_GAN = gan_crit(fake_pred, torch.ones_like(fake_pred))
            losses.loss_G_C_L1 = l1_crit(fake_color, ab)
            cycle_color = gen_C(fake_grey)
            losses.loss_G_C_Cycle = l1_crit(cycle_color, ab)
            losses.loss_G_C = losses.loss_G_C_GAN * lambda_gan + losses.loss_G_C_L1 * lambda_l1 + losses.loss_G_C_Cycle * lambda_cycle
            losses.loss_G_C.backward(retain_graph=True)

            # Grey
            gen_G.train()
            set_requires_grad(disc_G, False)
            fake_pred = disc_G(fake_grey_im.detach())
            losses.loss_G_G_GAN = gan_crit(fake_pred, torch.ones_like(fake_pred))
            losses.loss_G_G_L1 = l1_crit(fake_grey, l)
            cycle_grey = gen_G(fake_color)
            losses.loss_G_G_Cycle = l1_crit(cycle_grey, l)
            losses.loss_G_G = losses.loss_G_G_GAN * lambda_gan + losses.loss_G_G_L1 * lambda_l1 + losses.loss_G_G_Cycle * lambda_cycle
            losses.loss_G_G.backward(retain_graph=False)
            opt_G.step()
            gen_C.eval()
            gen_C.eval()

            # Update losses
            C_reals += real_pred.mean().item()
            C_fakes += fake_pred.mean().item()
            update_losses(losses, losses_dict, count=l.size(0))

            if idx % 100 == 0:
                log_results(losses_dict)
                im_rgb_np = my_lab2rgb(l, fake_color)
                im_rgb = torch.from_numpy(im_rgb_np)
                if not os.path.exists(saved_images_path + str(epoch)):
                    os.makedirs(saved_images_path + str(epoch))
                save_image(im_rgb, saved_images_path + str(epoch) + f"/{idx}_rgb_color.png")
                save_image(torch.from_numpy(my_lab2rgb(l, ab)), saved_images_path + str(epoch) + f"/{idx}_rgb_real.png")
                
            pbar.update(); pbar.set_postfix(C_real=C_reals/(idx+1), C_fake=C_fakes/(idx+1))
    return losses_dict

In [None]:
# Init training for U-Net only

def train_unet(dataset_name, reload = False):
    delete_images(saved_images_path) 

    total_losses = []

    gen_C = GeneratorPix2pix(input_channels=1, output_channels=2).to(DEVICE)

    summary(gen_C, (1, 256, 256))

    opt_G = optim.Adam(
        gen_C.parameters(),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    if reload:
        load_checkpoint(gdrive_saved_models + 'CHECKPOINT_GEN_C', gen_C, opt_G, LEARNING_RATE)

    l1_crit = nn.L1Loss()
    gan_crit = nn.MSELoss()

    loader = load_dataset(ROOT_PATH + dataset_name)

    for epoch in range(NUM_EPOCHS):
        losses = train_fn_unet(gen_C, opt_G, loader, l1_crit, gan_crit, epoch)
        total_losses.append(losses)
        folder_origin = saved_images_path + str(epoch)
        folder_destination = gdrive_saved_models + 'images/'

        if GDRIVE_MOUNT:
          !cp -av $folder_origin $folder_destination

          pickle.dump(total_losses, open(gdrive_saved_models + "losses.p", "wb" ))

          torch.save(gen_C, gdrive_saved_models + 'CHECKPOINT_GEN_C_model')

        pickle.dump(total_losses, open("losses.p", "wb" ))
        torch.save(gen_C, 'CHECKPOINT_GEN_C_model')

In [None]:
# Init training for Pix2Pix

def train_pix2pix(dataset_name, reload = False):
    delete_images(saved_images_path) 

    total_losses = []

    gen_C = GeneratorPix2pix(input_channels=1, output_channels=2).to(DEVICE)
    disc_C = PatchDiscriminator(in_channels=3).to(DEVICE)

    summary(gen_C, (1, 256, 256))

    opt_D = optim.Adam(
        disc_C.parameters(),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_G = optim.Adam(
        gen_C.parameters(),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    if reload:
        load_checkpoint(gdrive_saved_models + 'CHECKPOINT_GEN_C', gen_C, opt_G, LEARNING_RATE)
        load_checkpoint(gdrive_saved_models + 'CHECKPOINT_DISC_C', disc_C, opt_D, LEARNING_RATE)

    l1_crit = nn.L1Loss()
    gan_crit = nn.MSELoss()

    loader = load_dataset(ROOT_PATH + dataset_name)

    for epoch in range(NUM_EPOCHS):
        losses = train_fn_pix2pix(disc_C, gen_C, loader, opt_D, opt_G, l1_crit, gan_crit, epoch)
        total_losses.append(losses)
        folder_origin = saved_images_path + str(epoch)
        folder_destination = gdrive_saved_models + 'images/'

        if GDRIVE_MOUNT:
          !cp -av $folder_origin $folder_destination

          pickle.dump(total_losses, open(gdrive_saved_models + "losses.p", "wb" ))

          torch.save(gen_C, gdrive_saved_models + 'CHECKPOINT_GEN_C_model')
          # save_checkpoint(gen_C, opt_G, filename=gdrive_saved_models + 'CHECKPOINT_GEN_C')
          # save_checkpoint(disc_C, opt_D, filename=gdrive_saved_models + 'CHECKPOINT_DISC_C')
        
        pickle.dump(total_losses, open("losses.p", "wb" ))

        torch.save(gen_C, 'CHECKPOINT_GEN_C_model')

In [None]:
# Init training for CycleGAN

def train_cycle(dataset_name, reload = False):
    delete_images(saved_images_path) 

    total_losses = []

    gen_C = GeneratorPix2pix(input_channels=1, output_channels=2).to(DEVICE)
    gen_G = GeneratorPix2pix(input_channels=2, output_channels=1).to(DEVICE)

    disc_C = PatchDiscriminator(in_channels=3).to(DEVICE)
    disc_G = PatchDiscriminator(in_channels=3).to(DEVICE)

    opt_D = optim.Adam(
        list(disc_C.parameters()) + list(disc_G.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_G = optim.Adam(
        list(gen_C.parameters()) + list(gen_G.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    if reload:
        load_checkpoint(gdrive_saved_models + 'CHECKPOINT_GEN_C', gen_C, opt_G, LEARNING_RATE)
        load_checkpoint(gdrive_saved_models + 'CHECKPOINT_DISC_C', disc_C, opt_D, LEARNING_RATE)

    l1_crit = nn.L1Loss()
    gan_crit = nn.MSELoss()
    # gan_crit = nn.BCEWithLogitsLoss()

    loader = load_dataset(ROOT_PATH + dataset_name)

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        losses = train_fn_cycle(disc_C, disc_G, gen_C, gen_G, loader, opt_D, opt_G, l1_crit, gan_crit, d_scaler, g_scaler, epoch)
        total_losses.append(losses)
        folder_origin = saved_images_path + str(epoch)
        folder_destination = gdrive_saved_models + 'images/'
        !cp -av $folder_origin $folder_destination

        pickle.dump(total_losses, open(gdrive_saved_models + "losses.p", "wb" ))

        torch.save(gen_C, gdrive_saved_models + 'CHECKPOINT_GEN_C_model')
        # save_checkpoint(gen_C, opt_G, filename=gdrive_saved_models + 'CHECKPOINT_GEN_C')
        # save_checkpoint(disc_C, opt_D, filename=gdrive_saved_models + 'CHECKPOINT_DISC_C')

In [None]:
# Test function

def test_fn(model_path, test_path, gdrive_saved_images_test = None):
    delete_images(saved_images_test_path)

    # load model checkpoint
    else:
        if TRAINED_MODEL == 'Pix2Pix':
            if not os.path.isfile("/content/trained_gen_c_" + TRAINED_MODEL):
                !gdown --id '1d1PV3fg3bCoVwRhU3iyBMXYMwW7LuuiO'
        if TRAINED_MODEL == 'CycleGAN':
            if not os.path.isfile("/content/trained_gen_c_" + TRAINED_MODEL):
                !gdown --id '1R3nCUpRl2CdW0-4_BVd7CCirQ1ktuoFv'
        if TRAINED_MODEL == 'U-Net':
            if not os.path.isfile("/content/trained_gen_c_" + TRAINED_MODEL):
                !gdown --id '1imp1BKKyPu4rIPch0CHkBUh-1Fp8SPcG'
        gen_C = torch.load("/content/trained_gen_c" + TRAINED_MODEL, map_location=torch.device(DEVICE))
    gen_C.eval()

    # load test dataset
    loader = load_dataset(test_path, test = True, shuffle=False)

    # VGG critic
    vgg_crit = VGGPerceptualLoss().to(DEVICE)

    with tqdm.tqdm(total=(len(loader)), file=sys.stdout, leave=True, position=0) as pbar:
        for idx, (l, ab) in enumerate(loader, start=1):
            grey = torch.squeeze(l)
            l = l.to(DEVICE)
            ab = ab.to(DEVICE)

            fake_color = gen_C(l)

            real_im = np.transpose(my_lab2rgb(l, ab), (1,2,0))
            fake_im = np.transpose(my_lab2rgb(l, fake_color), (1,2,0))
            grey_im = grey

            fake = torch.from_numpy(np.transpose(fake_im, (2,0,1))).unsqueeze(dim=0).to(DEVICE)
            real = torch.from_numpy(np.transpose(real_im, (2,0,1))).unsqueeze(dim=0).to(DEVICE)

            vgg_loss = vgg_crit(fake.float(), real.float())

            f, axarr = plt.subplots(figsize=(40,20),nrows=1,ncols=3)
            plt.sca(axarr[0]); 
            plt.imshow(grey_im, cmap='gray'); plt.title('Grey Image', fontsize=20)
            plt.sca(axarr[1]); 
            plt.imshow(fake_im); plt.title('Generated Color', fontsize=20); plt.xlabel('VGG loss: ' + str(vgg_loss.item()), fontsize=20)
            plt.sca(axarr[2]); 
            plt.imshow(real_im); plt.title('Original Image', fontsize=20)
            plt.savefig(saved_images_test_path + 'im_' + str(idx))

            if GDRIVE_MOUNT:
                file_name = saved_images_test_path + 'im_' + str(idx) + ".png"
                !cp -av $file_name $gdrive_saved_images_test
            pbar.update();

In [None]:
# test for 2 models

def test_fn_2(test_path, gdrive_saved_images_test = None, first = 'CycleGAN', second = 'Pix2Pix'):
    delete_images(saved_images_test_path)
    model_addresses = {'CycleGAN':'17B5wxLCGNGfXoBtp05-Q3214Ebkwgmsm',
                       'Pix2Pix':'1d1PV3fg3bCoVwRhU3iyBMXYMwW7LuuiO',
                       'U-Net':'1h3KfiILcgQ6_H5-3wvAVnO-Y0oE46kft',
                       'U-Net_2':'1IrS2tZlrL8axqR4qMtZJPTs_jM4LCcsE',
                       'Pix2Pix_2':'1Gt4PEMW8fjzH2p3HRZOZiN4FdqB68ykE'}

    if not os.path.isfile("/content/trained_gen_c_" + first):
        !gdown --id {model_addresses[first]}
    if not os.path.isfile("/content/trained_gen_c_" + second):
        !gdown --id {model_addresses[second]}
    gen_C1 = torch.load("/content/trained_gen_c_" + first, map_location=torch.device(DEVICE))
    gen_C2 = torch.load("/content/trained_gen_c_" + second, map_location=torch.device(DEVICE))
    gen_C1.eval()
    gen_C2.eval()
    # load test dataset
    loader = load_dataset(test_path, test = True, shuffle=False)

    # VGG critic
    vgg_crit = VGGPerceptualLoss().to(DEVICE)

    with tqdm.tqdm(total=(len(loader)), file=sys.stdout, leave=True, position=0) as pbar:
        for idx, (l, ab) in enumerate(loader, start=1):
            grey = torch.squeeze(l)
            l = l.to(DEVICE)
            ab = ab.to(DEVICE)

            fake_color_1 = gen_C1(l)
            fake_color_2 = gen_C2(l)

            real_im = np.transpose(my_lab2rgb(l, ab), (1,2,0))
            fake_im_1 = np.transpose(my_lab2rgb(l, fake_color_1), (1,2,0))
            fake_im_2 = np.transpose(my_lab2rgb(l, fake_color_2), (1,2,0))
            grey_im = grey

            real = torch.from_numpy(np.transpose(real_im, (2,0,1))).unsqueeze(dim=0).to(DEVICE)
            fake = torch.from_numpy(np.transpose(fake_im_1, (2,0,1))).unsqueeze(dim=0).to(DEVICE)
            vgg_loss_1 = vgg_crit(fake.float(), real.float())
            fake = torch.from_numpy(np.transpose(fake_im_2, (2,0,1))).unsqueeze(dim=0).to(DEVICE)
            vgg_loss_2 = vgg_crit(fake.float(), real.float())

            f, axarr = plt.subplots(figsize=(40,20),nrows=1,ncols=4)
            plt.sca(axarr[0]); 
            plt.imshow(grey_im, cmap='gray'); plt.title('Grey Image', fontsize=20)
            plt.sca(axarr[1]); 
            plt.imshow(fake_im_1); plt.title(first, fontsize=20); plt.xlabel('VGG loss: ' + str(vgg_loss_1.item()), fontsize=20)
            plt.sca(axarr[2]); 
            plt.imshow(fake_im_2); plt.title(second, fontsize=20); plt.xlabel('VGG loss: ' + str(vgg_loss_2.item()), fontsize=20)
            plt.sca(axarr[3]); 
            plt.imshow(real_im); plt.title('Original Image', fontsize=20)
            plt.savefig(saved_images_test_path + 'im_' + str(idx), fontsize=20)

            if GDRIVE_MOUNT:
                file_name = saved_images_test_path + 'im_' + str(idx) + ".png"
                !cp -av $file_name $gdrive_saved_images_test
            pbar.update();

In [None]:
# Main

if __name__ == "__main__":
    dataset_name = 'landscape'
    test_dataset = 'landscape_test'
    model_path = gdrive_saved_models + "CHECKPOINT_GEN_C_model"
    gdrive_saved_images_test = gdrive_saved_models + 'images/test/'
    
    ## To train a new model, uncomment these lines - choose train_cycle for CycleGAN or train_pix2pix for Pix2Pix
    # IM_SIZE = (256, 256)
    # train_pix2pix(dataset_name, False)
    # train_cycle(dataset_name, False)
    # train_unet(dataset_name, False)

    ## To test only one model, uncomment these lines.
    ## For Pix2Pix model, set TRAINED_MODEL = 'Pix2Pix'
    ## For CycleGAN model, set TRAINED_MODEL = 'CycleGAN'
    ## For U-Net model, set TRAINED_MODEL = 'U-Net'
    # TRAINED_MODEL = 'Pix2Pix'
    # IM_SIZE = (512, 512)
    # test_fn(model_path = model_path, test_path = ROOT_PATH + test_dataset, gdrive_saved_images_test = gdrive_saved_images_test)

    ## To test CycleGAN vs Pix2Pix models, uncomment these lines
    IM_SIZE = (512, 512)
    test_fn_2(ROOT_PATH + test_dataset, gdrive_saved_images_test = gdrive_saved_images_test, first = 'U-Net', second = 'Pix2Pix')