In [None]:
import itertools
import os
import sys
sys.path.append("colorization_pytorch")
sys.path.append("flownet2_pytorch")
import torch
import torch.nn as nn
from colorization_pytorch.models import create_model
from flownet2_pytorch.models import FlowNet2
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
from datasets import *
from torch.utils.data import DataLoader
from tqdm import tqdm
from colorization_pytorch.options.train_options import TrainOptions
from flownet2_pytorch.parser import initialize_args
import torchvision.transforms as transforms
from colorization_pytorch.util import util

class Mask(nn.Module):
    def __init__(self, in_channels):
        super(Mask,self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=64, out_channels=8, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
        )

    def forward(self, data):
        return self.cnn(data)


class VideoColorization(nn.Module):
    def __init__(self, opt, args, batchNorm=False, div_flow=20.):
        super(VideoColorization,self).__init__()
        IDX_RANGE = 44
        self.opt = opt
        self.color_model = create_model(opt)
        self.color_model.setup(opt)
        self.color_model.print_networks(False)
        self.flownet = FlowNet2(args, batchNorm=batchNorm, div_flow=div_flow)
        checkpoint = torch.load(args.resume)
        self.flownet.load_state_dict(checkpoint['state_dict'])
        self.ori_index = torch.tensor(list(itertools.product(np.arange(IDX_RANGE), np.arange(IDX_RANGE)))). \
            reshape(IDX_RANGE, IDX_RANGE, -1)
        self.mask = Mask(512)

    def forward(self, color_data, flow_data):
        self.color_model.set_input(color_data[0])
        _,_, previous_feature_map = self.color_model.encode()
        print("color1")
        self.color_model.set_input(color_data[1])
        conv1_2, conv2_2, feature_map = self.color_model.encode()
        print("color2")
        print(flow_data.shape)
        flow = self.flownet(flow_data)  # B * 2 * 384 * 1024
        print("flow")
        FlowWeight = F.interpolate(flow, previous_feature_map.size()[-2:], mode='bilinear')
        H1, W1 = previous_feature_map.size()[-2:]
        FlowWeight = FlowWeight.permute(0, 2, 3, 1).int()
        FlowWeight += self.ori_index
        FlowWeight[FlowWeight < 0] = 0
        FlowWeight[FlowWeight >= H1] = H1 - 1
        predicted_feature_map = previous_feature_map[FlowWeight]  # B*512*22*22
        delta_feature_map = torch.abs(predicted_feature_map - feature_map)
        M = self.mask(delta_feature_map)
        output_feature_map = (1 - M) * feature_map + M * predicted_feature_map
        fake_B_class, fake_B_reg = self.color_model.decode(conv1_2, conv2_2, output_feature_map)
        

        
class flow_args():
    def __init__(self):
        self.rgb_max = 255
        self.fp16 = False
        self.fp16_scale = 1024.
        self.crop_size = [384,1024]
        self.inference_size = [-1,-1]
        self.resume = './check_point/FlowNet2_checkpoint.pth.tar'
        

opt = TrainOptions().parse()
args = flow_args()
debug = VideoColorization(opt, args)


In [None]:
train_set = MpiSintel(args, root = "data/training/")
train_loader = DataLoader(train_set, batch_size = 1, shuffle=True, num_workers=4)

