# Image Recoloring with Conditional GANs
<b>Reference paper</b>:
Phillip Isola et. Al. “Image to Image Translation with
Conditional Adversarial Networks”, CVPR 2017

In [None]:
!pip3 install opencv-python-headless
!pip3 install torchmetrics[image]
import numpy as np
import matplotlib.pyplot as plt
import cv2
import glob
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

In [None]:
# Download COCO dataset
!mkdir data

print("="*51 + "DOWNLOADING TRAINING SET" + "="*51)
!wget -c http://images.cocodataset.org/zips/train2017.zip -P ./data
!unzip -qq ./data/train2017.zip -d ./data
!rm ./data/train2017.zip

print("="*50 + "DOWNLOADING VALIDATION SET" + "="*50)
!wget -c http://images.cocodataset.org/zips/val2017.zip -P ./data
!unzip -qq ./data/val2017.zip -d ./data
!rm ./data/val2017.zip

print("="*53 + "DOWNLOADING TEST SET" + "="*53)
!wget -c http://images.cocodataset.org/zips/test2017.zip -P ./data
!unzip -qq ./data/test2017.zip -d ./data
!rm ./data/test2017.zip

print("="*59 + "FINISHED" + "="*59)

## Dataset class

In [None]:
class COCODataset(Dataset):
    def __init__(self, root:str, color_space:str = "RGB", size_limit=None, transform=None):
        self.paths = glob.glob(root+"/*.jpg")
        if size_limit != None:
            self.paths = self.paths[:size_limit]
        self.transform = transform
        self.color_space = color_space

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        image_path = self.paths[index]
        image = cv2.imread(image_path)

        if self.transform:
            image = self.transform(image)

        if self.color_space == "RGB":
            image = np.array(image)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            target_img = image
        if self.color_space == "Lab":
            image = np.array(image)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
            img = image[:,:,[0]]
            target_img = image[:,:,[1,2]]

        img = transforms.ToTensor()(img)
        target_img = transforms.ToTensor()(target_img)
        img = 2.0 * img - 1.0
        target_img = 2.0 * target_img - 1.0
        return (img, target_img)

In [None]:
IMG_SIZE = 256

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
])


RGB_test_set = COCODataset(root="./data/test2017", color_space="RGB", transform=test_transforms)
Lab_test_set = COCODataset(root="./data/test2017", color_space="Lab", transform=test_transforms)

index = 4

In [None]:
grayscale_image, image = RGB_test_set[index]

grayscale_image = (grayscale_image + 1.0) / 2.0
image = (image + 1.0) / 2.0

fig, ax = plt.subplots(1, 5, figsize=(15,15))
ax[0].imshow(transforms.ToPILImage()(grayscale_image), cmap='gray')
ax[1].imshow(transforms.ToPILImage()(image))
ax[2].imshow(transforms.ToPILImage()(image[[0],:,:]), cmap='Reds')
ax[3].imshow(transforms.ToPILImage()(image[[1],:,:]), cmap='Greens')
ax[4].imshow(transforms.ToPILImage()(image[[2],:,:]), cmap='Blues')
ax[0].set_title("Grayscale")
ax[1].set_title("Colored")
ax[2].set_title("Red channel")
ax[3].set_title("Green channel")
ax[4].set_title("Blue channel")
ax[0].axis("off")
ax[1].axis("off")
ax[2].axis("off")
ax[3].axis("off")
ax[4].axis("off")
plt.show()

print()
print(f"Input (Grayscale image) shape: {grayscale_image.shape}")
print(f"Label (Colored image) shape: {image.shape}")


In [None]:
L_image, ab_image = Lab_test_set[index]

L_image = (L_image + 1.0) / 2.0
ab_image = (ab_image + 1.0) / 2.0

colored_image = transforms.ToPILImage()(torch.cat([L_image, ab_image]))
colored_image = np.array(colored_image)
colored_image = cv2.cvtColor(colored_image, cv2.COLOR_Lab2RGB)

fig, ax = plt.subplots(1, 4, figsize=(15,15))
ax[0].imshow(transforms.ToPILImage()(colored_image))
ax[1].imshow(transforms.ToPILImage()(L_image), cmap='gray')
ax[2].imshow(transforms.ToPILImage()(image[[0],:,:]), cmap='Reds')
ax[3].imshow(transforms.ToPILImage()(image[[1],:,:]), cmap='Blues')
ax[0].set_title("Colored")
ax[1].set_title("L channel")
ax[2].set_title("a channel")
ax[3].set_title("b channel")
ax[0].axis("off")
ax[1].axis("off")
ax[2].axis("off")
ax[3].axis("off")
plt.show()

