In [1]:
import numpy as np
import tensorflow as tf
import os
import sys
import glob
from random import shuffle
import cv2
from skimage import color
from bilinear_sampler import bilinear_sampler
from stn import spatial_transformer_network as transformer
import matplotlib.pyplot as plt


#Image path
INPUT_PATH = 'flower/Flowers_8bit/*.png'
TEST_PATH = 'flower_test/*.png'

#LF Parameter
ANGULAR_RES_X = 14;
ANGULAR_RES_Y = 14;
ANGULAR_RES_TARGET = 8*8
IMG_WIDTH = 3584
IMG_HEIGHT = 2688
SPATIAL_HEIGHT = int(IMG_HEIGHT / ANGULAR_RES_Y)
SPATIAL_WIDTH = int(IMG_WIDTH / ANGULAR_RES_X)
CH_INPUT = 3
CH_OUTPUT = 3
SHIFT_VALUE = 1.0

#Training parameter
BATCH_SIZE = 1
TRAIN_SIZE = 1.0
LR_G =  0.001 # 0.001 0.0005 0.00146 0.0002
EPOCH = 180000
DECAY_STEP = 5000
LAMBDA_L1 = 100.0
LAMBDA_PIXEL = 50.0
LAMBDA_TV = 1e-7


#LF INDEX
# +---+----+----+----+----+----+----+----+
# | 0 | 8  | 16 | 24 | 32 | 40 | 48 | 56 |
# +---+----+----+----+----+----+----+----+
# | 1 | 9  | 17 | 25 | 33 | 41 | 49 | 57 |
# +---+----+----+----+----+----+----+----+
# | 2 | 10 | 18 | 26 | 34 | 42 | 50 | 58 |
# +---+----+----+----+----+----+----+----+
# | 3 | 11 | 19 | 27 | 35 | 43 | 51 | 59 |
# +---+----+----+----+----+----+----+----+
# | 4 | 12 | 20 | 28 | 36 | 44 | 52 | 60 |
# +---+----+----+----+----+----+----+----+
# | 5 | 13 | 21 | 29 | 37 | 45 | 53 | 61 |
# +---+----+----+----+----+----+----+----+
# | 6 | 14 | 22 | 30 | 38 | 46 | 54 | 62 |
# +---+----+----+----+----+----+----+----+
# | 7 | 15 | 23 | 31 | 39 | 47 | 55 | 63 |
# +---+----+----+----+----+----+----+----+

# UTILITIES

In [2]:
#Shift pixels
def tf_image_translate(images, tx, ty, interpolation='BILINEAR'):
    # got these parameters from solving the equations for pixel translations
    # on https://www.tensorflow.org/api_docs/python/tf/contrib/image/transform
    
    #+tx -> shift to left +ty ->shift up
    transforms = [1, 0, tx, 0, 1, ty, 0, 0]
    return tf.contrib.image.transform(images, transforms, interpolation)
    
    #+tx -> shift to right +ty ->shift down
    #translate = [-tx, -ty]
    #return tf.contrib.image.translate(images, translate, interpolation)

def preprocess(image):
    with tf.name_scope("preprocess"):
        # [0, 1] => [-1, 1]
        return image * 2 - 1
    
def deprocess(image):
    with tf.name_scope("deprocess"):
        # [-1, 1] => [0, 1]
        return (image + 1) / 2
    
#Input raw png LF, output center LF, 8x8 grid LF, and stacked LF in channel axis
def process_LF(lf):    
    full_LF_crop = np.zeros((SPATIAL_HEIGHT, SPATIAL_WIDTH, 3, ANGULAR_RES_Y, ANGULAR_RES_X))
    for ax in range(ANGULAR_RES_X):
        for ay in range(ANGULAR_RES_Y):
            resized = lf[ay::ANGULAR_RES_Y, ax::ANGULAR_RES_X, :]
            resized2 = cv2.resize(resized, dsize=(SPATIAL_WIDTH, SPATIAL_HEIGHT), interpolation=cv2.INTER_LINEAR)
            full_LF_crop[:, :, :, ay, ax] = resized2
            
    #Take 8x8 LF on the middle, since the side part suffer from vignetting
    middle_LF = full_LF_crop[:, :, :, 3:11, 3:11] # Take 8x8 LF in 5D
    
    #To visualize the 8x8 LF
    for ax in range(8):
        for ay in range(8):
            if ay == 0:
                y_img = middle_LF[:,:,:,ay,ax]
            else:
                y_img = np.concatenate((y_img, middle_LF[:,:,:,ay,ax]), 0)
            
            if ax == 0 and ay ==0:
                LF_stack = middle_LF[:,:,:,0,0]
            else:
                LF_stack = np.concatenate((LF_stack, middle_LF[:,:,:,ay,ax]), 2)
        if ax == 0:
            LF_grid = y_img
        else:
            LF_grid = np.concatenate((LF_grid, y_img), 1)
        y_img = middle_LF[:,:,:,ay,ax]
    
    center_view = middle_LF[:,:,:,3,3]
    return center_view, LF_stack, LF_grid


# INPUT PIPELINE

In [3]:
with tf.name_scope('Input_Pipeline'):
    #Augment input
    
    #TRAIN CASE
    gamma_val = tf.random_uniform(shape=[], minval=0.4, maxval=1.0) 
    #TEST CASE
    #gamma_val = tf.random_uniform(shape=[], minval=0.4, maxval=0.5)  
    
    #X Single RGB image
    tf_x = tf.placeholder(tf.float32, [None, SPATIAL_HEIGHT, SPATIAL_WIDTH, CH_INPUT], name='Input')
    tf_x = tf.image.adjust_gamma(tf_x, gamma_val)
    view_image = tf.summary.image('input', tf.reshape(tf_x, [-1, SPATIAL_HEIGHT, SPATIAL_WIDTH, CH_INPUT]), 1)
    image = tf.reshape(tf_x, [-1, SPATIAL_HEIGHT, SPATIAL_WIDTH, CH_INPUT], name='img_x')# (batch, height, width, channel)
    image_min = preprocess(image) #-1..1
    
    #LF GT in grid style for visualization purpose only
    tf_grid = tf.placeholder(tf.float32, [None, SPATIAL_HEIGHT*8, SPATIAL_WIDTH*8, CH_OUTPUT], name='Grids')
    tf_grid = tf.image.adjust_gamma(tf_grid, gamma_val)
    label_image = tf.summary.image('GT', tf.reshape(tf_grid, [-1, SPATIAL_HEIGHT*8, SPATIAL_WIDTH*8, CH_OUTPUT]), 1)
    
    #Y LF GT stacked in channel direction for loss
    tf_y = tf.placeholder(tf.float32, [None, SPATIAL_HEIGHT, SPATIAL_WIDTH, CH_OUTPUT*64], name='Target')
    tf_y = tf.image.adjust_gamma(tf_y, gamma_val)
    color_norm = tf.reshape(tf_y, [-1, SPATIAL_HEIGHT, SPATIAL_WIDTH, CH_OUTPUT*64], name='img_y')# (batch, height, width, channel)
    color_norm_min = preprocess(color_norm)



