In [1]:
import numpy as np
import tensorflow as tf

import os
import scipy.misc
import numpy as np
from utils import *
from VESPCN_utils import *
import tensorflow as tf
from easydict import EasyDict as edict

%reload_ext autoreload
%autoreload 2

config = edict()

config.sample_dir = "samples_MCT"
config.checkpoint_dir = "checkpoint/MCT"
config.log_dir = "logs"
config.train_size = 100000000 # use large number if you have enough memory
config.valid_size = 10 # use large number if you have enough memory
config.test_size = 6400 # use large number if you have enough memory
config.batch_size = 8 # use large number if you have enough memory
config.patch_shape = [50,50,3] #[51,51,3]
config.scale = 3 #3
config.learning_rate = 1e-5
config.epoch = 1000000
config.channels = 3
config.mode = "RGB"

#config.channels = 1
#config.mode = "YCbCr"

#'''
config.dataset = "CDVL"
config.num_videos = 76

#'''
config.train = edict()
config.train.lr_init = 1e-3
config.train.lr_decay = 0.5
config.train.decay_iter = 10
config.augmentation = True

if not os.path.exists(config.checkpoint_dir):
    os.makedirs(config.checkpoint_dir)
if not os.path.exists(config.sample_dir):
    os.makedirs(config.sample_dir)
if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)

In [6]:
from __future__ import division
import os
import time
from glob import glob
import tensorflow as tf
import scipy.misc
from subpixel import PS
import numpy as np

from utils import *
from VESPCN_utils import *
from warp import *

