In [51]:
# Import Dependencies
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb, rgba2rgb
import cv2

import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
import wandb
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
use_colab = None
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

cuda


In [None]:
wandb.init(project="image_colorization", config={
    "architecture": "GAN",
    "dataset": "CIFAR-100_VAL_2017",
})

In [53]:
path = "Data/MS COCO/MS COCO/val2017"
image_paths = glob.glob(path + "/*.jpg") # Get all image filenames
np.random.seed(420)
train_idxs = image_paths[:4500]
val_idxs = image_paths[4500:]
print(f'N_TRAIN IMG: {len(train_idxs)}\nN_VAL IMG: {len(val_idxs)}')

N_TRAIN IMG: 4500
N_VAL IMG: 490


In [58]:
size = 224
class ColorizationDataset(Dataset):
    def __init__(self, base_images, split='train'):
        self.file_list = os.listdir(base_images)
        if split == 'train':
            self.file_list = self.file_list
            print(len(self.file_list))
            self.transforms = transforms.Compose([
                transforms.Resize((256, 256), Image.BICUBIC),
                transforms.RandomHorizontalFlip(), # Low image augmentation to maintain realism
            ])
        elif split == 'val':
            self.file_list = self.file_list
            print(len(self.file_list))
            self.transforms = transforms.Resize((256, 256), Image.BICUBIC)
        
        self.split = split
        self.base_images = base_images

    def __len__(self):
        return len(self.file_list)
    
    def rgb_to_lab(self, rgb_image):
        # Transform the image to proper specifications
        img = self.transforms(rgb_image)
        rgb_img = np.array(img)
        lab_img = rgb2lab(rgb_img)
        lab_img = lab_img.transpose(2,1,0)
        lab_img_torch = torch.from_numpy(lab_img)
        # Normalize between -1 and 1
        L = (lab_img_torch[[0], ...] / 50 - 1.).to(dtype=torch.float32)
        ab = (lab_img_torch[[1, 2], ...] / 110.).to(dtype=torch.float32)
        return L.to(device), ab.to(device)
        
    def __getitem__(self, idx):
        colored_path = os.path.join(self.base_images, self.file_list[idx])
        colored_image = Image.open(colored_path)
        L, ab = self.rgb_to_lab(colored_image)
        return {'L': L, 'ab': ab}
    
# Utility function for creating the data-loader
def make_dataloaders(batch_size=16, n_workers=0, **kwargs): # Function for creating the dataloaders
    dataset = ColorizationDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers)
    return dataloader

In [55]:
train_dl = make_dataloaders(base_images="Data/MS COCO/MS COCO/val2017", split='train')
val_dl = make_dataloaders(base_images="Data/MS COCO/MS COCO/val2017", split='val')
print('Dataloaders Completed')

4500
490
Dataloaders Completed


In [6]:
# Building blocks for our full generator model
# Model will be very complex, so these blocks simplify the final product
class UnetBlock(nn.Module):
    def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
                 innermost=False, outermost=False):
        super().__init__()
        self.outermost = outermost
        if input_c is None: input_c = nf
        downconv = nn.Conv2d(input_c, ni, kernel_size=4,
                             stride=2, padding=1, bias=False)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = nn.BatchNorm2d(ni)
        uprelu = nn.ReLU(True)
        upnorm = nn.BatchNorm2d(nf)
        
        if outermost:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
                                        stride=2, padding=1, bias=False)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]
            if dropout: up += [nn.Dropout(0.5)]
            model = down + [submodule] + up
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)
        
# Now we will use these blocks for the final product
class Unet(nn.Module):
    def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
        super().__init__()
        unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
        for _ in range(n_down - 5):
            unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
        out_filters = num_filters * 8
        for _ in range(3):
            unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
            out_filters //= 2
        self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
    
    def forward(self, x):
        return self.model(x)

In [7]:
# Simple Discriminator Model
class PatchDiscriminator(nn.Module):
    def __init__(self, input_c, num_filters=64, n_down=3):
        super().__init__()
        model = [self.get_layers(input_c, num_filters, norm=False)]
        model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) 
                          for i in range(n_down)] # the 'if' statement is taking care of not using
                                                  # stride of 2 for the last block in this loop
        model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
                                                                                             # activation for the last layer of the model
        self.model = nn.Sequential(*model)                                                   
        
    def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
        layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]          # it's always helpful to make a separate method for that purpose
        if norm: layers += [nn.BatchNorm2d(nf)]
        if act: layers += [nn.LeakyReLU(0.2, True)]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return self.model(x)

In [8]:
PatchDiscriminator(3)