# PREPARE DATA


In [4]:
#Wrapper function
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def load_image(addr):
    # cv2 load images as BGR, convert it to RGB
    img = cv2.imread(addr)
    if img is None:
        return None
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    center_view, GT, grid = process_LF(img)
    
    center_view = np.uint8(center_view)
    grid = np.uint8(grid)
    GT = np.uint8(GT)
    
    return center_view, GT, grid

# DATASET RECORD

In [5]:
def createDataRecord(out_filename, addrs):
    # open the TFRecords file
    writer = tf.python_io.TFRecordWriter(out_filename)
    for i in range(len(addrs)):
        # print how many images are loaded every # images
        if not i % 10:
            print('Train data: {}/{} images'.format(i, len(addrs)))
            sys.stdout.flush()
        # Load the image
        img, label, grid = load_image(addrs[i]) 
        
        if img is None:
            continue

        if label is None:
            continue
            
        if grid is None:
            continue

        # Create a feature
        feature = {
            'image_raw': _bytes_feature(img.tostring()),
            'label': _bytes_feature(label.tostring()),
            'grid': _bytes_feature(grid.tostring())
        }
        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature=feature))
        
        # Serialize to string and write on the file
        writer.write(example.SerializeToString())
        
    writer.close()
    sys.stdout.flush()

# CREATE DATA FOR TRAINING

In [6]:
# %%time
# with tf.name_scope('Data_Folder_Read'):
#     input_path = INPUT_PATH
#     addrs = sorted(glob.glob(input_path))
    
# with tf.name_scope('Shuffle_Data'):
#     # to shuffle data
#     c = list(addrs)
#     shuffle(c)
#     addrs = c

# with tf.name_scope('Create_Datarecord_Train'):
#     # Divide the data into % train and % test
#     #train_addrs = addrs[0:int(TRAIN_SIZE*len(addrs))]
#     #createDataRecord('train.tfrecords', train_addrs)
    
#     train_addrs = addrs[0:int(TRAIN_SIZE*len(addrs))]
#     createDataRecord('train.tfrecords', train_addrs)

# NETWORK STRUCTURE [SYNTHESIS]

In [7]:
#Return shifting value based on the angular coordinate
def shift_value(i):
    if i<=7:
        tx = 3*SHIFT_VALUE
    elif i>7 and i<=15:
        tx = 2*SHIFT_VALUE
    elif i>15 and i<=23:
        tx = 1*SHIFT_VALUE
    elif i>23 and i<=31:
        tx = 0
    elif i>31 and i<=39:
        tx = -1*SHIFT_VALUE
    elif i>39 and i<=47:
        tx = -2*SHIFT_VALUE
    elif i>47 and i<=55:
        tx = -3*SHIFT_VALUE
    else:
        tx = -4*SHIFT_VALUE
    
    if i==0 or (i%8==0 and i>7):
        ty = 3*SHIFT_VALUE
    elif i == 1 or (i-1)%8==0:
        ty = 2*SHIFT_VALUE
    elif i == 2 or (i-2)%8==0:
        ty = 1*SHIFT_VALUE
    elif i == 3 or (i-3)%8==0:
        ty = 0
    elif i == 4 or (i-4)%8==0:
        ty = -1*SHIFT_VALUE
    elif i == 5 or (i-5)%8==0:
        ty = -2*SHIFT_VALUE
    elif i == 6 or (i-6)%8==0:
        ty = -3*SHIFT_VALUE
    else:
        ty = -4*SHIFT_VALUE
        
    return tx, ty

def add_layer(input_=None, rate=1):
    c = tf.nn.relu(input_)
    c = tf.layers.conv2d(c, 12, 3, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.he_normal(), dilation_rate=rate)
    return tf.concat([input_, c], -1)

def transition(input_=None):
    shape = input_.get_shape().as_list()
    filters = shape[-1]
    c = tf.nn.relu(input_)
    c = tf.layers.conv2d(c, filters, 3, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.he_normal())
    #No average pooling
    return c

def flow_layer(input_=None):
    shape = input_.get_shape().as_list()
    filters = shape[-1]
    c = tf.nn.relu(input_)
    c = tf.layers.conv2d(c, filters, 3, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.he_normal())
    c = tf.layers.conv2d(c, filters, 3, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.he_normal())    
    c = tf.layers.conv2d(c, (ANGULAR_RES_TARGET)*2, 3, padding='SAME', activation=None, kernel_initializer=tf.keras.initializers.he_normal())
    return c
    