print()
print(f"Input (L channel) shape: {L_image.shape}")
print(f"Label (ab channels) shape: {ab_image.shape}")

In [None]:
torch.manual_seed(0)
IMG_SIZE = 128
COLOR_SPACE = "Lab"
BATCH_SIZE = 32

train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
])


dataset = COCODataset(root="./data/test2017", color_space=COLOR_SPACE, size_limit=11000, transform=train_transforms)
train_set, val_set = torch.utils.data.random_split(dataset, [10000, 1000])
test_set = COCODataset(root="./data/val2017", color_space=COLOR_SPACE, size_limit=1000, transform=test_transforms)

train_dl = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=2, pin_memory=True, shuffle=True)
val_dl = DataLoader(val_set, batch_size=BATCH_SIZE,num_workers=2, pin_memory=True, shuffle=False)
test_dl = DataLoader(test_set, batch_size=8, num_workers=2, pin_memory=True, shuffle=False)

print(f"Training set has {len(train_set)} images")
print(f"Validation set has {len(val_set)} images")
print(f"Test set has {len(test_set)} images")
print()
for sample in test_dl:
    x, y = sample
    print(f"Input batch has shape: {x.shape}")
    print(f"Output batch has shape{y.shape}")
    break

## U-Net autoencoder for generator network

