In [1]:
!pip install tf_slim

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tf_slim
  Downloading tf_slim-1.1.0-py2.py3-none-any.whl (352 kB)
[K     |████████████████████████████████| 352 kB 4.2 MB/s 
Installing collected packages: tf-slim
Successfully installed tf-slim-1.1.0


In [2]:
# training dataset path
TRAIN_FOLDER = '/content/drive/MyDrive/Colab Notebooks/DIR-D/training'

# testing dataset path
TEST_FOLDER = '/content/drive/MyDrive/Colab Notebooks/DIR-D/testing'

# GPU index
GPU = '0'

# batch size for training
TRAIN_BATCH_SIZE = 1

# batch size for testing
TEST_BATCH_SIZE = 1

# num of iterations
ITERATIONS = 100000

# checkpoints path
SNAPSHOT_DIR = "/content/drive/MyDrive/Colab Notebooks/checkpoints"

# summary path
SUMMARY_DIR = "/content/drive/MyDrive/Colab Notebooks/summary"

# define the mesh resolution
GRID_W = 8
GRID_H = 6

In [3]:
import tensorflow as tf
import numpy as np

#######################################################
# Auxiliary matrices used to solve DLT
Aux_M1 = np.array([
    [0, 0, 0, 0, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.float64)

Aux_M2 = np.array([
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.float64)

Aux_M3 = np.array([
    [0],
    [1],
    [0],
    [1],
    [0],
    [1],
    [0],
    [1]], dtype=np.float64)

Aux_M4 = np.array([
    [-1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, -1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, -1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, -1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0]], dtype=np.float64)

Aux_M5 = np.array([
    [0, -1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, -1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, -1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, -1],
    [0, 0, 0, 0, 0, 0, 0, 0]], dtype=np.float64)

Aux_M6 = np.array([
    [-1],
    [0],
    [-1],
    [0],
    [-1],
    [0],
    [-1],
    [0]], dtype=np.float64)

Aux_M71 = np.array([
    [0, 1, 0, 0, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1],
    [0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.float64)

Aux_M72 = np.array([
    [1, 0, 0, 0, 0, 0, 0, 0],
    [-1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, -1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, -1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 1, 0],
    [0, 0, 0, 0, 0, 0, -1, 0]], dtype=np.float64)

Aux_M8 = np.array([
    [0, 1, 0, 0, 0, 0, 0, 0],
    [0, -1, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0],
    [0, 0, 0, -1, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1, 0, 0],
    [0, 0, 0, 0, 0, -1, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 1],
    [0, 0, 0, 0, 0, 0, 0, -1]], dtype=np.float64)

Aux_Mb = np.array([
    [0, -1, 0, 0, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, -1, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, -1, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, -1],
    [0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.float64)


########################################################

def solve_DLT(orig_pt4, pred_pt4):
    batch_size = tf.shape(input=orig_pt4)[0]
    orig_pt4 = tf.expand_dims(orig_pt4, [2])
    pred_pt4 = tf.expand_dims(pred_pt4, [2])

    # Auxiliary tensors used to create Ax = b equation
    m1 = tf.constant(Aux_M1, tf.float32)
    m1_tensor = tf.expand_dims(m1, [0])
    m1_tile = tf.tile(m1_tensor, [batch_size, 1, 1])

    m2 = tf.constant(Aux_M2, tf.float32)
    m2_tensor = tf.expand_dims(m2, [0])
    m2_tile = tf.tile(m2_tensor, [batch_size, 1, 1])

    m3 = tf.constant(Aux_M3, tf.float32)
    m3_tensor = tf.expand_dims(m3, [0])
    m3_tile = tf.tile(m3_tensor, [batch_size, 1, 1])

    m4 = tf.constant(Aux_M4, tf.float32)
    m4_tensor = tf.expand_dims(m4, [0])
    m4_tile = tf.tile(m4_tensor, [batch_size, 1, 1])

    m5 = tf.constant(Aux_M5, tf.float32)
    m5_tensor = tf.expand_dims(m5, [0])
    m5_tile = tf.tile(m5_tensor, [batch_size, 1, 1])

    m6 = tf.constant(Aux_M6, tf.float32)
    m6_tensor = tf.expand_dims(m6, [0])
    m6_tile = tf.tile(m6_tensor, [batch_size, 1, 1])

    m71 = tf.constant(Aux_M71, tf.float32)
    m71_tensor = tf.expand_dims(m71, [0])
    m71_tile = tf.tile(m71_tensor, [batch_size, 1, 1])

    m72 = tf.constant(Aux_M72, tf.float32)
    m72_tensor = tf.expand_dims(m72, [0])
    m72_tile = tf.tile(m72_tensor, [batch_size, 1, 1])

    m8 = tf.constant(Aux_M8, tf.float32)
    m8_tensor = tf.expand_dims(m8, [0])
    m8_tile = tf.tile(m8_tensor, [batch_size, 1, 1])

    mb = tf.constant(Aux_Mb, tf.float32)
    mb_tensor = tf.expand_dims(mb, [0])
    mb_tile = tf.tile(mb_tensor, [batch_size, 1, 1])

    # Form the equations Ax = b to compute H
    # Form A matrix
    a1 = tf.matmul(m1_tile, orig_pt4)  # Column 1
    a2 = tf.matmul(m2_tile, orig_pt4)  # Column 2
    a3 = m3_tile  # Column 3
    a4 = tf.matmul(m4_tile, orig_pt4)  # Column 4
    a5 = tf.matmul(m5_tile, orig_pt4)  # Column 5
    a6 = m6_tile  # Column 6
    a7 = tf.matmul(m71_tile, pred_pt4) * tf.matmul(m72_tile, orig_pt4)  # Column 7
    a8 = tf.matmul(m71_tile, pred_pt4) * tf.matmul(m8_tile, orig_pt4)  # Column 8

    # tmp = tf.reshape(a1, [-1, 8])  #batch_size * 8
    # A_mat: batch_size * 8 * 8          a1-A8相当�?*8中的每一�?
    a_mat = tf.transpose(a=tf.stack([tf.reshape(a1, [-1, 8]), tf.reshape(a2, [-1, 8]),
                                   tf.reshape(a3, [-1, 8]), tf.reshape(a4, [-1, 8]),
                                   tf.reshape(a5, [-1, 8]), tf.reshape(a6, [-1, 8]),
                                   tf.reshape(a7, [-1, 8]), tf.reshape(a8, [-1, 8])], axis=1),
                         perm=[0, 2, 1])  # BATCH_SIZE x 8 (A_i) x 8
    print('--Shape of A_mat:', a_mat.get_shape().as_list())
    # Form b matrix
    b_mat = tf.matmul(mb_tile, pred_pt4)
    print('--shape of b:', b_mat.get_shape().as_list())

    # Solve the Ax = b
    h_8el = tf.linalg.solve(a_mat, b_mat)  # BATCH_SIZE x 8.
    print('--shape of H_8el', h_8el)

    # Add ones to the last cols to reconstruct H for computing reprojection error
    h_ones = tf.ones([batch_size, 1, 1])
    h_9el = tf.concat([h_8el, h_ones], 1)
    h_flat = tf.reshape(h_9el, [-1, 9])
    # H_mat = tf.reshape(h_flat ,[-1 ,3 ,3])   # BATCH_SIZE x 3 x 3
    return h_flat

In [4]:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from keras.layers import UpSampling2D

grid_w = GRID_W
grid_h = GRID_H


def transformer(_u, mask, theta, name='SpatialTransformer'):
    def _repeat(x, n_repeats):
        with tf.compat.v1.variable_scope('_repeat'):
            rep = tf.transpose(
                a=tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), perm=[1, 0])
            rep = tf.cast(rep, 'int32')
            x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
            return tf.reshape(x, [-1])

    def _interpolate(im, x, y, out_size):
        with tf.compat.v1.variable_scope('_interpolate'):
            # constants
            num_batch = tf.shape(input=im)[0]
            height = tf.shape(input=im)[1]
            width = tf.shape(input=im)[2]
            channels = tf.shape(input=im)[3]

            x = tf.cast(x, 'float32')
            y = tf.cast(y, 'float32')

            out_height = out_size[0]
            out_width = out_size[1]
            zero = tf.zeros([], dtype='int32')
            max_y = tf.cast(tf.shape(input=im)[1] - 1, 'int32')
            max_x = tf.cast(tf.shape(input=im)[2] - 1, 'int32')

            # scale indices from [-1, 1] to [0, width/height]
            # x = (x + 1.0)*(width_f) / 2.0
            # y = (y + 1.0)*(height_f) / 2.0

            # do sampling
            x0 = tf.cast(tf.floor(x), 'int32')
            x1 = x0 + 1
            y0 = tf.cast(tf.floor(y), 'int32')
            y1 = y0 + 1

            x0 = tf.clip_by_value(x0, zero, max_x)
            x1 = tf.clip_by_value(x1, zero, max_x)
            y0 = tf.clip_by_value(y0, zero, max_y)
            y1 = tf.clip_by_value(y1, zero, max_y)
            dim2 = width
            dim1 = width * height
            base = _repeat(tf.range(num_batch) * dim1, out_height * out_width)
            base_y0 = base + y0 * dim2
            base_y1 = base + y1 * dim2
            idx_a = base_y0 + x0
            idx_b = base_y1 + x0
            idx_c = base_y0 + x1
            idx_d = base_y1 + x1

            # use indices to lookup pixels in the flat image and restore
            # channels dim
            im_flat = tf.reshape(im, tf.stack([-1, channels]))
            im_flat = tf.cast(im_flat, 'float32')
            ia = tf.gather(im_flat, idx_a)
            ib = tf.gather(im_flat, idx_b)
            ic = tf.gather(im_flat, idx_c)
            i_d = tf.gather(im_flat, idx_d)

            # and finally calculate interpolated values
            x0_f = tf.cast(x0, 'float32')
            x1_f = tf.cast(x1, 'float32')
            y0_f = tf.cast(y0, 'float32')
            y1_f = tf.cast(y1, 'float32')
            wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1)
            wb = tf.expand_dims(((x1_f - x) * (y - y0_f)), 1)
            wc = tf.expand_dims(((x - x0_f) * (y1_f - y)), 1)
            wd = tf.expand_dims(((x - x0_f) * (y - y0_f)), 1)
            output = tf.add_n([wa * ia, wb * ib, wc * ic, wd * i_d])
            return output

    # input:  batch_size*(grid_h+1)*(grid_w+1)*2
    # output: batch_size*grid_h*grid_w*9
    def get_Hs(_theta, width, height):
        with tf.compat.v1.variable_scope('get_Hs'):
            num_batch = tf.shape(input=_theta)[0]
            h = height / grid_h
            w = width / grid_w
            hs = []
            for i in range(grid_h):
                for j in range(grid_w):
                    hh = i * h
                    ww = j * w
                    ori = tf.tile(tf.constant([ww, hh, ww + w, hh, ww, hh + h, ww + w, hh + h], shape=[1, 8],
                                              dtype=tf.float32), multiples=[num_batch, 1])
                    # id = i * (grid_w + 1) + grid_w
                    tar = tf.concat([tf.slice(_theta, [0, i, j, 0], [-1, 1, 1, -1]),
                                     tf.slice(_theta, [0, i, j + 1, 0], [-1, 1, 1, -1]),
                                     tf.slice(_theta, [0, i + 1, j, 0], [-1, 1, 1, -1]),
                                     tf.slice(_theta, [0, i + 1, j + 1, 0], [-1, 1, 1, -1])], axis=1)

                    tar = tf.reshape(tar, [num_batch, 8])

                    # tar = tf.Print(tar, [tf.slice(ori, [0, 0], [1, -1])],message="[ori--i:"+str(i)+",j:"+str(j)+"]:",
                    # summarize=100,first_n=5)
                    # tar = tf.Print(tar, [tf.slice(tar, [0, 0], [1, -1])],message="[tar--i:"+str(i)+",j:"+str(j)+"]:",
                    # summarize=100,first_n=5)

                    hs.append(tf.reshape(tensorDLT_local.solve_DLT(ori, tar), [num_batch, 1, 9]))

            hs = tf.reshape(tf.concat(hs, axis=1), [num_batch, grid_h, grid_w, 9], name='Hs')
        return hs

    def _mesh_grid(height, width):

        x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
                        tf.transpose(a=tf.expand_dims(tf.linspace(0., tf.cast(width, 'float32') - 1.001, width), 1),
                                     perm=[1, 0]))
        y_t = tf.matmul(tf.expand_dims(tf.linspace(0., tf.cast(height, 'float32') - 1.001, height), 1),
                        tf.ones(shape=tf.stack([1, width])))

        x_t_flat = tf.reshape(x_t, (1, -1))
        y_t_flat = tf.reshape(y_t, (1, -1))

        ones = tf.ones_like(x_t_flat)
        grid = tf.concat([x_t_flat, y_t_flat, ones], 0)

        return grid

    def _transform3(_theta, input_dim, _mask):
        with tf.compat.v1.variable_scope('_transform'):
            num_batch = tf.shape(input=input_dim)[0]
            height = tf.shape(input=input_dim)[1]
            width = tf.shape(input=input_dim)[2]
            num_channels = tf.shape(input=input_dim)[3]

            # the width/height should be an integral multiple of grid_w/grid_h
            width_float = 512.
            height_float = 384.

            _theta = tf.cast(_theta, 'float32')
            h_s = get_Hs(_theta, width_float, height_float)

            ##########################################
            print("Hs")
            print(h_s.shape)
            h_array = UpSampling2D(size=(384 / grid_h, 512 / grid_w))(h_s)
            h_array = tf.reshape(h_array, [-1, 3, 3])
            ##########################################

            out_height = height
            out_width = width
            grid = _mesh_grid(out_height, out_width)
            grid = tf.expand_dims(grid, 0)
            grid = tf.reshape(grid, [-1])
            grid = tf.tile(grid, tf.stack([num_batch]))  # stack num_batch grids
            grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
            print("grid")
            print(grid.shape)
            # [bs, 3, N]

            grid = tf.expand_dims(tf.transpose(a=grid, perm=[0, 2, 1]), 3)
            # [bs, 3, N] -> [bs, N, 3] -> [bs, N, 3, 1]
            grid = tf.reshape(grid, [-1, 3, 1])
            # [bs*N, 3, 1]

            grid_row = tf.reshape(grid, [-1, 3])
            print("grid_row")
            print(grid_row.shape)
            x_s = tf.reduce_sum(input_tensor=tf.multiply(h_array[:, 0, :], grid_row), axis=1)
            y_s = tf.reduce_sum(input_tensor=tf.multiply(h_array[:, 1, :], grid_row), axis=1)
            t_s = tf.reduce_sum(input_tensor=tf.multiply(h_array[:, 2, :], grid_row), axis=1)

            # The problem may be here as a general homo does not preserve the parallelism
            # while an affine transformation preserves it.
            t_s_flat = tf.reshape(t_s, [-1])
            t_1 = tf.ones(shape=tf.shape(input=t_s_flat))
            t_0 = tf.zeros(shape=tf.shape(input=t_s_flat))
            sign_t = tf.compat.v1.where(t_s_flat >= 0, t_1, t_0) * 2 - 1
            t_s_flat = t_s_flat + sign_t * 1e-8

            x_s_flat = tf.reshape(x_s, [-1]) / t_s_flat
            y_s_flat = tf.reshape(y_s, [-1]) / t_s_flat

            out_size = (height, width)
            input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, out_size)
            mask_transformed = _interpolate(_mask, x_s_flat, y_s_flat, out_size)

            _warp_image = tf.reshape(input_transformed, tf.stack([num_batch, height, width, num_channels]),
                                     name='output_img')
            _warp_mask = tf.reshape(mask_transformed, tf.stack([num_batch, height, width, num_channels]),
                                    name='output_mask')

            return _warp_image, _warp_mask

    with tf.compat.v1.variable_scope(name):
        # output = _transform(theta, U, out_size)
        _u = _u - 1.
        warp_image, warp_mask = _transform3(theta, _u, mask)
        warp_image = warp_image + 1.
        warp_image = tf.clip_by_value(warp_image, -1, 1)
        return warp_image, warp_mask

In [5]:
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from keras.layers import UpSampling2D


grid_w = GRID_W
grid_h = GRID_H


def transformer_feature(_u, theta, name='SpatialTransformer'):
    def _repeat_feature(x, n_repeats):
        with tf.compat.v1.variable_scope('_repeat'):
            rep = tf.transpose(
                a=tf.expand_dims(tf.ones(shape=tf.stack([n_repeats, ])), 1), perm=[1, 0])
            rep = tf.cast(rep, 'int32')
            x = tf.matmul(tf.reshape(x, (-1, 1)), rep)
            return tf.reshape(x, [-1])

    def _interpolate_feature(im, x, y, out_size):
        with tf.compat.v1.variable_scope('_interpolate'):
            # constants
            num_batch = tf.shape(input=im)[0]
            height = tf.shape(input=im)[1]
            width = tf.shape(input=im)[2]
            channels = tf.shape(input=im)[3]

            x = tf.cast(x, 'float32')
            y = tf.cast(y, 'float32')

            out_height = out_size[0]
            out_width = out_size[1]
            zero = tf.zeros([], dtype='int32')
            max_y = tf.cast(tf.shape(input=im)[1] - 1, 'int32')
            max_x = tf.cast(tf.shape(input=im)[2] - 1, 'int32')

            # scale indices from [-1, 1] to [0, width/height]
            # x = (x + 1.0)*(width_f) / 2.0
            # y = (y + 1.0)*(height_f) / 2.0

            # do sampling
            x0 = tf.cast(tf.floor(x), 'int32')
            x1 = x0 + 1
            y0 = tf.cast(tf.floor(y), 'int32')
            y1 = y0 + 1

            x0 = tf.clip_by_value(x0, zero, max_x)
            x1 = tf.clip_by_value(x1, zero, max_x)
            y0 = tf.clip_by_value(y0, zero, max_y)
            y1 = tf.clip_by_value(y1, zero, max_y)
            dim2 = width
            dim1 = width * height
            base = _repeat_feature(tf.range(num_batch) * dim1, out_height * out_width)
            base_y0 = base + y0 * dim2
            base_y1 = base + y1 * dim2
            idx_a = base_y0 + x0
            idx_b = base_y1 + x0
            idx_c = base_y0 + x1
            idx_d = base_y1 + x1

            # use indices to lookup pixels in the flat image and restore
            # channels dim
            im_flat = tf.reshape(im, tf.stack([-1, channels]))
            im_flat = tf.cast(im_flat, 'float32')
            ia = tf.gather(im_flat, idx_a)
            ib = tf.gather(im_flat, idx_b)
            ic = tf.gather(im_flat, idx_c)
            i_d = tf.gather(im_flat, idx_d)

            # and finally calculate interpolated values
            x0_f = tf.cast(x0, 'float32')
            x1_f = tf.cast(x1, 'float32')
            y0_f = tf.cast(y0, 'float32')
            y1_f = tf.cast(y1, 'float32')
            wa = tf.expand_dims(((x1_f - x) * (y1_f - y)), 1)
            wb = tf.expand_dims(((x1_f - x) * (y - y0_f)), 1)
            wc = tf.expand_dims(((x - x0_f) * (y1_f - y)), 1)
            wd = tf.expand_dims(((x - x0_f) * (y - y0_f)), 1)
            output = tf.add_n([wa * ia, wb * ib, wc * ic, wd * i_d])
            return output

    # input:  batch_size*(grid_h+1)*(grid_w+1)*2
    # output: batch_size*grid_h*grid_w*9
    def get_Hs_feature(_theta, width, height):
        with tf.compat.v1.variable_scope('get_Hs'):
            num_batch = tf.shape(input=_theta)[0]
            h = height / grid_h
            w = width / grid_w
            hs = []
            for i in range(grid_h):
                for j in range(grid_w):
                    hh = i * h
                    ww = j * w
                    ori = tf.tile(
                        tf.constant([ww, hh, ww + w, hh, ww, hh + h, ww + w, hh + h], shape=[1, 8], dtype=tf.float32),
                        multiples=[num_batch, 1])
                    # id = i * (grid_w + 1) + grid_w
                    tar = tf.concat([tf.slice(_theta, [0, i, j, 0], [-1, 1, 1, -1]),
                                     tf.slice(_theta, [0, i, j + 1, 0], [-1, 1, 1, -1]),
                                     tf.slice(_theta, [0, i + 1, j, 0], [-1, 1, 1, -1]),
                                     tf.slice(_theta, [0, i + 1, j + 1, 0], [-1, 1, 1, -1])], axis=1)

                    tar = tf.reshape(tar, [num_batch, 8])

                    hs.append(tf.reshape(tensorDLT_local.solve_DLT(ori, tar), [num_batch, 1, 9]))
            hs = tf.reshape(tf.concat(hs, axis=1), [num_batch, grid_h, grid_w, 9], name='Hs')
        return hs

    def _mesh_grid_feature(height, width):
        x_t = tf.matmul(tf.ones(shape=tf.stack([height, 1])),
                        tf.transpose(a=tf.expand_dims(tf.linspace(0., tf.cast(width, 'float32') - 1.001, width), 1),
                                     perm=[1, 0]))
        y_t = tf.matmul(tf.expand_dims(tf.linspace(0., tf.cast(height, 'float32') - 1.001, height), 1),
                        tf.ones(shape=tf.stack([1, width])))

        x_t_flat = tf.reshape(x_t, (1, -1))
        y_t_flat = tf.reshape(y_t, (1, -1))

        ones = tf.ones_like(x_t_flat)
        grid = tf.concat([x_t_flat, y_t_flat, ones], 0)

        return grid

    def _transform3_feature(_theta, input_dim):
        with tf.compat.v1.variable_scope('_transform'):
            num_batch = tf.shape(input=input_dim)[0]
            height = tf.shape(input=input_dim)[1]
            width = tf.shape(input=input_dim)[2]
            num_channels = tf.shape(input=input_dim)[3]

            # the width/height should be an integral multiple of grid_w/grid_h
            width_float = 32.
            height_float = 24.

            _theta = tf.cast(_theta, 'float32')
            hs = get_Hs_feature(_theta, width_float, height_float)

            ##########################################
            print("Hs")
            print(hs.shape)
            h_array = UpSampling2D(size=(24 / grid_h, 32 / grid_w))(hs)
            h_array = tf.reshape(h_array, [-1, 3, 3])
            ##########################################

            out_height = height
            out_width = width
            grid = _mesh_grid_feature(out_height, out_width)
            grid = tf.expand_dims(grid, 0)
            grid = tf.reshape(grid, [-1])
            grid = tf.tile(grid, tf.stack([num_batch]))  # stack num_batch grids
            grid = tf.reshape(grid, tf.stack([num_batch, 3, -1]))
            print("grid")
            print(grid.shape)
            # [bs, 3, N]

            grid = tf.expand_dims(tf.transpose(a=grid, perm=[0, 2, 1]), 3)
            # [bs, 3, N] -> [bs, N, 3] -> [bs, N, 3, 1]
            grid = tf.reshape(grid, [-1, 3, 1])
            # [bs*N, 3, 1]

            grid_row = tf.reshape(grid, [-1, 3])
            print("grid_row")
            print(grid_row.shape)
            x_s = tf.reduce_sum(input_tensor=tf.multiply(h_array[:, 0, :], grid_row), axis=1)
            y_s = tf.reduce_sum(input_tensor=tf.multiply(h_array[:, 1, :], grid_row), axis=1)
            t_s = tf.reduce_sum(input_tensor=tf.multiply(h_array[:, 2, :], grid_row), axis=1)

            # The problem may be here as a general homo does not preserve the parallelism
            # while an affine transformation preserves it.
            t_s_flat = tf.reshape(t_s, [-1])
            t_1 = tf.ones(shape=tf.shape(input=t_s_flat))
            t_0 = tf.zeros(shape=tf.shape(input=t_s_flat))
            sign_t = tf.compat.v1.where(t_s_flat >= 0, t_1, t_0) * 2 - 1
            t_s_flat = t_s_flat + sign_t * 1e-8

            x_s_flat = tf.reshape(x_s, [-1]) / t_s_flat
            y_s_flat = tf.reshape(y_s, [-1]) / t_s_flat

            out_size = (height, width)
            input_transformed = _interpolate_feature(input_dim, x_s_flat, y_s_flat, out_size)
            # mask_transformed = _interpolate(mask, x_s_flat, y_s_flat, out_size)

            _warp_image = tf.reshape(input_transformed, tf.stack([num_batch, height, width, num_channels]),
                                     name='output_img')

            return _warp_image

    with tf.compat.v1.variable_scope(name):
        # output = _transform(theta, U, out_size)
        # U = U - 1.
        warp_image = _transform3_feature(theta, _u)
        # warp_image = warp_image + 1.
        # warp_image = tf.clip_by_value(warp_image, -1, 1)
        return warp_image

In [6]:
import tensorflow as tf

grid_w = GRID_W
grid_h = GRID_H

min_w = (512 / grid_w) / 8
min_h = (384 / grid_h) / 8


# pixel-level loss (l_num=1 for L1 loss, l_num=2 for L2 loss, ......)
def intensity_loss(gen_frames, gt_frames, l_num):
    return tf.reduce_mean(input_tensor=tf.abs((gen_frames - gt_frames) ** l_num))


# intra-grid constraint
def intra_grid_loss(pts):
    with tf.compat.v1.name_scope('soft_mesh_loss2'):

        delta_x = pts[:, :, 0:grid_w, 0] - pts[:, :, 1:grid_w + 1, 0]
        delta_y = pts[:, 0:grid_h, :, 1] - pts[:, 1:grid_h + 1, :, 1]

        loss_x = tf.nn.relu(delta_x + min_w)
        loss_y = tf.nn.relu(delta_y + min_h)

        loss = tf.reduce_mean(input_tensor=loss_x) + tf.reduce_mean(input_tensor=loss_y)

    return loss


# inter-grid constraint
def inter_grid_loss(train_mesh):
    w_edges = train_mesh[:, :, 0:grid_w, :] - train_mesh[:, :, 1:grid_w + 1, :]
    cos_w = tf.reduce_sum(input_tensor=w_edges[:, :, 0:grid_w - 1, :] * w_edges[:, :, 1:grid_w, :], axis=3) / (
                tf.sqrt(tf.reduce_sum(input_tensor=w_edges[:, :, 0:grid_w - 1, :] * w_edges[:, :, 0:grid_w - 1, :], axis=3)) * tf.sqrt(
                    tf.reduce_sum(input_tensor=w_edges[:, :, 1:grid_w, :] * w_edges[:, :, 1:grid_w, :], axis=3)))
    print("cos_w.shape")
    print(cos_w.shape)
    delta_w_angle = 1 - cos_w

    h_edges = train_mesh[:, 0:grid_h, :, :] - train_mesh[:, 1:grid_h + 1, :, :]
    cos_h = tf.reduce_sum(input_tensor=h_edges[:, 0:grid_h - 1, :, :] * h_edges[:, 1:grid_h, :, :], axis=3) / (
                tf.sqrt(tf.reduce_sum(input_tensor=h_edges[:, 0:grid_h - 1, :, :] * h_edges[:, 0:grid_h - 1, :, :], axis=3)) * tf.sqrt(
                    tf.reduce_sum(input_tensor=h_edges[:, 1:grid_h, :, :] * h_edges[:, 1:grid_h, :, :], axis=3)))

    delta_h_angle = 1 - cos_h

    loss = tf.reduce_mean(input_tensor=delta_w_angle) + tf.reduce_mean(input_tensor=delta_h_angle)

    return loss

In [7]:
import tensorflow as tf
import tf_slim as slim
# from tensorflow.contrib.layers import conv2d
from tensorflow.nn import conv2d


grid_w = GRID_W
grid_h = GRID_H


def shift2mesh(mesh_shift, width, height):
    batch_size = tf.shape(input=mesh_shift)[0]
    h = height / grid_h
    w = width / grid_w
    ori_pt = []
    for i in range(grid_h + 1):
        for j in range(grid_w + 1):
            ww = j * w
            hh = i * h
            p = tf.constant([ww, hh], shape=[2], dtype=tf.float32)
            ori_pt.append(tf.expand_dims(p, 0))
    ori_pt = tf.concat(ori_pt, axis=0)
    ori_pt = tf.reshape(ori_pt, [grid_h + 1, grid_w + 1, 2])
    ori_pt = tf.tile(tf.expand_dims(ori_pt, 0), [batch_size, 1, 1, 1])

    tar_pt = ori_pt + mesh_shift
    # tar_pt = tf.reshape(tar_pt, [batch_size, grid_h+1, grid_w+1, 2])

    return tar_pt


def RectanglingNetwork(train_input, train_mask, width=512., height=384.):
    mesh_shift_primary, mesh_shift_final = build_model(train_input, train_mask)

    mesh_primary = shift2mesh(mesh_shift_primary, width, height)
    mesh_final = shift2mesh(mesh_shift_final + mesh_shift_primary, width, height)

    warp_image_primary, warp_mask_primary = tf_spatial_transform_local.transformer(train_input, train_mask,
                                                                                   mesh_primary)
    warp_image_final, warp_mask_final = tf_spatial_transform_local.transformer(train_input, train_mask, mesh_final)

    return mesh_primary, warp_image_primary, warp_mask_primary, mesh_final, warp_image_final, warp_mask_final


# feature extraction module
def feature_extractor(image_tf):
    feature = []
    # 512*384
    with tf.compat.v1.variable_scope('conv_block1'):
        conv1 = conv2d(inputs=image_tf, num_outputs=64, kernel_size=3, rate=1, activation_fn=tf.nn.relu)
        conv1 = conv2d(inputs=conv1, num_outputs=64, kernel_size=3, rate=1, activation_fn=tf.nn.relu)
        maxpool1 = slim.max_pool2d(conv1, 2, stride=2, padding='SAME')
    # 256*192
    with tf.compat.v1.variable_scope('conv_block2'):
        conv2 = conv2d(inputs=maxpool1, num_outputs=64, kernel_size=3, activation_fn=tf.nn.relu)
        conv2 = conv2d(inputs=conv2, num_outputs=64, kernel_size=3, activation_fn=tf.nn.relu)
        maxpool2 = slim.max_pool2d(conv2, 2, stride=2, padding='SAME')
    # 128*96
    with tf.compat.v1.variable_scope('conv_block3'):
        conv3 = conv2d(inputs=maxpool2, num_outputs=128, kernel_size=3, activation_fn=tf.nn.relu)
        conv3 = conv2d(inputs=conv3, num_outputs=128, kernel_size=3, activation_fn=tf.nn.relu)
        maxpool3 = slim.max_pool2d(conv3, 2, stride=2, padding='SAME')
    # 64*48
    with tf.compat.v1.variable_scope('conv_block4'):
        conv4 = conv2d(inputs=maxpool3, num_outputs=128, kernel_size=3, activation_fn=tf.nn.relu)
        conv4 = conv2d(inputs=conv4, num_outputs=128, kernel_size=3, activation_fn=tf.nn.relu)
        feature.append(conv4)

    return feature


# mesh motion regression module
def regression_Net(correlation):
    conv1 = conv2d(inputs=correlation, num_outputs=256, kernel_size=3, activation_fn=tf.nn.relu)
    conv1 = conv2d(inputs=conv1, num_outputs=256, kernel_size=3, activation_fn=tf.nn.relu)

    maxpool1 = slim.max_pool2d(conv1, 2, stride=2, padding='SAME')  # 16
    conv2 = conv2d(inputs=maxpool1, num_outputs=256, kernel_size=3, activation_fn=tf.nn.relu)
    conv2 = conv2d(inputs=conv2, num_outputs=256, kernel_size=3, activation_fn=tf.nn.relu)

    maxpool2 = slim.max_pool2d(conv2, 2, stride=2, padding='SAME')  # 8
    conv3 = conv2d(inputs=maxpool2, num_outputs=512, kernel_size=3, activation_fn=tf.nn.relu)
    conv3 = conv2d(inputs=conv3, num_outputs=512, kernel_size=3, activation_fn=tf.nn.relu)

    maxpool3 = slim.max_pool2d(conv3, 2, stride=2, padding='SAME')  # 4
    conv4 = conv2d(inputs=maxpool3, num_outputs=512, kernel_size=3, activation_fn=tf.nn.relu)
    conv4 = conv2d(inputs=conv4, num_outputs=512, kernel_size=3, activation_fn=tf.nn.relu)

    fc1 = conv2d(inputs=conv4, num_outputs=2048, kernel_size=[3, 4], activation_fn=tf.nn.relu, padding="VALID")
    fc2 = conv2d(inputs=fc1, num_outputs=1024, kernel_size=1, activation_fn=tf.nn.relu)
    fc3 = conv2d(inputs=fc2, num_outputs=(grid_w + 1) * (grid_h + 1) * 2, kernel_size=1, activation_fn=None)
    # net3_f = tf.expand_dims(tf.squeeze(tf.squeeze(fc3,1),1), [2])
    net3_f_local = tf.reshape(fc3, (-1, grid_h + 1, grid_w + 1, 2))

    return net3_f_local


def build_model(train_input, train_mask):
    with tf.compat.v1.variable_scope('model'):

        with tf.compat.v1.variable_scope('feature_extract', reuse=None):
            features = feature_extractor(tf.concat([train_input, train_mask], axis=3))

        feature = tf.image.resize(features[-1], [24, 32], method=0)
        with tf.compat.v1.variable_scope('regression_coarse', reuse=None):
            mesh_shift_primary = regression_Net(feature)

        with tf.compat.v1.variable_scope('regression_fine', reuse=None):
            mesh_primary = shift2mesh(mesh_shift_primary / 16, 32., 24.)
            feature_warp = tf_spatial_transform_local_feature.transformer_feature(feature, mesh_primary)
            mesh_shift_final = regression_Net(feature_warp)

        return mesh_shift_primary, mesh_shift_final

In [8]:
import tensorflow as tf
import numpy as np
from collections import OrderedDict
import sys
import os
import glob
import cv2

rng = np.random.RandomState(2017)


class DataLoader(object):
    def __init__(self, data_folder):
        self.dir = data_folder
        self.datas = OrderedDict()
        self.setup()

    def __call__(self, batch_size):
        data_info_list = list(self.datas.values())
        length = data_info_list[0]['length']

        def data_clip_generator():
            while True:
                data_clip = []
                frame_id = rng.randint(0, length - 1)
                # inputs

                input_img = np_load_frame(data_info_list[1]['frame'][frame_id], 384, 512)
                mask_img = np_load_frame(data_info_list[2]['frame'][frame_id], 384, 512)
                gt_img = np_load_frame(data_info_list[0]['frame'][frame_id], 384, 512)

                data_clip.append(input_img)
                data_clip.append(mask_img)
                data_clip.append(gt_img)
                data_clip = np.concatenate(data_clip, axis=2)

                yield data_clip

                # creating augmentations

                data_clip = []

                flipped_input = np.fliplr(input_img)
                flipped_mask = np.fliplr(mask_img)
                flipped_gt = np.fliplr(gt_img)

                data_clip.append(flipped_input)
                data_clip.append(flipped_mask)
                data_clip.append(flipped_gt)
                data_clip = np.concatenate(data_clip, axis=2)

                yield data_clip

        dataset = tf.data.Dataset.from_generator(generator=data_clip_generator, output_types=tf.float32,
                                                 output_shapes=[384, 512, 9])

        print('generator dataset, {}'.format(dataset))
        dataset = dataset.prefetch(buffer_size=128)
        dataset = dataset.shuffle(buffer_size=128).batch(batch_size)
        print('epoch dataset, {}'.format(dataset))

        return dataset

    def __getitem__(self, data_name):
        assert data_name in self.datas.keys(), 'data = {} is not in {}!'.format(data_name, self.datas.keys())
        return self.datas[data_name]

    def setup(self):
        datas = glob.glob(os.path.join(self.dir, '*'))
        for data in sorted(datas):

            if sys.platform[:3] == 'win':
                data_name = data.split('\\')[-1]
            else:
                data_name = data.split('/')[-1]

            if data_name == 'gt' or data_name == 'input' or data_name == 'mask':
                self.datas[data_name] = {}
                self.datas[data_name]['path'] = data
                self.datas[data_name]['frame'] = glob.glob(os.path.join(data, '*.jpg'))
                self.datas[data_name]['frame'].sort()
                self.datas[data_name]['length'] = len(self.datas[data_name]['frame'])

        print(self.datas.keys())

    def get_data_clips(self, index):
        batch = []
        data_info_list = list(self.datas.values())

        batch.append(np_load_frame(data_info_list[1]['frame'][index], 384, 512))
        batch.append(np_load_frame(data_info_list[2]['frame'][index], 384, 512))
        batch.append(np_load_frame(data_info_list[0]['frame'][index], 384, 512))

        return np.concatenate(batch, axis=2)


def np_load_frame(filename, resize_height, resize_width):
    image_decoded = cv2.imread(filename)

    if resize_height is not None:
        image_resized = cv2.resize(image_decoded, (resize_width, resize_height))
    else:
        image_resized = image_decoded

    image_resized = image_resized.astype(dtype=np.float32)
    image_resized = (image_resized / 127.5) - 1.0
    return image_resized


def load(saver, sess, ckpt_path):
    print(ckpt_path)
    saver.restore(sess, ckpt_path)
    print("Restored model parameters from {}".format(ckpt_path))


def save(saver, sess, logdir, step):
    model_name = 'model.ckpt'
    checkpoint_path = os.path.join(logdir, model_name)
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    saver.save(sess, checkpoint_path, global_step=step, save_format='h5')
    print('The checkpoint has been created.')

In [9]:
import tensorflow as tf
import os
import numpy as np
import cv2 as cv

os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = GPU

test_folder = TEST_FOLDER
batch_size = TEST_BATCH_SIZE
grid_w = GRID_W
grid_h = GRID_H


def draw_mesh_on_warp(warp, f_local):
    min_w = np.minimum(np.min(f_local[:, :, 0]), 0).astype(np.int32)
    max_w = np.maximum(np.max(f_local[:, :, 0]), 512).astype(np.int32)
    min_h = np.minimum(np.min(f_local[:, :, 1]), 0).astype(np.int32)
    max_h = np.maximum(np.max(f_local[:, :, 1]), 384).astype(np.int32)
    cw = max_w - min_w
    ch = max_h - min_h

    pic = np.ones([ch + 10, cw + 10, 3], np.int32) * 255
    pic[0 - min_h + 5:0 - min_h + 384 + 5, 0 - min_w + 5:0 - min_w + 512 + 5, :] = warp

    warp = pic
    f_local[:, :, 0] = f_local[:, :, 0] - min_w + 5
    f_local[:, :, 1] = f_local[:, :, 1] - min_h + 5

    point_color = (0, 255, 0)  # BGR
    thickness = 2
    line_type = 8
    num = 1

    for i in range(grid_h + 1):
        for j in range(grid_w + 1):
            num = num + 1
            if j == grid_w and i == grid_h:
                continue
            elif j == grid_w:
                cv.line(warp, (f_local[i, j, 0], f_local[i, j, 1]), (f_local[i + 1, j, 0], f_local[i + 1, j, 1]),
                        point_color, thickness, line_type)
            elif i == grid_h:
                cv.line(warp, (f_local[i, j, 0], f_local[i, j, 1]), (f_local[i, j + 1, 0], f_local[i, j + 1, 1]),
                        point_color, thickness, line_type)
            else:
                cv.line(warp, (f_local[i, j, 0], f_local[i, j, 1]), (f_local[i + 1, j, 0], f_local[i + 1, j, 1]),
                        point_color, thickness, line_type)
                cv.line(warp, (f_local[i, j, 0], f_local[i, j, 1]), (f_local[i, j + 1, 0], f_local[i, j + 1, 1]),
                        point_color, thickness, line_type)

    return warp


snapshot_dir = '/content/drive/MyDrive/Colab Notebooks/checkpoints/pretrained_model/model.ckpt-100000'

# define dataset
with tf.compat.v1.name_scope('dataset'):
    # ----------- testing ----------- #
    tf.compat.v1.disable_eager_execution()
    test_inputs_clips_tensor = tf.compat.v1.placeholder(shape=[batch_size, None, None, 3 * 3], dtype=tf.float32)

    test_input = test_inputs_clips_tensor[..., 0:3]
    test_mask = test_inputs_clips_tensor[..., 3:6]
    test_gt = test_inputs_clips_tensor[..., 6:9]

    print('test input = {}'.format(test_input))
    print('test mask = {}'.format(test_mask))
    print('test gt = {}'.format(test_gt))

# define testing generator function
with tf.compat.v1.variable_scope('generator', reuse=None):
    print('testing = {}'.format(tf.compat.v1.get_variable_scope().name))
    test_mesh_primary, test_warp_image_primary, test_warp_mask_primary, test_mesh_final, test_warp_image_final, \
        test_warp_mask_final = RectanglingNetwork(test_input, test_mask)

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
with tf.compat.v1.Session(config=config) as sess:
    # dataset
    input_loader = DataLoader(test_folder)

    # initialize weights
    sess.run(tf.compat.v1.global_variables_initializer())
    print('Init global successfully!')

    # tf saver
    saver = tf.compat.v1.train.Saver(var_list=tf.compat.v1.global_variables(), max_to_keep=None)

    restore_var = [v for v in tf.compat.v1.global_variables()]
    loader = tf.compat.v1.train.Saver(var_list=restore_var)


    def inference_func(ckpt):
        print("============")
        print(ckpt)
        load(loader, sess, ckpt)
        print("============")
        length = 519  # len(os.listdir(test_folder+"/input"))

        for i in range(0, length):
            input_clip = np.expand_dims(input_loader.get_data_clips(i), axis=0)

            mesh_primary, warp_image_primary, warp_mask_primary, mesh_final, warp_image_final, warp_mask_final = \
                sess.run([test_mesh_primary, test_warp_image_primary, test_warp_mask_primary, test_mesh_final,
                          test_warp_image_final, test_warp_mask_final],
                         feed_dict={test_inputs_clips_tensor: input_clip})

            mesh = mesh_final[0]
            input_image = (input_clip[0, :, :, 0:3] + 1) / 2 * 255

            input_image = draw_mesh_on_warp(input_image, mesh)
            # input_mask = draw_mesh_on_warp(np.ones([384, 512, 3], np.int32)*255, mesh)

            path = "../final_mesh/" + str(i + 1).zfill(5) + ".jpg"
            cv.imwrite(path, input_image)

            # path = "../mesh_mask/" + str(i+1).zfill(5) + ".jpg"
            # cv.imwrite(path, input_mask)

            print('i = {} / {}'.format(i + 1, length))


    inference_func(snapshot_dir)

test input = Tensor("strided_slice:0", shape=(1, None, None, 3), dtype=float32)
test mask = Tensor("strided_slice_1:0", shape=(1, None, None, 3), dtype=float32)
test gt = Tensor("strided_slice_2:0", shape=(1, None, None, 3), dtype=float32)
testing = generator


TypeError: ignored

In [None]:
import numpy as np
import tensorflow as tf
import os
import cv2
import skimage
from skimage import metrics

os.environ['CUDA_DEVICES_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = GPU

test_folder = TEST_FOLDER
batch_size = TEST_BATCH_SIZE

snapshot_dir = SNAPSHOT_DIR + '/pretrained_model/model.ckpt-100000'

# define dataset
with tf.compat.v1.name_scope('dataset'):
    # --------- testing --------- #
    test_inputs_clips_tensor = tf.compat.v1.placeholder(shape=[batch_size, None, None, 3 * 3], dtype=tf.float32)

    test_input = test_inputs_clips_tensor[..., 0:3]
    test_mask = test_inputs_clips_tensor[..., 3:6]
    test_gt = test_inputs_clips_tensor[..., 6:9]

    print('test input = {}'.format(test_input))
    print('test mask = {}'.format(test_mask))
    print('test gt = {}'.format(test_gt))

# define testing generator function 
with tf.compat.v1.variable_scope('generator', reuse=None):
    print('testing = {}'.format(tf.compat.v1.get_variable_scope().name))
    test_mesh_primary, test_warp_image_primary, test_warp_mask_primary, test_mesh_final, test_warp_image_final, \
        test_warp_mask_final = RectanglingNetwork(test_input, test_mask)

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
with tf.compat.v1.Session(config=config) as sess:
    # dataset
    input_loader = DataLoader(test_folder)

    # initialize weights
    sess.run(tf.compat.v1.global_variables_initializer())
    print('Init global successfully!')

    # tf saver
    saver = tf.compat.v1.train.Saver(var_list=tf.compat.v1.global_variables(), max_to_keep=None)

    restore_var = [v for v in tf.compat.v1.global_variables()]
    loader = tf.compat.v1.train.Saver(var_list=restore_var)


    def inference_func(ckpt):
        print("============")
        print(ckpt)
        load(loader, sess, ckpt)
        print("============")
        length = 519  # len(os.listdir(test_folder+"/input"))
        psnr_list = []
        ssim_list = []

        for i in range(0, length):
            input_clip = np.expand_dims(input_loader.get_data_clips(i), axis=0)

            mesh_primary, warp_image_primary, warp_mask_primary, mesh_final, warp_image_final, warp_mask_final = \
                sess.run([test_mesh_primary, test_warp_image_primary, test_warp_mask_primary, test_mesh_final,
                          test_warp_image_final, test_warp_mask_final],
                         feed_dict={test_inputs_clips_tensor: input_clip})

            warp_image = (warp_image_final[0] + 1) * 127.5
            warp_gt = (input_clip[0, :, :, 6:9] + 1) * 127.5

            psnr = skimage.metrics.peak_signal_noise_ratio(warp_image, warp_gt, data_range=255)
            ssim = skimage.metrics.structural_similarity(warp_image, warp_gt, data_range=255, multichannel=True)

            path = "../final_rectangling/" + str(i + 1).zfill(5) + ".jpg"
            cv2.imwrite(path, warp_image)

            print('i = {} / {}, psnr = {:.6f}'.format(i + 1, length, psnr))

            psnr_list.append(psnr)
            ssim_list.append(ssim)

        print("===================Results Analysis==================")
        print('average psnr:', np.mean(psnr_list))
        print('average ssim:', np.mean(ssim_list))
        # as for FID, we use the CODE from https://github.com/bioinf-jku/TTUR to evaluate


    inference_func(snapshot_dir)