def LF_Synthesis(input_=None):
    with tf.name_scope('Initial_Conv'):
        conv = tf.layers.conv2d(input_, 8, 3, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.he_normal(), name='INITIAL_CONV_1')
        conv = tf.layers.conv2d(conv, 16, 3, padding='SAME', activation=tf.nn.relu, kernel_initializer=tf.keras.initializers.he_normal(), name='INITIAL_CONV_3')
        conv_prob = conv
        
    with tf.name_scope('Flow_Generator'): #Dense Net
        with tf.name_scope('Block_1'):
            for i in range(3):
                conv = add_layer(conv, 2)
            conv = transition(conv)

        with tf.name_scope('Block_2'):
            for i in range(3):
                conv = add_layer(conv, 4)
            conv = transition(conv)

        with tf.name_scope('Block_3'):
            for i in range(3):
                conv = add_layer(conv, 8)
            conv = transition(conv)

        with tf.name_scope('Block_4'):
            for i in range(3):
                conv = add_layer(conv,16)
            conv = transition(conv)

        with tf.name_scope('Flow'):
            flow_LF = flow_layer(conv)
        
    ###################################################################################################

    #Synthesize LF by element wise multiplication with flow
    with tf.name_scope('Estimation_Layer'):
        yuv = tf.image.rgb_to_yuv(image)
        y = yuv[:,:,:,0:1]
        y = preprocess(y)
        for i in range(ANGULAR_RES_TARGET):
         
            tx, ty = shift_value(i)
            image_shift = tf_image_translate(image_min, tx, ty)
            y_shift = tf_image_translate(y, tx, ty)
            
            if i==0:
                pred_LF = bilinear_sampler(image_shift, flow_LF[:, :, :, i*2:(i*2)+2])
                
                pred_LF_loss = bilinear_sampler(y_shift, flow_LF[:, :, :, i*2:(i*2)+2])
            elif i==27: #Input
                pred_LF = tf.concat((pred_LF, image_min), -1)
                
                pred_LF_loss = tf.concat((pred_LF_loss, y), -1)
            else:
                trans_image = bilinear_sampler(image_shift, flow_LF[:, :, :, i*2:(i*2)+2])
                pred_LF = tf.concat((pred_LF, trans_image), -1)
                
                trans_image = bilinear_sampler(y_shift, flow_LF[:, :, :, i*2:(i*2)+2])
                pred_LF_loss = tf.concat((pred_LF_loss, trans_image), -1)  
                
    print('pred_LF',pred_LF.shape)           
    print('pred_LF_loss',pred_LF_loss.shape) 
    return pred_LF, pred_LF_loss, flow_LF
    

# CREATE NETWORK

In [8]:
with tf.name_scope('View_Synthesis'):
    with tf.name_scope('Main_Network'):
        #image_norm = tf.divide(image, 255)
        yuv = tf.image.rgb_to_yuv(image)
        Y = tf.summary.image('Y', tf.reshape(yuv[:,:,:,0:1], [-1, SPATIAL_HEIGHT, SPATIAL_WIDTH, 1]), 1)
        y = preprocess(yuv[:,:,:,0:1])
        pred_LF, pred_LF_loss, flow_LF = LF_Synthesis(y)
        pred_LF_norm = deprocess(pred_LF)
    #####################################################################################
    with tf.name_scope('EPI_Slicing'):
        
        #Convert GT LF into Luma GT
        for i in range(ANGULAR_RES_TARGET):  
            temp = tf.image.rgb_to_yuv(color_norm[:,:,:,i*3:(i*3)+3])
            if i == 0:
                y_GT = temp[:,:,:,0:1]
            else:
                y_GT = tf.concat([y_GT, temp[:,:,:,0:1]], -1)
        y_GT = preprocess(y_GT)    
        
        #Cross spatial EPI slice
        center_width = int(SPATIAL_WIDTH/2)
        center_height = int(SPATIAL_HEIGHT/2)

        slice_epi_H = pred_LF_loss[:, center_height:center_height+1, :, :]
        slice_epi_GT_H = y_GT[:, center_height:center_height+1, :, :]

        slice_epi_V = pred_LF_loss[:, :, center_width:center_width+1, :]
        slice_epi_GT_V = y_GT[:, :, center_width:center_width+1, :]

        #Because of the stack is in row order for horizontal EPI it cannot be directly reshaped
        for j in range(8):
            for i in range(8):
                temp = slice_epi_H[:, :, :, (i*8)+(j):(i*8)+(j)+1]
                temp2 = slice_epi_GT_H[:, :, :, (i*8)+(j):(i*8)+(j)+1]
                temp3 = pred_LF_loss[:, :, :, (i*8)+(j):(i*8)+(j)+1]
                temp4 = y_GT[:, :, :, (i*8)+(j):(i*8)+(j)+1]
                
                if i==0 and j==0:
                    epi_synth_H = temp
                    epi_GT_H = temp2
                    pred_LF_loss_H = temp3
                    y_GT_H = temp4

                else:
                    epi_synth_H = tf.concat([epi_synth_H,temp], 1)
                    epi_GT_H = tf.concat([epi_GT_H,temp2], 1)
                    pred_LF_loss_H = tf.concat([pred_LF_loss_H,temp3], -1)
                    y_GT_H = tf.concat([y_GT_H,temp4], -1)

        
        #Vertical EPI fit the image stack row order and can be directly obtained with just reshape
        epi_synth_V = tf.reshape(slice_epi_V, [-1,SPATIAL_HEIGHT, 64, 1])
        epi_GT_V = tf.reshape(slice_epi_GT_V, [-1,SPATIAL_HEIGHT, 64, 1])
        
        print('pred_LF_loss_H', pred_LF_loss_H.shape)
        print('y_GT_H', y_GT_H.shape)
        #####################################################################################

pred_LF (?, 192, 256, 192)
pred_LF_loss (?, 192, 256, 64)
pred_LF_loss_H (?, 192, 256, 64)
y_GT_H (?, 192, 256, 64)


# LOSS

In [9]:
######################################################################################
with tf.name_scope('Pixel_Based_Loss'):
    #EPI based loss in horizontal and vertical direction sampled every 8 angular images
    with tf.name_scope('L_Loss'):
        pixel_loss_V = pixel_loss_H = 0
        for i in range(8):  
            temp1 = tf.reduce_sum(y_GT[:,:,:,(i*8):(i*8)+8], axis=-1)
            temp2 = tf.reduce_sum(pred_LF_loss[:,:,:,(i*8):(i*8)+8], axis=-1)
            pixel_loss_V += tf.losses.absolute_difference(temp1, temp2)
            
            temp3 = tf.reduce_sum(y_GT_H[:,:,:,(i*8):(i*8)+8], axis=-1)
            temp4 = tf.reduce_sum(pred_LF_loss_H[:,:,:,(i*8):(i*8)+8], axis=-1)
            pixel_loss_H += tf.losses.absolute_difference(temp3, temp4)
            
        tf.summary.scalar('pixel_loss_V', pixel_loss_V)
        tf.summary.scalar('pixel_loss_H', pixel_loss_H)
        
        pixel_wise_loss = tf.losses.mean_squared_error(y_GT, pred_LF_loss)
        tf.summary.scalar('pixel_wise_loss', pixel_wise_loss)
    # Total variation loss for flow to surpress amount of artifact and smooth flow
    with tf.name_scope('Total_Variation_Loss'):    
        tv_loss_x = tf.reduce_mean(tf.image.total_variation(flow_LF[:,:,:,0::2]))
        tv_loss_y = tf.reduce_mean(tf.image.total_variation(flow_LF[:,:,:,1::2]))
        tv_loss = tf.reduce_mean(tv_loss_x + tv_loss_y)
        tf.summary.scalar('TV_loss', tv_loss)

    #Cross EPI loss. TODO: Can be removed redundant with EPI based loss
