In [1]:
import numpy as np
import os
import torch

from glob import glob
from skimage import io
from tensorboardX import SummaryWriter
from torch.utils.data.dataloader import DataLoader
from torch import nn
from tqdm import tqdm
from time import time

from src.batch_generator import BatchGenerator, WraptorDataLoader
from src.command import Command
from src.models import LocalTransferNet, GlobalTransferer, RefinementNet
from src.train_utils import (draw_scalar_value, draw_images, rgb2gray, infer_batch,
                             inference_all_test_videos, load_all_test_videos, calculate_psnr_result, frame_to_tensor)


from color_utils import rgb2lab_torch, lab2rgb_torch
from skimage.color import rgb2lab, lab2rgb
from collections import defaultdict

In [2]:
os.environ['CUDA_VISIBLE_DEVICES']='0, 2'

experiment_dir = "experiments/train_refinement_onepass_six_frames_loss_pretrainetlocal_amsgrad"
logs_directory = os.path.join(experiment_dir, "logs")
save_test_output_videos_dir =  os.path.join(experiment_dir,  "results_test_videos")
checkpoints_folder =  os.path.join(experiment_dir, "checkpoints")

board_port = 6103
board_timeout=24*60*60
batch_size = 4
img_shape = (128, 128, 3)
val_num = 5
iteration_num = 5000
dataset_dir = "datasets/Davis480/480p"
learning_rate = 1e-3 / 5
verbose_every_it = 10
save_models_every_it = 100

test_every_it = 100

In [3]:
test_videos_folders = sorted(glob(dataset_dir + "/*"))[-val_num:]

In [4]:
writer = SummaryWriter(logs_directory) 
board = Command('tensorboard --logdir=run1:{} --port {}'.format(logs_directory, board_port))
board.run()

In [5]:
dir_paths = [experiment_dir, logs_directory, checkpoints_folder, save_test_output_videos_dir]
for path in dir_paths:
    try:
        os.mkdir(path)
    except:
        pass

