In [2]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import warnings
import time

In [3]:
images_paths = list()
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        image_path = os.path.join(dirname, filename)
        if image_path.endswith('.jpg') or image_path.endswith(".png"):
            images_paths.append(image_path)

In [4]:
class ColorizationDataset(Dataset):
    def __init__(self, paths):
        self.paths = paths

    def __getitem__(self, idx):
        return self.paths[idx]

    def __len__(self):
        return len(self.paths)
    
    
def make_dataloader(batch_size=16, n_workers=2, pin_memory=True, **kwargs):
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

In [5]:
np.random.seed(123)
all_images = images_paths
train_range = int(0.8 * len(all_images))
rand_idxs = np.random.permutation(len(all_images))
train_idxs = rand_idxs[:train_range] 
val_idxs = rand_idxs[train_range:] 
train_paths = [all_images[x] for x in train_idxs]
val_paths = [all_images[x] for x in val_idxs]
print(len(train_paths), len(val_paths))

11019 2755


In [6]:
train_loader = make_dataloader(paths=train_paths)
val_loader = make_dataloader(batch_size=32, paths=val_paths)  # INCREASED
len(train_loader), len(val_loader)

(689, 87)

In [7]:
def get_name_from_path(path):
    return path.split('/')[-1]

In [8]:
def save_grayscale_batch(batch):
    for img_path in batch:
        img_rgb = Image.open(img_path)
        img_gray = img_rgb.convert('L')
        img_gray.save(f'{get_name_from_path(img_path)}')

In [9]:
def save_batch(batch):
    for img_path in batch:
        img_rgb = Image.open(img_path)
        img_rgb.save(f'{get_name_from_path(img_path)}')

In [10]:
def load_grayscale_batch(batch):
    return [f'{get_name_from_path(img_path)}' for img_path in batch]

In [11]:
def del_grayscale_batch(batch):
    for img_path in batch:
        os.remove(f'{get_name_from_path(img_path)}')

In [12]:
import torch
from torch import nn

class BaseColor(nn.Module):
    def __init__(self):
        super(BaseColor, self).__init__()

        self.l_cent = 50.
        self.l_norm = 100.
        self.ab_norm = 110.

    def normalize_l(self, in_l):
        return (in_l-self.l_cent)/self.l_norm

    def unnormalize_l(self, in_l):
        return in_l*self.l_norm + self.l_cent

    def normalize_ab(self, in_ab):
        return in_ab/self.ab_norm

    def unnormalize_ab(self, in_ab):
        return in_ab*self.ab_norm



class ECCVGenerator(BaseColor):
    def __init__(self, norm_layer=nn.BatchNorm2d):
        super(ECCVGenerator, self).__init__()

        model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
        model1+=[nn.ReLU(True),]
        model1+=[norm_layer(64),]

        model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
        model2+=[nn.ReLU(True),]
        model2+=[norm_layer(128),]

        model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
        model3+=[nn.ReLU(True),]
        model3+=[norm_layer(256),]

        model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model4+=[nn.ReLU(True),]
        model4+=[norm_layer(512),]

        model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model5+=[nn.ReLU(True),]
        model5+=[norm_layer(512),]

        model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
        model6+=[nn.ReLU(True),]
        model6+=[norm_layer(512),]

        model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
        model7+=[nn.ReLU(True),]
        model7+=[norm_layer(512),]

        model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]
        model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]
        model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
        model8+=[nn.ReLU(True),]

        model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]

        self.model1 = nn.Sequential(*model1)
        self.model2 = nn.Sequential(*model2)
        self.model3 = nn.Sequential(*model3)
        self.model4 = nn.Sequential(*model4)
        self.model5 = nn.Sequential(*model5)
        self.model6 = nn.Sequential(*model6)
        self.model7 = nn.Sequential(*model7)
        self.model8 = nn.Sequential(*model8)

        self.softmax = nn.Softmax(dim=1)
        self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
        self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')

    def forward(self, input_l):
        conv1_2 = self.model1(self.normalize_l(input_l))
        conv2_2 = self.model2(conv1_2)
        conv3_3 = self.model3(conv2_2)
        conv4_3 = self.model4(conv3_3)
        conv5_3 = self.model5(conv4_3)
        conv6_3 = self.model6(conv5_3)
        conv7_3 = self.model7(conv6_3)
        conv8_3 = self.model8(conv7_3)
        out_reg = self.model_out(self.softmax(conv8_3))

        return self.unnormalize_ab(self.upsample4(out_reg))

