In [6]:
import numpy as np
import tensorflow as tf

import os
import scipy.misc
import numpy as np
from utils import *
from VESPCN_utils import *
from warp import *
import tensorflow as tf
from easydict import EasyDict as edict

%reload_ext autoreload
%autoreload 2

config = edict()

config.sample_dir = "samples_MCT"
config.checkpoint_dir = "checkpoint/MCT"
config.log_dir = "logs"
config.train_size = 100000000 # use large number if you have enough memory
config.valid_size = 10 # use large number if you have enough memory
config.test_size = 6400 # use large number if you have enough memory
config.batch_size = 8 # use large number if you have enough memory
config.patch_shape = [48,48,3] #[51,51,3]
config.scale = 3 #3
config.learning_rate = 1e-4
config.epoch = 1000000
config.channels = 3
config.mode = "RGB"

#config.channels = 1
#config.mode = "YCbCr"

#'''
config.dataset = "CDVL"
config.num_videos = 76

#'''
config.train = edict()
config.train.lr_init = 1e-3
config.train.lr_decay = 0.5
config.train.decay_iter = 10
config.augmentation = True

if not os.path.exists(config.checkpoint_dir):
    os.makedirs(config.checkpoint_dir)
if not os.path.exists(config.sample_dir):
    os.makedirs(config.sample_dir)
if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)

In [7]:
def get_batch_MCT(imdb, num_frames_per_video, batch_size, patch_size, augmentation = False):
    batch_t0 = np.zeros([batch_size, patch_size[0], patch_size[1], 3], dtype = 'uint8')
    batch_t1 = np.zeros([batch_size, patch_size[0], patch_size[1], 3], dtype = 'uint8')
    resize_ratio = 1
    flags = np.zeros(len(num_frames_per_video))
    for i in range(batch_size):
        video_index = np.random.randint(len(num_frames_per_video)) #select random frame from video
        while flags[video_index] == 1:
            video_index = np.random.randint(len(num_frames_per_video)) #select random frame from video
        flags[video_index] = 1
        #video_index = 0
        frame_num = np.random.randint(num_frames_per_video[video_index])
        if augmentation == True:
            resize_ratio = np.random.rand()*0.5 + 0.5
        if frame_num == 0 or frame_num == num_frames_per_video[video_index] - 1:
            t0 = imdb[video_index][frame_num, : ,: ,:]
            t1 = imdb[video_index][frame_num, : ,: ,:]
        else:
            t0 = imdb[video_index][frame_num, : ,: ,:]
            t1 = imdb[video_index][frame_num+1, : ,: ,:]
        #print(t0.shape)
        H = np.random.randint(t0.shape[0]-int(np.ceil(patch_size[0]/resize_ratio)))
        W = np.random.randint(t0.shape[1]-int(np.ceil(patch_size[1]/resize_ratio)))
        patch_t0 = t0[H:H+int(np.ceil(patch_size[0]/resize_ratio)), W:W+int(np.ceil(patch_size[1]/resize_ratio)),:]
        patch_t0 = imresize(patch_t0, patch_size, interp = "bicubic")
        patch_t1 = t1[H:H+int(np.ceil(patch_size[0]/resize_ratio)), W:W+int(np.ceil(patch_size[1]/resize_ratio)),:]
        patch_t1 = imresize(patch_t1, patch_size, interp = "bicubic")
        batch_t0[i,:,:,:] = patch_t0
        batch_t1[i,:,:,:] = patch_t1
    return batch_t0, batch_t1

In [5]:
#imdb, num_frames_per_video = load_videos(10, 60, 50, "RGB")
#print(imdb[0].shape)
batch_t0, batch_t1 = get_batch_MCT(imdb, num_frames_per_video, 1, [1000, 1000], augmentation = False)
print(batch_t0.shape, batch_t1.shape)
imageio.imwrite("test_0.png", np.squeeze(batch_t0[0,:,:,:]))
imageio.imwrite("test_1.png", np.squeeze(batch_t1[0,:,:,:]))

(1080, 1920, 3)
(1, 1000, 1000, 3) (1, 1000, 1000, 3)