In [6]:
train_videos = load_all_test_videos(sorted(glob(dataset_dir + "/*"))[:-val_num], resize_size=img_shape)
val_videos = load_all_test_videos(sorted(glob(dataset_dir + "/*"))[-val_num:], resize_size=img_shape)

  warn("The default mode, 'constant', will be changed to 'reflect' in "


In [7]:
local_transfer_net = nn.DataParallel(LocalTransferNet().double().cuda())
global_transferer = GlobalTransferer()
refinement_net = nn.DataParallel(RefinementNet().double().cuda())

In [8]:
local_transfer_net.load_state_dict(torch.load("experiments/pretrain_local_several_frames_back_highlr_amsgrad/checkpoints/7800/local_net_state_dict"))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [9]:
pretrain_local = False
end2end = False

In [10]:
loss_function = torch.nn.L1Loss() #torch.nn.MSELoss() or torch.nn.L1Loss()
if pretrain_local:
    optimizer = torch.optim.Adam(list(local_transfer_net.parameters()), 
                                 lr=learning_rate, amsgrad=True)
else:
    if end2end:
        optimizer = torch.optim.Adam(list(local_transfer_net.parameters()) + list(refinement_net.parameters()), 
                                     lr=learning_rate, amsgrad=True)
    else:
        optimizer = torch.optim.Adam(list(refinement_net.parameters()), 
                                     lr=learning_rate, amsgrad=True)

In [11]:
def get_random_video_frames(input_videos, frames_num):
    video_id = np.random.randint(0, len(input_videos))
    start_frame = np.random.randint(0, len(input_videos[video_id]) -  1)
    return input_videos[video_id][start_frame: start_frame + frames_num]

In [12]:
def tensor_from_list(list_frames):
    return torch.cat([frame[None, ...] for frame in list_frames], dim=0)

In [13]:
def inference_train_video(frames, refinement_net, local_transfer_net, global_transferer, use_only_local=False):
    I0 = frames[0]
    I_prev = frame_to_tensor(frames[0]).cuda()
    Gk_1 = rgb2gray(I0)
    G0 = Gk_1.copy()
    output_rgb_frames = []
    output_lab_frames = []
    output_local = []
    output_global = []
    for cur_frame in frames[1:]:
        batch = (frame_to_tensor(I0[None, ...]), 
                 I_prev[None, ...].detach(),
                 frame_to_tensor(cur_frame[None, ...]))
        result_lab, result_rgb, result_local, result_global = infer_batch(batch, refinement_net, local_transfer_net, 
                                                                 global_transferer, use_only_local)
        output_rgb_frames.append(result_rgb[0].cpu())
        output_lab_frames.append(result_lab[0])
        output_local.append(result_local[0].cpu())
        output_global.append(result_global[0].cpu())
        I_prev = output_rgb_frames[-1]
        
    return (tensor_from_list(output_lab_frames),
            tensor_from_list(output_rgb_frames), 
            tensor_from_list(output_local), 
            tensor_from_list(output_global))

In [14]:
start_it = 0
inference_frames = 6

In [None]:
#I1, Ik_1, Ik
for it in tqdm(range(start_it, iteration_num)):
    
    input_frames = get_random_video_frames(train_videos, inference_frames)
    gt_frames = torch.tensor(input_frames).cuda()
    refinement_output_lab_train, refinement_output_rgb_train, local_output_train, global_output_train = inference_train_video(input_frames, refinement_net, 
                                                                                        local_transfer_net, 
                                                                                        global_transferer, 
                                                                                        use_only_local=False)
    
    train_loss = loss_function(refinement_output_lab_train[..., 1:], rgb2lab_torch(gt_frames)[1:, ..., 1:])
    
    
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    
    layers_grad = defaultdict()
    
    layer_id = 0
    for i in range(len(refinement_net.module.conv_layers)):
        try:
            weight_grad_mean = torch.mean((refinement_net.module.conv_layers[i].weight.grad) ** 2)
            bias_grad_mean = torch.mean((refinement_net.module.conv_layers[i].bias.grad) ** 2)
            layers_grad["layer_%d_weight_grad_norm" % layer_id] = weight_grad_mean
            layers_grad["layer_%d_bias_grad_norm" % layer_id] = bias_grad_mean
            layer_id += 1
        except:
            continue
    
    
    if (it % verbose_every_it) == 0:
        input_frames = get_random_video_frames(val_videos, inference_frames)
        gt_frames_train = gt_frames[1:]
        gt_frames_val = torch.tensor(input_frames[1:]).cuda()
        refinement_output_lab_val, refinement_output_rgb_val, local_output_val, global_output_val = inference_train_video(input_frames, refinement_net, 
                                                                                        local_transfer_net, 
                                                                                        global_transferer, 
                                                                                        use_only_local=False)
        val_loss = loss_function(refinement_output_lab_val[..., 1:], rgb2lab_torch(gt_frames_val)[..., 1:])
        

        draw_scalar_value(writer, "losses", "train loss", train_loss.detach().cpu().numpy(), it)
        draw_scalar_value(writer, "losses", "val loss", val_loss.detach().cpu().numpy(), it)
        writer.add_scalars("grad_norm", layers_grad, it)
        
        local_output_train = torch.clamp(local_output_train,  0, 1)
        local_output_val = torch.clamp(local_output_val,  0, 1)
        refinement_output_rgb_train = torch.clamp(refinement_output_rgb_train,  0, 1)
        refinement_output_rgb_val = torch.clamp(refinement_output_rgb_val,  0, 1)

        concat_res_train = torch.cat((global_output_train.cpu(), local_output_train.cpu(), refinement_output_rgb_train.cpu()), dim=2)
        draw_images(writer, gt_frames_train.cpu().detach().numpy(), concat_res_train.cpu().detach().numpy(), it, tag="train")
        
        concat_res_val = torch.cat((global_output_val.cpu(), local_output_val.cpu(), refinement_output_rgb_val.cpu()), dim=2)
        draw_images(writer, gt_frames_val.cpu().detach().numpy(), concat_res_val.cpu().detach().numpy(), it, tag="val")
    
    if ((it % save_models_every_it) == 0) and (it > 0):
        os.mkdir(os.path.join(checkpoints_folder, str(it)))
        torch.save(refinement_net, os.path.join(checkpoints_folder,  str(it), "refinement_net"))
        torch.save(refinement_net.state_dict(), os.path.join(checkpoints_folder,  str(it), "refinement_net_state_dict"))
        torch.save(local_transfer_net, os.path.join(checkpoints_folder,  str(it), "local_net"))
        torch.save(local_transfer_net.state_dict(), os.path.join(checkpoints_folder,  str(it), "local_net_state_dict"))
    
    if ((it % test_every_it) == 0) and (it > 0):
        refinement_result, local_result, global_result = inference_all_test_videos(test_videos_folders,
                              refinement_net, local_transfer_net, global_transferer,
                              save_test_output_videos_dir, it, use_only_local=pretrain_local,
                              save_result=True, resize_size=img_shape)
        gt_frames = load_all_test_videos(test_videos_folders, resize_size=img_shape)
        metric_value = calculate_psnr_result(gt_frames, refinement_result)
        draw_scalar_value(writer, "metrics", "test psnr", metric_value, it)
    

        

  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  .format(dtypeobj_in, dtypeobj_out))
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of 

In [None]:
videos_frames = load_all_test_videos(test_videos_folders, img_shape)