def eccv16(pretrained=True):
    model = ECCVGenerator()
    if(pretrained):
        import torch.utils.model_zoo as model_zoo
        model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
    return model


from PIL import Image
import numpy as np
from skimage import color
import torch
import torch.nn.functional as F

# def load_img(img_path):
def load_imgs(img_path_list):
    out_np_list = [np.asarray(Image.open(img_path).convert('RGB'))
                   for img_path in img_path_list]
    
    # if(out_np.ndim==2):
    #    out_np = np.tile(out_np[:,:,None],3)
    for i, out_np in enumerate(out_np_list):
        if out_np.ndim==2:
            out_np_list[i] = np.tile(out_np[:,:,None],3)
    # return out_np
    return out_np_list

def resize_img(img, HW=(256,256), resample=3):
    return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))

# def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
def preprocess_imgs(img_rgb_orig_list, HW=(256,256), resample=3):
    # return original size L and resized L as torch Tensors
    # img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
    img_rgb_rs_list = [resize_img(img_rgb_orig, HW=HW, resample=resample)
                       for img_rgb_orig in img_rgb_orig_list]

    # img_lab_orig = color.rgb2lab(img_rgb_orig)
    # img_lab_rs = color.rgb2lab(img_rgb_rs)
    img_lab_orig_list = [color.rgb2lab(img_rgb_orig)
                         for img_rgb_orig in img_rgb_orig_list]
    img_lab_rs_list = [color.rgb2lab(img_rgb_rs)
                       for img_rgb_rs in img_rgb_rs_list]

    # img_l_orig = img_lab_orig[:,:,0]
    # img_l_rs = img_lab_rs[:,:,0]
    img_l_orig_list = [img_lab_orig[:,:,0]
                       for img_lab_orig in img_lab_orig_list]
    img_l_rs_list = [img_lab_rs[:,:,0]
                     for img_lab_rs in img_lab_rs_list]
    # FOR GT
    # (256, 256, 2) 1 2 -> 0 1
    img_ab_rs_list = [torch.transpose(torch.transpose(torch.Tensor(img_lab_rs[:,:,1:3]), 1, 2), 0, 1)
                     for img_lab_rs in img_lab_rs_list]

    # RESHAPE FIRST
    img_l_rs_list_reshaped = [torch.Tensor(img_l_rs)[None, None, :, :]
                              for img_l_rs in img_l_rs_list]
    img_ab_rs_list_reshaped = [torch.Tensor(img_ab_rs)[None,:, :, :]
                              for img_ab_rs in img_ab_rs_list]

    # tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
    # tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
    tens_orig_l_list = [torch.Tensor(img_l_orig)[None,None,:,:]
                        for img_l_orig in img_l_orig_list]
    
    batch_size = len(img_rgb_orig_list)
    tens_rs_l_tensor = torch.cat(img_l_rs_list_reshaped, dim=0)
    tens_rs_ab_tensor = torch.cat(img_ab_rs_list_reshaped, dim=0)

    return (tens_orig_l_list, tens_rs_l_tensor, tens_rs_ab_tensor)

# def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
def postprocess_imgs(tens_orig_l_list, out_ab, mode='bilinear'):
    # tens_orig_l     1 x 1 x H_orig x W_orig
    # out_ab         1 x 2 x H x W
    
    # tens_orig_l_list: (16 elements) 1 x 1 x H_orig x W_orig
    # out_ab                          16 x 2 x H x W

    # HW_orig = tens_orig_l.shape[2:]
    # HW = out_ab.shape[2:]
    HW_orig_list = [tens_orig_l.shape[2:] for tens_orig_l in tens_orig_l_list]  # 16 ele H_orig x W_orig
    HW = out_ab.shape[2:]  # (H, W)

    # call resize function if needed
    #if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
    #    out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
    #else:
    #    out_ab_orig = out_ab
    
    out_ab_orig_list = list()
    for i, HW_orig in enumerate(HW_orig_list):
        if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
            out_ab_orig = F.interpolate(out_ab[i][None, :, :, :], size=HW_orig, mode='bilinear')
        else:
            out_ab_orig = out_ab[i][None, :, :, :]     
        out_ab_orig_list.append(out_ab_orig)

    # out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
    out_lab_orig_list = list()
    for i in range(len(out_ab_orig_list)):
        out_lab_orig_list.append(
            torch.cat((tens_orig_l_list[i].cuda(), out_ab_orig_list[i]), dim=1)
        )
    
    # return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
    return [color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
           for out_lab_orig in out_lab_orig_list]