In [None]:
class Unet(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, use_bias=False, use_dropout=True):
        """
        :param input_nc: number of input channels
        :param output_nc: number of output channels
        :param ngf: number of generator filters in the first convolutional layer
        """
        super().__init__()
        self.downrelu = nn.LeakyReLU(0.2, True)
        self.uprelu = nn.ReLU(True)
        self.tanh = nn.Tanh()
        self.drop_rate = 0.5 if use_dropout else 0.0
        
        self.downconv1 = nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downconv2 = nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn2   = nn.BatchNorm2d(ngf*2)
        self.downconv3 = nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn3   = nn.BatchNorm2d(ngf*4)
        self.downconv4 = nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn4   = nn.BatchNorm2d(ngf*8)
        self.downconv5 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn5   = nn.BatchNorm2d(ngf*8)
        self.downconv6 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn6   = nn.BatchNorm2d(ngf*8)
        self.downconv7 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn7   = nn.BatchNorm2d(ngf*8)
        self.downconv8 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        
        self.upconv1   = nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn1     = nn.BatchNorm2d(ngf*8)
        self.updrop1   = nn.Dropout(self.drop_rate)
        self.upconv2   = nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn2     = nn.BatchNorm2d(ngf*8)
        self.updrop2   = nn.Dropout(self.drop_rate)
        self.upconv3   = nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn3     = nn.BatchNorm2d(ngf*8)
        self.updrop3   = nn.Dropout(self.drop_rate)
        self.upconv4   = nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn4     = nn.BatchNorm2d(ngf*8)
        self.upconv5   = nn.ConvTranspose2d(ngf*8*2, ngf*4, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn5     = nn.BatchNorm2d(ngf*4)
        self.upconv6   = nn.ConvTranspose2d(ngf*4*2, ngf*2, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn6     = nn.BatchNorm2d(ngf*2)
        self.upconv7   = nn.ConvTranspose2d(ngf*2*2, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn7     = nn.BatchNorm2d(ngf)
        self.upconv8   = nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
  
    def forward(self, x):
        e1  = self.downconv1(x)                                        # input x is (input_nc) x 256 x 256
        e2  = self.downbn2(self.downconv2(self.downrelu(e1)))          # input e1 is (ngf) x 128 x 128
        e3  = self.downbn3(self.downconv3(self.downrelu(e2)))          # input e2 is (ngf * 2) x 64 x 64
        e4  = self.downbn4(self.downconv4(self.downrelu(e3)))          # input e3 is (ngf * 4) x 32 x 32
        e5  = self.downbn5(self.downconv5(self.downrelu(e4)))          # input e4 is (ngf * 8) x 16 x 16
        e6  = self.downbn6(self.downconv6(self.downrelu(e5)))          # input e5 is (ngf * 8) x 8 x 8
        e7  = self.downbn7(self.downconv7(self.downrelu(e6)))          # input e6 is (ngf * 8) x 4 x 4
        e8  = self.downconv8(self.downrelu(e7))                        # input e7 is (ngf * 8) x 2 x 2
        d1_ = self.updrop1(self.upbn1(self.upconv1(self.uprelu(e8))))  # input e8 is (ngf * 8) x 1 x 1
        d1  = torch.cat([d1_, e7], dim=1)
        d2_ = self.updrop2(self.upbn2(self.upconv2(self.uprelu(d1))))  # input d1 is (ngf * 8 * 2) x 2 x 2
        d2  = torch.cat([d2_, e6], dim=1)             
        d3_ = self.updrop3(self.upbn3(self.upconv3(self.uprelu(d2))))  # input d2 is (ngf * 8 * 2) x 4 x 4
        d3  = torch.cat([d3_, e5], dim=1)             
        d4_ = self.upbn4(self.upconv4(self.uprelu(d3)))                # input d3 is (ngf * 8 * 2) x 8 x 8
        d4  = torch.cat([d4_, e4], dim=1)             
        d5_ = self.upbn5(self.upconv5(self.uprelu(d4)))                # input d4 is (ngf * 8 * 2) x 16 x 16
        d5  = torch.cat([d5_, e3], dim=1)
        d6_ = self.upbn6(self.upconv6(self.uprelu(d5)))                # input d5 is (ngf * 4 * 2) x 32 x 32
        d6  = torch.cat([d6_, e2], dim=1)
        d7_ = self.upbn7(self.upconv7(self.uprelu(d6)))                # input d6 is (ngf * 2 * 2) x 64 x 64
        d7  = torch.cat([d7_, e1], dim=1)
        d8  = self.upconv8(self.uprelu(d7))                            # input d7 is (ngf * 2) x 128 x 128
        o1  = self.tanh(d8)                                            # input d8 is (output_nc) x 256 x 256
        return o1

net = Unet(3, 3)
x = torch.randn(1, 3, 256, 256)
net(x).shape

In [None]:
class Unet_128(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, use_bias=False, use_dropout=True):
        """
        :param input_nc: number of input channels
        :param output_nc: number of output channels
        :param ngf: number of generator filters in the first convolutional layer
        """
        super().__init__()
        self.downrelu = nn.LeakyReLU(0.2, True)
        self.uprelu = nn.ReLU(True)
        self.tanh = nn.Tanh()
        self.drop_rate = 0.5 if use_dropout else 0.0
        
        self.downconv1 = nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downconv2 = nn.Conv2d(ngf, ngf*2, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn2   = nn.BatchNorm2d(ngf*2)
        self.downconv3 = nn.Conv2d(ngf*2, ngf*4, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn3   = nn.BatchNorm2d(ngf*4)
        self.downconv4 = nn.Conv2d(ngf*4, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn4   = nn.BatchNorm2d(ngf*8)
        self.downconv5 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn5   = nn.BatchNorm2d(ngf*8)
        self.downconv6 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn6   = nn.BatchNorm2d(ngf*8)
        self.downconv7 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.downbn7   = nn.BatchNorm2d(ngf*8)
        self.downconv8 = nn.Conv2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        
        self.upconv1   = nn.ConvTranspose2d(ngf*8, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn1     = nn.BatchNorm2d(ngf*8)
        self.updrop1   = nn.Dropout(self.drop_rate)
        self.upconv2   = nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn2     = nn.BatchNorm2d(ngf*8)
        self.updrop2   = nn.Dropout(self.drop_rate)
        self.upconv3   = nn.ConvTranspose2d(ngf*8*2, ngf*8, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn3     = nn.BatchNorm2d(ngf*8)
        self.updrop3   = nn.Dropout(self.drop_rate)
        self.upconv4   = nn.ConvTranspose2d(ngf*8*2, ngf*4, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn4     = nn.BatchNorm2d(ngf*4)
        self.upconv5   = nn.ConvTranspose2d(ngf*4*2, ngf*2, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn5     = nn.BatchNorm2d(ngf*2)
        self.upconv6   = nn.ConvTranspose2d(ngf*2*2, ngf, kernel_size=4, stride=2, padding=1, bias=use_bias)
        self.upbn6     = nn.BatchNorm2d(ngf*1)
        self.upconv7   = nn.ConvTranspose2d(ngf*2, output_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
  
    def forward(self, x):
        e1  = self.downconv1(x)                                        # input x is (input_nc) x 128 x 128
        e2  = self.downbn2(self.downconv2(self.downrelu(e1)))          # input e1 is (ngf) x 64 x 64
        e3  = self.downbn3(self.downconv3(self.downrelu(e2)))          # input e2 is (ngf * 2) x 32 x 32
        e4  = self.downbn4(self.downconv4(self.downrelu(e3)))          # input e3 is (ngf * 4) x 16 x16
        e5  = self.downbn5(self.downconv5(self.downrelu(e4)))          # input e4 is (ngf * 8) x 8 x 8
        e6  = self.downbn6(self.downconv6(self.downrelu(e5)))          # input e5 is (ngf * 8) x 4 x 4
        e7  = self.downconv7(self.downrelu(e6))                        # input e6 is (ngf * 8) x 2 x 2

        d1_ = self.updrop1(self.upbn1(self.upconv1(self.uprelu(e7))))  # input e7 is (ngf * 8) x 1 x 1
        d1  = torch.cat([d1_, e6], dim=1)
        d2_ = self.updrop2(self.upbn2(self.upconv2(self.uprelu(d1))))  # input d1 is (ngf * 8 * 2) x 2 x 2
        d2  = torch.cat([d2_, e5], dim=1)             
        d3_ = self.updrop3(self.upbn3(self.upconv3(self.uprelu(d2))))  # input d2 is (ngf * 8 * 2) x 4 x 4
        d3  = torch.cat([d3_, e4], dim=1)             
        d4_ = self.upbn4(self.upconv4(self.uprelu(d3)))                # input d3 is (ngf * 8 * 2) x 8 x 8
        d4  = torch.cat([d4_, e3], dim=1)      
        d5_ = self.upbn5(self.upconv5(self.uprelu(d4)))                # input d4 is (ngf * 8 * 2) x 16 x 16
        d5  = torch.cat([d5_, e2], dim=1)
        d6_ = self.upbn6(self.upconv6(self.uprelu(d5)))                # input d5 is (ngf * 4 * 2) x 32 x 32
        d6  = torch.cat([d6_, e1], dim=1)
        d7 =  self.upconv7(self.uprelu(d6))                            # input d6 is (ngf * 2 * 2) x 64 x 64
                                                                       # input d7 is (ngf * 2) x 128 x 128
        o1  = self.tanh(d7)                                            
        return o1

net = Unet_128(3, 3)
x = torch.randn(1, 3, 128, 128)
net(x).shape

## Patch Discriminator networks

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64):
        """
        :param input_nc: number of input channels
        :param ndf: number of discriminator filters in the first convolutional layer
        """
        super().__init__()

        self.leaky_relu = nn.LeakyReLU(0.2, True)
        
        self.conv1    = nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)   
        self.conv2    = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)       
        self.conv2_bn = nn.BatchNorm2d(ndf*2)
        self.conv3    = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)   
        self.conv3_bn = nn.BatchNorm2d(ndf*4)
        self.conv4    = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=1, padding=1)
        self.conv4_bn = nn.BatchNorm2d(ndf*8)
        self.conv5    = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=1)

    def forward(self, x):
        x = self.leaky_relu(self.conv1(x))
        x = self.leaky_relu(self.conv2_bn(self.conv2(x)))
        x = self.leaky_relu(self.conv3_bn(self.conv3(x)))
        x = self.leaky_relu(self.conv4_bn(self.conv4(x)))
        x = self.conv5(x)  # No sigmoid since BCEWithLogitsLoss is used
        return x

net = Discriminator(3)
x = torch.randn(1, 3, 256, 256)
net(x).shape

In [None]:
class Critic(nn.Module):
    def __init__(self, input_nc, ndf=64):
        """
        :param input_nc: number of input channels
        :param ndf: number of discriminator filters in the first convolutional layer
        """
        super().__init__()

        self.leaky_relu = nn.LeakyReLU(0.2, True)
        
        self.conv1    = nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1)   
        self.conv2    = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)       
        self.conv2_n  = nn.InstanceNorm2d(ndf*2)
        self.conv3    = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)   
        self.conv3_n  = nn.InstanceNorm2d(ndf*4)
        self.conv4    = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=1, padding=1)
        self.conv4_n  = nn.InstanceNorm2d(ndf*8)
        self.conv5    = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=1, padding=1)

    def forward(self, x):
        x = self.leaky_relu(self.conv1(x))
        x = self.leaky_relu(self.conv2_n(self.conv2(x)))
        x = self.leaky_relu(self.conv3_n(self.conv3(x)))
        x = self.leaky_relu(self.conv4_n(self.conv4(x)))
        x = self.conv5(x)  # No sigmoid since Critic
        return x
