Skip to content

Commit

Permalink
changed patch_size and batch_size to variables
Browse files Browse the repository at this point in the history
To enable the use of different models
Also SPCI models are now runnable
  • Loading branch information
TalSchuster committed Nov 16, 2016
1 parent 7bdfcda commit 74b05bf
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
22 changes: 13 additions & 9 deletions patchbatch.py
Expand Up @@ -14,17 +14,15 @@
import cPickle as pickle

DEBUG = False
patch_size = 51
batch_size = 256

def calc_descs(img1_filename, img2_filename, model_name):
def calc_descs(img1_filename, img2_filename, model_name, patch_size, batch_size):
""" given two image files and a CNN model name, calculates the dense descriptor tensor of both images
img1_filename - full path of source image
img2_filename - full path of target image
model_name - name of the trained CNN to use, out of the supported models, documented in pb_Models """

net_name, weights_filename, eparams_filename = Models.nets[model_name]
nn_model, theano_func = NN.get_net_and_funcs(net_name, batch_size, weights_filename, eparams_filename)
nn_model, theano_func = NN.get_net_and_funcs(net_name, patch_size, batch_size, weights_filename, eparams_filename)

img_descs = []
for img_filename in [img1_filename, img2_filename]:
Expand Down Expand Up @@ -137,7 +135,7 @@ def calc_flow_and_cost(img1_descs, img2_descs, pm_params, eliminate_bidi_errors

return flow_res, cost_res

def calc_flow(img1_filename, img2_filename, model_name, output_filename, bidi=False):
def calc_flow(img1_filename, img2_filename, model_name, output_filename, patch_size=51, batch_size=256, bidi=False):
""" Given two input filenames, model_name and output_filename return+save flow res
img1_filename - filename of source image
img2_filename - filename of target image
Expand All @@ -154,7 +152,7 @@ def calc_flow(img1_filename, img2_filename, model_name, output_filename, bidi=Fa
pm_params = (2, 20, 10, 10)

print 'Calculating descriptors...'
img_descs = calc_descs(img1_filename, img2_filename, model_name)
img_descs = calc_descs(img1_filename, img2_filename, model_name, patch_size, batch_size)
print 'Calculating flow fields and matching cost'
flow_res, cost_res = calc_flow_and_cost(img_descs[0], img_descs[1], pm_params, bidi)

Expand All @@ -168,8 +166,7 @@ def calc_flow(img1_filename, img2_filename, model_name, output_filename, bidi=Fa
return flow_res


if __name__ == '__main__':

def main(patch_size=51, batch_size=256):
parser = argparse.ArgumentParser(description = 'PatchBatch Optical Flow algorithm')
parser.add_argument('img1_filename', help = 'Filename (+path) of the source image')
parser.add_argument('img2_filename', help = 'Filename (+path) of the target image')
Expand All @@ -192,6 +189,9 @@ def calc_flow(img1_filename, img2_filename, model_name, output_filename, bidi=Fa
print 'Error! Unsupported pm_params'
sys.exit()

if 'SPCI' in parser.model_name:
patch_size = 71
batch_size = 255

#model_name = 'KITTI2012_CENTSD_ACCURATE'
#img1_filename = '/home/MAGICLEAP/dgadot/patchflow_data/training/image_0/000000_10.png'
Expand All @@ -206,7 +206,7 @@ def calc_flow(img1_filename, img2_filename, model_name, output_filename, bidi=Fa
print 'DEBUG mode is', DEBUG

print 'Calculating descriptors...'
img_descs = calc_descs(parser.img1_filename, parser.img2_filename, parser.model_name)
img_descs = calc_descs(parser.img1_filename, parser.img2_filename, parser.model_name, patch_size, batch_size)
print 'Calculating flow fields and matching cost'
flow_res, cost_res = calc_flow_and_cost(img_descs[0], img_descs[1], pm_params, parser.bidi)

Expand All @@ -224,3 +224,7 @@ def calc_flow(img1_filename, img2_filename, model_name, output_filename, bidi=Fa
pickle.dump(img_descs, f)

kittitool.flow_visualize(flow_res, mode='Y')


if __name__ == '__main__':
main()
25 changes: 19 additions & 6 deletions pb_Models.py
Expand Up @@ -7,7 +7,6 @@

leaky_param = 0.1
in_channels = 1
patch_size = 51
border_mode = 'valid'

cur_dir = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -21,11 +20,11 @@
cur_dir + '/weights/KITTI2012_ACCURATE/241015_080511PAPERKITTI2012-model_drlim7_33conv_allconv_neg1_8_m100_epoch4000_adadelta_testsamples800k_impdrlimv3.yaml-best-test-weights.pickle',
cur_dir + '/weights/KITTI2012_ACCURATE/241015_080511-eparams-test.pickle'],

'KITTI2015_SPCI' : ['model_CENTSD_33conv',
'KITTI2015_SPCI' : ['model_CENTSD_33conv_elu',
cur_dir + '/weights/KITTI2015_SPCI/111016_151719PAPERKITTI2015-model_drlim7_33conv_elu_allconv_neg1_8_m100_epoch4000_adadelta_testsamples200k_hingelosssd_p71_normsamp_sp_load_k15.yaml-best-test-weights.pickle',
cur_dir + '/weights/KITTI2015_SPCI/111016_151719-eparams-test.pickle'],

'KITTI2012_SPCI' : ['model_CENTSD_33conv',
'KITTI2012_SPCI' : ['model_CENTSD_33conv_elu',
cur_dir + '/weights/KITTI2012_SPCI/131016_180941PAPERKITTI2012-model_drlim7_33conv_elu_allconv_neg1_8_m100_epoch4000_adadelta_testsamples200k_hingelosssd_p71_normsamp_sp_load.yaml-best-test-weights.pickle',
cur_dir + '/weights/KITTI2012_SPCI/131016_180941-eparams-test.pickle']}

Expand Down Expand Up @@ -98,10 +97,16 @@ def layer_factory(in_layer, layer_type, **kwargs):

return output_layer

def model_CENTSD_33conv(batch_size,FAST_network=False, FAST_imgheight=None, FAST_imgwidth=None):
def model_CENTSD_33conv(patch_size, batch_size, FAST_network=False, FAST_imgheight=None, FAST_imgwidth=None, nonlin_func='leaky'):
""" Describes the main network used in the PatchBatch paper """

nonlin = nonlinearities.LeakyRectify(leaky_param)
if nonlin_func == 'leaky':
nonlin = nonlinearities.LeakyRectify(leaky_param)
elif nonlin_func == 'elu':
nonlin = nonlinearities.elu
else:
print 'Error! Unsupported non-linearity function'
return

if FAST_network:
l_in0 = layers.InputLayer(
Expand Down Expand Up @@ -157,4 +162,12 @@ def model_CENTSD_33conv(batch_size,FAST_network=False, FAST_imgheight=None, FAST

return layer

all_models = {'model_CENTSD_33conv' : model_CENTSD_33conv}

def model_CENTSD_33conv_elu(batch_size,FAST_network=False, FAST_imgheight=None, FAST_imgwidth=None):
""" creates a CENTSD_33conv model with elu nonlinearity """

return model_CENTSD_33conv(batch_size, FAST_network, FAST_imgheight, FAST_imgwidth, nonlin_func = 'elu')


all_models = {'model_CENTSD_33conv' : model_CENTSD_33conv,
'model_CENTSD_33conv_elu' : model_CENTSD_33conv_elu}
7 changes: 3 additions & 4 deletions pb_NN.py
Expand Up @@ -83,13 +83,12 @@ def get_descriptors(nn_model, theano_func, patches, batch_size, patch_size):
descs.append(cur_descs.squeeze())

res = numpy.vstack(descs)
desc_size = res.shape[-1]
res = res.reshape(h, w, 1, desc_size)
res = res.reshape(h, w, 1, -1)

return res


def get_net_and_funcs(net_name, batch_size, weights_filename, eparams_filename):
def get_net_and_funcs(net_name, patch_size, batch_size, weights_filename, eparams_filename):
""" Creates a Lasagne network + theano function given a model name
net_name - one of the supported models for KITTI2012, KITTI2015 and MPI-Sintel
batch_size - the batch size
Expand All @@ -101,7 +100,7 @@ def get_net_and_funcs(net_name, batch_size, weights_filename, eparams_filename):


print 'Creating NN', net_name
nn_model = Models.all_models[net_name](batch_size)
nn_model = Models.all_models[net_name](patch_size, batch_size)

print 'Describing network:'
describe_network(nn_model)
Expand Down

0 comments on commit 74b05bf

Please sign in to comment.