In [1]:
import numpy as np
import os
import torch

from glob import glob
from skimage import io
from tensorboardX import SummaryWriter
from torch.utils.data.dataloader import DataLoader
from torch import nn
from tqdm import tqdm
from time import time

from src.batch_generator import BatchGenerator, WraptorDataLoader
from src.command import Command
from src.models import LocalTransferNet, GlobalTransferer, RefinementNet
from src.train_utils import (draw_scalar_value, draw_images, rgb2gray, infer_batch,
                             inference_all_test_videos, load_all_test_videos, calculate_psnr_result, frame_to_tensor)


from color_utils import rgb2lab_torch, lab2rgb_torch
from skimage.color import rgb2lab, lab2rgb
from collections import defaultdict

In [4]:
os.environ['CUDA_VISIBLE_DEVICES']='6, 7'

experiment_dir = "experiments/train_refinement_onepass_six_frames_loss_pretrainetlocal_amsgrad_mobile_optical_flow"
logs_directory = os.path.join(experiment_dir, "logs")
save_test_output_videos_dir =  os.path.join(experiment_dir,  "results_test_videos")
checkpoints_folder =  os.path.join(experiment_dir, "checkpoints")

board_port = 6101
board_timeout=24*60*60
batch_size = 4
img_shape = (128, 128, 3)
val_num = 5
iteration_num = 5000
dataset_dir = "datasets/Davis480/480p"
learning_rate = 1e-3 / 5
verbose_every_it = 10
save_models_every_it = 100

test_every_it = 100