net = Critic(3)
x = torch.randn(1, 3, 128, 128)
net(x).shape

In [None]:
def init_weights(net, init='norm', gain=0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
            
    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

## Main Model

In [None]:
class GANModel(nn.Module):
    def __init__(self, generator_lr=2e-4, discriminator_lr=2e-4, color_space="Lab", lambda_coef=100.0, betas=(0.5, 0.999)):
        """
        """
        super().__init__()
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.color_space = color_space
        if self.color_space == "Lab":
            out_ch = 2
        else:
            out_ch = 3

        self.generator = init_model(Unet(input_nc=1, output_nc=out_ch), self.device)
        self.discriminator = init_model(Discriminator(input_nc=3), self.device)
        
        self.GANLoss = nn.BCEWithLogitsLoss()
        self.L1Loss = nn.L1Loss()
        self.lambda_coef = lambda_coef
        
        self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=generator_lr, betas=betas)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=discriminator_lr, betas=betas)
        
    def set_requires_grad(self, model, requires_grad=True):
        for param in model.parameters():
            param.requires_grad = requires_grad
    
    def save(self, epoch, log, path="./checkpoint.pt", download=True):
        """Saves state_dict for generator, discriminator and optimizers to path 
        """
        torch.save({
            'epoch' : epoch,
            'log' : log,
            'generator_state_dict' : self.generator.state_dict(),
            'discriminator_state_dict' : self.discriminator.state_dict(),
            'generator_optimizer_state_dict' : self.generator_optimizer.state_dict(),
            'discriminator_optimizer_state_dict' : self.discriminator_optimizer.state_dict()
        }, path)
        
        if download:
            # Only works on Colab
            files.download(path)   
            
    def load(self, path="./checkpoint.pt"):
        """Loads state_dict for generator, discriminator and optimizers from path 
        """
        checkpoint = torch.load(path)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(checkpoint['generator_optimizer_state_dict'])
        self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict'])
        return checkpoint['epoch'], checkpoint['log']

            