In [None]:
class solver():
    def __init__(self, model):
        self.model = model
        self.lr = 1e-4
        self.optimizer_flow = torch.optim.Adam(self.model.flownet.parameters(), lr=2e-4, weight_decay=1e-6)
        self.optimizer_color_network = []
        self.use_D = opt.lambda_GAN > 0
        self.optimizer_G = torch.optim.Adam(self.model.color_model.netG.parameters(),
                                            lr=0, betas=(opt.beta1, 0.999)) #lr=opt.lr
        self.optimizer_color_network.append(self.optimizer_G)

        if self.use_D:
            self.optimizer_D = torch.optim.Adam(self.model.color_model.netD.parameters(),
                                                    lr=0, betas=(opt.beta1, 0.999))
            self.optimizer_color_network.append(self.optimizer_D)
        
        self.optimizer_mask = torch.optim.Adam(self.model.mask.parameters(), lr=1e-4, weight_decay=1e-6)
        self.batch_size = 5
        self.epoch = 100
        self.H = 384
        self.W = 1024
        self.ori_index = torch.tensor(list(itertools.product(np.arange(self.H), np.arange(self.W)))). \
            reshape(self.H, self.W, -1)
        
    def plot_grad_flow(self, named_parameters):
        '''Plots the gradients flowing through different layers in the net during training.
        Can be used for checking for possible gradient vanishing / exploding problems.
    
        Usage: Plug this function in Trainer class after loss.backwards() as 
        "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
        ave_grads = []
        max_grads= []
        layers = []
        for n, p in named_parameters:
            if(p.requires_grad) and ("bias" not in n):
                layers.append(n)
                ave_grads.append(p.grad.abs().mean())
                max_grads.append(p.grad.abs().max())
        plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
        plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
        plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
        plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
        plt.xlim(left=0, right=len(ave_grads))
        plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
        plt.xlabel("Layers")
        plt.ylabel("average gradient")
        plt.title("Gradient flow")
        plt.grid(True)
        plt.legend([matplotlib.lines.Line2D([0], [0], color="c", lw=4),
                matplotlib.lines.Line2D([0], [0], color="b", lw=4),
                matplotlib.lines.Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
        plt.show()
        
    def pixel_flow(self, flow, img):
        FlowWeight = flow.int()
        FlowWeight += self.ori_index
        FlowWeight[FlowWeight < 0] = 0
        FlowWeight[FlowWeight >= self.H] = self.H - 1
        out = img[FlowWeight]
        return out
        
    def eval_metric(self, O_1, O_2, flow, mask):
        R_2 = self.pixel_flow(flow, O_1)
        Estab = torch.sum(mask * (O_2 - R_2) * (O_2 - R_2))
        return Estab
    
    def cohe_loss(self, O_2, S_1, flow, mask):
        R_2 = self.pixel_flow(flow, S_1)
        L_cohe = torch.sum(mask * (O_2 - R_2) * (O_2 - R_2))
        return L_cohe
    
    def occ_loss(self, O_2, S_2, mask):
        L_occ = torch.sum((1-mask) * (O_2 - S_2) * (O_2 - S_2))
        return L_occ
        
    def flow_loss(self, F_est, flow):
        F_down = F.interpolate(flow, F_est.size(), mode='bilinear')
        L_flow = torch.sum((F_est-F_down) * (F_est-F_down))
        return L_flow
        
    def train(self, train_loader):
        alpha = 1e-5
        beta = 2e-4
        gamma = 20
        for epoch in range(1, self.epoch):
            for Img, flow, mask in train_loader:
                Img = Img.cuda()
                print(Img.shape)
                temp_1 = util.get_colorization_data(Img[:,:,0,:,:].unsqueeze(0),self.model.opt)
                I_1 = temp_1['A']
                S_1 = temp_1['B']
                temp_2 = util.get_colorization_data(Img[:,:,1,:,:].unsqueeze(0),self.model.opt)
                I_2 = temp_2['A']
                S_2 = temp_2['B']
                inp = torch.cat((I_1,I_2), dim=1)
                self.optimizer_flow.zero_grad()
                self.optimizer_G.zero_grad()
                self.optimizer_mask.zero_grad()
                _, O_2, F_est = self.model([temp_1,temp_2],inp)
                print(2)
                loss = alpha*self.cohe_loss(O_2,S_1,flow,mask)+beta*self.occ_loss(O_2,S_2,mask)+gamma*self.flow_loss(F_est,flow)
                loss.backward()
                self.optimizer_flow.step()
                self.optimizer_G.step()
                self.optimizer_mask.step()

In [None]:
debug = debug.cuda()
solv = solver(debug)
solv.train(train_loader)