`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


In [8]:
from __future__ import division
import os
import time
from glob import glob
import tensorflow as tf
import scipy.misc
from subpixel import PS
import numpy as np

from utils import *
from VESPCN_utils import *
from warp import *

class MotionCompensationTransformer(object):
    def __init__(self, sess, config):
        self.sess = sess
        self.config = config
        self.batch_size = config.batch_size
        self.valid_size = config.batch_size
        self.patch_shape = config.patch_shape
        self.input_size = int(config.patch_shape[0])
        self.dataset_name = config.dataset
        self.mode = config.mode
        self.channels = config.channels
        self.augmentation = config.augmentation
        self.checkpoint_dir = config.checkpoint_dir
        self.build_model()
        tf.global_variables_initializer().run(session=self.sess)
        #x = self.sess.run([self.output], feed_dict = {self.input_t0: np.zeros([8,48,48,self.channels]), self.input_t1: np.zeros([8,48,48,self.channels])})
        #print(x[0].shape)
        self.num_videos = config.num_videos

    def build_model(self):
        #input frames
        self.input_t0 = tf.placeholder(tf.float32, [self.batch_size, self.input_size, self.input_size, self.channels], name='input_t0') 
        self.input_t1 = tf.placeholder(tf.float32, [self.batch_size, self.input_size, self.input_size, self.channels], name='input_t1') 

        #output frame (compensated t1)
        self.output = self.network(self.input_t0, self.input_t1)
        
        #for unknown sizes
        self.input2_t0 = tf.placeholder(tf.float32, [None, 576, 720, self.channels], name='input_t0_unkown')
        self.input2_t1 = tf.placeholder(tf.float32, [None, 576, 720, self.channels], name='input_t1_unkown')
        self.output2 = self.network(self.input2_t0, self.input2_t1)

        self.loss = tf.reduce_mean(tf.square(self.input_t0-self.output)) + \
                0.01 * tf.reduce_mean(tf.sqrt(tf.add(tf.square(self.input_t0-self.output),1e-4))) #approximated Huber loss
        self.vars = tf.trainable_variables()
        print("Number of variables in network:",len(self.vars),", full list:",self.vars)
        self.optimizer = tf.train.AdamOptimizer(self.config.learning_rate).minimize(self.loss, var_list=self.vars)

        self.saver = tf.train.Saver()

    def network(self, t0, t1):
        ######## coasrse flow ######
        tmp = tf.concat([t0, t1], axis = 3) #early fusion
        tmp = tf.layers.conv2d(tmp, 24, 5, strides = 2, padding = 'SAME', name = 'Coarse_1', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Coarse_2', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 5, strides = 2, padding = 'SAME', name = 'Coarse_3', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Coarse_4', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 32, 3, strides = 1, padding = 'SAME', name = 'Coarse_5', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.tanh(tmp)

        coarse_x = tf.layers.conv2d(tmp, 4*4*1, 3, strides = 1, padding = 'SAME', name = 'Coarse_x', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        coarse_y = tf.layers.conv2d(tmp, 4*4*1, 3, strides = 1, padding = 'SAME', name = 'Coarse_y', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        coarse_x = PS(coarse_x, 4, color=False)
        coarse_y = PS(coarse_y, 4, color=False)
        #print("shape: ", coarse_x.shape, coarse_y.shape)
        
        ######## fine flow ######
        tmp = tf.concat([t0, t1], axis = 3) #early fusion
        tmp = tf.layers.conv2d(tmp, 24, 5, strides = 2, padding = 'SAME', name = 'Fine_1', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Fine_2', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Fine_3', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 24, 3, strides = 1, padding = 'SAME', name = 'Fine_4', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.relu(tmp)
        tmp = tf.layers.conv2d(tmp, 8, 3, strides = 1, padding = 'SAME', name = 'Fine_5', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        tmp = tf.nn.tanh(tmp)

        fine_x = tf.layers.conv2d(tmp, 2*2*1, 3, strides = 1, padding = 'SAME', name = 'Fine_x', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        fine_y = tf.layers.conv2d(tmp, 2*2*1, 3, strides = 1, padding = 'SAME', name = 'Fine_y', 
                               kernel_initializer = tf.contrib.layers.xavier_initializer(), reuse=tf.AUTO_REUSE)
        fine_x = PS(fine_x, 2, color=False)
        fine_y = PS(fine_y, 2, color=False)
        
        #add coarse, fine flow
        flow_x = tf.add(coarse_x, fine_x)
        flow_y = tf.add(coarse_y, fine_y)
        flow = tf.concat([flow_x, flow_y], axis = 3)
        #Warp
        out = batch_warp2d(t1, flow)
        #print("shape:", t1.shape, flow.shape, out.shape, out)
        print(out)
        return out
        
        
    def train(self, config, load = True):
        # setup train/validation data
        '''
        valid = sorted(glob(os.path.join(self.config.valid.hr_path, "*.png")))
        shuffle(valid)
        
        valid_files = valid[0:self.valid_size]
        valid = [load_image(valid_file, self.mode) for valid_file in valid_files]
        valid_LR = [doresize(xx, [self.input_size,]*2) for xx in valid]
        valid_HR = np.array(valid)
        valid_LR = np.array(valid_LR)
        if self.mode == "YCbCr":
            valid_RGB_HR =  np.copy(valid_HR)
            valid_HR = np.split(valid_RGB_HR,3, axis=3)[0]
            valid_RGB_LR = np.copy(valid_LR)
            valid_LR = np.split(valid_RGB_LR,3, axis=3)[0]
        '''
        counter = 1
        start_time = time.time()
        if load == True:
            if self.load(self.checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")
        else:
            print(" Training starts from beginning")

        for epoch in range(self.config.epoch):
            if epoch % 100 == 0:
                print("Loading videos again...")
                self.imdb = []
                self.num_frames_per_video = []
                self.imdb, self.num_frames_per_video = load_videos(30, self.num_videos, 50, self.mode)
            batch_idxs = min(len(self.imdb), self.config.train_size) // self.config.batch_size

            #for idx in range(0, batch_idxs):
            for idx in range(0, 100):
                batch_t0, batch_t1 = get_batch_MCT(self.imdb, self.num_frames_per_video, self.batch_size, 
                                                   [self.input_size, self.input_size], augmentation = self.augmentation)
                batch_t0 = np.array(batch_t0)
                batch_t1 = np.array(batch_t1)
                if self.mode == "YCbCr":
                    batch_t0 = np.split(batch_t0, 3, axis=3)[0]
                    batch_t1 = np.split(batch_t1, 3, axis=3)[0]

                _, loss = self.sess.run([self.optimizer, self.loss],
                    feed_dict={ self.input_t0: batch_t0, self.input_t1: batch_t1 })

                counter+=1
                if idx % 500 == 1 and epoch % 100 == 0:
                    #print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f" %(epoch, idx, batch_idxs, time.time() - start_time, loss))
                    print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f" %(epoch, idx, 100, time.time() - start_time, loss))
                    self.save(self.config.checkpoint_dir)
             
            # occasional testing
            if epoch % 100 == 0:
                avg_PSNR = self.test(load = False)
                print("Epoch: [%2d] test PSNR: %.6f" % (epoch, avg_PSNR))
        self.save(self.config.checkpoint_dir)
    
    def test(self, name = "calendar", load = True):
        result_dir = os.path.join("./samples_MCT/",str(name))
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        img_list = sorted(glob(os.path.join("/home/johnyi/deeplearning/research/VSR_Datasets/test/",str(name),"*.png")))
        
        if load == True:
            if self.load(self.checkpoint_dir):
                print(" [*] Load SUCCESS")
            else:
                print(" [!] Load failed...")
        avg_PSNR = 0
        xx = sorted(glob(os.path.join("/home/johnyi/deeplearning/research/VSR_Datasets/test/vid4", "*.png")))
        frame_t0 = scipy.misc.imread("/home/johnyi/deeplearning/research/VSR_Datasets/test/vid4/calendar/001.png", mode = self.mode)
        frame_t1 = scipy.misc.imread("/home/johnyi/deeplearning/research/VSR_Datasets/test/vid4/calendar/002.png", mode = self.mode)
        if self.mode == "YCbCr":
            frame_t0 = np.split(frame_t0, 3, axis=2)[0]
            frame_t1 = np.split(frame_t1, 3, axis=2)[0]
        out = self.sess.run([self.output2],
                    feed_dict={ self.input2_t0: [frame_t0], self.input2_t1:[frame_t1] })
        PSNR_original = calc_PSNR(frame_t0, frame_t1)
        imageio.imwrite(result_dir+"/original_0.png", frame_t0)
        imageio.imwrite(result_dir+"/original_1.png", frame_t1)
        imageio.imwrite(result_dir+"/compensated_0.png", np.squeeze(out[0]))
        PSNR_MCT = calc_PSNR(frame_t0, out[0][0,:,:,:])
        print("PSNR original:", PSNR_original, "MCT: ", PSNR_MCT)
        return avg_PSNR
    
    def save(self, checkpoint_dir):
        model_name = "MCT-"+str(self.mode)
        model_dir = "%s" % (self.dataset_name)
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        self.saver.save(self.sess, os.path.join(checkpoint_dir, model_name))

    def load(self, checkpoint_dir):
        print(" [*] Reading checkpoints...")
        model_dir = "%s"% (self.dataset_name)
        checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        print("loading from ",checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            model_name = "MCT-"+str(self.mode)+"-x"+str(self.scale)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, model_name))
            return True
        else:
            return False

In [9]:
tf.reset_default_graph()
sess = tf.Session()
MCT = MotionCompensationTransformer(sess, config)

Tensor("Reshape_14:0", shape=(?, ?, ?, ?), dtype=float32)
Tensor("Reshape_29:0", shape=(?, ?, ?, ?), dtype=float32)
Number of variables in network: 28 , full list: [<tf.Variable 'Coarse_1/kernel:0' shape=(5, 5, 6, 24) dtype=float32_ref>, <tf.Variable 'Coarse_1/bias:0' shape=(24,) dtype=float32_ref>, <tf.Variable 'Coarse_2/kernel:0' shape=(3, 3, 24, 24) dtype=float32_ref>, <tf.Variable 'Coarse_2/bias:0' shape=(24,) dtype=float32_ref>, <tf.Variable 'Coarse_3/kernel:0' shape=(5, 5, 24, 24) dtype=float32_ref>, <tf.Variable 'Coarse_3/bias:0' shape=(24,) dtype=float32_ref>, <tf.Variable 'Coarse_4/kernel:0' shape=(3, 3, 24, 24) dtype=float32_ref>, <tf.Variable 'Coarse_4/bias:0' shape=(24,) dtype=float32_ref>, <tf.Variable 'Coarse_5/kernel:0' shape=(3, 3, 24, 32) dtype=float32_ref>, <tf.Variable 'Coarse_5/bias:0' shape=(32,) dtype=float32_ref>, <tf.Variable 'Coarse_x/kernel:0' shape=(3, 3, 32, 16) dtype=float32_ref>, <tf.Variable 'Coarse_x/bias:0' shape=(16,) dtype=float32_ref>, <tf.Variable '

In [5]:
with sess.as_default():
    MCT.train(config, False)

 Training starts from beginning
Loading videos again...
loaded video indexes: [73, 17, 48, 11, 51, 0, 28, 66, 14, 41, 32, 7, 50, 18, 20, 8, 25, 58, 42, 26, 62, 54, 13, 23, 64, 69, 46, 22, 49, 15] num_frames per video: 50 runtime:  76.38733911514282


`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


