In [1]:
import tensorflow as tf
import numpy as np
from model import GCNet
from dataset import Scene_Flow_disparity
from pfm_IO import read
import cv2
import os
import timeit
from PIL import Image
from random import shuffle
from random import randrange
from IPython.display import clear_output
from tensorflow.python.keras.callbacks import CSVLogger
from tensorflow.python.keras.callbacks import Callback

  from ._conv import register_converters as _register_converters


In [2]:
dataset = Scene_Flow_disparity()

>> already download flyingthings3d__frames_cleanpass.tar of Scene Flow Datasets
>> already extracted flyingthings3d__frames_cleanpass of Scene Flow Datasets
>> already download driving__frames_cleanpass.tar of Scene Flow Datasets
>> already extracted driving__frames_cleanpass of Scene Flow Datasets
>> already download monkaa__frames_cleanpass.tar of Scene Flow Datasets
>> already extracted monkaa__frames_cleanpass of Scene Flow Datasets
>> already download flyingthings3d__frames_finalpass.tar of Scene Flow Datasets
>> already extracted flyingthings3d__frames_finalpass of Scene Flow Datasets
>> already download driving__frames_finalpass.tar of Scene Flow Datasets
>> already extracted driving__frames_finalpass of Scene Flow Datasets
>> already download monkaa__frames_finalpass.tar of Scene Flow Datasets
>> already extracted monkaa__frames_finalpass of Scene Flow Datasets
>> already download flyingthings3d__disparity.tar.bz2 of Scene Flow Datasets
>> already extracted flyingthings3d__disp

In [3]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

num_of_gpu = 1
dataset_img_size = (540, 960, 3)
max_disp = 192
t_v_data_rate = 0.8
train_batch_size = 1 * num_of_gpu
input_img_size = (256, 512, 3)
learning_rate = 0.1**4

In [4]:
model = GCNet(
    img_height = input_img_size[0], 
    img_width = input_img_size[1], 
    img_depth = input_img_size[2], 
    disp_range = max_disp,
    learning_rate = learning_rate,
    num_of_gpu = num_of_gpu)

net = model.inference()

input image resized by (height = 256, width = 512)
Tensor("CostVolume_left/stack:0", shape=(?, 96, 128, 256, 64), dtype=float32)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
l_img (InputLayer)              (None, 256, 512, 3)  0                                            
__________________________________________________________________________________________________
r_img (InputLayer)              (None, 256, 512, 3)  0                                            
__________________________________________________________________________________________________
model_1 (Model)                 (None, 128, 256, 32) 159648      l_img[0][0]                      
                                                                 r_img[0][0]                      
_______________________________________________________________________________

In [None]:
csv_logger = CSVLogger('./checkpoints/GCNet_training_log.csv', append=True, separator=';')
TensorBoard = tf.contrib.keras.callbacks.TensorBoard(
    log_dir="TensorBoard/GCNetTensorBoard/",
    histogram_freq = 0,
    write_graph=True, 
    write_images=True)
class batchtime(Callback):
    def __init__(self):
        self.start_time = 0
        self.end_time = 0
        
    def on_batch_begin(self, batch, logs={}):
        self.start_time = timeit.default_timer()

    def on_batch_end(self, batch, logs={}):
        self.end_time = timeit.default_timer()
        clear_output()
        print('batch_step_time = %.3f' % (self.end_time - self.start_time))
        
class WeightsSaver(Callback):
    def __init__(self, N):
        self.N = N
        self.epoch = 0

    def on_epoch_end(self, epoch, logs={}):
        if self.epoch % self.N == 0:
            self.model.save_weights('./checkpoints/GCNet(%dth).hdf5' % self.epoch)
            '''
            if not os.path.exists('./test_set/GCNet_test/'):
                os.makedirs('./test_set/GCNet_test/')
            test_paths = ['CG_Testset1_20180711', 'ETRI_chef_0_100', 'TechnicolorPainter_pr_100_#00_#15', 'middleburry', 'Sceneflow']
            clear_output()
            print('GCNet_test of %dth epoch' % self.epoch)
            for test_path in test_paths:
                path = './test_set/' + test_path
                left_image = cv2.imread((path + '/Image_00.png').replace("Scene Flow Datasets", "resized_SFD"), cv2.IMREAD_COLOR)
                left_image_=cv2.resize(left_image, (model.model_in_width, model.model_in_height),cv2.INTER_CUBIC)
                right_image = cv2.imread((path + '/Image_01.png').replace("left", "right").replace("Scene Flow Datasets", "resized_SFD"), cv2.IMREAD_COLOR) 
                right_image_=cv2.resize(left_image, (model.model_in_width, model.model_in_height),cv2.INTER_CUBIC)
                prediction = self.model.predict([left_image_[np.newaxis,:], right_image_[np.newaxis,:]], batch_size=1)
                fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(20, 10))
                ax0.imshow(np.zeros((256, 448, 3)), cmap = 'summer')
                if os.path.exists(path + 'Disp_00.pfm'):
                    ground_truth = dataset.read_pfm(fpath = path + 'Disp_00.pfm')
                    ground_truth_img = cv2.applyColorMap(ground_truth, cv2.COLORMAP_SUMMER)
                    b, g, r = cv2.split(ground_truth_img)
                    image2 =cv2.merge([r,g,b])
                    ax0.imshow(image2, cmap = 'summer')
                if not os.path.exists('./test_set/GCNet_test/%dth_epoch/' % self.epoch):
                    os.makedirs('./test_set/GCNet_test/%dth_epoch/' % self.epoch)
                prediction_img = cv2.applyColorMap(prediction[0,:].astype(np.uint8), cv2.COLORMAP_JET)
                b, g, r = cv2.split(prediction_img)
                image2 =cv2.merge([r,g,b])
                ax1.imshow(image2, cmap = 'summer')
                plt.show()
                save_path = './test_set/GCNet_test/%dth_epoch/' % self.epoch
                cv2.imwrite(save_path + test_path + '.png', prediction_img)
                np.save(save_path + test_path + '.npy', prediction[0,:])
            #'''
            self.epoch += 1

