In [1]:
import torch

from glob import glob
from skimage import io
from tensorboardX import SummaryWriter
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm

from src.batch_generator import BatchGenerator, WraptorDataLoader
from src.command import Command
from src.models import LocalTransferNet, GlobalTransferer, RefinementNet
from src.train_utils import draw_scalar_value

In [2]:
logs_directory = "experiments/test_tensorboard/logs"
board_port = 9000
board_timeout=24*60*60

In [3]:
writer = SummaryWriter(logs_directory) 
board = Command('tensorboard --logdir=run1:{} --port {}'.format(logs_directory, board_port))
board.run()



In [4]:
dir_paths = [logs_directory]
for path in dir_paths:
    try:
        os.mkdir(path)
    except:
        pass

In [5]:
import numpy as np
def rgb2gray(rgb):
    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    gray = (0.2989 * r + 0.5870 * g + 0.1140 * b)[..., None]
    return np.repeat(gray, 3, 2)


In [6]:
train_bg = BatchGenerator(glob("datasets/DAVIS/JPEGImages/480p/*"), resize_shape=(64, 64, 3))
train_loader = WraptorDataLoader(train_bg, batch_size=1)

100%|██████████| 50/50 [00:00<00:00, 780.08it/s]


In [7]:
local_transfer_net = LocalTransferNet().double()
global_transferer = GlobalTransferer()

In [8]:
refinement_net = RefinementNet().double()

In [9]:
loss_function = torch.nn.L1Loss() #torch.nn.MSELoss() or torch.nn.L1Loss()
optimizer = torch.optim.Adam(list(local_transfer_net.parameters()) + list(refinement_net.parameters()) )

In [10]:
from color_utils import rgb2lab_torch, lab2rgb_torch

In [11]:
len(train_loader)

3455

In [12]:
iteration_num = 5000
verbose_every_it = 5

In [13]:
#I1, Ik_1, Ik
train_losses = []
for it in tqdm(range(iteration_num)):
    batch = next(iter(train_loader))
    I1_lab = rgb2lab_torch(batch[0], use_gpu=False)
    Ik_lab = rgb2lab_torch(batch[2], use_gpu=False)

    G1_array = np.array([rgb2gray(img.numpy()) for img in batch[0]])
    Gk_1_array = np.array([rgb2gray(img.numpy()) for img in batch[1]])
    Gk_array = np.array([rgb2gray(img.numpy()) for img in batch[2]])

    G1_tensor = torch.tensor(G1_array.transpose(0, 3, 1, 2), dtype=torch.double, requires_grad=False)
    Gk_1_tensor = torch.tensor(Gk_1_array.transpose(0, 3, 1, 2), dtype=torch.double, requires_grad=False)
    Gk_tensor = torch.tensor(Gk_array.transpose(0, 3, 1, 2), dtype=torch.double, requires_grad=False)

    local_batch_output = local_transfer_net.forward(Gk_1_tensor, Gk_tensor, batch[1].permute(0, 3, 1, 2))
    local_batch_output_lab = rgb2lab_torch(local_batch_output.permute(0, 2, 3, 1), use_gpu=False).permute(0, 3, 1, 2)

    global_batch_output = global_transferer.forward(G1_array[0], Gk_array[0], batch[0][0])
    global_batch_output = torch.tensor(global_batch_output[None, ...].transpose(0, 3, 1, 2), dtype=torch.double, requires_grad=False)
    global_batch_output_lab = rgb2lab_torch(global_batch_output.permute(0, 2, 3, 1), use_gpu=False).permute(0, 3, 1, 2)


    stacked_input_refinement = torch.cat([Gk_tensor[0:1], local_batch_output_lab[0:1], global_batch_output_lab], dim=1)

    refinement_output_lab = refinement_net(stacked_input_refinement).permute(0, 2, 3, 1)

    loss = loss_function(refinement_output_lab, Ik_lab[:1])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    


    gt_l = Ik_lab[0, ..., 0][..., None]

    predicted_ab = refinement_output_lab[0, ..., 1:]

    result_lab = torch.cat((gt_l, predicted_ab), dim=2).cpu()
    result_rgb = lab2rgb_torch(result_lab, use_gpu=False)
    
    if (it % verbose_every_it) == 0:
        draw_scalar_value(writer, "losses", "train loss", loss.detach().cpu().numpy(), it)

    #io.imshow(np.concatenate((batch[2][0], result_rgb.cpu().detach().numpy()), 1))
    #io.show()

  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "
  0%|          | 2/5000 [00:13<9:35:32,  6.91s/it] 

KeyboardInterrupt: 