###  **Module imports**

In [0]:
# coding: utf-8
#get_ipython().magic(u'matplotlib inline')
import tensorflow as tf
import numpy as np
import scipy as sp
from scipy import io
from scipy import interpolate
from scipy import ndimage
from scipy.misc import imsave
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import matplotlib.pyplot as plt

###  **Hyperparameters**

In [0]:
#parameters
lfsize = [360, 540, 7, 7] #dimensions of Lytro light fields

batchsize = 6 
patchsize = [120, 120] #spatial dimensions of training light fields
num_crops = 5 #number of random spatial crops per light field for each input queue thread to push

test_batchsize = 8  
test_patchsize = [210, 210]

disp_mult = 10.0 #max disparity between adjacent views
test_num_crops = 4

learning_rate = 0.0001
test = 1500
train_iters = 100000
test_iters = 75

###  **Defocus code**

In [0]:
defocus_code = np.ones([1,1,lfsize[2],lfsize[3],1]) # spatial blur kernel - uniform
defocus_code = defocus_code[np.newaxis,:,:,:,:,:]/49.0

blur_code = np.tile(defocus_code,(1,patchsize[0],patchsize[1],1,1,3))
test_code = np.tile(defocus_code,(1,test_patchsize[0],test_patchsize[1],1,1,3))

###  **CNN layer functions**

In [0]:

def weight_variable(w_shape, name):
    return tf.get_variable(name, w_shape, initializer=tf.contrib.layers.xavier_initializer_conv2d())

def bias_variable(b_shape, init_bias=0.0):
    return tf.get_variable('bias', b_shape, initializer=tf.constant_initializer(init_bias))