In [13]:
model = eccv16(pretrained=False).cuda()
model.load_state_dict(torch.load('/kaggle/input/fork-of-eccv16-train/model_9'))
model.eval()

ECCVGenerator(
  (model1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (model2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (model3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU(inplace=Tru

In [14]:
import math
from PIL import Image
import cv2
from cv2 import imshow

def gaussian(window_size, sigma):
    gauss = torch.Tensor([math.exp(-(x - window_size//2) ** 2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()


def create_window(window_size, channel=1):

    # Generate an 1D tensor containing values sampled from a gaussian distribution
    _1d_window = gaussian(window_size=window_size, sigma=1.5).unsqueeze(1)
    
    # Converting to 2D  
    _2d_window = _1d_window.mm(_1d_window.t()).float().unsqueeze(0).unsqueeze(0)
     
    window = torch.Tensor(_2d_window.expand(channel, 1, window_size, window_size).contiguous())

    return window


def ssim(img1, img2, val_range, window_size=11, window=None, size_average=True, full=False):

    L = val_range # L is the dynamic range of the pixel values (255 for 8-bit grayscale images),

    pad = window_size // 2
    
    try:
        _, channels, height, width = img1.size()
    except:
        channels, height, width = img1.size()

    # if window is not provided, init one
    if window is None: 
        real_size = min(window_size, height, width) # window should be atleast 11x11 
        window = create_window(real_size, channel=channels).to(img1.device)
    
    # calculating the mu parameter (locally) for both images using a gaussian filter 
    # calculates the luminosity params
    mu1 = F.conv2d(img1, window, padding=pad, groups=channels)
    mu2 = F.conv2d(img2, window, padding=pad, groups=channels)
    
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2 
    mu12 = mu1 * mu2

    # now we calculate the sigma square parameter
    # Sigma deals with the contrast component 
    sigma1_sq = F.conv2d(img1 * img1, window, padding=pad, groups=channels) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=pad, groups=channels) - mu2_sq
    sigma12 =  F.conv2d(img1 * img2, window, padding=pad, groups=channels) - mu12

    # Some constants for stability 
    C1 = (0.01 * 255) ** 2  # NOTE: Removed L from here (ref PT implementation)
    C2 = (0.03 * 255) ** 2 

    contrast_metric = (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
    contrast_metric = torch.mean(contrast_metric)

    numerator1 = 2 * mu12 + C1  
    numerator2 = 2 * sigma12 + C2
    denominator1 = mu1_sq + mu2_sq + C1 
    denominator2 = sigma1_sq + sigma2_sq + C2

    ssim_score = (numerator1 * numerator2) / (denominator1 * denominator2)

    if size_average:
        ret = ssim_score.mean() 
    else: 
        ret = ssim_score.mean(1).mean(1).mean(1)
    
    if full:
        return ret, contrast_metric
    
    return ret


# display imgs 
def display_imgs(x, transpose=True, resize=True):
    if resize:
        x=cv2.resize(x, (400, 400))
    if transpose:
        imshow(cv2.cvtColor(x, cv2.COLOR_BGR2RGB))
    else:
        imshow(x)

def tensorify(x, mul = True):
#   if vector is already norm, set mul = False
    if mul:
        return torch.Tensor(x.transpose((2, 0, 1))).unsqueeze(0).float().div(255.0)
    else:
        return torch.Tensor(x.transpose((2, 0, 1))).unsqueeze(0).float()

def compute_ssim(image1, image2):
    img1 = tensorify(np.asarray(image1))
    img2 = tensorify(np.asarray(image2))
    return ssim(img1, img2, val_range= 255)


In [15]:
import gc

In [16]:
all_scores = list()

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    for batch_idx, batch in enumerate(val_loader):
        original_imgs = load_imgs(batch)
        orig_l_list, rs_l_tensor, gt = preprocess_imgs(original_imgs)

        output_tensor = model(rs_l_tensor.cuda()) 

        
        postprocessed_imgs = postprocess_imgs(orig_l_list, output_tensor)

        scores = [compute_ssim(postprocessed_imgs[i]*255, np.asarray(Image.open(batch[i]).convert("RGB")))
                  for i in range(len(batch))]
        all_scores.extend(scores)
        print(np.mean(scores), '\t', f'Current accumulative score={np.mean(all_scores):.15f}')
        
        gc.collect()
        torch.cuda.empty_cache()

0.99950624 	 Current accumulative score=0.999506235122681
0.9992124 	 Current accumulative score=0.999359250068665
0.99949014 	 Current accumulative score=0.999402940273285


KeyboardInterrupt: 