In [5]:
#import torch
from pytorch_spynet.run import estimate, Network
optical_flow_net = Network().cuda().eval()
optical_flow_net.load_state_dict(torch.load('pytorch_spynet/network-' + "sintel-final" + '.pytorch'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [6]:
test_videos_folders = sorted(glob(dataset_dir + "/*"))[-val_num:]

In [7]:
writer = SummaryWriter(logs_directory) 
board = Command('tensorboard --logdir=run1:{} --port {}'.format(logs_directory, board_port))
board.run()

In [8]:
dir_paths = [experiment_dir, logs_directory, checkpoints_folder, save_test_output_videos_dir]
for path in dir_paths:
    try:
        os.mkdir(path)
    except:
        pass

In [9]:
train_videos = load_all_test_videos(sorted(glob(dataset_dir + "/*"))[-1:], need_resize=False)
val_videos = load_all_test_videos(sorted(glob(dataset_dir + "/*"))[-1:], need_resize=False)

In [17]:
local_transfer_net = nn.DataParallel(LocalTransferNet().double().cuda())
global_transferer = GlobalTransferer()
refinement_net = RefinementNet(input_channels=[12, 64, 64, 64, 64, 64]).double()

In [18]:
pretrain_local = False
end2end = False

In [19]:
loss_function = torch.nn.L1Loss() #torch.nn.MSELoss() or torch.nn.L1Loss()
if pretrain_local:
    optimizer = torch.optim.Adam(list(local_transfer_net.parameters()), 
                                 lr=learning_rate, amsgrad=True)
else:
    if end2end:
        optimizer = torch.optim.Adam(list(local_transfer_net.parameters()) + list(refinement_net.parameters()), 
                                     lr=learning_rate, amsgrad=True)
    else:
        optimizer = torch.optim.Adam(list(refinement_net.parameters()), 
                                     lr=learning_rate, amsgrad=True)

In [14]:
pretrain_local = False
end2end = False

In [15]:
loss_function = torch.nn.L1Loss() #torch.nn.MSELoss() or torch.nn.L1Loss()
if pretrain_local:
    optimizer = torch.optim.Adam(list(local_transfer_net.parameters()), 
                                 lr=learning_rate, amsgrad=True)
else:
    if end2end:
        optimizer = torch.optim.Adam(list(local_transfer_net.parameters()) + list(refinement_net.parameters()), 
                                     lr=learning_rate, amsgrad=True)
    else:
        optimizer = torch.optim.Adam(list(refinement_net.parameters()), 
                                     lr=learning_rate, amsgrad=True)

In [16]:
from skimage.transform import resize
def get_random_video_frames(input_videos, frames_num, crop_size=256, resie_size=128):
    video_id = np.random.randint(0, len(input_videos))
    start_frame = np.random.randint(0, len(input_videos[video_id]) -  1)
    frames = input_videos[video_id][start_frame: start_frame + frames_num]
    random_crop_h = np.random.randint(0, max(frames[0].shape[0] - crop_size, 1))
    random_crop_w = np.random.randint(0, max(frames[0].shape[1] - crop_size, 1))
    for i in range(len(frames)):
        frames[i] = frames[i][random_crop_h: random_crop_h + crop_size, random_crop_w: random_crop_w + crop_size, :]
        frames[i] = resize(frames[i], (resie_size, resie_size))
    return frames
        
        

In [17]:
def tensor_from_list(list_frames):
    return torch.cat([frame[None, ...] for frame in list_frames], dim=0)

In [18]:
def inference_train_video(frames, refinement_net, local_transfer_net, global_transferer, use_only_local=False):
    I0 = frames[0]
    I_prev = frame_to_tensor(frames[0]).cuda()
    Gk_1 = rgb2gray(I0)
    G0 = Gk_1.copy()
    output_rgb_frames = []
    output_lab_frames = []
    output_local = []
    output_global = []
    output_optical = []
    for cur_frame in frames[1:]:
        batch = (frame_to_tensor(I0[None, ...]), 
                 I_prev[None, ...].detach(),
                 frame_to_tensor(cur_frame[None, ...]))
        result_lab, result_rgb, result_local, result_global, result_optical = infer_batch(batch, refinement_net, local_transfer_net, 
                                                                                          global_transferer, use_only_local, 
                                                                                          use_optical_flow=True,
                                                                                          optical_flow_net=optical_flow_net)
        output_rgb_frames.append(result_rgb[0].cpu())
        output_lab_frames.append(result_lab[0])
        output_local.append(result_local[0].cpu())
        output_global.append(result_global[0].cpu())
        output_optical.append(result_optical[0].cpu())
        I_prev = output_rgb_frames[-1]
        
    return (tensor_from_list(output_lab_frames),
            tensor_from_list(output_rgb_frames), 
            tensor_from_list(output_local), 
            tensor_from_list(output_global),
            tensor_from_list(output_optical))

In [19]:
start_it = 0
inference_frames = 3

In [20]:
start_it = 5000

In [25]:
for i in range(len(refinement_net.module.conv_layers)):
    weight_grad_mean = torch.mean((refinement_net.module.conv_layers[i].weight.grad) ** 2)
    bias_grad_mean = torch.mean((refinement_net.module.conv_layers[i].bias.grad) ** 2)
    layers_grad["layer_%d_weight_grad_norm" % layer_id] = weight_grad_mean
    layers_grad["layer_%d_bias_grad_norm" % layer_id] = bias_grad_mean
    layer_id += 1

TypeError: unsupported operand type(s) for ** or pow(): 'NoneType' and 'int'

In [21]:
#I1, Ik_1, Ik
for it in tqdm(range(start_it, start_it + iteration_num)):
    
    input_frames = get_random_video_frames(train_videos, inference_frames)
    gt_frames = torch.tensor(input_frames).cuda()
    refinement_output_lab_train, refinement_output_rgb_train, local_output_train, global_output_train, optical_output_train = inference_train_video(input_frames, refinement_net, 
                                                                                        local_transfer_net, 
                                                                                        global_transferer, 
                                                                                        use_only_local=False)
    
    train_loss = loss_function(refinement_output_lab_train[..., 1:], 
                               rgb2lab_torch(gt_frames)[1:, ..., 1:]).clone().detach().requires_grad_(True)
    
    
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    
    layers_grad = defaultdict()
    
    layer_id = 0
    for i in range(len(refinement_net.module.conv_layers)):
        try:
            weight_grad_mean = torch.mean((refinement_net.module.conv_layers[i].weight.grad) ** 2)
            bias_grad_mean = torch.mean((refinement_net.module.conv_layers[i].bias.grad) ** 2)
            layers_grad["layer_%d_weight_grad_norm" % layer_id] = weight_grad_mean
            layers_grad["layer_%d_bias_grad_norm" % layer_id] = bias_grad_mean
            layer_id += 1
        except:
            continue
    
    
    if (it % verbose_every_it) == 0:
        input_frames = get_random_video_frames(val_videos, inference_frames)
        gt_frames_train = gt_frames[1:]
        gt_frames_val = torch.tensor(input_frames[1:]).cuda()
        refinement_output_lab_val, refinement_output_rgb_val, local_output_val, global_output_val, optical_output_val = inference_train_video(input_frames, refinement_net, 
                                                                                        local_transfer_net, 
                                                                                        global_transferer, 
                                                                                        use_only_local=False)
        val_loss = loss_function(refinement_output_lab_val[..., 1:], rgb2lab_torch(gt_frames_val)[..., 1:])
        

        draw_scalar_value(writer, "losses", "train loss", train_loss.detach().cpu().numpy(), it)
        draw_scalar_value(writer, "losses", "val loss", val_loss.detach().cpu().numpy(), it)
        writer.add_scalars("grad_norm", layers_grad, it)
        
        local_output_train = torch.clamp(local_output_train,  0, 1)
        local_output_val = torch.clamp(local_output_val,  0, 1)
        refinement_output_rgb_train = torch.clamp(refinement_output_rgb_train,  0, 1)
        refinement_output_rgb_val = torch.clamp(refinement_output_rgb_val,  0, 1)
        optical_output_train = torch.clamp(optical_output_train,  0, 1)
        optical_output_val = torch.clamp(optical_output_val,  0, 1)

        concat_res_train = torch.cat((global_output_train.cpu(), local_output_train.cpu(), 
                                      optical_output_train.cpu(), refinement_output_rgb_train.cpu()), dim=2)
        draw_images(writer, gt_frames_train.cpu().detach().numpy(), concat_res_train.cpu().detach().numpy(), it, tag="train")
        
        concat_res_val = torch.cat((global_output_val.cpu(), local_output_val.cpu(), 
                                    optical_output_val.cpu(), refinement_output_rgb_val.cpu()), dim=2)
        draw_images(writer, gt_frames_val.cpu().detach().numpy(), concat_res_val.cpu().detach().numpy(), it, tag="val")
    
    if ((it % save_models_every_it) == 0) and (it > 0):
        os.mkdir(os.path.join(checkpoints_folder, str(it)))
        torch.save(refinement_net, os.path.join(checkpoints_folder,  str(it), "refinement_net"))
        torch.save(refinement_net.state_dict(), os.path.join(checkpoints_folder,  str(it), "refinement_net_state_dict"))
        torch.save(local_transfer_net, os.path.join(checkpoints_folder,  str(it), "local_net"))
        torch.save(local_transfer_net.state_dict(), os.path.join(checkpoints_folder,  str(it), "local_net_state_dict"))
    
    if ((it % test_every_it) == 0) and (it > 0):
        refinement_result, local_result, global_result, _ = inference_all_test_videos(test_videos_folders,
                              refinement_net, local_transfer_net, global_transferer,
                              save_test_output_videos_dir, it, use_only_local=pretrain_local, 
                              use_optical_flow=True, optical_flow_net=optical_flow_net,
                              save_result=True, resize_size=img_shape)
        gt_frames = load_all_test_videos(test_videos_folders, resize_size=img_shape)
        metric_value = calculate_psnr_result(gt_frames, refinement_result)
        draw_scalar_value(writer, "metrics", "test psnr", metric_value, it)
    

        

  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  .format(dtypeobj_in, dtypeobj_out))


FileExistsError: [Errno 17] File exists: 'experiments/train_refinement_onepass_six_frames_loss_pretrainetlocal_amsgrad_mobile_optical_flow/checkpoints/5000'

In [22]:
videos_frames = load_all_test_videos(test_videos_folders, img_shape)

KeyboardInterrupt: 