# No activation
def cnn_layer_no_act(input_tensor, w_shape, b_shape, layer_name, is_training, rate=1, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.atrous_conv2d(input_tensor, W, rate, padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        return h
    
#standard atrous layer
def cnn_layer(input_tensor, w_shape, b_shape, layer_name, is_training, rate=1, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.atrous_conv2d(input_tensor, W, rate, padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        h = tf.nn.elu(h)
        h = tf.contrib.layers.batch_norm(h, scale=True, updates_collections=None, 
                                             is_training=is_training, scope=layer_name + '_bn')
        return h
    
#standard strided layer
def cnn_layer_strided(input_tensor, w_shape, b_shape, layer_name, is_training, stride=1, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.conv2d(input_tensor, W, strides=[1, stride, stride, 1], padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        h = tf.nn.elu(h)
        h = tf.contrib.layers.batch_norm(h, scale=True, updates_collections=None, 
                                             is_training=is_training, scope=layer_name + '_bn')
        return h
    
def cnn_layer_transposed(input_tensor, w_shape, b_shape, o_shape, layer_name, is_training, stride=1, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.conv2d_transpose(input_tensor, W,  o_shape, strides=[1, stride, stride, 1], padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        h = tf.nn.elu(h)
        h = tf.contrib.layers.batch_norm(h, scale=True, updates_collections=None, 
                                             is_training=is_training, scope=layer_name + '_bn')
        return h
    
#layer with no normalization or activation
def cnn_layer_no_bn(input_tensor, w_shape, b_shape, layer_name, stride=1, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.conv2d(input_tensor, W, strides=[1, stride, stride, 1], padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        return h

#layer with activation no batch norm
def cnn_layer_act_nobn(input_tensor, w_shape, b_shape, layer_name, is_training, rate=1, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.atrous_conv2d(input_tensor, W, rate, padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        h = tf.nn.elu(h)
        return h

#transpose convolution with activation no batch norm
def cnn_layer_tr_nobn(input_tensor, w_shape, b_shape, o_shape, layer_name, is_training, stride=1, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.conv2d_transpose(input_tensor, W,  o_shape, strides=[1, stride, stride, 1], padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        h = tf.nn.elu(h)
        return h        
    
def res_sum(a1, a2, layer_name, is_training):
    with tf.variable_scope(layer_name):
        h = a1 + a2
        h = tf.nn.elu(h)
        #h = tf.contrib.layers.batch_norm(h, scale=True, updates_collections=None, 
        #                                     is_training=is_training, scope=layer_name + '_bn')
        return h

#transpose convolution with no activation no batch norm
def cnn_layer_tr_nobn_noact(input_tensor, w_shape, b_shape, o_shape, layer_name, is_training, stride=1, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.conv2d_transpose(input_tensor, W,  o_shape, strides=[1, stride, stride, 1], padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        return h     
    
    
#standard atrous layer
def cnn_layer3D(input_tensor, w_shape, b_shape, layer_name, is_training, rate=[1, 1, 1, 1, 1], padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.conv3d(input_tensor, W, strides=[1,1,1,1,1], padding=padding_type, 
                         name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        h = tf.nn.elu(h)
        h = tf.contrib.layers.batch_norm(h, scale=True, updates_collections=None, 
                                             is_training=is_training, scope=layer_name + '_bn')
        return h

#3D convolution layer with no activation no batch norm
def cnn_layer3D_no_bn(input_tensor, w_shape, b_shape, layer_name, padding_type='SAME'):
    with tf.variable_scope(layer_name):
        W = weight_variable(w_shape, '_weights')
        h = tf.nn.conv3d(input_tensor, W, strides=[1,1,1,1,1], padding=padding_type, name=layer_name + '_conv')
        h = h + bias_variable(b_shape)
        return h

###  **Depth estimation network**

In [0]:
def depth_network(x, xc, lfsize, disp_mult, is_training, name):
    with tf.variable_scope(name):
        
        b_sz = tf.shape(x)[0]
        y_sz = tf.shape(x)[1]
        x_sz = tf.shape(x)[2]
        v_sz = lfsize[2]
        u_sz = lfsize[3]
        
        net_in = tf.concat([x,xc],axis=3)
        c1 = cnn_layer(net_in, [3, 3, 6, 16], [16], 'c1', is_training)
        c2 = cnn_layer(c1, [3, 3, 16, 64], [64], 'c2', is_training)
        c3 = cnn_layer(c2, [3, 3, 64, 128], [128], 'c3', is_training)
        c4 = cnn_layer(c3, [3, 3, 128, 128], [128], 'c4', is_training, rate=2)
        c5 = cnn_layer(c4, [3, 3, 128, 128], [128], 'c5', is_training, rate=2)
        c6 = cnn_layer(c5, [3, 3, 128, 128], [128], 'c6', is_training, rate=4)
        c7 = cnn_layer(c6, [3, 3, 128, 128], [128], 'c7', is_training, rate=8)
        c8 = cnn_layer(c7, [3, 3, 128, 64], [64], 'c8', is_training, rate=16)

        sc1 = cnn_layer(c3, [3, 3, 128, 128], [128], 'sc1', is_training)
        sc2 = cnn_layer(sc1, [3, 3, 128, 128], [128], 'sc2', is_training)
        sc3 = cnn_layer(sc2, [3, 3, 128, 64], [64], 'sc3', is_training)
        
        dsc1 = cnn_layer(c6, [3, 3, 128, 64], [64], 'dsc1', is_training)
        dsc2 = cnn_layer(dsc1, [3, 3, 64, 64], [64], 'dsc2', is_training)
        
        dsc3 = cnn_layer(c7, [3, 3, 128, 32], [32], 'dsc3', is_training)        
        dsc4 = cnn_layer(dsc3, [3, 3, 32, 32], [32], 'dsc4', is_training)        
        
        dsc5 = cnn_layer(c8, [3, 3, 64, 32], [32], 'dsc5', is_training)        
        dsc6 = cnn_layer(dsc5, [3, 3, 32, 32], [32], 'dsc6', is_training)        
        
        concat_feat = tf.concat([sc3,dsc2,dsc4,dsc6],axis=3)
        
        c13 = cnn_layer(concat_feat, [3, 3, 192, 128], [128], 'c13', is_training)
        c14 = cnn_layer(c13, [3, 3, 128, 128], [128], 'c14', is_training)
        c15 = cnn_layer(c14, [3, 3, 128, 49], [49], 'c15', is_training)
        c12 = cnn_layer(c15, [3, 3, 49, 49], [49], 'c12', is_training)
        c16 = cnn_layer(c12, [3, 3, 49, lfsize[2]*lfsize[3]], [lfsize[2]*lfsize[3]], 'c16', is_training)
        c17 = disp_mult*tf.tanh(cnn_layer_no_bn(c16, [3, 3, lfsize[2]*lfsize[3], lfsize[2]*lfsize[3]], 
                                                [lfsize[2]*lfsize[3]], 'c10'))
        #print tf.shape(c16), disp_mult
        return tf.reshape(c17, [b_sz, y_sz, x_sz, v_sz, u_sz], name='rayd')

###  **Light field refinement network**

In [0]:
def occlusions_network3D_1(d, xc, shear, lfsize, is_training, name):
    with tf.variable_scope(name):
        
        b_sz = tf.shape(d)[0]
        y_sz = tf.shape(d)[1]
        x_sz = tf.shape(d)[2]
        v_sz = lfsize[2]
        u_sz = lfsize[3]
        
        #depth
        d = tf.tile(tf.expand_dims(d,5), [1,1,1,1,1,3])
        d = tf.reshape(d, [b_sz, y_sz, x_sz, v_sz*u_sz, 3])
        
        #light field
        x = tf.reshape(shear, [b_sz, y_sz, x_sz, v_sz*u_sz, 3])
        #defocused image
        xc = tf.expand_dims(xc,3)
        #concatenating light field, depth and coded image as input to the network
        xdc = tf.concat([x, d, xc],axis=3)
        
        
        xdc = tf.transpose(xdc, [0, 4, 1, 2, 3]) # B, C, H, W, 2*v*u + 1
        #shear = tf.reshape(shear, [b_sz, y_sz, x_sz, v_sz*u_sz*3])
        #[filter_depth, filter_height, filter_width, in_channels, out_channels]
        c1 = cnn_layer3D(xdc, [3,3,3,99,98], [98], 'c1', is_training, padding_type='SAME')
        c2 = cnn_layer3D(c1, [3,3,3,98,98], [98], 'c2', is_training, padding_type='SAME')
        c3 = cnn_layer3D(c2, [3,3,3,98,98], [98], 'c3', is_training, padding_type='SAME')
        c4 = cnn_layer3D(c3, [3,3,3,98,49], [49], 'c4', is_training, padding_type='SAME')
        c5 = cnn_layer3D(c4, [3,3,3,49,49], [49], 'c5', is_training, padding_type='SAME')
        c6 = cnn_layer3D_no_bn(c5, [3, 3, 3, v_sz*u_sz, v_sz*u_sz], [v_sz*u_sz], 'c6', padding_type='SAME')
        # o - b,3,h,w,49
        c7 = tf.transpose(tf.reshape(c6, [b_sz, 3, y_sz, x_sz, v_sz, u_sz]), [0, 2, 3, 4, 5, 1])
        c8 = tf.sigmoid(c7 + shear)
        
        return tf.reshape(c8, [b_sz, y_sz, x_sz, v_sz, u_sz, 3]), c7

In [0]:
def occlusions_network3D_2(d, xc, shear, lfsize, is_training, name):
    with tf.variable_scope(name):
        
        b_sz = tf.shape(d)[0]
        y_sz = tf.shape(d)[1]
        x_sz = tf.shape(d)[2]
        v_sz = lfsize[2]
        u_sz = lfsize[3]
        
        d = tf.tile(tf.expand_dims(d,5), [1,1,1,1,1,3])
        d = tf.reshape(d, [b_sz, y_sz, x_sz, v_sz*u_sz, 3])
        
        x = tf.reshape(shear, [b_sz, y_sz, x_sz, v_sz*u_sz, 3])
        
        xc = tf.expand_dims(xc,3)
        xdc = tf.concat([x, d, xc],axis=3)
        
        
        xdc = tf.transpose(xdc, [0, 4, 1, 2, 3]) # B, C, H, W, 2*v*u + 1
        #shear = tf.reshape(shear, [b_sz, y_sz, x_sz, v_sz*u_sz*3])
        #[filter_depth, filter_height, filter_width, in_channels, out_channels]
        c1 = cnn_layer3D(xdc, [3,3,3,99,98], [98], 'c1', is_training, padding_type='SAME')
        c2 = cnn_layer3D(c1, [3,3,3,98,98], [98], 'c2', is_training, padding_type='SAME')
        c3 = cnn_layer3D(c2, [3,3,3,98,49], [49], 'c3', is_training, padding_type='SAME')
        c6 = cnn_layer3D_no_bn(c3, [3, 3, 3, v_sz*u_sz, v_sz*u_sz], [v_sz*u_sz], 'c6', padding_type='SAME')
        # o - b,3,h,w,49
        c7 = tf.transpose(tf.reshape(c6, [b_sz, 3, y_sz, x_sz, v_sz, u_sz]), [0, 2, 3, 4, 5, 1])
        c8 = tf.sigmoid(c7 + shear)
        
        return tf.reshape(c8, [b_sz, y_sz, x_sz, v_sz, u_sz, 3])

###  **Pipeline**

In [0]:
#full forward model
def forward_model(x, xc, lf_batch, lfsize, shear_values, disp_mult, is_training):
    try:
        with tf.variable_scope('forward_model', reuse=None) as scope:
            
            #predict ray depths from input image and coded image
            ray_depths = depth_network(x, xc, lfsize, disp_mult, is_training, 'ray_depths')

            #shear input image by predicted ray depths to render Lambertian light field
            lf_shear_r = depth_rendering(x[:,:,:,0], ray_depths, lfsize)
            lf_shear_g = depth_rendering(x[:,:,:,1], ray_depths, lfsize)
            lf_shear_b = depth_rendering(x[:,:,:,2], ray_depths, lfsize)
            lf_shear = tf.stack([lf_shear_r, lf_shear_g, lf_shear_b], axis=5)
            
            #refocusing light field at a different depth
            #shear_vals = np.random.uniform(-1.8,-0.35, 0.35, 1,batchsize)
            lf_shear_ref,_ = refocus(lf_shear, batchsize, shift_tf=shear_values)
            lf_ref,_ = refocus(lf_batch, batchsize, shift_tf=shear_values)
            
            #occlusion/non-Lambertian prediction network
            d = tf.stop_gradient(ray_depths)
            lfs = tf.stop_gradient(lf_shear)
            y = occlusions_network3D_2(d, xc, lfs, lfsize, is_training, 'occlusions')
            
            return ray_depths, lf_shear, y, lf_ref, lf_shear_ref
    except ValueError:
        with tf.variable_scope('forward_model', reuse=True) as scope:
            
            #predict ray depths from input image and coded image
            ray_depths = depth_network(x, xc, lfsize, disp_mult, is_training, 'ray_depths')

            #shear input image by predicted ray depths to render Lambertian light field
            lf_shear_r = depth_rendering(x[:,:,:,0], ray_depths, lfsize)
            lf_shear_g = depth_rendering(x[:,:,:,1], ray_depths, lfsize)
            lf_shear_b = depth_rendering(x[:,:,:,2], ray_depths, lfsize)
            lf_shear = tf.stack([lf_shear_r, lf_shear_g, lf_shear_b], axis=5)
            
            #regularization loss with focal plane variation 
            lf_shear_ref = lf_shear
            lf_ref = lf_shear
            
            #occlusion/non-Lambertian prediction network
            d = tf.stop_gradient(ray_depths)
            lfs = tf.stop_gradient(lf_shear)
            y = occlusions_network3D_2(d, xc, lfs, lfsize, is_training, 'occlusions')
            
            return ray_depths, lf_shear, y, lf_ref, lf_shear_ref

###  **Refocusing light field**

In [0]:
def refocus_lf(lf, shifts, b_sz):
    with tf.variable_scope('refocus_lf') as scope:
        #b_sz = tf.shape(lf)[0]
        y_sz = tf.shape(lf)[1]
        x_sz = tf.shape(lf)[2]
        u_sz = lfsize[2]
        v_sz = lfsize[3]
        
        #create and reparameterize light field grid
        b_vals = tf.to_float(tf.range(b_sz))
        v_vals = tf.to_float(tf.range(v_sz)) - tf.to_float(v_sz/2)
        u_vals = tf.to_float(tf.range(u_sz)) - tf.to_float(u_sz/2)
        y_vals = tf.to_float(tf.range(y_sz))
        x_vals = tf.to_float(tf.range(x_sz))
    
        b, y, x, v, u = tf.meshgrid(b_vals, y_vals, x_vals, v_vals, u_vals, indexing='ij')
        #warp coordinates by ray depths
        y_t = y - v * shifts*tf.ones_like(lf)
        x_t = x - u * shifts*tf.ones_like(lf)
        
        v_r = v + tf.to_float(v_sz/2)
        u_r = u + tf.to_float(u_sz/2)
        
        #indices for linear interpolation
        b_1 = tf.to_int32(b)
        y_1 = tf.to_int32(tf.floor(y_t))
        y_2 = y_1 + 1
        x_1 = tf.to_int32(tf.floor(x_t))
        x_2 = x_1 + 1
        v_1 = tf.to_int32(v_r)
        u_1 = tf.to_int32(u_r)
        
        y_1 = tf.clip_by_value(y_1, 0, y_sz-1)
        y_2 = tf.clip_by_value(y_2, 0, y_sz-1)
        x_1 = tf.clip_by_value(x_1, 0, x_sz-1)
        x_2 = tf.clip_by_value(x_2, 0, x_sz-1)
        
        #assemble interpolation indices
        interp_pts_1 = tf.stack([b_1, y_1, x_1, v_1, u_1], -1)
        interp_pts_2 = tf.stack([b_1, y_2, x_1, v_1, u_1], -1)
        interp_pts_3 = tf.stack([b_1, y_1, x_2, v_1, u_1], -1)
        interp_pts_4 = tf.stack([b_1, y_2, x_2, v_1, u_1], -1)
        
        #gather light fields to be interpolated
        lf_1 = tf.gather_nd(lf, interp_pts_1)
        lf_2 = tf.gather_nd(lf, interp_pts_2)
        lf_3 = tf.gather_nd(lf, interp_pts_3)
        lf_4 = tf.gather_nd(lf, interp_pts_4)
        
        #calculate interpolation weights
        y_1_f = tf.to_float(y_1)
        x_1_f = tf.to_float(x_1)
        d_y_1 = 1.0 - (y_t - y_1_f)
        d_y_2 = 1.0 - d_y_1
        d_x_1 = 1.0 - (x_t - x_1_f)
        d_x_2 = 1.0 - d_x_1
        
        w1 = d_y_1 * d_x_1
        w2 = d_y_2 * d_x_1
        w3 = d_y_1 * d_x_2
        w4 = d_y_2 * d_x_2
        
        refocus_lf = tf.add_n([w1*lf_1, w2*lf_2, w3*lf_3, w4*lf_4])
                        
    return refocus_lf

In [0]:
#render light field from input image and ray depths by backward warping
def depth_rendering(central, ray_depths, lfsize):
    with tf.variable_scope('depth_rendering') as scope:
        b_sz = tf.shape(central)[0]
        y_sz = tf.shape(central)[1]
        x_sz = tf.shape(central)[2]
        u_sz = lfsize[2]
        v_sz = lfsize[3]
        
        central = tf.expand_dims(tf.expand_dims(central, 3), 4)
                                                
        #create and reparameterize light field grid
        b_vals = tf.to_float(tf.range(b_sz))
        v_vals = tf.to_float(tf.range(v_sz)) - tf.to_float(v_sz/2)
        u_vals = tf.to_float(tf.range(u_sz)) - tf.to_float(u_sz/2)
        y_vals = tf.to_float(tf.range(y_sz))
        x_vals = tf.to_float(tf.range(x_sz))
    
        b, y, x, v, u = tf.meshgrid(b_vals, y_vals, x_vals, v_vals, u_vals, indexing='ij')
               
        #warp coordinates by ray depths
        y_t = y + v * ray_depths
        x_t = x + u * ray_depths
        
        v_r = tf.zeros_like(b)
        u_r = tf.zeros_like(b)
        
        #indices for linear interpolation
        b_1 = tf.to_int32(b)
        y_1 = tf.to_int32(tf.floor(y_t))
        y_2 = y_1 + 1
        x_1 = tf.to_int32(tf.floor(x_t))
        x_2 = x_1 + 1
        v_1 = tf.to_int32(v_r)
        u_1 = tf.to_int32(u_r)
        
        y_1 = tf.clip_by_value(y_1, 0, y_sz-1)
        y_2 = tf.clip_by_value(y_2, 0, y_sz-1)
        x_1 = tf.clip_by_value(x_1, 0, x_sz-1)
        x_2 = tf.clip_by_value(x_2, 0, x_sz-1)
        
        #assemble interpolation indices
        interp_pts_1 = tf.stack([b_1, y_1, x_1, v_1, u_1], -1)
        interp_pts_2 = tf.stack([b_1, y_2, x_1, v_1, u_1], -1)
        interp_pts_3 = tf.stack([b_1, y_1, x_2, v_1, u_1], -1)
        interp_pts_4 = tf.stack([b_1, y_2, x_2, v_1, u_1], -1)
        
        #gather light fields to be interpolated
        lf_1 = tf.gather_nd(central, interp_pts_1)
        lf_2 = tf.gather_nd(central, interp_pts_2)
        lf_3 = tf.gather_nd(central, interp_pts_3)
        lf_4 = tf.gather_nd(central, interp_pts_4)
        
        #calculate interpolation weights
        y_1_f = tf.to_float(y_1)
        x_1_f = tf.to_float(x_1)
        d_y_1 = 1.0 - (y_t - y_1_f)
        d_y_2 = 1.0 - d_y_1
        d_x_1 = 1.0 - (x_t - x_1_f)
        d_x_2 = 1.0 - d_x_1
        
        w1 = d_y_1 * d_x_1
        w2 = d_y_2 * d_x_1
        w3 = d_y_1 * d_x_2
        w4 = d_y_2 * d_x_2
        
        lf = tf.add_n([w1*lf_1, w2*lf_2, w3*lf_3, w4*lf_4])
                        
    return lf

In [0]:
#resample ray depths for depth consistency regularization
def transform_ray_depths(ray_depths, u_step, v_step, lfsize):
    with tf.variable_scope('transform_ray_depths') as scope:
        b_sz = tf.shape(ray_depths)[0]
        y_sz = tf.shape(ray_depths)[1]
        x_sz = tf.shape(ray_depths)[2]
        u_sz = lfsize[2]
        v_sz = lfsize[3]
                                                        
        #create and reparameterize light field grid
        b_vals = tf.to_float(tf.range(b_sz))
        v_vals = tf.to_float(tf.range(v_sz)) - tf.to_float(v_sz/2)
        u_vals = tf.to_float(tf.range(u_sz)) - tf.to_float(u_sz/2)
        y_vals = tf.to_float(tf.range(y_sz))
        x_vals = tf.to_float(tf.range(x_sz))
    
        b, y, x, v, u = tf.meshgrid(b_vals, y_vals, x_vals, v_vals, u_vals, indexing='ij')
               
        #warp coordinates by ray depths
        y_t = y + v_step * ray_depths
        x_t = x + u_step * ray_depths
        
        v_t = v - v_step + tf.to_float(v_sz/2)
        u_t = u - u_step + tf.to_float(u_sz/2)
        
        #indices for linear interpolation
        b_1 = tf.to_int32(b)
        y_1 = tf.to_int32(tf.floor(y_t))
        y_2 = y_1 + 1
        x_1 = tf.to_int32(tf.floor(x_t))
        x_2 = x_1 + 1
        v_1 = tf.to_int32(v_t)
        u_1 = tf.to_int32(u_t)
        
        y_1 = tf.clip_by_value(y_1, 0, y_sz-1)
        y_2 = tf.clip_by_value(y_2, 0, y_sz-1)
        x_1 = tf.clip_by_value(x_1, 0, x_sz-1)
        x_2 = tf.clip_by_value(x_2, 0, x_sz-1)
        v_1 = tf.clip_by_value(v_1, 0, v_sz-1)
        u_1 = tf.clip_by_value(u_1, 0, u_sz-1)
        
        #assemble interpolation indices
        interp_pts_1 = tf.stack([b_1, y_1, x_1, v_1, u_1], -1)
        interp_pts_2 = tf.stack([b_1, y_2, x_1, v_1, u_1], -1)
        interp_pts_3 = tf.stack([b_1, y_1, x_2, v_1, u_1], -1)
        interp_pts_4 = tf.stack([b_1, y_2, x_2, v_1, u_1], -1)
        
        #gather light fields to be interpolated
        lf_1 = tf.gather_nd(ray_depths, interp_pts_1)
        lf_2 = tf.gather_nd(ray_depths, interp_pts_2)
        lf_3 = tf.gather_nd(ray_depths, interp_pts_3)
        lf_4 = tf.gather_nd(ray_depths, interp_pts_4)
        
        #calculate interpolation weights
        y_1_f = tf.to_float(y_1)
        x_1_f = tf.to_float(x_1)
        d_y_1 = 1.0 - (y_t - y_1_f)
        d_y_2 = 1.0 - d_y_1
        d_x_1 = 1.0 - (x_t - x_1_f)
        d_x_2 = 1.0 - d_x_1
        
        w1 = d_y_1 * d_x_1
        w2 = d_y_2 * d_x_1
        w3 = d_y_1 * d_x_2
        w4 = d_y_2 * d_x_2
        
        lf = tf.add_n([w1*lf_1, w2*lf_2, w3*lf_3, w4*lf_4])
                        
    return lf

###  **Loss functions**

In [0]:
#loss to encourage consistency of ray depths corresponding to same scene point
def fn_depth_consistency_loss(x, lfsize):
    x_u = transform_ray_depths(x, 1.0, 0.0, lfsize)
    x_v = transform_ray_depths(x, 0.0, 1.0, lfsize)
    x_uv = transform_ray_depths(x, 1.0, 1.0, lfsize)
    d1 = (x[:,:,:,1:,1:]-x_u[:,:,:,1:,1:])
    d2 = (x[:,:,:,1:,1:]-x_v[:,:,:,1:,1:])
    d3 = (x[:,:,:,1:,1:]-x_uv[:,:,:,1:,1:])
    l1 = tf.reduce_mean(tf.abs(d1)+tf.abs(d2)+tf.abs(d3))
    return l1

def gradient(img):
    gx = img[:,:,:-1,:] - img[:,:,1:,:]
    gy = img[:,:-1,:,:] - img[:,1:,:,:]

    return gx, gy

#spatial TV loss (l1 of spatial derivatives)
def fn_tv_loss(x):
    temp = x[:,0:patchsize[0]-1,0:patchsize[1]-1,:,:]
    dy = (x[:,1:patchsize[0],0:patchsize[1]-1,:,:] - temp)
    dx = (x[:,0:patchsize[0]-1,1:patchsize[1],:,:] - temp)
    l1 = tf.reduce_mean(tf.abs(dy)+tf.abs(dx))
    return l1

def fn_grad_loss(depths, lf):
    b_sz = tf.shape(depths)[0]
    y_sz = tf.shape(depths)[1]
    x_sz = tf.shape(depths)[2]
    u_sz = lfsize[2]
    v_sz = lfsize[3]
    # reshape depths of [B,h,w,v,u]
    depths = tf.reshape(tf.transpose(depths, [0,3,4,1,2]), [b_sz*v_sz*u_sz, y_sz, x_sz])
    depth_imgs = tf.expand_dims(depths, 3)
    
    lf_imgs = tf.reshape(tf.transpose(lf, [0,3,4,1,2,5]), [b_sz*v_sz*u_sz, y_sz, x_sz, 3])
    d_gx, d_gy = gradient(depth_imgs) # b,h,w,1
    l_gx, l_gy = gradient(lf_imgs)  # b,h,w,3
    w_gx = tf.exp(-tf.reduce_mean(tf.abs(l_gx), 3, keep_dims=True))
    w_gy = tf.exp(-tf.reduce_mean(tf.abs(l_gy), 3, keep_dims=True))
    
    smooth_gx = w_gx*tf.abs(d_gx)
    smooth_gy = w_gy*tf.abs(d_gy)
    return tf.reduce_mean(smooth_gx[:,:-1,:,:] + smooth_gy[:,:,:-1,:])

#normalize to between -1 and 1, given input between 0 and 1
def normalize_lf(lf):
    return lf#2.0*(lf-0.5)

In [0]:
def get_corners(data):
    b_sz  = tf.shape(data)[0]
    y_sz  = tf.shape(data)[1]
    x_sz  = tf.shape(data)[2]
    c_sz = tf.shape(data)[5]
    
    tl = data[:,:,:,0:1,0:1,:]
    bl = data[:,:,:,6:,0:1,:]
    tr = data[:,:,:,0:1,6:,:]
    br = data[:,:,:,6:,6:,:]
    cat_data = tf.concat([tl,tr,bl,br],3)
    cat_data = tf.reshape(cat_data, [b_sz, y_sz, x_sz, 2, 2, c_sz])
    #print cat_data.get_shape()
    return tf.squeeze(cat_data)

def get_shear_vals(shift):
    bsz = shift.shape[0]
    shear_vals = []
    for i in range(bsz):
        p = np.random.uniform(0,1,1)[0]
        if p<0.56:
            tmp = np.random.uniform(-1.15,-0.81,1)[0]
            shear_vals.append(tmp)
        else:
            tmp = np.random.uniform(0.35,0.55,1)[0]
            shear_vals.append(tmp)
            
    new_shift = tf.convert_to_tensor(shear_vals, dtype=tf.float32)
    
    return new_shift


def get_rival_shifts(shift):
    b_sz = shift.shape[0]
    new_shift = np.zeros([b_sz])
    for i in range(b_sz): #-0.98,.28
        if shift[i] < -0.40:
            new_shift[i] = np.random.uniform(0.65,1.08,1)[0]
        #elif shift[i] > -0.18 and shift[i] < 0.18:
        #    p = np.random.uniform(0,1,1)[0]
        #    new_shift[i] = (p>0.6)*np.random.uniform(0.15,0.38,1)[0] + (p<0.6)*np.random.uniform(-0.38,-0.18,1)[0]
        else:
            new_shift[i] = np.random.uniform(-0.58,-0.43,1)[0]
    
    new_shift = tf.convert_to_tensor(new_shift, dtype=tf.float32)
    
    return new_shift

In [0]:
# input pipeline
def refocus(lightfield, b_sz, shift_tf=None):
    if shift_tf is None:
        shift = np.zeros([b_sz])
        for i in range(b_sz):
            p = np.random.uniform(0,1,1)[0]
            if p<0.851:
                val = np.random.uniform(-0.55,-0.08,1)[0]
            else:
                val = np.random.uniform(0.04,0.15,1)[0]
            shift[i] = val
        shift_tf = tf.convert_to_tensor(shift, dtype=tf.float32 )
    else:
        shift = [None]*b_sz

    #shift = tf.stack(shift)
    #shift = tf.convert_to_tensor(shift, dtype=tf.float32 )
    shift_tf = tf.reshape(shift_tf, [b_sz,1,1,1,1])
        
    lfr = refocus_lf(lightfield[:,:,:,:,:,0], shift_tf, b_sz)
    lfg = refocus_lf(lightfield[:,:,:,:,:,1], shift_tf, b_sz)
    lfb = refocus_lf(lightfield[:,:,:,:,:,2], shift_tf, b_sz)
    
    lf_refocus = tf.stack([lfr, lfg, lfb], axis=5)
    return lf_refocus, shift

def process_lf(lf, num_crops, lfsize, patchsize, prob_refocus):
    vsz = 8
    usz = 8
    
    lf = tf.to_float(lf[:lfsize[0]*14, :lfsize[1]*14, :])/65535.0
    
    lf = tf.image.rgb_to_hsv(tf.pow(lf, 1/1.5))
    lf = tf.concat([lf[:,:,0:1],lf[:,:,1:2]*1.5,lf[:,:,2:3]],axis=2)
    lf = tf.image.hsv_to_rgb(lf)
    lf = tf.clip_by_value(lf, 0.0,1.0)
    
    lf = normalize_lf(lf)
    lf = tf.transpose(tf.reshape(lf, [lfsize[0], 14, lfsize[1], 14, 3]), [0, 2, 1, 3, 4], name='process')
    lf = lf[:, :, (14/2)-(vsz/2):(14/2)+(vsz/2), (14/2)-(usz/2):(14/2)+(usz/2), :]
    
    
    #print lf.get_shape()
    su = np.random.randint(0,2,1)[0]
    sv = np.random.randint(0,2,1)[0]
    lf = lf[:,:,su:su+lfsize[2],sv:sv+lfsize[3],:] # extracts random LFs of angular res 7x7
    
    aif = lf[:, :, lfsize[2]/2, lfsize[3]/2, :]
    aif_list = []
    lf_list = []
    shift_list = []
    for i in range(num_crops):
        r = tf.random_uniform(shape=[], minval=0, maxval=tf.shape(lf)[0]-patchsize[0], dtype=tf.int32)
        c = tf.random_uniform(shape=[], minval=0, maxval=tf.shape(lf)[1]-patchsize[1], dtype=tf.int32)
        
        prefocus = np.random.uniform(0,1,1)[0]
        #print 'check', prefocus, prob_refocus
        if(prefocus < prob_refocus):
            r = tf.random_uniform(shape=[], minval=0, maxval=tf.shape(lf)[0]-patchsize[0]-20, dtype=tf.int32)
            c = tf.random_uniform(shape=[], minval=0, maxval=tf.shape(lf)[1]-patchsize[1]-20, dtype=tf.int32)
            patch_lf = lf[r:r+patchsize[0]+20, c:c+patchsize[1]+20, :, :, :]
            
            ref_patch_lf, shift = refocus(tf.expand_dims(patch_lf,0), 1)
            ref_patch_lf = tf.squeeze(ref_patch_lf)
            
            new_shift = get_rival_shifts(shift)       
            
            #print tf.shape(new_shift)
            lf_list.append(ref_patch_lf[10:-10, 10:-10, :, :, :])#shift_val[idx]
            aif_list.append(ref_patch_lf[10:-10, 10:-10, lfsize[2]/2, lfsize[3]/2, :])
            shift_list.append(new_shift)
        else:
            lf_list.append(lf[r:r+patchsize[0], c:c+patchsize[1], :, :, :])
            aif_list.append(aif[r:r+patchsize[0], c:c+patchsize[1], :])
            
            shift_list.append(tf.convert_to_tensor([-0.61], dtype=tf.float32 ))
            
    return aif_list, lf_list, shift_list

def read_lf(filename_queue, num_crops, lfsize, Test, patchsize, prob_refocus):
    value = tf.read_file(filename_queue[0])
    lf = tf.image.decode_png(value, channels=3, dtype=tf.uint16)
    aif_list, lf_list, shift_list = process_lf(lf, num_crops, lfsize, patchsize, prob_refocus)
    return aif_list, lf_list, shift_list

def input_pipeline(filenames, lfsize, patchsize, batchsize, num_crops, m, n, c, prob_refocus, Test=False):
    filename_queue = tf.train.slice_input_producer([filenames], shuffle=True)
    example_list = [read_lf(filename_queue, num_crops, lfsize, Test, patchsize, prob_refocus) for _ in range(n)] #number of threads for populating queue
    min_after_dequeue = m
    capacity = c
    aif_batch, lf_batch, shift_list = tf.train.shuffle_batch_join(example_list, batch_size=batchsize, capacity=capacity, 
                                                      min_after_dequeue=min_after_dequeue, enqueue_many=True,
                                                      shapes=[[patchsize[0], patchsize[1], 3], 
                                                              [patchsize[0], patchsize[1], lfsize[2], lfsize[3], 3], 
                                                              [1]])
    return aif_batch, lf_batch, shift_list

###  **Loading light field data**

In [0]:
#path to training examples
train_path = '/media/flash/ExTra/ulf_focdef/data/TrainingSet/OURS/' 
train_filenames = [os.path.join(train_path, f) for f in os.listdir(train_path) if not f.startswith('.')]

#path to validation examples
val_path2 = '/media/flash/ExTra/ulf_focdef/data/TestSet/PAPER/'
val_path = '/media/flash/ExTra/ulf_focdef/data/TestSet/EXTRA/'

val_filenames = [os.path.join(val_path, f) for f in os.listdir(val_path) if not f.startswith('.')]
val_filenames2 = [os.path.join(val_path2, f) for f in os.listdir(val_path2) if not f.startswith('.')]

val_filenames = val_filenames + val_filenames2

#loading light field data from files
aif_batch, lf_batch, shear_values = input_pipeline(train_filenames, lfsize, patchsize, batchsize, num_crops, m=30,n=5, c=88, prob_refocus=.15)
vaif_batch, vlf_batch, vshear_values = input_pipeline(val_filenames, lfsize, test_patchsize, test_batchsize, test_num_crops,
                                                      m=0,n=2, c=10, prob_refocus=0.0011)

is_training = tf.placeholder(tf.bool, [])
#shear_values = tf.placeholder(tf.float32, [batchsize])
lr = tf.placeholder(tf.float32) #learning rate

###  **Forward pass of the pipeline**

In [0]:
blur_code = tf.convert_to_tensor(blur_code, dtype=tf.float32)
test_code = tf.convert_to_tensor(test_code, dtype=tf.float32)

#geenrating defocused images from blur code
lfc = tf.reduce_sum(lf_batch*blur_code, [-2, -3])
vlfc = tf.reduce_sum(vlf_batch*test_code, [-2, -3])

#forward pass of the pipeline
ray_depths, lf_shear, y, lf_ref, lf_shear_ref = forward_model(aif_batch, lfc, lf_batch, lfsize, 
                                                              shear_values, disp_mult, is_training)

vray_depths, vlf_shear, vy, _, _ = forward_model(vaif_batch, vlfc, vlf_batch, lfsize, 
                                                 vshear_values, disp_mult, is_training)

###  **Losses**

In [0]:
#training losses to minimize
lam_tv = 0.01
lam_dc = 0.0065
lam_gr = 0.01
with tf.name_scope('loss'):
    shear_loss = tf.reduce_mean(tf.abs(lf_shear-lf_batch))
    ref_loss = tf.reduce_mean(tf.abs(lf_shear_ref-lf_ref))
    output_loss = tf.reduce_mean(tf.abs(y-lf_batch)) 

    tv_loss = lam_tv *fn_tv_loss(ray_depths)
    grad_loss = lam_gr * fn_grad_loss(ray_depths, lf_batch)
    
    depth_consistency_loss = lam_dc * fn_depth_consistency_loss(ray_depths, lfsize)
    
    regu_loss = tv_loss + depth_consistency_loss + grad_loss + ref_loss
    train_loss =  shear_loss +  output_loss + regu_loss

with tf.name_scope('vloss'):
    vshear_loss = tf.reduce_mean(tf.abs(vlf_shear-vlf_batch))
    voutput_loss = tf.reduce_mean(tf.abs(vy-vlf_batch)) 
    val_loss = vshear_loss + voutput_loss
    
with tf.name_scope('train'):
    train_step = tf.train.AdamOptimizer(learning_rate=lr).minimize(train_loss)    

###  **Generating summaries**

In [0]:
corner1 = tf.convert_to_tensor(np.array([0,0,6,6]).reshape(-1,1), dtype=tf.int32)
corner2 = tf.convert_to_tensor(np.array([0,6,0,6]).reshape(-1,1), dtype=tf.int32)

corner_rays = get_corners(tf.expand_dims(ray_depths,5))
corner_lfbatch = get_corners(lf_batch)
corner_lfshear = get_corners(lf_shear)
corner_lfy = get_corners(y)

#tensorboard summaries
tf.summary.scalar('shear_loss', shear_loss)
tf.summary.scalar('ref_loss', ref_loss)
tf.summary.scalar('output_loss', output_loss)
tf.summary.scalar('tv_loss', tv_loss)
tf.summary.scalar('grd_loss', grad_loss)
tf.summary.scalar('depth_consistency_loss', depth_consistency_loss)
tf.summary.scalar('train_loss', train_loss)

tf.summary.histogram('ray_depths', ray_depths)
tf.summary.histogram('shear_vals', shear_values)
tf.summary.image('input_image', aif_batch[0:2,:,:,:])
tf.summary.image('coded_image', tf.reduce_sum(lf_batch[0:2,:,:,:,:,:]*defocus_code, [-2, -3]))
tf.summary.image('disp_image', tf.reshape(tf.clip_by_value(ray_depths[0:2,:,:,3,3],-1.25,1.25), [2,patchsize[0],patchsize[1],1]))


tf.summary.image('lf_rays', tf.reshape(tf.transpose(corner_rays[0:2, ...], perm=[0, 3, 1, 4, 2]), 
                                        [2, patchsize[0]*2, patchsize[1]*2, 1]))
tf.summary.image('lf_gt', tf.reshape(tf.transpose(corner_lfbatch[0:2, ...], perm=[0, 3, 1, 4, 2, 5]), 
                                        [2, patchsize[0]*2, patchsize[1]*2, 3]))
tf.summary.image('lf_shear', tf.reshape(tf.transpose(corner_lfshear[0:2, ...], perm=[0, 3, 1, 4, 2, 5]), 
                                        [2, patchsize[0]*2, patchsize[1]*2, 3]))
tf.summary.image('lf_y', tf.reshape(tf.transpose(corner_lfy[0:2, ...], perm=[0, 3, 1, 4, 2, 5]), 
                                        [2, patchsize[0]*2, patchsize[1]*2, 3]))

merged = tf.summary.merge_all()


###  **Training**

In [0]:
logdir = '/media/flash/ExTra/ulf_focdef/logs/' #path to store logs
checkpointdir = '/media/flash/ExTra/ulf_focdef/logs/ckpt/' #path to store checkpoints
vf = '/media/flash/ExTra/ulf_focdef/logs/val_results.txt' #path to store validation loss

fid = open(vf,'a')
fid.write(str(.3)+'\n')
fid.close()

config = tf.ConfigProto()
config.gpu_options.allow_growth=True

with tf.Session(config=config) as sess:
    train_writer = tf.summary.FileWriter(logdir, sess.graph)
    sess.run(tf.global_variables_initializer()) #initialize variables (comment out if restoring from trained model)
    
    saver = tf.train.Saver()
    #saver.restore(sess, '/media/data/susmitha/ucsd/dila_defgtcv/checkpoints/2-occ/hsv_occ/model.ckpt-1999') # restore trained model
    #saver = tf.train.Saver()

    print('training')
    coord = tf.train.Coordinator() #coordinator for input queue threads
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) #start input queue threads
    for i in range(train_iters):
        
        #training training stepgi
        tloss = sess.run([shear_loss, output_loss, train_step], 
                         feed_dict={is_training:True, lr:learning_rate}) #shear_values:shear_vals, 
        
        if (i+1) % 50 == 0:
            print i, tloss[0], tloss[1], tloss[1]/tloss[0]
                
        #save training summaries
        if (i+1) % 50 == 0: #can change the frequency of writing summaries for faster training
            trainsummary = sess.run(merged, feed_dict={is_training:True}) #shear_values:shear_vals
            train_writer.add_summary(trainsummary, i)  
            
        #save checkpoint
        if (i+1) % 500 == 0:
            saver.save(sess, checkpointdir + 'model.ckpt', global_step=i)
            
        
        if (i+1) % 1000 == 0:
            print 'testing'
            vloss = []
            sloss = []
            oloss = []
            for j in range(160):
                print j, 
                vl = sess.run([val_loss, vshear_loss, voutput_loss], feed_dict={is_training:False})
                vloss.append(vl[0])
                sloss.append(vl[1])
                oloss.append(vl[2])
            print 
            vloss = np.array(vloss)
            sloss = np.array(sloss)
            oloss = np.array(oloss)
            print np.mean(sloss), np.mean(vloss), np.mean(oloss)
            fid = open(vf,'a')
            fid.write(str(vloss.mean()) + ' '+ str(np.array(sloss).mean()) + ' ' + str(np.array(oloss).mean()) + '\n')
            fid.close()

    #cleanup
    train_writer.close()
    coord.request_stop()
    coord.join(threads)
