In [1]:
import os
import glob

import numpy as np
import cv2
from PIL import Image
from pathlib import Path
# from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import cv2
import torchvision
from torchvision import models
import torch.nn.functional as F

# import scipy
import torch
import torchvision.datasets as dset
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
np.random.seed(2021)
torch.manual_seed(2021)
import copy

# **Defining Dataset**

In [2]:
data_root = "../input/flickrfaceshq-dataset-ffhq"

class FaceData(torch.utils.data.Dataset):
    def __init__(self, transformations=None, root = '../input/flickrfaceshq-dataset-ffhq'):
        self.root = root
        self.transformations = transformations
        self.img_list = glob.glob(os.path.join(root,'*'))
        self.img_list.sort()
    
    def __len__(self):
        return len(glob.glob("../input/flickrfaceshq-dataset-ffhq/*"))

    def mask(self, img):
        canvas = np.full((224,224,3), 255, np.uint8) #draw white canvas
        
        for _ in range(np.random.randint(1,10)):
            x1, x2 = np.random.randint(1,224), np.random.randint(1,224)
            y1, y2 = np.random.randint(1,224), np.random.randint(1,224)
            t = np.random.randint(1,3)
            
            cv2.line(canvas,(x1,y1),(x2,y2),(0,0,0),t)
            
            
        masked_img = img.copy()
        masked_img[canvas==0] = 255
            
        return masked_img, canvas
    
    def __getitem__(self, idx):
        ip_img = Image.open(self.img_list[idx])#.convert("RGB")
        ip_img = ip_img.resize((224,224))
        
        ip_img = np.array(ip_img)
        img_lab = rgb2lab(ip_img).astype("float32") # Converting RGB to L*a*b
        img_lab = transforms.ToTensor()(img_lab)

        L = img_lab[[0]] / 50. - 1. # Max for L is 100 so need /50 -1 to get b/w -1,1
        ab = img_lab[[1,2]]/100. #Max for a and b is 127
        
        masked_L = copy.deepcopy(ip_img)
        masked_L,_ = self.mask(masked_L)

        masked_L = rgb2lab(masked_L).astype("float32")
        masked_L = transforms.ToTensor()(masked_L)
        masked_L = masked_L[[0]] / 50. - 1.
        
        return L, ab, masked_L
        

In [3]:
data_set = FaceData()

# sample_loader = torch.utils.data.DataLoader(data_set, batch_size=32)

# sample_batch=next(iter(sample_loader))
# len(sample_batch)

# type(sample_batch[0]),sample_batch[0].shape,sample_batch[1].shape,sample_batch[2].shape
# concat_lab=torch.cat((sample_batch[0],sample_batch[1]),dim=1)
# concat_lab.shape
# concat_rgb = lab2rgb(torch.cat(((sample_batch[0]+1)*50.,sample_batch[1]*100.),dim=1).permute(0,2,3,1))

# concat_rgb = torch.tensor(concat_rgb).permute(0,3,1,2)

# grid1 = torchvision.utils.make_grid(sample_batch[0],nrow=16)
# grid2 = torchvision.utils.make_grid(concat_lab,nrow=16)
# grid3 =  torchvision.utils.make_grid(concat_rgb,nrow=16)
# grid4 =  torchvision.utils.make_grid(sample_batch[2],nrow=16)

# fig, axs = plt.subplots(4,1,figsize=(50,50),sharex=True)#2 rows, 1 column

# axs[0].imshow(np.transpose(grid1,(1,2,0)))
# axs[1].imshow(np.transpose(grid2,(1,2,0)))
# axs[2].imshow(np.transpose(grid3,(1,2,0)))
# axs[3].imshow(np.transpose(grid4,(1,2,0)))

# plt.subplots_adjust(top=0.5)

In [4]:
# indices = list(range(len(data_set.img_list)))
# val_split = int(0.9*len(data_set.img_list)) 
# test_split = int(val_split+ ((len(data_set.img_list)-val_split)//2))
indices = list(range(30000))

