In [None]:
%pylab inline
!pip install natsort

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import os
from zipfile import ZipFile
from natsort import natsorted
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

torch.manual_seed(1)
!nvidia-smi

In [None]:
class gray2image(torch.utils.data.Dataset):
    def __init__(self, gray_dir, color_dir, transform):
        self.gray_dir = gray_dir
        self.color_dir = color_dir
        self.transform = transform
        gray_imgs, color_imgs = os.listdir(gray_dir), os.listdir(color_dir)
        self.total_gray = natsorted(gray_imgs)
        self.total_color = natsorted(color_imgs)

    def __len__(self):
        return len(self.total_gray)
    
    def __getitem__(self, idx):
        img_loc1 = os.path.join(self.gray_dir, self.total_gray[idx])
        img_loc2 = os.path.join(self.color_dir, self.total_color[idx])
        image1 = Image.open(img_loc1).convert("RGB")
        image2 = Image.open(img_loc2).convert("RGB")
        
        image1, image2 = self.transform(image1), self.transform(image2)
        return image1, image2

In [None]:
std_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((256, 256)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
BS = 1
ds = gray2image('../input/landscape-image-colorization/landscape Images/gray', '../input/landscape-image-colorization/landscape Images/color', transform=std_transform)
images_loader = torch.utils.data.DataLoader(ds, batch_size=BS, shuffle = True, drop_last = True)

In [None]:
def concat_img(imgs):
    figsize(16,16)
    figure()
    imgs = (imgs + 1) / 2
    imgs = imgs.movedim((0, 1, 2, 3), (0, 3, 1, 2)).detach().cpu().numpy() 
    axs = imshow(np.concatenate(imgs.tolist(), axis=1))
    plt.axis('off')
    plt.show()
    
def print_img(content, style, output):
    printable = torch.cat((content.cpu(), style.cpu(), output.cpu()), 0)
    concat_img((printable).detach().cpu())
    
gray, color = next(iter(images_loader))
print("Gray vs color")
concat_img(torch.cat((gray[:4], color[:4]), 0))

In [None]:
class Generator(torch.nn.Module):
    def __init__(self, d):
        def block_down(in_c, out_c, kernel, stride, padding, p, batchnorm=True):
            if (not batchnorm):
                return [torch.nn.Conv2d(in_c, out_c, kernel, stride, padding, bias=False),
                    torch.nn.LeakyReLU(0.2),
                    torch.nn.Dropout(p),
                    ] # Batchnorm and Dropout regardless train or test
            return [torch.nn.Conv2d(in_c, out_c, kernel, stride, padding, bias=False),
                    torch.nn.BatchNorm2d(out_c),
                    torch.nn.LeakyReLU(0.2),
                    torch.nn.Dropout(p),
                    ] # Batchnorm and Dropout regardless train or test
        
        def block_up(in_c, out_c, kernel, stride, padding, p, last=False):
            if (last):
                return [torch.nn.ConvTranspose2d(in_c, out_c, kernel, stride, padding, bias=False),
                        torch.nn.BatchNorm2d(out_c),
                        torch.nn.Tanh(),
                        ] # Batchnorm and Dropout regardless train or test
            return [torch.nn.ConvTranspose2d(in_c, out_c, kernel, stride, padding, bias=False),
                    torch.nn.BatchNorm2d(out_c),
                    torch.nn.ReLU(0.2),
                    torch.nn.Dropout(p),
                   ]

        super(Generator, self).__init__()
        self.down = torch.nn.Sequential(
            *block_down(3, d, 4, 2, 1, 0, batchnorm=False), #0
            *block_down(d, d * 2, 4, 2, 1, 0.5, batchnorm=False), #1
            *block_down(d * 2, d * 4, 4, 2, 1, 0.5), #2
            *block_down(d * 4, d * 8, 4, 2, 1, 0.5), #3
            *block_down(d * 8, d * 8, 4, 2, 1, 0.5), #4
            # ------
            #*block_down(d * 8, d * 8, 4, 2, 1, 0), #5
            #*block_down(d * 8, d * 8, 4, 2, 1, 0), #6
            #*block_down(d * 8, d * 8, 4, 2, 1, 0), #7
        )
        
        self.intermediate = torch.nn.Sequential(
            *block_down(d * 8, d * 8, 4, 2, 1, 0.5),
            *block_up(d * 8, d * 8, 4, 2, 1, 0.5),
        )
        
        self.up = torch.nn.Sequential(
            #*block_up(d * 8, d * 8, 4, 2, 1, 0), # 7
            #*block_up(d * 8 * 2, d * 8, 4, 2, 1, 0), # 6
            #*block_up(d * 8 * 2, d * 8, 4, 2, 1, 0), # 5
            *block_up(d * 8 * 2, d * 8, 4, 2, 1, 0.5), # 4
            *block_up(d * 8 * 2, d * 4, 4, 2, 1, 0.5), # 3 
            *block_up(d * 4 * 2, d * 2, 4, 2, 1, 0), # 2
            *block_up(d * 2 * 2, d, 4, 2, 1, 0), # 1
            #*block_up(d * 2, d, 4, 2, 1, 0), # 0
        )
        self.last = torch.nn.Sequential(
            *block_up(d, 3, 4, 2, 1, 0, last=True)
        )
    def forward(self, x):
        outputs_down = []
        i = 0
        for layer in self.down:
            x = layer(x)
            if isinstance(layer, torch.nn.LeakyReLU):
                outputs_down.append(x)
        
        x = self.intermediate(x)

        for layer in self.up:
            if isinstance(layer, torch.nn.ConvTranspose2d):
                x = layer(torch.cat((x, outputs_down[len(outputs_down) - 1 - i]), 1))
                i += 1
            else:
                x = layer(x)
        
        x = self.last(x)
        
        return x

G = Generator(64).cuda()
G(gray.cuda()).shape

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self, d):
        super(Discriminator, self).__init__()
        def block_conv(in_c, out_c, kernel, stride, padding, batchnorm=True):
            if(batchnorm):
                return [torch.nn.Conv2d(in_c, out_c, kernel, stride, padding, bias=False),
                        torch.nn.BatchNorm2d(out_c),
                        torch.nn.LeakyReLU(0.2),
                        ]
            return [torch.nn.Conv2d(in_c, out_c, kernel, stride, padding, bias=False),
                   torch.nn.LeakyReLU(0.2)]
        
        self.convos = torch.nn.Sequential(
            *block_conv(6, d, 4, 2, 1, batchnorm=False),
            *block_conv(d, d * 2, 4, 2, 1),
            #*block_conv(d * 2, d * 4, 4, 2, 1),
            #*block_conv(d * 4, d * 8, 4, 1, 1),
            torch.nn.Conv2d(d * 2, 1, 4, 1, 1),
            torch.nn.Sigmoid(),
        )
    def forward(self, x, y):
        #y = y + torch.randn_like(y) * 0.1
        tomonkey = torch.cat((x, y), 1)
        tomonkey = self.convos(tomonkey)
        return tomonkey