PatchDiscriminator(
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (

In [9]:
discriminator = PatchDiscriminator(3)
dummy_input = torch.randn(16, 3, 256, 256) # batch_size, channels, size, size
out = discriminator(dummy_input)
out.shape

torch.Size([16, 1, 30, 30])

In [10]:
# Now we will make a class that can calculate the GAN loss of our final model
class GANLoss(nn.Module):
    def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
        super().__init__()
        self.register_buffer('real_label', torch.tensor(real_label))
        self.register_buffer('fake_label', torch.tensor(fake_label))
        if gan_mode == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif gan_mode == 'lsgan':
            self.loss = nn.MSELoss()
    
    def get_labels(self, preds, target_is_real):
        if target_is_real:
            labels = self.real_label
        else:
            labels = self.fake_label
        return labels.expand_as(preds)
    
    def __call__(self, preds, target_is_real):
        labels = self.get_labels(preds, target_is_real)
        loss = self.loss(preds, labels)
        return loss

In [11]:
# Function for instantiating model and weights
def init_weights(net, init='norm', gain=0.02):
    
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and 'Conv' in classname:
            if init == 'norm':
                nn.init.normal_(m.weight.data, mean=0.0, std=gain)
            elif init == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=gain)
            elif init == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif 'BatchNorm2d' in classname:
            nn.init.normal_(m.weight.data, 1., gain)
            nn.init.constant_(m.bias.data, 0.)
            
    net.apply(init_func)
    print(f"model initialized with {init} initialization")
    return net

def init_model(model, device):
    model = model.to(device)
    model = init_weights(model)
    return model

In [12]:
# Combine the generator and discriminator into one model for ease of use during training
class MainModel(nn.Module):
    def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, 
                 beta1=0.5, beta2=0.999, lambda_L1=100.):
        super().__init__()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lambda_L1 = lambda_L1
        
        if net_G is None:
            self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
        else:
            self.net_G = net_G.to(self.device)
        self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
        self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
        self.L1criterion = nn.L1Loss()
        self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
        self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    def set_requires_grad(self, model, requires_grad=True):
        for p in model.parameters():
            p.requires_grad = requires_grad
        
    def setup_input(self, data):
        self.L = data['L'].to(self.device)
        self.ab = data['ab'].to(self.device)
        
    def forward(self):
        self.fake_color = self.net_G(self.L)
    
    def backward_D(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image.detach())
        self.loss_D_fake = self.GANcriterion(fake_preds, False)
        real_image = torch.cat([self.L, self.ab], dim=1)
        real_preds = self.net_D(real_image)
        self.loss_D_real = self.GANcriterion(real_preds, True)
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
        self.loss_D.backward()
    
    def backward_G(self):
        fake_image = torch.cat([self.L, self.fake_color], dim=1)
        fake_preds = self.net_D(fake_image)
        self.loss_G_GAN = self.GANcriterion(fake_preds, True)
        self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()
    
    def optimize(self):
        self.forward()
        self.net_D.train()
        self.set_requires_grad(self.net_D, True)
        self.opt_D.zero_grad()
        self.backward_D()
        self.opt_D.step()
        
        self.net_G.train()
        self.set_requires_grad(self.net_D, False)
        self.opt_G.zero_grad()
        self.backward_G()
        self.opt_G.step()

In [13]:
# Create some utility functions
class AverageMeter:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.count, self.avg, self.sum = [0.] * 3
    
    def update(self, val, count=1):
        self.count += count
        self.sum += count * val
        self.avg = self.sum / self.count

def create_loss_meters():
    loss_D_fake = AverageMeter()
    loss_D_real = AverageMeter()
    loss_D = AverageMeter()
    loss_G_GAN = AverageMeter()
    loss_G_L1 = AverageMeter()
    loss_G = AverageMeter()
    
    return {'loss_D_fake': loss_D_fake,
            'loss_D_real': loss_D_real,
            'loss_D': loss_D,
            'loss_G_GAN': loss_G_GAN,
            'loss_G_L1': loss_G_L1,
            'loss_G': loss_G}

def update_losses(model, loss_meter_dict, count):
    for loss_name, loss_meter in loss_meter_dict.items():
        loss = getattr(model, loss_name)
        loss_meter.update(loss.item(), count=count)
        wandb.log({loss_name: loss.item()})

def lab_to_rgb(L, ab):
    """
    Takes a batch of images
    """
    
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