#     with tf.name_scope('EPI_Loss'): 
#         epi_loss_H = tf.losses.absolute_difference(epi_GT_H, epi_synth_H)
#         epi_loss_V = tf.losses.absolute_difference(epi_GT_V, epi_synth_V)
#         epi_loss_total = epi_loss_H + epi_loss_V
#         tf.summary.scalar('EPI_loss', epi_loss_total)
        
    
with tf.name_scope('Total_Loss'):
    Total_Loss = (LAMBDA_L1 * pixel_loss_V) + (LAMBDA_L1 * pixel_loss_H) \
                + (LAMBDA_TV * tv_loss) + (LAMBDA_PIXEL * pixel_wise_loss)
    tf.summary.scalar('Total_Loss', Total_Loss)
    y_img = pred_LF[:,:,:,0:3]
    
    
    #Reshape the output into a grid LF
    for i in range(1,ANGULAR_RES_TARGET):
        if i==8:
            grid_LF = y_img
            y_img = pred_LF[:,:,:,i*3:(i*3)+3]
                
        elif i%8==0 and i>8:
            grid_LF = tf.concat([grid_LF, y_img], 2)
            y_img = pred_LF[:,:,:,i*3:(i*3)+3]

        elif i == 63:
            y_img = tf.concat([y_img, pred_LF[:,:,:,i*3:(i*3)+3]], 1)
            grid_LF = tf.concat([grid_LF, y_img], 2)

        else:
            y_img = tf.concat([y_img, pred_LF[:,:,:,i*3:(i*3)+3]], 1)

###################################################################################### 

with tf.name_scope('Evaluation'):
    psnr = tf.image.psnr(color_norm, pred_LF_norm, max_val=1.0)
    tf.summary.scalar('PSNR', psnr[0])
    
    ssim = tf.image.ssim(color_norm, pred_LF_norm, max_val=1.0)
    tf.summary.scalar('SSIM', ssim[0])

# TRAIN OP

In [10]:
with tf.name_scope('Train'):
    global_step_G = tf.Variable(0, dtype=tf.float32)
    learning_rate_G = tf.train.exponential_decay(
                      LR_G,                  # Base learning rate.
                      global_step_G,         # Current index into the dataset.
                      DECAY_STEP,            # Decay step.
                      0.90,                  # Decay rate.
                      staircase=True)
    tf.summary.scalar('LR_G', learning_rate_G)
    train_G = tf.train.AdamOptimizer(learning_rate=learning_rate_G, name='optimizer_G').minimize(Total_Loss, global_step=global_step_G)


# VISUALIZATION

In [11]:
with tf.name_scope('Visualization'):     
    #Open the predicted LF into grid for EPI visualization
    grid_LF = tf.zeros_like(image)
    y_img = pred_LF[:,:,:,0:3]
    
    for i in range(1,ANGULAR_RES_TARGET):
        if i==8:
            grid_LF = y_img
            y_img = pred_LF[:,:,:,i*3:(i*3)+3]

                  
        elif i%8==0 and i>8:
            grid_LF = tf.concat([grid_LF, y_img], 2)
            y_img = pred_LF[:,:,:,i*3:(i*3)+3]
            
        elif i == 63:
            y_img = tf.concat([y_img, pred_LF[:,:,:,i*3:(i*3)+3]], 1)
            grid_LF = tf.concat([grid_LF, y_img], 2)
        
        else:
            y_img = tf.concat([y_img, pred_LF[:,:,:,i*3:(i*3)+3]], 1)
            
    ######################################################################################
    
    #TF Summary image
    grid_LF_show = deprocess(grid_LF)
    output_image = tf.summary.image('Synthesized_LF', tf.cast(tf.reshape(grid_LF_show*255, 
                            [-1, SPATIAL_HEIGHT*8, SPATIAL_WIDTH*8, CH_OUTPUT]), tf.uint8) , 1)
    
    #EPI VISUALIZATION
    with tf.name_scope('EPI'): 
        epi_synth_H = deprocess(epi_synth_H)
        epi_GT_H = deprocess(epi_GT_H)
        epi_synth_V = deprocess(epi_synth_V)
        epi_GT_V = deprocess(epi_GT_V)
        epi_horizontal = tf.summary.image('epi_horizontal', tf.cast(tf.reshape(epi_synth_H*255, 
                                [-1, 64, SPATIAL_WIDTH, 1]), tf.uint8) , 1)
        epi_horizontal_GT = tf.summary.image('epi_horizontal_GT', tf.cast(tf.reshape(epi_GT_H*255, 
                                [-1, 64, SPATIAL_WIDTH, 1]), tf.uint8) , 1)
        epi_vertical = tf.summary.image('epi_vertical', tf.cast(tf.reshape(epi_synth_V*255, 
                                [-1, SPATIAL_HEIGHT, 64, 1]), tf.uint8) , 1)
        epi_vertical_GT = tf.summary.image('epi_vertical_GT', tf.cast(tf.reshape(epi_GT_V*255, 
                                [-1, SPATIAL_HEIGHT, 64, 1]), tf.uint8) , 1)  

# VALIDATION