D = Discriminator(64).cuda()
print(gray.shape)
print(color.shape)
D(gray.cuda(), color.cuda()).shape
#d_loss = BinaryCrossEntropy(D([y_hat, y]), [0, 1])
#_loss = BinaryCrossEntropy(D[y_hat], [1]) + alpha * torch.mean(||y - y_hat||1)

In [None]:
alpha = 100
LR = 2e-4

G = Generator(64).cuda()
D = Discriminator(64).cuda()

g_lr = LR
d_lr = LR

g_optim = torch.optim.Adam(G.parameters(), lr=g_lr, betas = (0.5, 0.9))
d_optim = torch.optim.Adam(D.parameters(), lr=d_lr, betas = (0.5, 0.9))

PATH = '../input/models-pretrained-landscapes/landscapes_v1_g'
checkpoint = torch.load(PATH)
G.load_state_dict(checkpoint['model_state_dict'])
g_optim.load_state_dict(checkpoint['optimizer_state_dict'])
iters = checkpoint['iters']

PATH = '../input/models-pretrained-landscapes/landscapes_v1_d'
checkpoint = torch.load(PATH)
D.load_state_dict(checkpoint['model_state_dict'])
d_optim.load_state_dict(checkpoint['optimizer_state_dict'])
iters = checkpoint['iters']