def visualize(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    number = np.random.randint(low=0, high=16)
    wandb.log({
        "input_image": wandb.Image(np.rot90(L[number][0].cpu(), 3)),
        "reconstruction_image": wandb.Image(np.rot90(fake_imgs[number], 3)),
        "ground_truth_image": wandb.Image(np.rot90(real_imgs[number], 3)),
    })
    
def log_results(loss_meter_dict):
    for loss_name, loss_meter in loss_meter_dict.items():
        print(f'{loss_name}: {loss_meter.avg:.5f}')

In [14]:
def train_model(model, train_dl, epochs, display_every=200):
    data = next(iter(val_dl)) # getting a batch for visualizing the model output after fixed intrvals
    for e in range(epochs):
        loss_meter_dict = create_loss_meters() # function returing a dictionary of objects to 
        i = 0                                  # log the losses of the complete network
        for data in tqdm(train_dl):
            model.setup_input(data) 
            model.optimize()
            update_losses(model, loss_meter_dict, count=data['L'].size(0)) # function updating the log objects
            i += 1
            if i % display_every == 0:
                print(f"\nEpoch {e+1}/{epochs}")
                print(f"Iteration {i}/{len(train_dl)}")
                log_results(loss_meter_dict) # function to print out the losses
                visualize(model, data, save=False) # function displaying the model's outputs
    torch.save(model.state_dict(), 'colorization_model.pth')

model = MainModel()
train_model(model, train_dl, 100)

model initialized with norm initialization
model initialized with norm initialization


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 1/100
Iteration 200/282
loss_D_fake: 0.41537
loss_D_real: 0.42294
loss_D: 0.41916
loss_G_GAN: 1.83694
loss_G_L1: 9.45337
loss_G: 11.29031


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 2/100
Iteration 200/282
loss_D_fake: 0.48764
loss_D_real: 0.52443
loss_D: 0.50603
loss_G_GAN: 1.48347
loss_G_L1: 10.31897
loss_G: 11.80243


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 3/100
Iteration 200/282
loss_D_fake: 0.50506
loss_D_real: 0.55371
loss_D: 0.52938
loss_G_GAN: 1.33586
loss_G_L1: 10.41135
loss_G: 11.74721


  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 4/100
Iteration 200/282
loss_D_fake: 0.51448
loss_D_real: 0.56770
loss_D: 0.54109
loss_G_GAN: 1.32454
loss_G_L1: 10.40909
loss_G: 11.73362


  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 5/100
Iteration 200/282
loss_D_fake: 0.51859
loss_D_real: 0.57694
loss_D: 0.54777
loss_G_GAN: 1.25302
loss_G_L1: 10.34082
loss_G: 11.59384


  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 6/100
Iteration 200/282
loss_D_fake: 0.51534
loss_D_real: 0.58110
loss_D: 0.54822
loss_G_GAN: 1.24479
loss_G_L1: 10.29425
loss_G: 11.53904


  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 7/100
Iteration 200/282
loss_D_fake: 0.51961
loss_D_real: 0.59096
loss_D: 0.55528
loss_G_GAN: 1.21787
loss_G_L1: 10.25211
loss_G: 11.46998


  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 8/100
Iteration 200/282
loss_D_fake: 0.51658
loss_D_real: 0.58905
loss_D: 0.55281
loss_G_GAN: 1.22240
loss_G_L1: 10.18595
loss_G: 11.40834


  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)
  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 9/100