In [12]:
#SLICE another EPI in the LF to make sure not only cross EPI is correct but all EPI is correct
with tf.name_scope('VALIDATION'): 
    center_width = tf.random_uniform(shape=[], minval=10, maxval=SPATIAL_WIDTH-10, dtype=tf.int32)
    center_height = tf.random_uniform(shape=[], minval=10, maxval=SPATIAL_HEIGHT-10, dtype=tf.int32)

    slice_epi_H2 = pred_LF_norm[:, center_height:center_height+1, :, :]
    slice_epi_V2 = pred_LF_norm[:, :, center_width:center_width+1, :]
    
    slice_epi_GT_H2 = color_norm[:, center_height:center_height+1, :, :]
    slice_epi_GT_V2 = color_norm[:, :, center_width:center_width+1, :]
        

    #Because of the stack is in row order for horizontal EPI it cannot be directly reshaped
    for j in range(8):
        for i in range(8):
            temp = slice_epi_H2[:, :, :, (i*8*3)+(j*3):(i*8*3)+(j*3)+3]
            temp2 = slice_epi_GT_H2[:, :, :, (i*8*3)+(j*3):(i*8*3)+(j*3)+3]
            if i==0 and j==0:
                epi_synth_H2 = temp
                epi_GT_H2 = temp2
            else:
                epi_synth_H2 = tf.concat([epi_synth_H2,temp], 1)
                epi_GT_H2 = tf.concat([epi_GT_H2,temp2], 1)

    epi_synth_V2 = tf.reshape(slice_epi_V2, [-1,SPATIAL_HEIGHT, 64, 3])
    epi_GT_V2 = tf.reshape(slice_epi_GT_V2, [-1,SPATIAL_HEIGHT, 64, 3])

    EPI_H2 = tf.summary.image('epi_horizontal2', tf.cast(tf.reshape(epi_synth_H2*255, 
                            [-1, 64, SPATIAL_WIDTH, 3]), tf.uint8) , 1)
    epi_horizontal_GT2 = tf.summary.image('epi_horizontal_GT2', tf.cast(tf.reshape(epi_GT_H2*255, 
                                [-1, 64, SPATIAL_WIDTH, 3]), tf.uint8) , 1)
    EPI_V2 = tf.summary.image('epi_vertical2', tf.cast(tf.reshape(epi_synth_V2*255, 
                            [-1, SPATIAL_HEIGHT, 64, 3]), tf.uint8) , 1)
    epi_vertical_GT2 = tf.summary.image('epi_vertical_GT2', tf.cast(tf.reshape(epi_GT_V2*255, 
                                [-1, SPATIAL_HEIGHT, 64, 3]), tf.uint8) , 1)

# INPUT PARSING

In [13]:
#To get one record and parse it to get the label and image out
def parser(record):
    keys_to_features = {
        "image_raw": tf.FixedLenFeature([], tf.string),
        "label":     tf.FixedLenFeature([], tf.string),
        "grid":     tf.FixedLenFeature([], tf.string)
    }
    #Read one record
    parsed = tf.parse_single_example(record, keys_to_features)
    #Take the image and bytes
    image = tf.decode_raw(parsed["image_raw"], tf.uint8)
    label = tf.decode_raw(parsed["label"], tf.uint8)
    grid = tf.decode_raw(parsed["grid"], tf.uint8)
    #Cast to float
    image = tf.cast(image, tf.float32)
    label = tf.cast(label, tf.float32)
    grid = tf.cast(grid, tf.float32)
    
    image = tf.reshape(image, shape=[SPATIAL_HEIGHT, SPATIAL_WIDTH, CH_INPUT])
    label = tf.reshape(label, shape=[SPATIAL_HEIGHT, SPATIAL_WIDTH, CH_OUTPUT*64])
    grid = tf.reshape(grid, shape=[SPATIAL_HEIGHT*8, SPATIAL_WIDTH*8, CH_OUTPUT])
    #Normalize the input and label into [0...1]
    image = tf.divide(image, 255)
    label = tf.divide(label, 255)

    return {'image': image}, {'label': label}, {'grid': grid}

def input_fn(filenames):
    #Create data record
    dataset = tf.data.TFRecordDataset(filenames=filenames, num_parallel_reads=1)
    dataset = dataset.map(parser, num_parallel_calls=1)
    dataset = dataset.shuffle(50).repeat().batch(BATCH_SIZE)
    #dataset = dataset.prefetch(buffer_size=2)
    return dataset

def test_fn(filenames):
    #Create data record
    dataset = tf.data.TFRecordDataset(filenames=filenames, num_parallel_reads=1)
    dataset = dataset.map(parser, num_parallel_calls=1)
    dataset = dataset.batch(10)
    return dataset

def train_input_fn():
    return input_fn(filenames=["train.tfrecords"])

def test_input_fn():
    return test_fn(filenames=["test.tfrecords"])

# CREATE TRAIN SET

In [14]:
with tf.name_scope('Create_Training_Set'):
    train_dataset = train_input_fn()
    iterator = train_dataset.make_initializable_iterator()
    next_batch = iterator.get_next()

# TRAIN

In [None]:
merged = tf.summary.merge_all()
saver = tf.train.Saver()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

sess=tf.Session(config=config)
sess.run(tf.group(tf.global_variables_initializer(), iterator.initializer))
writer = tf.summary.FileWriter('log/Synthesis/DENSE_GAN',sess.graph)

run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True, trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()

for step in range(EPOCH+1):
    train_x, train_y, train_grid = sess.run(next_batch)                       
    
    _, G_loss_, psnr_, ssim_ = sess.run([train_G, Total_Loss, psnr, ssim], 
    {tf_x:train_x['image'], tf_y:train_y['label'], tf_grid:train_grid['grid']})
   
    if step%300 == 0:
        #writer.add_run_metadata(run_metadata, 'step%d' % step)
        summary_ = sess.run(merged, {tf_x:train_x['image'], tf_y:train_y['label'], tf_grid:train_grid['grid']}
                           , options=run_options, run_metadata=run_metadata)
        writer.add_summary(summary_, step)     
        print('Step:', step, '| loss:%.5f' %G_loss_, '| PSNR:%.3f' %psnr_[0], '| SSIM:%.3f' %ssim_[0])
        
    if step%20000==0:
        save_path = saver.save(sess, "saver/Synthesis/model%i.ckpt" %step)
        print("Model saved in path: %s" % save_path)

Step: 0 | loss:892.48090 | PSNR:18.587 | SSIM:0.745
Model saved in path: saver/Synthesis/model0.ckpt
Step: 300 | loss:440.47540 | PSNR:25.621 | SSIM:0.810
Step: 600 | loss:215.62901 | PSNR:31.108 | SSIM:0.896
Step: 900 | loss:381.49469 | PSNR:27.221 | SSIM:0.798
Step: 1200 | loss:401.80212 | PSNR:26.232 | SSIM:0.822
Step: 1500 | loss:139.62825 | PSNR:34.678 | SSIM:0.943
Step: 1800 | loss:204.91272 | PSNR:30.058 | SSIM:0.909
Step: 2100 | loss:512.47223 | PSNR:23.856 | SSIM:0.793
Step: 2400 | loss:170.21962 | PSNR:34.024 | SSIM:0.901
Step: 2700 | loss:286.08249 | PSNR:28.743 | SSIM:0.869
Step: 3000 | loss:201.32086 | PSNR:33.309 | SSIM:0.903
Step: 3300 | loss:394.19278 | PSNR:27.725 | SSIM:0.743
Step: 3600 | loss:237.38304 | PSNR:30.965 | SSIM:0.871
Step: 3900 | loss:116.07413 | PSNR:34.180 | SSIM:0.955
Step: 4200 | loss:229.27771 | PSNR:32.823 | SSIM:0.886
Step: 4500 | loss:85.70263 | PSNR:37.616 | SSIM:0.962
Step: 4800 | loss:389.35004 | PSNR:27.348 | SSIM:0.821
Step: 5100 | loss:283.8