class MotionCompensationTransformer(object):
    def __init__(self, sess, config):
        self.sess = sess
        self.config = config
        self.batch_size = config.batch_size
        self.valid_size = config.batch_size
        self.patch_shape = config.patch_shape
        self.input_size = int(config.patch_shape[0])
        self.dataset_name = config.dataset
        self.mode = config.mode
        self.channels = config.channels
        self.augmentation = config.augmentation
        self.checkpoint_dir = config.checkpoint_dir
        self.build_model()
        tf.global_variables_initializer().run(session=self.sess)
        #x = self.sess.run([self.output], feed_dict = {self.input_t0: np.zeros([8,48,48,self.channels]), self.input_t1: np.zeros([8,48,48,self.channels])})
        #print(x[0].shape)
        self.num_videos = config.num_videos

    def build_model(self):   
        
        identity_x = np.zeros([self.batch_size, self.input_size, self.input_size, 1])
        identity_y = np.zeros([self.batch_size, self.input_size, self.input_size, 1])
        for i in range(0, self.batch_size):
            for j in range(0, self.input_size):
                for k in range(0, self.input_size):
                    identity_x[i,j,k] = 2.0*j/(self.input_size-1) -1
                    identity_y[i,j,k] = 2.0*k/(self.input_size-1) -1
        self.id_x = tf.constant(identity_x, dtype = tf.float32)  
        self.id_y = tf.constant(identity_y, dtype = tf.float32)   
        
        #input frames
        self.input_t0 = tf.placeholder(tf.float32, [self.batch_size, self.input_size, self.input_size, self.channels], name='input_t0') 
        self.input_t1 = tf.placeholder(tf.float32, [self.batch_size, self.input_size, self.input_size, self.channels], name='input_t1') 

        #output frame (compensated t1)
        self.output,self.coarse_x, self.coarse_y = self.network(self.input_t0, self.input_t1)
        print("output shape:", self.output.shape)
        
        #for unknown sizes
        self.input2_t0 = tf.placeholder(tf.float32, [1, 480, 720, self.channels], name='input_t0_unkown')
        self.input2_t1 = tf.placeholder(tf.float32, [1, 480, 720, self.channels], name='input_t1_unkown')
        self.output2,self.coarse_x2, self.coarse_y2 = self.network(self.input2_t0, self.input2_t1)

        self.loss = tf.reduce_mean(tf.square(self.input_t0-self.output)) \
        #+ 0.01 * tf.reduce_mean(tf.sqrt(tf.add(tf.square(self.coarse_x)+tf.square(self.coarse_y),1e-4)))
        + 0.01 * tf.reduce_mean(tf.sqrt(tf.add(tf.square(self.coarse_x-self.id_x)+tf.square(self.coarse_y-self.id_y),1e-4))) #approximated Huber loss
        self.vars = tf.trainable_variables()
        print("Number of variables in network:",len(self.vars),", full list:",self.vars)
        self.optimizer = tf.train.AdamOptimizer(self.config.learning_rate).minimize(self.loss, var_list=self.vars)

        self.saver = tf.train.Saver()

    def network(self, t0, t1):
        '''
        ######## coasrse flow ######
        tmp = tf.concat([t0, t1], axis = 3) #early fusion
        tmp = tf.layers.conv2d(tmp, 24, 5, strides = 2, padding = 'SAME', name = 'Coarse_1', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Coarse_2', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 5, strides = 2, padding = 'SAME', name = 'Coarse_3', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Coarse_4', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 32, 3, strides = 1, padding = 'SAME', name = 'Coarse_5', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.tanh(tmp)

        coarse_x = tf.layers.conv2d(tmp, 4*4*1, 3, strides = 1, padding = 'SAME', name = 'Coarse_x', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        coarse_y = tf.layers.conv2d(tmp, 4*4*1, 3, strides = 1, padding = 'SAME', name = 'Coarse_y', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        coarse_x = PS(coarse_x, 4, color=False)
        coarse_y = PS(coarse_y, 4, color=False)
        #print("shape: ", coarse_x.shape, coarse_y.shape)
        
        ######## fine flow ######
        tmp = tf.concat([t0, t1], axis = 3) #early fusion
        tmp = tf.layers.conv2d(tmp, 24, 5, strides = 2, padding = 'SAME', name = 'Fine_1', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Fine_2', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Fine_3', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Fine_4', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 8, 3, strides = 1, padding = 'SAME', name = 'Fine_5', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.tanh(tmp)

        fine_x = tf.layers.conv2d(tmp, 2*2*1, 3, strides = 1, padding = 'SAME', name = 'Fine_x', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        fine_y = tf.layers.conv2d(tmp, 2*2*1, 3, strides = 1, padding = 'SAME', name = 'Fine_y', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        fine_x = PS(fine_x, 2, color=False)
        fine_y = PS(fine_y, 2, color=False)
        
        #add coarse, fine flow
        flow_x = tf.add(coarse_x, fine_x)
        flow_y = tf.add(coarse_y, fine_y)
        flow = tf.concat([flow_x, flow_y], axis = 3)
        #Warp
        out = batch_warp2d(t1, flow)
        #print("shape:", t1.shape, flow.shape, out.shape)
        return out
        '''
         ######## coasrse flow ######
        tmp = tf.concat([t0, t1], axis = 3) #early fusion
        tmp = tf.layers.conv2d(tmp, 24, 5, strides = 1, padding = 'SAME', name = 'Coarse_1', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Coarse_2', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 5, strides = 1, padding = 'SAME', name = 'Coarse_3', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Coarse_4', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 32, 3, strides = 1, padding = 'SAME', name = 'Coarse_5', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        #tmp = tf.nn.tanh(tmp)
        tmp = tf.nn.relu(tmp)

        coarse_x = tf.layers.conv2d(tmp, 1, 1, strides = 1, padding = 'SAME', name = 'Coarse_x', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        coarse_y = tf.layers.conv2d(tmp, 1, 1, strides = 1, padding = 'SAME', name = 'Coarse_y', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)

        #coarse_x = PS(coarse_x, 4, color=False)
        #coarse_y = PS(coarse_y, 4, color=False)
        n, h, w, _ = t0.get_shape().as_list()
        base_x = np.zeros([n, h, w, 1])
        base_y = np.zeros([n, h, w, 1])
        for i in range(0, n):
            for j in range(0, h):
                for k in range(0, w):
                    base_x[i,j,k] = 2.0*j/(h-1) -1
                    base_y[i,j,k] = 2.0*k/(w-1) -1
        _x = tf.constant(base_x, dtype = tf.float32)  
        _y = tf.constant(base_y, dtype = tf.float32)
        #Warp
        #flow_x = tf.add(tf.multiply(coarse_x,1e-10), _x)
        #flow_y = tf.add(tf.multiply(coarse_y,1e-10), _y)
        flow_x = tf.add(coarse_x, _x)
        flow_y = tf.add(coarse_y, _y)
        out = batch_warp2d_2(t1, flow_x, flow_y)

        return out, coarse_x, coarse_y
                
    def train(self, config, load = True):
        # setup train/validation data
        '''
        valid = sorted(glob(os.path.join(self.config.valid.hr_path, "*.png")))
        shuffle(valid)
        
        valid_files = valid[0:self.valid_size]
        valid = [load_image(valid_file, self.mode) for valid_file in valid_files]
        valid_LR = [doresize(xx, [self.input_size,]*2) for xx in valid]
        valid_HR = np.array(valid)
        valid_LR = np.array(valid_LR)
        if self.mode == "YCbCr":
            valid_RGB_HR =  np.copy(valid_HR)
            valid_HR = np.split(valid_RGB_HR,3, axis=3)[0]
            valid_RGB_LR = np.copy(valid_LR)
            valid_LR = np.split(valid_RGB_LR,3, axis=3)[0]
        '''
        counter = 1
        start_time = time.time()
        if load == True:
            if self.load(self.checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")
        else:
            print(" Training starts from beginning")

        for epoch in range(self.config.epoch):
            if epoch % 100 == 0:
                print("Loading videos again...")
                self.imdb = []
                self.num_frames_per_video = []
                self.imdb, self.num_frames_per_video = load_videos(30, self.num_videos, 50, self.mode)
            batch_idxs = min(len(self.imdb), self.config.train_size) // self.config.batch_size

            #for idx in range(0, batch_idxs):
            for idx in range(0, 100):
                batch_t0, batch_t1 = get_batch_MCT(self.imdb, self.num_frames_per_video, self.batch_size, 
                                                   [self.input_size, self.input_size], augmentation = self.augmentation)
                batch_t0 = np.array(batch_t0)
                batch_t1 = np.array(batch_t1)

                _, loss = self.sess.run([self.optimizer, self.loss],
                    feed_dict={ self.input_t0: batch_t0, self.input_t1: batch_t1 })

                counter+=1
                if idx % 500 == 1 and epoch % 100 == 0:
                    #print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f" %(epoch, idx, batch_idxs, time.time() - start_time, loss))
                    print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f" %(epoch, idx, 100, time.time() - start_time, loss))
                    self.save(self.config.checkpoint_dir)
             
            # occasional testing
            if epoch % 100 == 0:
                avg_PSNR_original, avg_PSNR_MCT = self.test(load = False, epoch = epoch)
                print("Epoch: [%2d] test PSNR original, MTC: %.6f, %.6f" % (epoch, avg_PSNR_original, avg_PSNR_MCT))
        self.save(self.config.checkpoint_dir)
    
    def test(self, name = "foliage", epoch = 0, load = True):
        result_dir = os.path.join("./samples_MCT/",str(name))
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        img_list = sorted(glob(os.path.join("/home/johnyi/deeplearning/research/VSR_Datasets/test/",str(name),"*.png")))
        
        if load == True:
            if self.load(self.checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")
        avg_PSNR_original = 0
        avg_PSNR_MCT = 0
        xx = sorted(glob(os.path.join("/home/johnyi/deeplearning/research/VSR_Datasets/test/vid4", "*.png")))
        frame_t0 = scipy.misc.imread("/home/johnyi/deeplearning/research/VSR_Datasets/test/vid4/"+str(name)+"/001.png", mode = self.mode)
        frame_t1 = scipy.misc.imread("/home/johnyi/deeplearning/research/VSR_Datasets/test/vid4/"+str(name)+"/002.png", mode = self.mode)

        out, flow_x, flow_y = self.sess.run([self.output2, self.coarse_x2, self.coarse_y2],
                    feed_dict={ self.input2_t0: [frame_t0], self.input2_t1:[frame_t1] })
        #print("out:",np.squeeze(out[0]).shape)
        #print("flow:", flow_x, flow_y)
        #print(np.squeeze(out[0]))
        PSNR_original = calc_PSNR(frame_t0, frame_t1)
        #PSNR_MCT = calc_PSNR(frame_t0, out[0][0,:,:,:])
        PSNR_MCT = calc_PSNR(frame_t0, out[0,:,:,:])
        avg_PSNR_original += PSNR_original
        avg_PSNR_MCT += PSNR_MCT
        
        imageio.imwrite(result_dir+"/original_0.png", frame_t0)
        imageio.imwrite(result_dir+"/original_1.png", frame_t1)
        #imageio.imwrite(result_dir+"/compensated_0_"+str(epoch)+".png", np.squeeze(out[0]))
        imageio.imwrite(result_dir+"/compensated_0_"+str(epoch)+".png", out[0])
        print("PSNR original:", PSNR_original, "MCT: ", PSNR_MCT)
        return avg_PSNR_original, avg_PSNR_MCT
    
    def save(self, checkpoint_dir):
        model_name = "MCT-"+str(self.mode)
        model_dir = "%s" % (self.dataset_name)
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        self.saver.save(self.sess, os.path.join(checkpoint_dir, model_name))

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoints...")
        model_dir = "%s"% (self.dataset_name)
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        print("loading from ",checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            model_name = "MCT-"+str(self.mode)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, model_name))
            return True
        else:
            return False

In [7]:
tf.reset_default_graph()
sess = tf.Session()
MCT = MotionCompensationTransformer(sess, config)

output shape: (?, ?, ?, ?)
Number of variables in network: 14 , full list: [<tf.Variable 'Coarse_1/kernel:0' shape=(5, 5, 6, 24) dtype=float32_ref>, <tf.Variable 'Coarse_1/bias:0' shape=(24,) dtype=float32_ref>, <tf.Variable 'Coarse_2/kernel:0' shape=(3, 3, 24, 24) dtype=float32_ref>, <tf.Variable 'Coarse_2/bias:0' shape=(24,) dtype=float32_ref>, <tf.Variable 'Coarse_3/kernel:0' shape=(5, 5, 24, 24) dtype=float32_ref>, <tf.Variable 'Coarse_3/bias:0' shape=(24,) dtype=float32_ref>, <tf.Variable 'Coarse_4/kernel:0' shape=(3, 3, 24, 24) dtype=float32_ref>, <tf.Variable 'Coarse_4/bias:0' shape=(24,) dtype=float32_ref>, <tf.Variable 'Coarse_5/kernel:0' shape=(3, 3, 24, 32) dtype=float32_ref>, <tf.Variable 'Coarse_5/bias:0' shape=(32,) dtype=float32_ref>, <tf.Variable 'Coarse_x/kernel:0' shape=(1, 1, 32, 1) dtype=float32_ref>, <tf.Variable 'Coarse_x/bias:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'Coarse_y/kernel:0' shape=(1, 1, 32, 1) dtype=float32_ref>, <tf.Variable 'Coarse_y/bias:0' s

In [9]:
with sess.as_default():
    MCT.train(config, False)

 Training starts from beginning
Loading videos again...
loaded video indexes: [52, 70, 14, 28, 65, 46, 31, 43, 39, 8, 61, 34, 48, 63, 20, 7, 36, 71, 3, 2, 16, 41, 42, 1, 10, 73, 37, 74, 54, 15] num_frames per video: 50 runtime:  72.0022828578949
Epoch: [ 0] [   1/ 100] time: 72.1303, loss: 188.08206177


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


PSNR original: 21.325452844498134 MCT:  15.959287965164444
Epoch: [ 0] test PSNR original, MTC: 21.325453, 15.959288
Loading videos again...
loaded video indexes: [43, 53, 73, 14, 27, 19, 31, 72, 63, 16, 1, 0, 49, 2, 15, 67, 64, 38, 68, 65, 56, 42, 34, 17, 48, 39, 24, 29, 45, 37] num_frames per video: 50 runtime:  74.37800860404968
Epoch: [100] [   1/ 100] time: 223.0540, loss: 1359.01306152


  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


PSNR original: 21.325452844498134 MCT:  15.875836276693335
Epoch: [100] test PSNR original, MTC: 21.325453, 15.875836
Loading videos again...
loaded video indexes: [66, 13, 36, 4, 23, 46, 54, 67, 53, 33, 44, 69, 14, 72, 64, 32, 51, 12, 39, 35, 18, 57, 19, 15, 38, 26, 59, 1, 60, 75] num_frames per video: 50 runtime:  73.23020839691162
Epoch: [200] [   1/ 100] time: 372.4361, loss: 615.30114746


  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


PSNR original: 21.325452844498134 MCT:  15.919345656392693
Epoch: [200] test PSNR original, MTC: 21.325453, 15.919346


KeyboardInterrupt: 

In [8]:
with sess.as_default():
    MCT.test()

 [*] Reading checkpoints...
loading from  checkpoint/MCT/CDVL
INFO:tensorflow:Restoring parameters from checkpoint/MCT/CDVL/MCT-RGB
 [*] Load SUCCESS


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.


PSNR original: 21.325452844498134 MCT:  15.797822092797935


  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


In [None]:
with sess.as_default():
    start_time = time.time()
    a,b = espcn.test(name = "Set5", load = True)
    print(time.time()-start_time)
    print("avg:",a,"bicubic:",b)