Iteration 200/282
loss_D_fake: 0.52179
loss_D_real: 0.58706
loss_D: 0.55442
loss_G_GAN: 1.22475
loss_G_L1: 10.18636
loss_G: 11.41111


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 10/100
Iteration 200/282
loss_D_fake: 0.52856
loss_D_real: 0.59525
loss_D: 0.56191
loss_G_GAN: 1.22157
loss_G_L1: 10.12553
loss_G: 11.34710


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 11/100
Iteration 200/282
loss_D_fake: 0.52682
loss_D_real: 0.58992
loss_D: 0.55837
loss_G_GAN: 1.21487
loss_G_L1: 10.02279
loss_G: 11.23766


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 12/100
Iteration 200/282
loss_D_fake: 0.52537
loss_D_real: 0.59114
loss_D: 0.55826
loss_G_GAN: 1.21872
loss_G_L1: 9.92439
loss_G: 11.14311


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 13/100
Iteration 200/282
loss_D_fake: 0.52324
loss_D_real: 0.57605
loss_D: 0.54965
loss_G_GAN: 1.22168
loss_G_L1: 9.84742
loss_G: 11.06910


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 14/100
Iteration 200/282
loss_D_fake: 0.52600
loss_D_real: 0.58769
loss_D: 0.55684
loss_G_GAN: 1.21498
loss_G_L1: 9.72738
loss_G: 10.94236


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 15/100
Iteration 200/282
loss_D_fake: 0.52804
loss_D_real: 0.58730
loss_D: 0.55767
loss_G_GAN: 1.21522
loss_G_L1: 9.62782
loss_G: 10.84304


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 16/100
Iteration 200/282
loss_D_fake: 0.52936
loss_D_real: 0.58834
loss_D: 0.55885
loss_G_GAN: 1.19966
loss_G_L1: 9.51563
loss_G: 10.71529


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 17/100
Iteration 200/282
loss_D_fake: 0.52895
loss_D_real: 0.58304
loss_D: 0.55600
loss_G_GAN: 1.20658
loss_G_L1: 9.34929
loss_G: 10.55587


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 18/100
Iteration 200/282
loss_D_fake: 0.52679
loss_D_real: 0.58001
loss_D: 0.55340
loss_G_GAN: 1.20136
loss_G_L1: 9.27660
loss_G: 10.47796


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 19/100
Iteration 200/282
loss_D_fake: 0.52706
loss_D_real: 0.58215
loss_D: 0.55461
loss_G_GAN: 1.21139
loss_G_L1: 9.15562
loss_G: 10.36701


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 20/100
Iteration 200/282
loss_D_fake: 0.52814
loss_D_real: 0.57421
loss_D: 0.55117
loss_G_GAN: 1.20946
loss_G_L1: 9.08841
loss_G: 10.29787


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 21/100
Iteration 200/282
loss_D_fake: 0.52926
loss_D_real: 0.58282
loss_D: 0.55604
loss_G_GAN: 1.20890
loss_G_L1: 8.86633
loss_G: 10.07523


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 22/100
Iteration 200/282
loss_D_fake: 0.53372
loss_D_real: 0.58193
loss_D: 0.55782
loss_G_GAN: 1.21321
loss_G_L1: 8.77223
loss_G: 9.98545


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 23/100
Iteration 200/282
loss_D_fake: 0.53568
loss_D_real: 0.58019
loss_D: 0.55793
loss_G_GAN: 1.20589
loss_G_L1: 8.67814
loss_G: 9.88403


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 24/100
Iteration 200/282
loss_D_fake: 0.53553
loss_D_real: 0.57542
loss_D: 0.55547
loss_G_GAN: 1.19861
loss_G_L1: 8.54755
loss_G: 9.74616


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 25/100
Iteration 200/282
loss_D_fake: 0.54278
loss_D_real: 0.58005
loss_D: 0.56141
loss_G_GAN: 1.20572
loss_G_L1: 8.40195
loss_G: 9.60767


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 26/100
Iteration 200/282
loss_D_fake: 0.54007
loss_D_real: 0.57517
loss_D: 0.55762
loss_G_GAN: 1.19834
loss_G_L1: 8.29553
loss_G: 9.49387


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 27/100
Iteration 200/282
loss_D_fake: 0.53943
loss_D_real: 0.57567
loss_D: 0.55755
loss_G_GAN: 1.19925
loss_G_L1: 8.18088
loss_G: 9.38014


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 28/100
Iteration 200/282
loss_D_fake: 0.54531
loss_D_real: 0.57285
loss_D: 0.55908
loss_G_GAN: 1.20198
loss_G_L1: 8.03055
loss_G: 9.23252


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 29/100
Iteration 200/282
loss_D_fake: 0.54562
loss_D_real: 0.56982
loss_D: 0.55772
loss_G_GAN: 1.18156
loss_G_L1: 7.94227
loss_G: 9.12383


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 30/100
Iteration 200/282
loss_D_fake: 0.54805
loss_D_real: 0.56851
loss_D: 0.55828
loss_G_GAN: 1.20275
loss_G_L1: 7.84847
loss_G: 9.05122


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 31/100
Iteration 200/282
loss_D_fake: 0.54505
loss_D_real: 0.56370
loss_D: 0.55437
loss_G_GAN: 1.18709
loss_G_L1: 7.76796
loss_G: 8.95505


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 32/100
Iteration 200/282
loss_D_fake: 0.54594
loss_D_real: 0.56531
loss_D: 0.55563
loss_G_GAN: 1.18768
loss_G_L1: 7.62189
loss_G: 8.80956


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 33/100
Iteration 200/282
loss_D_fake: 0.54218
loss_D_real: 0.56105
loss_D: 0.55162
loss_G_GAN: 1.19826
loss_G_L1: 7.54235
loss_G: 8.74060


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 34/100
Iteration 200/282
loss_D_fake: 0.54312
loss_D_real: 0.56582
loss_D: 0.55447
loss_G_GAN: 1.18437
loss_G_L1: 7.46103
loss_G: 8.64541


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 35/100
Iteration 200/282
loss_D_fake: 0.54838
loss_D_real: 0.56316
loss_D: 0.55577
loss_G_GAN: 1.18311
loss_G_L1: 7.33521
loss_G: 8.51833


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 36/100
Iteration 200/282
loss_D_fake: 0.54113
loss_D_real: 0.56002
loss_D: 0.55058
loss_G_GAN: 1.18998
loss_G_L1: 7.24366
loss_G: 8.43364


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 37/100
Iteration 200/282
loss_D_fake: 0.54737
loss_D_real: 0.56685
loss_D: 0.55711
loss_G_GAN: 1.18449
loss_G_L1: 7.18601
loss_G: 8.37050


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 38/100
Iteration 200/282
loss_D_fake: 0.54234
loss_D_real: 0.55808
loss_D: 0.55021
loss_G_GAN: 1.19334
loss_G_L1: 7.09901
loss_G: 8.29234


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 39/100
Iteration 200/282
loss_D_fake: 0.55093
loss_D_real: 0.56817
loss_D: 0.55955
loss_G_GAN: 1.18941
loss_G_L1: 7.00168
loss_G: 8.19109


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 40/100
Iteration 200/282
loss_D_fake: 0.55885
loss_D_real: 0.56837
loss_D: 0.56361
loss_G_GAN: 1.18149
loss_G_L1: 6.90139
loss_G: 8.08287


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 41/100
Iteration 200/282
loss_D_fake: 0.55150
loss_D_real: 0.55843
loss_D: 0.55497
loss_G_GAN: 1.18643
loss_G_L1: 6.80871
loss_G: 7.99515


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 42/100
Iteration 200/282
loss_D_fake: 0.55979
loss_D_real: 0.56424
loss_D: 0.56201
loss_G_GAN: 1.19813
loss_G_L1: 6.71148
loss_G: 7.90961


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 43/100
Iteration 200/282
loss_D_fake: 0.55510
loss_D_real: 0.56303
loss_D: 0.55906
loss_G_GAN: 1.18652
loss_G_L1: 6.67722
loss_G: 7.86374


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 44/100
Iteration 200/282
loss_D_fake: 0.54585
loss_D_real: 0.55285
loss_D: 0.54935
loss_G_GAN: 1.18169
loss_G_L1: 6.62922
loss_G: 7.81091


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 45/100
Iteration 200/282
loss_D_fake: 0.55508
loss_D_real: 0.56007
loss_D: 0.55757
loss_G_GAN: 1.20482
loss_G_L1: 6.53195
loss_G: 7.73677


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 46/100
Iteration 200/282
loss_D_fake: 0.56019
loss_D_real: 0.56030
loss_D: 0.56024
loss_G_GAN: 1.18841
loss_G_L1: 6.45377
loss_G: 7.64217


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 47/100
Iteration 200/282
loss_D_fake: 0.54870
loss_D_real: 0.55364
loss_D: 0.55117
loss_G_GAN: 1.19384
loss_G_L1: 6.40053
loss_G: 7.59437


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 48/100
Iteration 200/282
loss_D_fake: 0.55643
loss_D_real: 0.55958
loss_D: 0.55801
loss_G_GAN: 1.20121
loss_G_L1: 6.35823
loss_G: 7.55944


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 49/100
Iteration 200/282
loss_D_fake: 0.54967
loss_D_real: 0.55499
loss_D: 0.55233
loss_G_GAN: 1.19310
loss_G_L1: 6.32365
loss_G: 7.51675


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 50/100
Iteration 200/282
loss_D_fake: 0.54825
loss_D_real: 0.55569
loss_D: 0.55197
loss_G_GAN: 1.19437
loss_G_L1: 6.26642
loss_G: 7.46079


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 51/100
Iteration 200/282
loss_D_fake: 0.55376
loss_D_real: 0.55409
loss_D: 0.55393
loss_G_GAN: 1.20423
loss_G_L1: 6.19761
loss_G: 7.40184


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 52/100
Iteration 200/282
loss_D_fake: 0.55351
loss_D_real: 0.55743
loss_D: 0.55547
loss_G_GAN: 1.20429
loss_G_L1: 6.15352
loss_G: 7.35781


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 53/100
Iteration 200/282
loss_D_fake: 0.54615
loss_D_real: 0.55105
loss_D: 0.54860
loss_G_GAN: 1.19579
loss_G_L1: 6.09691
loss_G: 7.29270


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 54/100
Iteration 200/282
loss_D_fake: 0.56814
loss_D_real: 0.57074
loss_D: 0.56944
loss_G_GAN: 1.19870
loss_G_L1: 6.02412
loss_G: 7.22282


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 55/100
Iteration 200/282
loss_D_fake: 0.54779
loss_D_real: 0.54964
loss_D: 0.54872
loss_G_GAN: 1.19907
loss_G_L1: 6.00541
loss_G: 7.20448


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 56/100
Iteration 200/282
loss_D_fake: 0.56385
loss_D_real: 0.56246
loss_D: 0.56315
loss_G_GAN: 1.21162
loss_G_L1: 5.95429
loss_G: 7.16591


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 57/100
Iteration 200/282
loss_D_fake: 0.53939
loss_D_real: 0.54254
loss_D: 0.54097
loss_G_GAN: 1.20082
loss_G_L1: 5.92799
loss_G: 7.12881


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 58/100
Iteration 200/282
loss_D_fake: 0.56046
loss_D_real: 0.55945
loss_D: 0.55996
loss_G_GAN: 1.22722
loss_G_L1: 5.87918
loss_G: 7.10641


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 59/100
Iteration 200/282
loss_D_fake: 0.54427
loss_D_real: 0.54740
loss_D: 0.54583
loss_G_GAN: 1.22666
loss_G_L1: 5.86978
loss_G: 7.09644


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 60/100
Iteration 200/282
loss_D_fake: 0.55416
loss_D_real: 0.54905
loss_D: 0.55160
loss_G_GAN: 1.23589
loss_G_L1: 5.84776
loss_G: 7.08365


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 61/100
Iteration 200/282
loss_D_fake: 0.56160
loss_D_real: 0.56201
loss_D: 0.56180
loss_G_GAN: 1.24289
loss_G_L1: 5.76195
loss_G: 7.00484


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 62/100
Iteration 200/282
loss_D_fake: 0.57731
loss_D_real: 0.58165
loss_D: 0.57948
loss_G_GAN: 1.23929
loss_G_L1: 5.70593
loss_G: 6.94522


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 63/100
Iteration 200/282
loss_D_fake: 0.53556
loss_D_real: 0.53505
loss_D: 0.53530
loss_G_GAN: 1.23061
loss_G_L1: 5.72111
loss_G: 6.95172


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 64/100
Iteration 200/282
loss_D_fake: 0.53888
loss_D_real: 0.54180
loss_D: 0.54034
loss_G_GAN: 1.26754
loss_G_L1: 5.68805
loss_G: 6.95558


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 65/100
Iteration 200/282
loss_D_fake: 0.53450
loss_D_real: 0.53429
loss_D: 0.53439
loss_G_GAN: 1.24770
loss_G_L1: 5.62748
loss_G: 6.87518


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 66/100
Iteration 200/282
loss_D_fake: 0.53877
loss_D_real: 0.53821
loss_D: 0.53849
loss_G_GAN: 1.27074
loss_G_L1: 5.61304
loss_G: 6.88379


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 67/100
Iteration 200/282
loss_D_fake: 0.53611
loss_D_real: 0.53213
loss_D: 0.53412
loss_G_GAN: 1.28638
loss_G_L1: 5.57368
loss_G: 6.86006


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 68/100
Iteration 200/282
loss_D_fake: 0.62041
loss_D_real: 0.60658
loss_D: 0.61349
loss_G_GAN: 1.27820
loss_G_L1: 5.52392
loss_G: 6.80212


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 69/100
Iteration 200/282
loss_D_fake: 0.52620
loss_D_real: 0.52721
loss_D: 0.52670
loss_G_GAN: 1.26024
loss_G_L1: 5.52578
loss_G: 6.78602


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 70/100
Iteration 200/282
loss_D_fake: 0.53310
loss_D_real: 0.53033
loss_D: 0.53172
loss_G_GAN: 1.29156
loss_G_L1: 5.51097
loss_G: 6.80254


  img_rgb = lab2rgb(img)


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 71/100
Iteration 200/282
loss_D_fake: 0.54632
loss_D_real: 0.54389
loss_D: 0.54511
loss_G_GAN: 1.31314
loss_G_L1: 5.43008
loss_G: 6.74322


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 72/100
Iteration 200/282
loss_D_fake: 0.52328
loss_D_real: 0.51828
loss_D: 0.52078
loss_G_GAN: 1.34642
loss_G_L1: 5.43459
loss_G: 6.78102


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 73/100
Iteration 200/282
loss_D_fake: 0.57468
loss_D_real: 0.57622
loss_D: 0.57545
loss_G_GAN: 1.30287
loss_G_L1: 5.40205
loss_G: 6.70493


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 74/100
Iteration 200/282
loss_D_fake: 0.52296
loss_D_real: 0.51788
loss_D: 0.52042
loss_G_GAN: 1.32369
loss_G_L1: 5.36417
loss_G: 6.68786


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 75/100
Iteration 200/282
loss_D_fake: 0.51974
loss_D_real: 0.52138
loss_D: 0.52056
loss_G_GAN: 1.34746
loss_G_L1: 5.36930
loss_G: 6.71676


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 76/100
Iteration 200/282
loss_D_fake: 0.51437
loss_D_real: 0.51392
loss_D: 0.51414
loss_G_GAN: 1.34320
loss_G_L1: 5.33066
loss_G: 6.67387


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 77/100
Iteration 200/282
loss_D_fake: 0.52946
loss_D_real: 0.52302
loss_D: 0.52624
loss_G_GAN: 1.36180
loss_G_L1: 5.29368
loss_G: 6.65548


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 78/100
Iteration 200/282
loss_D_fake: 0.51820
loss_D_real: 0.51446
loss_D: 0.51633
loss_G_GAN: 1.37895
loss_G_L1: 5.27977
loss_G: 6.65872


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 79/100
Iteration 200/282
loss_D_fake: 0.56703
loss_D_real: 0.55469
loss_D: 0.56086
loss_G_GAN: 1.37144
loss_G_L1: 5.26034
loss_G: 6.63178


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 80/100
Iteration 200/282
loss_D_fake: 0.51243
loss_D_real: 0.50504
loss_D: 0.50873
loss_G_GAN: 1.38514
loss_G_L1: 5.23596
loss_G: 6.62110


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 81/100
Iteration 200/282
loss_D_fake: 0.50430
loss_D_real: 0.49989
loss_D: 0.50210
loss_G_GAN: 1.35378
loss_G_L1: 5.21258
loss_G: 6.56636


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 82/100
Iteration 200/282
loss_D_fake: 0.50690
loss_D_real: 0.50510
loss_D: 0.50600
loss_G_GAN: 1.40573
loss_G_L1: 5.18932
loss_G: 6.59504


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 83/100
Iteration 200/282
loss_D_fake: 0.53306
loss_D_real: 0.52861
loss_D: 0.53083
loss_G_GAN: 1.42681
loss_G_L1: 5.17930
loss_G: 6.60610


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 84/100
Iteration 200/282
loss_D_fake: 0.57340
loss_D_real: 0.57052
loss_D: 0.57196
loss_G_GAN: 1.39553
loss_G_L1: 5.11557
loss_G: 6.51110


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 85/100
Iteration 200/282
loss_D_fake: 0.51368
loss_D_real: 0.51057
loss_D: 0.51212
loss_G_GAN: 1.43272
loss_G_L1: 5.14621
loss_G: 6.57894


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 86/100
Iteration 200/282
loss_D_fake: 0.49624
loss_D_real: 0.48954
loss_D: 0.49289
loss_G_GAN: 1.45026
loss_G_L1: 5.12652
loss_G: 6.57678


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 87/100
Iteration 200/282
loss_D_fake: 0.49234
loss_D_real: 0.48961
loss_D: 0.49097
loss_G_GAN: 1.47953
loss_G_L1: 5.10905
loss_G: 6.58858


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 88/100
Iteration 200/282
loss_D_fake: 0.50802
loss_D_real: 0.49653
loss_D: 0.50228
loss_G_GAN: 1.47700
loss_G_L1: 5.08076
loss_G: 6.55775


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 89/100
Iteration 200/282
loss_D_fake: 0.49627
loss_D_real: 0.48818
loss_D: 0.49223
loss_G_GAN: 1.48524
loss_G_L1: 5.05195
loss_G: 6.53719


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 90/100
Iteration 200/282
loss_D_fake: 0.47961
loss_D_real: 0.47486
loss_D: 0.47723
loss_G_GAN: 1.47340
loss_G_L1: 5.04651
loss_G: 6.51991


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 91/100
Iteration 200/282
loss_D_fake: 0.47475
loss_D_real: 0.46880
loss_D: 0.47177
loss_G_GAN: 1.42135
loss_G_L1: 5.02375
loss_G: 6.44510


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 92/100
Iteration 200/282
loss_D_fake: 0.48116
loss_D_real: 0.47133
loss_D: 0.47625
loss_G_GAN: 1.44564
loss_G_L1: 4.99558
loss_G: 6.44122


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 93/100
Iteration 200/282
loss_D_fake: 0.48505
loss_D_real: 0.47666
loss_D: 0.48085
loss_G_GAN: 1.48151
loss_G_L1: 4.99512
loss_G: 6.47663


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 94/100
Iteration 200/282
loss_D_fake: 0.48302
loss_D_real: 0.47281
loss_D: 0.47792
loss_G_GAN: 1.52374
loss_G_L1: 4.97812
loss_G: 6.50186


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 95/100
Iteration 200/282
loss_D_fake: 0.49885
loss_D_real: 0.49499
loss_D: 0.49692
loss_G_GAN: 1.51792
loss_G_L1: 4.96003
loss_G: 6.47795


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 96/100
Iteration 200/282
loss_D_fake: 0.48076
loss_D_real: 0.47623
loss_D: 0.47850
loss_G_GAN: 1.53662
loss_G_L1: 4.93984
loss_G: 6.47647


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 97/100
Iteration 200/282
loss_D_fake: 0.47646
loss_D_real: 0.46780
loss_D: 0.47213
loss_G_GAN: 1.53616
loss_G_L1: 4.94124
loss_G: 6.47740


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 98/100
Iteration 200/282
loss_D_fake: 0.48379
loss_D_real: 0.47666
loss_D: 0.48023
loss_G_GAN: 1.47357
loss_G_L1: 4.91722
loss_G: 6.39079


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 99/100
Iteration 200/282
loss_D_fake: 0.46489
loss_D_real: 0.45517
loss_D: 0.46003
loss_G_GAN: 1.56501
loss_G_L1: 4.93891
loss_G: 6.50392


  0%|          | 0/282 [00:00<?, ?it/s]