Step: 43500 | loss:228.27841 | PSNR:27.352 | SSIM:0.933
Step: 43800 | loss:200.59691 | PSNR:30.171 | SSIM:0.945
Step: 44100 | loss:205.93430 | PSNR:33.024 | SSIM:0.916
Step: 44400 | loss:119.82607 | PSNR:36.074 | SSIM:0.965
Step: 44700 | loss:126.33183 | PSNR:33.807 | SSIM:0.956
Step: 45000 | loss:103.90488 | PSNR:36.946 | SSIM:0.971
Step: 45300 | loss:165.84773 | PSNR:32.655 | SSIM:0.954
Step: 45600 | loss:365.13190 | PSNR:27.875 | SSIM:0.890
Step: 45900 | loss:207.92690 | PSNR:30.498 | SSIM:0.931
Step: 46200 | loss:172.07497 | PSNR:33.671 | SSIM:0.928
Step: 46500 | loss:392.53528 | PSNR:24.623 | SSIM:0.888
Step: 46800 | loss:225.60959 | PSNR:30.262 | SSIM:0.917
Step: 47100 | loss:208.84998 | PSNR:31.436 | SSIM:0.940
Step: 47400 | loss:105.61834 | PSNR:37.096 | SSIM:0.965
Step: 47700 | loss:187.36946 | PSNR:33.377 | SSIM:0.922
Step: 48000 | loss:125.13520 | PSNR:36.791 | SSIM:0.966
Step: 48300 | loss:136.31500 | PSNR:35.480 | SSIM:0.948
Step: 48600 | loss:135.92682 | PSNR:34.507 | SSI

Step: 87000 | loss:126.42300 | PSNR:33.068 | SSIM:0.967
Step: 87300 | loss:114.91788 | PSNR:35.569 | SSIM:0.963
Step: 87600 | loss:171.35657 | PSNR:32.707 | SSIM:0.937
Step: 87900 | loss:102.74484 | PSNR:35.307 | SSIM:0.969
Step: 88200 | loss:112.69872 | PSNR:36.503 | SSIM:0.961
Step: 88500 | loss:110.20403 | PSNR:37.664 | SSIM:0.967
Step: 88800 | loss:119.30759 | PSNR:36.053 | SSIM:0.967
Step: 89100 | loss:119.60025 | PSNR:33.782 | SSIM:0.973
Step: 89400 | loss:84.58283 | PSNR:39.018 | SSIM:0.968
Step: 89700 | loss:101.39545 | PSNR:36.879 | SSIM:0.966
Step: 90000 | loss:115.58080 | PSNR:37.130 | SSIM:0.952
Step: 90300 | loss:246.66005 | PSNR:30.918 | SSIM:0.800
Step: 90600 | loss:84.51998 | PSNR:39.764 | SSIM:0.965
Step: 90900 | loss:123.80022 | PSNR:36.341 | SSIM:0.950
Step: 91200 | loss:153.62689 | PSNR:33.709 | SSIM:0.963
Step: 91500 | loss:92.91564 | PSNR:38.231 | SSIM:0.971
Step: 91800 | loss:110.58170 | PSNR:36.632 | SSIM:0.971
Step: 92100 | loss:96.20409 | PSNR:35.944 | SSIM:0.

In [None]:
save_path = saver.save(sess, "saver/Synthesis/model%i.ckpt" %step)
print("Model saved in path: %s" % save_path)

# TRAIN RESTORE

In [None]:
# merged = tf.summary.merge_all()
# saver = tf.train.Saver()

# config = tf.ConfigProto()
# config.gpu_options.allow_growth = True

# sess=tf.Session(config=config)
# sess.run(iterator.initializer)
# writer = tf.summary.FileWriter('log/Synthesis/Restore',sess.graph)

# run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True, trace_level=tf.RunOptions.FULL_TRACE)
# run_metadata = tf.RunMetadata()
# saver.restore(sess, "saver/Synthesis/model1.ckpt")

# for step in range(EPOCH+1):
#     train_x, train_y, train_grid = sess.run(next_batch)    

# #     _, G_loss_, psnr_ = sess.run([train_op, total_loss, psnr], 
# #         {tf_x:train_x['image'], tf_y:train_y['label'], tf_grid:train_grid['grid']}, options=run_options, run_metadata=run_metadata)
    
#     if step%1 == 0:
#         _, G_loss_, psnr_ = sess.run([train_G, G_Adv_loss, psnr], 
#         {tf_x:train_x['image'], tf_y:train_y['label'], tf_grid:train_grid['grid']})
#     if step%1 == 0:
#         _, D_loss_ = sess.run([train_D, D_Total_Loss], 
#         {tf_x:train_x['image'], tf_y:train_y['label'], tf_grid:train_grid['grid']})
   
#     if step%150 == 0:
#         #writer.add_run_metadata(run_metadata, 'step%d' % step)
#         summary_ = sess.run(merged, {tf_x:train_x['image'], tf_y:train_y['label'], tf_grid:train_grid['grid']}
#                            , options=run_options, run_metadata=run_metadata)
#         writer.add_summary(summary_, step)     
#         print('Step:', step, '| G loss:%.5f' %G_loss_, '| D loss:%.5f' %D_loss_, '| PSNR[0]:%.3f' %psnr_[0])
        
#     if step%15000==0:
#         save_path = saver.save(sess, "saver/Synthesis/model%i.ckpt" %step)
#         print("Model saved in path: %s" % save_path)


# FORWARD 

In [None]:
# with tf.name_scope('Test_Folder_Read'):
#     label_path = TEST_PATH
#     addrs = sorted(glob.glob(label_path))
    
# with tf.name_scope('Create_Datarecord_Test'):
#     test_addrs = addrs[:]
#     createDataRecord('test.tfrecords', test_addrs)

In [None]:
import imageio

with tf.name_scope('Create_Test_Set'):
    test_dataset = test_input_fn()
    iterator = test_dataset.make_initializable_iterator()
    next_batch = iterator.get_next()