In [None]:
def discriminator_loss(inputs, targets):
    BCE = torch.nn.BCELoss().cuda()
    return BCE(inputs, targets)

def generator_loss(d_fake_predictions, generated, real, alpha):
    BCE = torch.nn.BCELoss().cuda()
    L1 = torch.nn.L1Loss().cuda()
    gan_loss = BCE(d_fake_predictions, torch.ones(d_fake_predictions.size()).cuda())
    # l1 loss measures distance between real and generated image
    l1_loss = L1(generated, real)
    return gan_loss + alpha * l1_loss

In [None]:
iters = 0
BS = 1
epochs = 20
for epoch in range(epochs):
    G.train()
    D.train()
    for batch_idx, (gray, color) in enumerate(images_loader):
        gray, color = gray.cuda(), color.cuda()
        
        # Train critic:
        d_optim.zero_grad()
        
        fake = G(gray)
        d_real = D(gray, color)
        d_fake = D(gray, fake.detach())
        predictions = torch.cat((d_real, d_fake), 0)
        targets = torch.cat((torch.ones(d_real.size()).cuda(), torch.zeros(d_fake.size()).cuda()), 0)
        
        d_loss = discriminator_loss(predictions, targets)
            
        d_loss.backward()
        d_optim.step()

        # Train generator:
        g_optim.zero_grad()
        
        fake = G(gray)
        d_fake = D(gray, fake)
        g_loss = generator_loss(d_fake, fake, color, alpha)

        g_loss.backward()
        g_optim.step()
        if (batch_idx % 500 == 0):
            print('Epoch {} batch {} Discriminator loss: {:.3f} Generator loss: {:.3f}'.format(epoch, batch_idx, d_loss, g_loss))
        if (batch_idx % 1000 == 0):
            print_img(gray[:1], fake[:1], color[:1])
            plt.show()
        iters += 1

In [None]:
G.eval()
BS = 128
#images_loader = torch.utils.data.DataLoader(ds, batch_size=BS, shuffle = True, drop_last = True)
#gray, color = next(iter(images_loader))
#fake = G(gray.cuda())
sample = 20
print_img(gray[sample - 1:sample], fake[sample - 1:sample], color[sample - 1:sample])
sample += 1
print_img(gray[sample - 1:sample], fake[sample - 1:sample], color[sample - 1:sample])
sample += 1
print_img(gray[sample - 1:sample], fake[sample - 1:sample], color[sample - 1:sample])
sample += 1
print_img(gray[sample - 1:sample], fake[sample - 1:sample], color[sample - 1:sample])

In [None]:
# lets test it on some unseen data from google to check if it generalizes
BS = 4
ds = gray2image('../input/validation-images', '../input/validation-images', transform=std_transform)
validation_loader = torch.utils.data.DataLoader(ds, batch_size=BS, shuffle = True, drop_last = True)

gray, _ = next(iter(validation_loader))
gray = gray.cuda()
fake = G(gray)
sample = 1
print("This are unseen during training 256 by 256 images colored")
concat_img(torch.cat((gray[sample - 1:sample], fake[sample - 1:sample]), 0))
sample += 1
concat_img(torch.cat((gray[sample - 1:sample], fake[sample - 1:sample]), 0))
sample += 1
concat_img(torch.cat((gray[sample - 1:sample], fake[sample - 1:sample]), 0))
sample += 1
concat_img(torch.cat((gray[sample - 1:sample], fake[sample - 1:sample]), 0))

In [None]:
def save_models():
    torch.save({
            'epoch': epoch,
            'model_state_dict': G.state_dict(),
            'optimizer_state_dict': g_optim.state_dict(),
            'loss': g_loss,
            'iters': iters,
            }, './landscapes_v2_g')
    torch.save({
            'epoch': epoch,
            'model_state_dict': D.state_dict(),
            'optimizer_state_dict': d_optim.state_dict(),
            'loss': d_loss,
            'iters': iters,
            }, './landscapes_v2_d')
save_models()