Epoch 100/100
Iteration 200/282
loss_D_fake: 0.47133
loss_D_real: 0.46255
loss_D: 0.46694
loss_G_GAN: 1.59045
loss_G_L1: 4.90707
loss_G: 6.49752


In [17]:
model = MainModel()
model.load_state_dict(torch.load('colorization_model.pth'))
data = next(iter(val_dl))
visualize(model, data, save=False)

model initialized with norm initialization
model initialized with norm initialization


In [23]:
model = MainModel()
batch_size=16
input_shape = (16, 3, 256, 256)
summary(model.net_D, input_size=input_shape)

model initialized with norm initialization
model initialized with norm initialization


Layer (type:depth-idx)                   Output Shape              Param #
PatchDiscriminator                       [16, 1, 30, 30]           --
├─Sequential: 1-1                        [16, 1, 30, 30]           --
│    └─Sequential: 2-1                   [16, 64, 128, 128]        --
│    │    └─Conv2d: 3-1                  [16, 64, 128, 128]        3,136
│    │    └─LeakyReLU: 3-2               [16, 64, 128, 128]        --
│    └─Sequential: 2-2                   [16, 128, 64, 64]         --
│    │    └─Conv2d: 3-3                  [16, 128, 64, 64]         131,072
│    │    └─BatchNorm2d: 3-4             [16, 128, 64, 64]         256
│    │    └─LeakyReLU: 3-5               [16, 128, 64, 64]         --
│    └─Sequential: 2-3                   [16, 256, 32, 32]         --
│    │    └─Conv2d: 3-6                  [16, 256, 32, 32]         524,288
│    │    └─BatchNorm2d: 3-7             [16, 256, 32, 32]         512
│    │    └─LeakyReLU: 3-8               [16, 256, 32, 32]         --