model = GANModel()   
model.save(epoch=0, download=False)
ep = model.load()
print(ep)

## W-GAN with Gradient Penalty

In [None]:
class WGANModel(nn.Module):
    def __init__(self, generator_lr=1e-4, discriminator_lr=1e-4, color_space="Lab", n_critic_iterations=5, lambda_gp=10.0, lambda_L1=0.0, betas=(0.0, 0.9)):
        """
        """
        super().__init__()
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.color_space = color_space
        if self.color_space == "Lab":
            out_ch = 2
        else:
            out_ch = 3

        self.generator = init_model(Unet_128(input_nc=1, output_nc=out_ch), self.device)
        self.critic = init_model(Critic(input_nc=3), self.device)

        self.n_critic_iterations = n_critic_iterations
        self.L1Loss = nn.L1Loss()
        self.lambda_L1 = lambda_L1
        self.lambda_gp = lambda_gp
        self.PerceptualLoss = LearnedPerceptualImagePatchSimilarity(net_type='vgg').to(self.device)
        
        self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=generator_lr, betas=betas)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=discriminator_lr, betas=betas)
        
    def set_requires_grad(self, model, requires_grad=True):
        for param in model.parameters():
            param.requires_grad = requires_grad
    
    def save(self, epoch, log, path="./checkpoint.pt", download=True):
        """Saves state_dict for generator, discriminator and optimizers to path 
        """
        torch.save({
            'epoch' : epoch,
            'log' : log,
            'generator_state_dict' : self.generator.state_dict(),
            'discriminator_state_dict' : self.critic.state_dict(),
            'generator_optimizer_state_dict' : self.generator_optimizer.state_dict(),
            'discriminator_optimizer_state_dict' : self.critic_optimizer.state_dict()
        }, path)
        
        if download:
            files.download(path)
            
    def load(self, path="./checkpoint.pt"):
        """Loads state_dict for generator, discriminator and optimizers from path 
        """
        checkpoint = torch.load(path)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.critic.load_state_dict(checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(checkpoint['generator_optimizer_state_dict'])
        self.critic_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict'])
        return checkpoint['epoch'], checkpoint['log']


In [None]:
def visualize(model, batch, color_space="Lab"):
    model.generator.eval()
    with torch.no_grad():
        x, label = batch
        x = x.to(model.device)
        output = model.generator(x)
    
        x = x.cpu()
        label = label.cpu()
        output = output.detach().cpu()

        x = (x + 1.0) / 2.0
        output = (output + 1.0) / 2.0
        label = (label + 1.0) / 2.0
    model.generator.train()

    if color_space == "Lab":
        generated_images = torch.cat([x, output], dim=1)
        true_images = torch.cat([x, label], dim=1)
    else:
        generated_images = output
        true_images = label

    num_images = generated_images.shape[0]

    fig, ax = plt.subplots(3,num_images, figsize=(20,15))

    i = 0
    for inp, img, true_img in zip(x, generated_images, true_images):
        if color_space == "Lab":
            img = transforms.ToPILImage()(img)
            img = np.array(img)
            img = cv2.cvtColor(img, cv2.COLOR_Lab2RGB)
            true_img = transforms.ToPILImage()(true_img)
            true_img = np.array(true_img)
            true_img = cv2.cvtColor(true_img, cv2.COLOR_Lab2RGB)
        ax[0,i].imshow(transforms.ToPILImage()(inp), cmap='gray')
        ax[1,i].imshow(transforms.ToPILImage()(img))
        ax[2,i].imshow(transforms.ToPILImage()(true_img))
        ax[0,i].axis("off")
        ax[1,i].axis("off")
        ax[2,i].axis("off")
        i+=1
    fig.tight_layout()
    plt.show()


In [None]:
def get_gradient_penalty(model, real_images, generated_images):
    epsilon = torch.rand(real_images.shape[0], 1, device=model.device)
    epsilon = epsilon.expand(real_images.shape[0], real_images.nelement() // real_images.shape[0]).contiguous().view(*real_images.shape)
    interpolated = epsilon * real_images + (1 - epsilon) * generated_images
    interpolated.requires_grad_(True)
    interpolated_outputs = model.critic(interpolated)
    
    gradients = torch.autograd.grad(
        outputs=interpolated_outputs,
        inputs=interpolated,
        grad_outputs=torch.ones(interpolated_outputs.size()).to(model.device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )
    gradients = gradients[0].view(real_images.size(0), -1) 
    gradients_norm = (gradients + 1e-16).norm(2, dim=1)
    gradient_penalty = ((gradients_norm - 1.0)**2).mean()
    
    return gradient_penalty

In [None]:
model = GANModel(color_space=COLOR_SPACE)
for batch in val_dl:
    visualize(model, batch, COLOR_SPACE)
    break

In [None]:
def train(model, train_dl, val_dl, test_dl, epochs=100, show_every=100, checkpoint=None):
    test_dl_iterator = iter(test_dl)
    if checkpoint:
        curr_ep, log = model.load(checkpoint)
    else:
        log = {'tr_generator_loss' : [], 'tr_discriminator_loss' : [], 'val_generator_loss' : [], 'val_discriminator_loss' : []}
        curr_ep = 0
    for ep in range(epochs):
        ep_num = (curr_ep+ep+1)
        b_n = 0
        ep_gen_tr_loss = []
        ep_dis_tr_loss = []
        ep_gen_val_loss = []
        ep_dis_val_loss = []

        model.generator.train()
        model.discriminator.train()
        for batch in tqdm(train_dl):
            x, label = batch
            
            x = x.to(model.device)
            label = label.to(model.device)
            
            generator_output = model.generator(x)
            
            model.discriminator.train()
            model.set_requires_grad(model.discriminator, True)
            model.discriminator_optimizer.zero_grad()
            
            if model.color_space == "Lab":
                generated_images = torch.cat([x, generator_output], dim=1)
                true_images = torch.cat([x, label], dim=1)
            else:
                generated_images = generator_output
                true_images = label
            
            prediction_generated_images = model.discriminator(generated_images.detach())
            zeros = torch.tensor(0.0).expand_as(prediction_generated_images).to(model.device)
            loss_generated_images = model.GANLoss(prediction_generated_images, zeros)
            
            prediction_true_images = model.discriminator(true_images)
            ones = torch.tensor(1.0).expand_as(prediction_true_images).to(model.device)
            loss_true_images = model.GANLoss(prediction_true_images, ones)
            
            discriminator_loss = 0.5 * (loss_generated_images + loss_true_images)
            discriminator_loss.backward()
            
            model.discriminator_optimizer.step()
            
            model.generator.train()
            model.set_requires_grad(model.discriminator, False)
            model.generator_optimizer.zero_grad()
            
            prediction_generated_images = model.discriminator(generated_images)
            ones = torch.tensor(1.0).expand_as(prediction_generated_images).to(model.device)
            loss_generated_images = model.GANLoss(prediction_generated_images, ones)
            loss_L1 = model.L1Loss(generator_output, label)
            generator_loss = loss_generated_images + loss_L1 * model.lambda_coef
            generator_loss.backward()
            
            model.generator_optimizer.step()

            ep_dis_tr_loss.append(discriminator_loss.detach().cpu().numpy())
            ep_gen_tr_loss.append(generator_loss.detach().cpu().numpy())
            
            b_n += 1
            if b_n % show_every == 0:
                b = next(iter(test_dl_iterator))
                visualize(model, b, COLOR_SPACE)
        
        with torch.no_grad():
            model.generator.eval()
            model.discriminator.eval()
            for batch in tqdm(val_dl):
                x, label = batch
            
                x = x.to(model.device)
                label = label.to(model.device)
            
                generator_output = model.generator(x)
                if model.color_space == "Lab":
                    generated_images = torch.cat([x, generator_output], dim=1)
                    true_images = torch.cat([x, label], dim=1)
                else:
                    generated_images = generator_output
                    true_images = label

                prediction_generated_images = model.discriminator(generated_images.detach())
                zeros = torch.tensor(0.0).expand_as(prediction_generated_images).to(model.device)
                loss_generated_images = model.GANLoss(prediction_generated_images, zeros)
            
                prediction_true_images = model.discriminator(true_images)
                ones = torch.tensor(1.0).expand_as(prediction_true_images).to(model.device)
                loss_true_images = model.GANLoss(prediction_true_images, ones)
            
                discriminator_loss = 0.5 * (loss_generated_images + loss_true_images)
            
                prediction_generated_images = model.discriminator(generated_images)
                ones = torch.tensor(1.0).expand_as(prediction_generated_images).to(model.device)
                loss_generated_images = model.GANLoss(prediction_generated_images, ones)
                loss_L1 = model.L1Loss(generator_output, label)
                generator_loss = loss_generated_images + loss_L1 * model.lambda_coef
            
                ep_dis_val_loss.append(discriminator_loss.detach().cpu().numpy())
                ep_gen_val_loss.append(generator_loss.detach().cpu().numpy())
        
        tr_D_loss = np.mean(ep_dis_tr_loss)
        tr_G_loss = np.mean(ep_gen_tr_loss)
        val_D_loss = np.mean(ep_dis_val_loss)
        val_G_loss = np.mean(ep_gen_val_loss)
    
        log['tr_discriminator_loss'].append(tr_D_loss)
        log['tr_generator_loss'].append(tr_G_loss)
        log['val_discriminator_loss'].append(val_D_loss)
        log['val_generator_loss'].append(val_G_loss)

        print(f"EPOCH {ep_num}", end="")
        print("-"*100)
        print(f"discriminator tr_loss:{tr_D_loss} generator tr_loss:{tr_G_loss}")
        print(f"discriminator val_loss:{val_D_loss} generator val_loss:{val_G_loss}")
        model.save(epoch=ep_num, log=log, path=f"./checkpoint.pt", download=False)
        print(f"SAVED CHECKPOINT EPOCH {ep_num}")
    
    return log

In [None]:
def train_wgan(model, train_dl, val_dl, test_dl, epochs=100, show_every=100, checkpoint=None, use_perceptual_loss=True):
    test_dl_iterator = iter(test_dl)
    if checkpoint:
        curr_ep, log = model.load(checkpoint)
    else:
        log = {'tr_generator_loss' : [], 'tr_discriminator_loss' : [], 'val_generator_loss' : [], 'val_discriminator_loss' : []}
        curr_ep = 0
    for ep in range(epochs):
        ep_num = (curr_ep+ep+1)
        b_n = 0
        ep_gen_tr_loss = []
        ep_dis_tr_loss = []
        ep_gen_val_loss = []
        ep_dis_val_loss = []

        model.generator.train()
        model.critic.train()
        for batch in tqdm(train_dl):
            x, label = batch
            
            x = x.to(model.device)
            label = label.to(model.device)
            
            model.critic.train()
            model.set_requires_grad(model.critic, True)
            for _ in range(model.n_critic_iterations):
                generator_output = model.generator(x)
                model.critic_optimizer.zero_grad()
            
                if model.color_space == "Lab":
                    generated_images = torch.cat([x, generator_output], dim=1)
                    true_images = torch.cat([x, label], dim=1)
                else:
                    generated_images = generator_output
                    true_images = label
            
                prediction_generated_images = model.critic(generated_images)
                loss_generated_images = prediction_generated_images.mean()
                prediction_true_images = model.critic(true_images)
                loss_true_images = -prediction_true_images.mean()
                gradient_penalty_loss = get_gradient_penalty(model, true_images, generated_images)
                discriminator_loss = loss_generated_images + loss_true_images + model.lambda_gp * gradient_penalty_loss
                discriminator_loss.backward(retain_graph=True)
                model.critic_optimizer.step()
            
            model.generator.train()
            model.set_requires_grad(model.critic, False)
            model.generator_optimizer.zero_grad()
            
            prediction_generated_images = model.critic(generated_images)
            loss_generated_images = -prediction_generated_images.mean()
            loss_L1 = model.L1Loss(generator_output, label)
            generator_loss = loss_generated_images + loss_L1 * model.lambda_L1
            if use_perceptual_loss:
                perceptual_loss = model.PerceptualLoss(generated_images, true_images)
                generator_loss += perceptual_loss
            generator_loss.backward()
            
            model.generator_optimizer.step()

            ep_dis_tr_loss.append(discriminator_loss.detach().cpu().numpy())
            ep_gen_tr_loss.append(generator_loss.detach().cpu().numpy())
            
            b_n += 1
            if b_n % show_every == 0:
                b = next(test_dl_iterator)
                visualize(model, b, COLOR_SPACE)

        with torch.no_grad():
            model.generator.eval()
            model.critic.eval()
            for batch in tqdm(val_dl):
                x, label = batch
            
                x = x.to(model.device)
                label = label.to(model.device)
            
                generator_output = model.generator(x)
                if model.color_space == "Lab":
                    generated_images = torch.cat([x, generator_output], dim=1)
                    true_images = torch.cat([x, label], dim=1)
                else:
                    generated_images = generator_output
                    true_images = label

                prediction_generated_images = model.critic(generated_images)
                loss_generated_images = prediction_generated_images.mean()
            
                prediction_true_images = model.critic(true_images)
                loss_true_images = -prediction_true_images.mean()
            
                discriminator_loss = 0.5 * (loss_generated_images + loss_true_images)
            
                prediction_generated_images = model.critic(generated_images)
                loss_generated_images = -prediction_generated_images.mean()
                loss_L1 = model.L1Loss(generator_output, label)
                generator_loss = loss_generated_images + loss_L1 * model.lambda_L1

                if use_perceptual_loss:
                    perceptual_loss = model.PerceptualLoss(generated_images, true_images)
                    generator_loss += perceptual_loss

                ep_dis_val_loss.append(discriminator_loss.detach().cpu().numpy())
                ep_gen_val_loss.append(generator_loss.detach().cpu().numpy())
        
        tr_D_loss = np.mean(ep_dis_tr_loss)
        tr_G_loss = np.mean(ep_gen_tr_loss)
        val_D_loss = np.mean(ep_dis_val_loss)
        val_G_loss = np.mean(ep_gen_val_loss)
    
        log['tr_discriminator_loss'].append(tr_D_loss)
        log['tr_generator_loss'].append(tr_G_loss)
        log['val_discriminator_loss'].append(val_D_loss)
        log['val_generator_loss'].append(val_G_loss)

        print(f"EPOCH {ep_num}", end="")
        print("-"*100)
        print(f"discriminator tr_loss:{tr_D_loss} generator tr_loss:{tr_G_loss}")
        print(f"discriminator val_loss:{val_D_loss} generator val_loss:{val_G_loss}")
        model.save(epoch=ep_num, log=log, path=f"./checkpoint.pt", download=False)
        print(f"SAVED CHECKPOINT EPOCH {ep_num}")
    
    return log

In [None]:
model = GANModel(color_space=COLOR_SPACE)
checkpoint = None
train(model, train_dl, val_dl, test_dl, epochs=100, checkpoint=checkpoint)

In [None]:
# Continue training the model
model = GANModel(color_space=COLOR_SPACE)
checkpoint = "./checkpoint.pt"
train(model, train_dl, val_dl, test_dl, epochs=100, checkpoint=checkpoint)

## Using a Pre-trained Generator

In [None]:
from fastai.vision.learner import create_body
from torchvision.models.resnet import resnet18
from fastai.vision.models.unet import DynamicUnet

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def pretrain_generator(generator, train_dl, optimizer, loss_fn=nn.L1Loss(), epochs=20):
    log = {'training_loss' : []}
    for ep in range(epochs):
        ep_loss = []
        generator.train()
        for batch in tqdm(train_dl):
            x, label = batch
            
            x = x.to(DEVICE)
            label = label.to(DEVICE)
            
            generator_output = generator(x)
            
            loss = loss_fn(generator_output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            b_loss = loss.detach().cpu().numpy()
            ep_loss.append(b_loss)
        
        episode_loss = np.mean(ep_loss)
        log['training_loss'].append(episode_loss)
        print(f"EPOCH {ep} training_loss: {episode_loss}")
        torch.save(generator.state_dict(), "pretrained_generator.pt")
        print(f"SAVED MODEL")
        
            

In [None]:
net = resnet18()
body = create_body(model=net, pretrained=True, n_in=1, cut=-2)
generator = DynamicUnet(body, 2, (256, 256)).to(DEVICE)
optimizer = optim.Adam(generator.parameters(), lr=1e-4)

In [None]:
pretrain_generator(generator, train_dl, optimizer)

In [None]:
model = GANModel(color_space=COLOR_SPACE)
model.generator = generator
model.generator_optimizer = optim.Adam(model.generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
checkpoint = "./checkpoint.pt"
train(model, train_dl, val_dl, test_dl, epochs=10, checkpoint=checkpoint, show_every=313, alpha=0.1, use_perceptual_loss=False)

In [None]:
checkpoint= "./checkpoint.pt"
ep, log = model.load(checkpoint)
print(ep)
it = iter(test_dl)

In [None]:
b = next(it)
visualize(model, b, COLOR_SPACE)