In [None]:
sess=tf.Session()

merged = tf.summary.merge_all()
sess.run(iterator.initializer)
writer = tf.summary.FileWriter('log/Synthesis/Test',sess.graph)
saver = tf.train.Saver()
run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True)

%%time
saver.restore(sess, "saver/Synthesis/model120000.ckpt")

In [None]:
%%time
test_x, test_y, test_grid = sess.run(next_batch)
output_, grid_, epi_H_, epi_V_, epi_GT_H_, epi_GT_V_, summary_ = sess.run(
    [pred_LF_norm, grid_LF_show, epi_synth_H2, epi_synth_V2, epi_GT_H2, epi_GT_V2, merged], 
      {tf_x:test_x['image'], tf_y:test_y['label'], tf_grid:test_grid['grid']}, options=run_options)

In [None]:
# for i in range(10):
#     save = cv2.cvtColor(test_grid['grid'][i], cv2.COLOR_BGR2RGB)
#     cv2.imwrite('GT%i.png'%i, save)

In [None]:
i = 3
temp = grid_[i]*255
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
cv2.imwrite('Output.png', temp)

save = cv2.cvtColor(test_grid['grid'][i], cv2.COLOR_BGR2RGB)
cv2.imwrite('GT.png', save)

temp = epi_H_[i]*255
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
cv2.imwrite('EPI_H.png', temp)
temp = epi_V_[i]*255
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
cv2.imwrite('EPI_V.png', temp)

temp = epi_GT_H_[i]*255
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
cv2.imwrite('EPI_GT_H.png', temp)
temp = epi_GT_V_[i]*255
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
cv2.imwrite('EPI_GT_V.png', temp)

GT= []
Output = []
Error = []

In [None]:
temp1 = test_y['label'][i]*255
for n in range(ANGULAR_RES_TARGET):
    temp = temp1[:,:,(n*3):(n*3)+3]
    save = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
    cv2.imwrite("LF_Synthesized/GT/output%i.png" %(n), save)
    temp = np.uint8(temp)
    GT.append(temp)

temp2 = output_[i,:]*255
for n in range(ANGULAR_RES_TARGET):
    temp = temp2[:,:,(n*3):(n*3)+3]
    save = cv2.cvtColor(temp2[:,:,(n*3):(n*3)+3], cv2.COLOR_BGR2RGB)
    cv2.imwrite("LF_Synthesized/Output/output%i.png" %(n), save)
    temp = np.uint8(temp)
    Output.append(temp)


In [None]:
temp1 = test_y['label'][i]*255
temp2 = output_[i,:]*255
for n in range(ANGULAR_RES_TARGET):
    save1 = cv2.cvtColor(temp1[:,:,(n*3):(n*3)+3], cv2.COLOR_BGR2RGB)
    save2 = cv2.cvtColor(temp2[:,:,(n*3):(n*3)+3], cv2.COLOR_BGR2RGB)
    save = np.absolute(save1-save2)*1.2
    save = np.uint8(save)
    save = cv2.applyColorMap(save, cv2.COLORMAP_JET)
    cv2.imwrite("LF_Synthesized/Error/output%i.png" %(n), save)
    save = cv2.cvtColor(save, cv2.COLOR_RGB2BGR)
    Error.append(save)
    
for y in range(8):
    for x in range(8):
        save1 = temp1[:,:,(x*8*3)+(y*3):(x*8*3)+(y*3)+3]
        save2 = temp2[:,:,(x*8*3)+(y*3):(x*8*3)+(y*3)+3]
        save = np.absolute(save1-save2)*1.2
        save = np.uint8(save)
        save = cv2.applyColorMap(save, cv2.COLORMAP_JET)
        save = cv2.cvtColor(save, cv2.COLOR_RGB2BGR)
        Error.append(save)

In [None]:
c=0
for y in range(8):
    for x in range(8):
        temp = temp1[:,:,(x*8*3)+(y*3):(x*8*3)+(y*3)+3]
        save = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
        cv2.imwrite("LF_Synthesized/Horizontal_GT/output%i.png" %(c), save)
        temp = np.uint8(temp)
        GT.append(temp)
        
        temp = temp2[:,:,(x*8*3)+(y*3):(x*8*3)+(y*3)+3]
        save = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
        cv2.imwrite("LF_Synthesized/Horizontal_Output/output%i.png" %(c), save)
        temp = np.uint8(temp)
        Output.append(temp)
        c+=1

In [None]:
import warnings
warnings.filterwarnings(action='once')
imageio.mimsave('GT.gif', GT, duration=0.05)
imageio.mimsave('Output.gif', Output, duration=0.05)
imageio.mimsave('Error.gif', Error, duration=0.05)

In [None]:
import math
from skimage.measure import compare_ssim as ssim

def psnr(img1, img2):
    mse = np.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

original = cv2.imread("GT.png")
contrast = cv2.imread("Output.png")
PSNR=psnr(original,contrast)
print(PSNR)
print("#########################################")

SSIM = ssim(original, contrast, multichannel=True,
              data_range=contrast.max() - contrast.min())
print(SSIM)
print("#########################################")

f= open("eval.txt","w+")
 
f.write("PSNR: %f\r" % PSNR)
f.write("SSIM: %f\r\n" % SSIM)
 
f.close() 

# END

In [None]:
lf = cv2.imread("test.png")
temp = lf[:, :, :]
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
imgplots = plt.imshow((temp).astype('uint8'))
plt.show()
temp.shape

In [None]:
numImgsX = 14;
numImgsY = 14;

h,w,c = temp.shape
h_angular = h / numImgsY; 
w_angular = w / numImgsX;
fullLF = np.zeros((int(h_angular), int(w_angular), 3, numImgsY, numImgsX));

In [None]:
for ax in range(numImgsX):
    for ay in range(numImgsY):
        fullLF[:, :, :, ay, ax] = temp[ay::numImgsY, ax::numImgsX, :]
        
padded_LF = np.pad(fullLF, ((1,0),(0,0),(0,0),(0,0),(0,0)), 'constant', constant_values=(0))

In [None]:
middleLF = fullLF[:, :, :, 3:11, 3:11]
middleLF.shape