In [26]:
transforms = transforms.Resize((256, 256), Image.BICUBIC)
image = Image.open('lincoln.png')
resized_image = transforms(image)
resized_image_np = np.array(resized_image)
resized_rgb_img = rgba2rgb(resized_image_np)
print(resized_rgb_img.shape)

(256, 256, 3)


In [46]:
l_chan_origin = resized_rgb_img[:, :, 0].reshape(1, 256, 256)
l_chan = l_chan_origin.reshape(1, 1, 256, 256)
l_chan = torch.from_numpy(l_chan).to(dtype=torch.float)
ab_chan = model.net_G(l_chan.to(device)).detach().cpu().numpy()
ab_chan_origin = ab_chan[0]
print(l_chan_origin.shape, ab_chan_origin.shape)

(1, 256, 256) (2, 256, 256)


In [79]:
def visualize_thumb(model, data, save=True):
    model.net_G.eval()
    with torch.no_grad():
        model.setup_input(data)
        model.forward()
    model.net_G.train()
    fake_color = model.fake_color.detach()
    real_color = model.ab
    L = model.L
    fake_imgs = lab_to_rgb(L, fake_color)
    real_imgs = lab_to_rgb(L, real_color)
    for i in range(4):
        wandb.log({
            "input_image": wandb.Image(np.rot90(L[i][0].cpu(), 3)),
            "reconstruction_image": wandb.Image(np.rot90(fake_imgs[i], 3)),
            "ground_truth_image": wandb.Image(np.rot90(real_imgs[i], 3)),
        })

In [80]:
test_dl = make_dataloaders(base_images="Test/", split='val')
model = MainModel()
model.load_state_dict(torch.load('colorization_model.pth'))
data = next(iter(test_dl))
visualize_thumb(model, data, save=False)

4
model initialized with norm initialization
model initialized with norm initialization