def trainDataGenerator(data_paths, target_input_size = (model.model_in_width, model.model_in_height), 
                       target_output_size = (model.model_in_width, model.model_in_height), batch_size = 1, suffle = True):
    train_left_inputBatch = []
    train_right_inputBatch = []
    train_left_groundTruthBatch = []
    train_right_groundTruthBatch = []
    while True:
        for dir_ in data_paths:
            x = randrange(0,dataset_img_size[1]-input_img_size[1])
            y = randrange(0,dataset_img_size[0]-input_img_size[0])
            tmp = cv2.imread(dir_, cv2.IMREAD_COLOR)
            l_img_nparray = tmp[y:y+input_img_size[0],x:x+input_img_size[1],:]

            tmp = cv2.imread(dir_.replace("left", "right"), cv2.IMREAD_COLOR)
            r_img_nparray = tmp[y:y+input_img_size[0],x:x+input_img_size[1],:]

            if 'driving__frames_cleanpass' in dir_:
                tmp = dir_.replace("driving__frames_cleanpass", "driving__disparity")
                dir_ = tmp.replace("frames_cleanpass", "disparity")
            elif 'driving__frames_finalpass' in dir_:
                tmp = dir_.replace("driving__frames_finalpass", "driving__disparity")
                dir_ = tmp.replace("frames_finalpass", "disparity")
            elif 'flyingthings3d__frames_cleanpass' in dir_:
                tmp = dir_.replace("flyingthings3d__frames_cleanpass", "flyingthings3d__disparity")
                dir_ = tmp.replace("frames_cleanpass", "disparity")
            elif 'flyingthings3d__frames_finalpass' in dir_:
                tmp = dir_.replace("flyingthings3d__frames_finalpass", "flyingthings3d__disparity")
                dir_ = tmp.replace("frames_finalpass", "disparity")
            elif 'monkaa__frames_cleanpass' in dir_:
                tmp = dir_.replace("monkaa__frames_cleanpass", "monkaa__disparity")
                dir_ = tmp.replace("frames_cleanpass", "disparity")
            elif 'monkaa__frames_finalpass' in dir_:
                tmp = dir_.replace("monkaa__frames_finalpass", "monkaa__disparity")
                dir_ = tmp.replace("frames_finalpass", "disparity")

            ground_truth = dataset.read_pfm(dir_.replace(".png", ".pfm"))
            ground_truth_ =dataset.read_pfm(dir_.replace(".png", ".pfm").replace("left", "right"))
            
            train_left_inputBatch += [l_img_nparray]
            train_right_inputBatch += [r_img_nparray]
            train_left_groundTruthBatch += [ground_truth[y:y+input_img_size[0],x:x+input_img_size[1]]]
            train_right_groundTruthBatch += [ground_truth_[y:y+input_img_size[0],x:x+input_img_size[1]]]

            if len(train_left_groundTruthBatch) == batch_size:
                batch_input_x = [np.array(train_left_inputBatch), np.array(train_right_inputBatch)]
                batch_input_y = [np.array(train_left_groundTruthBatch), np.array(train_right_groundTruthBatch)]
                train_left_inputBatch = []
                train_right_inputBatch = []
                train_left_groundTruthBatch = []
                train_right_groundTruthBatch = []

                yield (batch_input_x, batch_input_y)

directories = [i for i in dataset.data_paths if ('left' in i) and ('.png' in i)]
shuffle(directories)

train_paths = directories[0:int(len(directories)*t_v_data_rate)]
train_generator = trainDataGenerator(data_paths = train_paths, batch_size = train_batch_size)

validation_paths = directories[int(len(directories)*t_v_data_rate):]
validation_generator = trainDataGenerator(data_paths = validation_paths, batch_size = train_batch_size)

net.fit_generator(train_generator,
                  epochs = 3,
                  steps_per_epoch = len(train_paths)/train_batch_size,
                  validation_data = validation_generator,
                  validation_steps = len(validation_paths)/train_batch_size,
                  verbose = 1,
                  callbacks=[TensorBoard, csv_logger, WeightsSaver(1)])

Epoch 1/3
  231/63718 [..............................] - ETA: 56:15:55 - loss: 112.1648 - model_2_loss: 66.7017

In [None]:
net.save_weights('./checkpoints/GCNet.hdf5')