# train_indices,val_indices,test_indices = indices[:20000], indices[20000:25000],indices[25000:30000]
train_indices,val_indices,test_indices = indices[:20000], indices[20000:25000],indices[25000:30000]

train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.sampler.SequentialSampler(val_indices)
test_sampler = torch.utils.data.sampler.SubsetRandomSampler(test_indices)

train_data = torch.utils.data.DataLoader(data_set,batch_size =32, sampler=train_sampler,
                                        num_workers = 2)
val_data = torch.utils.data.DataLoader(data_set,batch_size = 32, sampler =val_sampler,
                                      num_workers =2)
test_data = torch.utils.data.DataLoader(data_set,batch_size = 32, sampler =test_sampler,
                                      num_workers =2)
print(f'Train_data_size={len(train_data)}; Val_data_size={len(val_data)} ; Test_data_size={len(test_data)}')

# **Defining Model**

In [5]:
class Image_Colorization_Model(nn.Module):
    def __init__(self):
        super(Image_Colorization_Model, self).__init__()
#         Encoder
        self.conv_preprocess1 = nn.Conv2d(1, 3, kernel_size=3, padding=1)
        self.conv_preprocess2 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        model_resnet = models.resnet50(pretrained=True)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
#         Decoder
        self.upsample = nn.Upsample(scale_factor = 2, mode = 'nearest')
        self.conv_decode2_1 = nn.Conv2d(768, 128, kernel_size=3, padding=1)
        self.conv_decode2_2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        
        self.conv_decode1_1 = nn.Conv2d(128, 16, kernel_size=3, padding=1)
        self.conv_decode1_2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        
        self.conv_decode0_1 = nn.Conv2d(11, 4, kernel_size=3, padding=1)
        self.conv_decode0_2 = nn.Conv2d(4, 2, kernel_size=1)
        
        #       Inpainting_Branch
        self.seg_upsample = nn.Upsample(scale_factor = 2, mode = 'nearest')
        self.seg_conv_decode2_1 = nn.Conv2d(768, 128, kernel_size=3, padding=1)
        self.seg_conv_decode2_2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        
        self.seg_conv_decode1_1 = nn.Conv2d(128, 16, kernel_size=3, padding=1)
        self.seg_conv_decode1_2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        
        self.seg_conv_decode0_1 = nn.Conv2d(11, 4, kernel_size=3, padding=1)
        self.seg_conv_decode0_2 = nn.Conv2d(4, 1, kernel_size=1)
        
        
        
    def forward(self,x, eval_flag=False):
        ######ENCODER############
        x_pre = F.relu(self.conv_preprocess1(x))
        x_pre = F.relu(self.conv_preprocess2(x_pre))
        
        encode_1 = self.conv1(x_pre)
        encode_1 = self.bn1(encode_1)
        encode_1 = self.relu(encode_1)
        
        x_mp = self.maxpool(encode_1)
        
        encode_2 = self.layer1(x_mp)
        bottle_neck = self.layer2(encode_2)
        
        
        #########Inpainting_Branch#########
        if not eval_flag:
            seg_decode_2 = self.upsample(bottle_neck[bottle_neck.size(0)//2:])
            seg_decode_2 = torch.cat((encode_2[encode_2.size(0)//2:],seg_decode_2),dim=1)
    #         print(seg_decode_2.shape)
            seg_decode_2 = F.relu(self.seg_conv_decode2_1(seg_decode_2))
            seg_decode_2 = F.relu(self.seg_conv_decode2_2(seg_decode_2))

            seg_decode_1 = self.upsample(seg_decode_2)
            seg_decode_1 = torch.cat((encode_1[encode_1.size(0)//2:],seg_decode_1),dim=1)
            seg_decode_1 = F.relu(self.seg_conv_decode1_1(seg_decode_1))
            seg_decode_1 = F.relu(self.seg_conv_decode1_2(seg_decode_1))

            seg_decode_pre = self.upsample(seg_decode_1)
            seg_decode_pre = torch.cat((x_pre[x_pre.size(0)//2:],seg_decode_pre),dim=1)
            seg_decode_pre = F.relu(self.seg_conv_decode0_1(seg_decode_pre))
            seg_decode_pre = (self.seg_conv_decode0_2(seg_decode_pre))
            
        ########DECODER#########
        if eval_flag:
            decode_2 = self.upsample(bottle_neck)

            decode_2 = torch.cat((encode_2,decode_2),dim=1)
        else:          
            decode_2 = self.upsample(bottle_neck[:bottle_neck.size(0)//2])

            decode_2 = torch.cat((encode_2[:encode_2.size(0)//2],decode_2),dim=1)
            
#         print(decode_2.shape)
        decode_2 = F.relu(self.conv_decode2_1(decode_2))
        decode_2 = F.relu(self.conv_decode2_2(decode_2))
        
        if not eval_flag:
            ''''''
            decode_2 = seg_decode_2 * decode_2
            ''''''
        decode_1 = self.upsample(decode_2)
        
        if eval_flag:
            decode_1 = torch.cat((encode_1,decode_1),dim=1)
        else:
            decode_1 = torch.cat((encode_1[:encode_1.size(0)//2],decode_1),dim=1)
        
        decode_1 = F.relu(self.conv_decode1_1(decode_1))
        decode_1 = F.relu(self.conv_decode1_2(decode_1))
        
        if not eval_flag:
            ''''''
            decode_1 = seg_decode_1 * decode_1
            ''''''
        decode_pre = self.upsample(decode_1)
        
        if eval_flag:
            decode_pre = torch.cat((x_pre,decode_pre),dim=1)
        else:
            decode_pre = torch.cat((x_pre[:x_pre.size(0)//2],decode_pre),dim=1)
        
        decode_pre = F.relu(self.conv_decode0_1(decode_pre))
        decode_pre = F.relu(self.conv_decode0_2(decode_pre))
        
        if eval_flag:
            return decode_pre
        else: 
            return decode_pre, seg_decode_pre

class Descriminator(nn.Module):
    def __init__(self):
        super(Descriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=4,stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4,stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4,stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4,stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        
        self.conv5 = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
    
    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)),negative_slope=0.2)
        x = F.leaky_relu(self.bn2(self.conv2(x)),negative_slope=0.2)
        x = F.leaky_relu(self.bn3(self.conv3(x)),negative_slope=0.2)
        x = F.leaky_relu(self.bn4(self.conv4(x)),negative_slope=0.2)
        
        x = self.conv5(x)
        
        return x
        

In [6]:
# discriminator = Descriminator()
# dummy_input = torch.randn(4, 3, 256, 256) # batch_size, channels, size, size
# out = discriminator(dummy_input)
# out.shape

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate = 1e-3
# encoder_model = ResNet().to(device)
# decoder_model = Decoder(device).to(device)
# print(next(decoder_model.parameters()).device)
model = Image_Colorization_Model().cuda()
model_D = Descriminator().cuda()
# model = model.to(device)
optimizer_G = torch.optim.Adam(model.parameters(), lr=learning_rate)
optimizer_D = torch.optim.Adam(model_D.parameters(), lr=learning_rate)
L1loss = nn.L1Loss()
Disc_loss = nn.BCEWithLogitsLoss()
# os.mkdir('./weights')

# **Training Colorization Model**

In [None]:
os.mkdir('./weights')
def plot_loss_L1(train_list):
    plt.figure(figsize=(20,20))
    plt.plot(train_list,label='Train loss L1')
    plt.xlabel('Iteration')
    plt.ylabel('L1 Loss')
    plt.legend(loc='upper right')
    plt.savefig(os.path.join('./', 'Loss_graph_L1'))
    plt.close()

def plot_loss_mse(train_list):
    plt.figure(figsize=(20,20))
    plt.plot(train_list,label='Train loss MSE')
    plt.xlabel('Iteration')
    plt.ylabel('MSE Loss')
    plt.legend(loc='upper right')
    plt.savefig(os.path.join(os.path.join('./', 'Loss_graph_mse')))
    plt.close()

def train_discriminator(model_D, gray, predicted_ab_space, true_ab_space):
    fake_img = torch.cat((gray,predicted_ab_space),dim=1)
    disk_pred_fake = model_D(fake_img.detach())
    fake_label = torch.zeros_like(disk_pred_fake)
    Disc_loss_fake = Disc_loss(disk_pred_fake, fake_label)
#     print(disk_pred_fake.shape,fake_label.shape)
    
    real_img = torch.cat((gray,true_ab_space),dim=1)
    disk_pred_real = model_D(real_img)
    real_label = torch.ones_like(disk_pred_real)
    Disc_loss_real = Disc_loss(disk_pred_real, real_label)
#     print(disk_pred_real.shape,real_label.shape)
    
#     print(Disc_loss_fake,Disc_loss_real)
    return (Disc_loss_fake + Disc_loss_real)*0.5

def train_generator(model_D, gray, predicted_ab_space):
    fake_img = torch.cat((gray,predicted_ab_space),dim=1)
    disk_pred_fake = model_D(fake_img)
    real_label = torch.ones_like(disk_pred_fake)
    adversarial_loss = Disc_loss(disk_pred_fake, real_label)
#     print(disk_pred_fake.shape,real_label.shape)
    
    return adversarial_loss

# '''    
chkpt = torch.load("../input/55-epoch-ad-color-mix/_epoch_55 (2)")

model.load_state_dict(chkpt['model_state_dict'])
model_D.load_state_dict(chkpt['model_D_state_dict'])
optimizer_D.load_state_dict(chkpt['optimizer_D_state_dict'])
optimizer_G.load_state_dict(chkpt['optimizer_G_state_dict'])
chkpt_epoch = chkpt['epoch']

train_loss_l1 = chkpt['loss']['L1_Loss'].copy()
train_loss_mse = chkpt['loss']['MSE_Loss'].copy()
train_loss_inpaint = chkpt['loss']['inpaint_Loss'].copy()
total_loss = chkpt['loss']['total_Loss'].copy()
train_loss_ad = chkpt['loss']['train_loss_ad'].copy()
train_disc_loss = chkpt['loss']['train_disc_loss'].copy()

# '''

model.train()
model_D.train()

# train_loss_l1 = []
# train_loss_mse = []
# train_loss_inpaint = []
# total_loss = []
# train_loss_ad = []
# train_disc_loss = []

for i in range(70):#Total=70 epoch
#     b=0
    for data in train_data:
#         print(len(data))
        model.train()
        model_D.train()
        gray = data[0].to(device)  
#         plt.imshow(gray)
        true_ab_space = data[1].to(device)
        masked_L = data[2].to(device)
        
        input_cat = torch.cat((gray, masked_L))
        
        
        optimizer_D.zero_grad()
#         print(f'x:{gray.shape}')
#         print(f'y:{true_ab_space.shape}')
#         print(f'z:{masked_L.shape}')
#         print(f'zzz:{input_cat.shape}')
    
        predicted_ab_space, inpainted_img = model(input_cat)
        
#         fake_img = torch.cat((gray,predicted_ab_space),dim=1)
#         disk_pred_fake = model_D(fake_img.detach())
#         fake_label = torch.zeros_like(disk_pred_fake)
#         Disc_loss_fake = Disc_loss(disk_pred_fake, fake_label)
        
#         real_img = torch.cat((gray,true_ab_space),dim=1)
#         disk_pred_real = model_D(real_img.detach())
#         real_label = torch.ones_like(disk_pred_real)
#         Disc_loss_real = Disc_loss(disk_pred_real, real_label)
        
        for param in model_D.parameters():
            param.requires_grad = True
            
        discriminator_loss = train_discriminator(model_D, gray, predicted_ab_space, true_ab_space)
        discriminator_loss.backward()
        
        optimizer_D.step()
        
#         print(predicted_ab_space.shape)
#         print(true_ab_space.shape)
        optimizer_G.zero_grad()
        
        for param in model_D.parameters():
            param.requires_grad = False
        
        model_D.eval()
        ad_loss = train_generator(model_D, gray, predicted_ab_space)
        loss_color = L1loss(predicted_ab_space, true_ab_space) * 100.
        loss_inpaint = L1loss(inpainted_img, gray) * 100.
        
        gen_loss = ad_loss + loss_color + loss_inpaint
        
        gen_loss.backward()
        
        optimizer_G.step()
    
        
#         loss_inpaint = L1loss(inpainted_img, gray)
        
#         loss = loss_color + loss_inpaint
    
        train_loss_l1.append(loss_color.item())
        train_loss_mse.append((F.mse_loss(predicted_ab_space, true_ab_space)).item())
        train_loss_inpaint.append(loss_inpaint.item())
        
        train_loss_ad.append(ad_loss.item())
        train_disc_loss.append(discriminator_loss.item())
        
        total_loss.append(gen_loss.item())

#         print(loss)
#         b+=1
#         print(b)
        
#         plot_loss_L1(train_loss_l1)
#         plot_loss_mse(train_loss_mse)
        
    if (i+1) % 1 == 0:
        print('%d iterations' % (i+1))
        print('L1_Loss %.3f' % np.mean(train_loss_l1[-100:]))
        print('MSE_Loss: %.3f' % np.mean(train_loss_mse[-100:]))
        print('train_loss_inpaint: %.3f' % np.mean(train_loss_inpaint[-100:]))
        print('train_loss_ad: %.3f' % np.mean(train_loss_ad[-100:]))
        print('train_disc_loss: %.3f' % np.mean(train_disc_loss[-100:]))
        print('total_loss: %.3f' % np.mean(total_loss[-100:]))


    if (i+1)%5 == 0:
        torch.save({'epoch':i,
                    'model_state_dict':model.state_dict(),
                    'model_D_state_dict':model_D.state_dict(),
                    'optimizer_D_state_dict':optimizer_D.state_dict(),
                    'optimizer_G_state_dict':optimizer_G.state_dict(),
                    'loss':{'L1_Loss':train_loss_l1.copy(),'MSE_Loss':train_loss_mse.copy(),'inpaint_Loss':train_loss_inpaint.copy(),'train_loss_ad':train_loss_ad.copy(),'train_disc_loss':train_disc_loss.copy(),'total_Loss':total_loss.copy()},
                   },os.path.join('./weights',f'_epoch_{i+56}'))

In [None]:
# val_loss_L1 = []
# val_loss_mse = []
# model.eval()
# with torch.no_grad():
#     for i,data in enumerate(val_data):
#             print(i)
            
#             gray = data[0].to(device)  

#             true_ab_space = data[1].to(device)
#             masked_L = data[2].to(device)

#             input_cat = torch.cat((gray, masked_L))
            
#             predicted_ab_space,_ = model(input_cat)
# #             predicted_ab_space = model(gray,eval_flag=True)
            

#             val_loss_L1.append(L1loss(predicted_ab_space, true_ab_space).item())
#             val_loss_mse.append((F.mse_loss(predicted_ab_space, true_ab_space)).item())



In [None]:
# mean_val_loss_l1 = sum(val_loss_L1)/len(val_loss_L1)
# mean_val_loss_mse = sum(val_loss_mse)/len(val_loss_mse)

# mean_val_loss_l1,mean_val_loss_mse

# **Testing Trained Model**

In [None]:
# torch.cuda.memory_allocated(), torch.cuda.current_device()

In [8]:
chkpt = torch.load("../input/adversarial-mixup-105-epoch/_epoch_105",map_location=torch.device('cpu'))
# chkpt_inpaint = torch.load("../input/inpaint35/inpaint_epoch_35 (1)",map_location=torch.device('cpu'))
model.load_state_dict(chkpt['model_state_dict'])
model_D.load_state_dict(chkpt['model_D_state_dict'])

# inpaint_model.load_state_dict(chkpt_inpaint['model_state_dict'])

model.eval()

with torch.no_grad():
    data_iterator = iter(val_data)
    for _ in range(7):
        (gray_val, true_ab_val,masked_L) = next(data_iterator)
    #     show(gray_val.cpu().numpy().transpose(0,2,3,1)[0])

        gray_val_orig = copy.deepcopy(gray_val)
        true_ab_val_orig = copy.deepcopy(true_ab_val)

        gray_val = gray_val.to(device)
        masked_L = masked_L.to(device)
        true_ab_val = true_ab_val.to(device)

        pred_ab, pred_inpaint = model(torch.cat((gray_val,masked_L)))

        fake_img = torch.cat((gray_val,pred_ab),dim=1)
        real_img = torch.cat((gray_val,true_ab_val),dim=1)
        disk_pred_fake = model_D(fake_img)
        disk_pred_real = model_D(real_img)

pred_ab = (pred_ab*100.).permute(0, 2, 3, 1).contiguous()
gray_val = ((gray_val+1)*50.).permute(0,2,3,1).contiguous()

ab_rgb = (true_ab_val_orig*100.).permute(0, 2, 3, 1).contiguous()
gray_rgb = ((gray_val_orig+1)*50.).permute(0,2,3,1).contiguous()

true_rgb = torch.tensor(lab2rgb(torch.cat((gray_rgb,ab_rgb),dim=3).cpu()))
true_rgb = true_rgb.permute(0,3,1,2)
# gray_val_orig = gray_val_orig.cpu().numpy().transpose(0,2,3,1)

# pred_ab = (pred_ab*100.).cpu().numpy().transpose(0,2,3,1)
# gray_val = ((gray_val+1)*50.).cpu().numpy().transpose(0,2,3,1)
# pred_lab = np.concatenate((gray_val,pred_ab),axis=3)
# print(pred_ab.shape,gray_val.shape,pred_lab.shape)
# pred_rgb =  lab2rgb(pred_lab)

pred_rgb = torch.tensor(lab2rgb(torch.cat((gray_val,pred_ab),dim=3).cpu()))
pred_rgb = pred_rgb.permute(0,3,1,2)

In [9]:
from torchvision.utils import make_grid
def show(img,figure_size=(20,20)):
#     print(img.shape)
    npimg = img.numpy()
    npimg = img
    print(npimg.shape)
    _, ax = plt.subplots(figsize=figure_size)
    fig = ax.imshow(np.transpose(npimg, (1,2,0)))
    fig.axes.get_xaxis().set_visible(True)
    fig.axes.get_yaxis().set_visible(True)
    plt.savefig("./output")

show(make_grid(((gray_val/100.)).permute(0,3,1,2).cpu().data))

In [10]:
show(make_grid(pred_rgb.cpu().data) )
# show(pred_rgb[0])

In [None]:
# disk_pred_fake=disk_pred_fake>0.5
# disk_pred_fake

In [None]:
# disk_pred_real=disk_pred_real>0.5
# disk_pred_real

In [None]:
# show(make_grid(pred_rgb.cpu().data) )

In [None]:
show(make_grid(true_rgb.cpu().data) )

In [None]:
# show(make_grid(pred_inpaint.cpu().data) )

In [None]:
# show(make_grid(masked_L.cpu().data) )

In [None]:
# pip install torch-summary

In [None]:
# from torchsummary import summary
# summary(model, (1, 224, 224))

In [None]:
# loss = chkpt['loss']['L1_Loss']
# print(len(loss))
# f = plt.figure(figsize=(20,8))
# ax = f.add_subplot(1,2,1)
# ax.plot(loss)
# # ax.set_yscale('log')
# ax.set_title('L1 Loss')
# ax.set_xlabel('Iteration')
# f.savefig('./loss')

In [None]:
# print('L1_Loss %.3f' % np.mean(loss[-100:]))

In [None]:
# loss = chkpt['loss']['MSE_Loss']
# f = plt.figure(figsize=(20,8))
# ax = f.add_subplot(1,2,1)
# ax.plot(loss)
# # ax.set_yscale('log')
# ax.set_title('MSE_Loss Loss')
# ax.set_xlabel('Iteration')

In [None]:
# loss = chkpt['loss']['train_disc_loss'][:200]
# f = plt.figure(figsize=(20,8))
# ax = f.add_subplot(1,2,1)
# ax.plot(loss)
# # ax.set_yscale('log')
# ax.set_title('MSE_Loss Loss')
# ax.set_xlabel('Iteration')

In [None]:
# loss = chkpt['loss']['train_loss_ad'][:1000]
# f = plt.figure(figsize=(20,8))
# ax = f.add_subplot(1,2,1)
# ax.plot(loss)
# # ax.set_yscale('log')
# ax.set_title('MSE_Loss Loss')
# ax.set_xlabel('Iteration')

In [11]:
def mask(img):
        canvas = np.full((224,224,3), 255, np.uint8) #draw white canvas
        
        for _ in range(np.random.randint(1,10)):
            x1, x2 = np.random.randint(1,224), np.random.randint(1,224)
            y1, y2 = np.random.randint(1,224), np.random.randint(1,224)
            t = np.random.randint(1,3)
            
            cv2.line(canvas,(x1,y1),(x2,y2),(0,0,0),t)
            
            
        masked_img = img.copy()
        masked_img[canvas==0] = 255
            
        return masked_img, canvas
    
ip_img = Image.open('../input/internet-imgs/Sardar_patel_(cropped).jpg').convert("RGB")
# ip_img = Image.open('../input/internet-imgs/netaji-subash-chandra-bose1611225337348.jpg').convert("RGB")
ip_img = ip_img.resize((224,224))
b_w = copy.deepcopy(ip_img)
ip_img = np.array(ip_img)

masked_L = copy.deepcopy(ip_img)
masked_L,_ = mask(masked_L)

masked_L = rgb2lab(masked_L).astype("float32")
masked_L = transforms.ToTensor()(masked_L)
masked_L = masked_L[[0]] / 50. - 1.
# print(ip_img.shape)

img_lab = rgb2lab(ip_img).astype("float32") # Converting RGB to L*a*b
img_lab = transforms.ToTensor()(img_lab)

L = img_lab[[0]] / 50. - 1. # Max for L is 100 so need /50 -1 to get b/w -1,1
ab = img_lab[[1,2]]/100. #Max for a and b is 127
L=torch.unsqueeze(L,dim=0).cuda()
masked_L=torch.unsqueeze(masked_L,dim=0).cuda()
# print(masked_L.shape)
# L.shape

chkpt = torch.load("../input/adversarial-mixup-105-epoch/_epoch_105",map_location=torch.device('cpu'))
# chkpt_inpaint = torch.load("../input/inpaint35/inpaint_epoch_35 (1)",map_location=torch.device('cpu'))
model.load_state_dict(chkpt['model_state_dict'])

# inpaint_model.load_state_dict(chkpt_inpaint['model_state_dict'])

model.eval()
with torch.no_grad():
    gray_val = copy.deepcopy(L)
    true_ab_val_orig = copy.deepcopy(ab)
    
#     gray_val = gray_val.to(device)

    pred_ab,_ = model(torch.cat((L,masked_L),dim=0))
# print(type(pred_ab),pred_ab.size)    
pred_ab = (pred_ab*100.).permute(0, 2, 3, 1).contiguous()
gray_val = ((gray_val+1)*50.).permute(0,2,3,1).contiguous()
pred_rgb = torch.tensor(lab2rgb(torch.cat((gray_val,pred_ab),dim=3).cpu()))
pred_rgb = pred_rgb.permute(0,3,1,2)


In [12]:
plt.imshow(b_w)

In [13]:
show(make_grid(pred_rgb.cpu().data),figure_size=(4,4))

# Personal Images

In [18]:
def show_p(img,name,figure_size=(20,20)):
#     print(img.shape)
    npimg = img.numpy()
    npimg = img
    print(npimg.shape)
    _, ax = plt.subplots(figsize=figure_size)
    fig = ax.imshow(np.transpose(npimg, (1,2,0)))
    fig.axes.get_xaxis().set_visible(True)
    fig.axes.get_yaxis().set_visible(True)
    plt.savefig("./"+name)

In [29]:
def mask(img):
        canvas = np.full((224,224,3), 255, np.uint8) #draw white canvas
        
        for _ in range(np.random.randint(1,10)):
            x1, x2 = np.random.randint(1,224), np.random.randint(1,224)
            y1, y2 = np.random.randint(1,224), np.random.randint(1,224)
            t = np.random.randint(1,3)
            
            cv2.line(canvas,(x1,y1),(x2,y2),(0,0,0),t)
            
            
        masked_img = img.copy()
        masked_img[canvas==0] = 255
            
        return masked_img, canvas
    
add_string = '../input/p-photos/pandit ji.jpeg'
ip_img = Image.open(add_string).convert("RGB")
p_name = add_string.strip('../input/p-photos/').strip('.jpeg')

# ip_img = Image.open('../input/internet-imgs/netaji-subash-chandra-bose1611225337348.jpg').convert("RGB")
ip_img = ip_img.resize((224,224))
b_w = copy.deepcopy(ip_img)
ip_img = np.array(ip_img)

masked_L = copy.deepcopy(ip_img)
masked_L,_ = mask(masked_L)

masked_L = rgb2lab(masked_L).astype("float32")
masked_L = transforms.ToTensor()(masked_L)
masked_L = masked_L[[0]] / 50. - 1.
# print(ip_img.shape)

img_lab = rgb2lab(ip_img).astype("float32") # Converting RGB to L*a*b
img_lab = transforms.ToTensor()(img_lab)

L = img_lab[[0]] / 50. - 1. # Max for L is 100 so need /50 -1 to get b/w -1,1
ab = img_lab[[1,2]]/100. #Max for a and b is 127
L=torch.unsqueeze(L,dim=0).cuda()
masked_L=torch.unsqueeze(masked_L,dim=0).cuda()
# print(masked_L.shape)
# L.shape

chkpt = torch.load("../input/adversarial-mixup-105-epoch/_epoch_105",map_location=torch.device('cpu'))
# chkpt_inpaint = torch.load("../input/inpaint35/inpaint_epoch_35 (1)",map_location=torch.device('cpu'))
model.load_state_dict(chkpt['model_state_dict'])

# inpaint_model.load_state_dict(chkpt_inpaint['model_state_dict'])

model.eval()
with torch.no_grad():
    gray_val = copy.deepcopy(L)
    true_ab_val_orig = copy.deepcopy(ab)
    
#     gray_val = gray_val.to(device)

    pred_ab,_ = model(torch.cat((L,masked_L),dim=0))
# print(type(pred_ab),pred_ab.size)    
pred_ab = (pred_ab*100.).permute(0, 2, 3, 1).contiguous()
gray_val = ((gray_val+1)*50.).permute(0,2,3,1).contiguous()
pred_rgb = torch.tensor(lab2rgb(torch.cat((gray_val,pred_ab),dim=3).cpu()))
pred_rgb = pred_rgb.permute(0,3,1,2)

show_p(make_grid(pred_rgb.cpu().data),p_name,figure_size=(4,4))