Epoch: [ 0] [   1/ 100] time: 81.1488, loss: 16282.41796875


`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


PSNR original: 18.480092851641135 MCT:  13.22803646817676
Epoch: [ 0] test PSNR: 0.000000
Loading videos again...
loaded video indexes: [50, 15, 57, 33, 20, 39, 11, 68, 75, 12, 53, 60, 19, 23, 13, 1, 41, 70, 45, 8, 66, 27, 2, 22, 4, 47, 74, 5, 36, 16] num_frames per video: 50 runtime:  83.97659087181091
Epoch: [100] [   1/ 100] time: 225.3637, loss: 61.19231796


  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


PSNR original: 18.480092851641135 MCT:  21.106966172511868
Epoch: [100] test PSNR: 0.000000
Loading videos again...
loaded video indexes: [33, 60, 42, 71, 70, 1, 25, 14, 20, 40, 2, 55, 8, 3, 17, 4, 30, 50, 43, 21, 64, 63, 72, 39, 68, 0, 47, 66, 31, 51] num_frames per video: 50 runtime:  78.47426629066467
Epoch: [200] [   1/ 100] time: 364.8360, loss: 57.20231247


  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


PSNR original: 18.480092851641135 MCT:  22.5350557477666
Epoch: [200] test PSNR: 0.000000
Loading videos again...
loaded video indexes: [61, 36, 62, 63, 45, 74, 30, 5, 73, 20, 41, 13, 23, 3, 51, 24, 14, 52, 50, 57, 35, 11, 75, 46, 71, 68, 54, 28, 64, 43] num_frames per video: 50 runtime:  73.54035830497742
Epoch: [300] [   1/ 100] time: 498.1088, loss: 36.39022064


  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


PSNR original: 18.480092851641135 MCT:  23.21801573015413
Epoch: [300] test PSNR: 0.000000
Loading videos again...
loaded video indexes: [57, 37, 48, 71, 9, 38, 58, 5, 19, 8, 66, 10, 68, 36, 69, 43, 50, 51, 49, 64, 61, 59, 72, 3, 73, 23, 14, 31, 13, 47] num_frames per video: 50 runtime:  74.97953581809998
Epoch: [400] [   1/ 100] time: 632.3898, loss: 4.58948565


  'range [{2}, {3}]'.format(dtype_str, out_type.__name__, mi, ma))


PSNR original: 18.480092851641135 MCT:  23.82413472657514
Epoch: [400] test PSNR: 0.000000


KeyboardInterrupt: 

In [None]:
with sess.as_default():
    start_time = time.time()
    a,b = espcn.test(name = "Set5", load = True)
    print(time.time()-start_time)
    print("avg:",a,"bicubic:",b)