In [None]:
list_img = []
i = 1
for ax in range(8):
    for ay in range(8):
        if ay == 0:
            y_img = middleLF[:,:,:,ay,ax]
        else:
            y_img = np.concatenate((y_img, middleLF[:,:,:,ay,ax]), 0)
        list_img.append(middleLF[:,:,:,ay,ax])     
    if ax == 0:
        full_img = y_img
    else:
        full_img = np.concatenate((full_img, y_img), 1)
    y_img = fullLF[:,:,:,ay,ax]

In [None]:
temp = fullLF[:,:,:,7,7]
imgplots = plt.imshow((temp).astype('uint8'))
plt.show()

In [None]:
imgplots = plt.imshow((full_img).astype('uint8'))
plt.show()
cv2.imwrite("Grid.png", full_img)

In [None]:
for d in range(100):
    plt.figure(d)
    temp = list_img[d]
    imgplot = plt.imshow((temp).astype('uint8'))
    #plt.savefig("Stack%i.png" %d, dpi=100, bbox_inches='tight', frameon=False)
    plt.show()
    temp = np.uint8(temp)
    temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
    cv2.imwrite("Stack%i.png" %d, temp)

In [None]:
temp = list_img[0]
temp.shape

# PSNR

In [None]:
import math
def psnr(img1, img2):
    mse = np.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

original = cv2.imread("GT.png")
contrast = cv2.imread("Output.png")
d=psnr(original,contrast)
print(d)
print("#########################################")

# SSIM

In [None]:
from skimage.measure import compare_ssim as ssim


original = cv2.imread("GT.png")
contrast = cv2.imread("Output.png")
ssim_noise = ssim(original, contrast, multichannel=True,
              data_range=contrast.max() - contrast.min())
print(ssim_noise)
print("#########################################")

In [None]:
ones_3d = np.ones((5,5,5))
weight_3d = np.ones((3,3,3))
strides_3d = [1, 1, 1, 1, 1]

in_3d = tf.constant(ones_3d, dtype=tf.float32)
filter_3d = tf.constant(weight_3d, dtype=tf.float32)

in_width = int(in_3d.shape[0])
in_height = int(in_3d.shape[1])
in_depth = int(in_3d.shape[2])

filter_width = int(filter_3d.shape[0])
filter_height = int(filter_3d.shape[1])
filter_depth = int(filter_3d.shape[2])

input_3d   = tf.reshape(in_3d, [1, in_depth, in_height, in_depth, 1])
kernel_3d = tf.reshape(filter_3d, [filter_depth, filter_height, filter_width, 1, 1])

temp = tf.nn.conv3d(input_3d, kernel_3d, strides=strides_3d, padding='SAME')
output_3d = tf.squeeze(temp)
sess=tf.Session()
sess.run(output_3d)

In [None]:
#LF INDEX
# +---+----+----+----+----+----+----+----+
# | 0 | 8  | 16 | 24 | 32 | 40 | 48 | 56 |
# +---+----+----+----+----+----+----+----+
# | 1 | 9  | 17 | 25 | 33 | 41 | 49 | 57 |
# +---+----+----+----+----+----+----+----+
# | 2 | 10 | 18 | 26 | 34 | 42 | 50 | 58 |
# +---+----+----+----+----+----+----+----+
# | 3 | 11 | 19 | 27 | 35 | 43 | 51 | 59 |
# +---+----+----+----+----+----+----+----+
# | 4 | 12 | 20 | 28 | 36 | 44 | 52 | 60 |
# +---+----+----+----+----+----+----+----+
# | 5 | 13 | 21 | 29 | 37 | 45 | 53 | 61 |
# +---+----+----+----+----+----+----+----+
# | 6 | 14 | 22 | 30 | 38 | 46 | 54 | 62 |
# +---+----+----+----+----+----+----+----+
# | 7 | 15 | 23 | 31 | 39 | 47 | 55 | 63 |
# +---+----+----+----+----+----+----+----+

In [None]:
SHIFT_VALUE = 1
for i in range(64):
    if i==0 or (i%8==0 and i>7):
        ty = 3*SHIFT_VALUE
    elif i == 1 or (i-1)%8==0:
        ty = 2*SHIFT_VALUE
    elif i == 2 or (i-2)%8==0:
        ty = 1*SHIFT_VALUE
    elif i == 3 or (i-3)%8==0:
        ty = 0
    elif i == 4 or (i-4)%8==0:
        ty = -1*SHIFT_VALUE
    elif i == 5 or (i-5)%8==0:
        ty = -2*SHIFT_VALUE
    elif i == 6 or (i-6)%8==0:
        ty = -3*SHIFT_VALUE
    else:
        ty = -4*SHIFT_VALUE
    print('i: ',i, 'ty: ',ty)

In [None]:
SHIFT_VALUE = 1
for i in range(64):
    if i<=7:
        tx = 3*SHIFT_VALUE
    elif i>7 and i<=15:
        tx = 2*SHIFT_VALUE
    elif i>15 and i<=23:
        tx = 1*SHIFT_VALUE
    elif i>23 and i<=31:
        tx = 0
    elif i>31 and i<=39:
        tx = -1*SHIFT_VALUE
    elif i>39 and i<=47:
        tx = -2*SHIFT_VALUE
    elif i>47 and i<=55:
        tx = -3*SHIFT_VALUE
    else:
        tx = -4*SHIFT_VALUE
    print('i: ',i, 'tx: ',tx)


In [None]:
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
image_shift = tf_image_translate(c, 1, 0)

In [None]:
sess = tf.Session()
result = sess.run(image_shift)

In [None]:
print(result)

In [None]:
print(result)

In [None]:
filename_queue = tf.train.string_input_producer(['test.png']) #  list of files to read

reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)

my_img = tf.image.decode_png(value) # use png or jpg decoder based on your files.
my_img = tf.to_float(my_img)
my_img2 = tf.expand_dims(my_img, 0)
print(my_img2.shape)
sobel = tf.image.sobel_edges(my_img2)

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    
    sess.run(init_op)

    # Start populating the filename queue.

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(1): #length of your filename list
        image = sobel.eval() #here is your image Tensor :) 

    print(image.shape)

    coord.request_stop()
    coord.join(threads)

In [None]:
temp = image[0]

In [None]:
temp1 = temp[:,:,:,0]
temp2 = temp[:,:,:,1]
edge = np.sqrt(temp1**2 + temp2**2)

In [None]:
imgplot = plt.imshow((temp1).astype('uint8'))

In [None]:
imgplot = plt.imshow((temp2).astype('uint8'))

In [None]:
imgplot = plt.imshow((edge).astype('uint8'))