In [139]:
import numpy as np
import tensorflow.contrib.distributions as ds
from sklearn.utils import shuffle
import tensorflow as tf
import argparse
import ram

In [140]:
sess = tf.InteractiveSession()

In [141]:
glimpse_size = 4
num_resolutions = 3
batch_size = 9
length = 100
num_tfs = 2

In [142]:
def make_dna_seq(batch_size, length):
    one_hot_bases = np.eye(4)
    sample_indices = np.random.randint(0, 4, [batch_size, length])
    return one_hot_bases[sample_indices]


In [143]:
def make_atac_seq(batch_size, length):
    return np.random.randint(0, 20, [batch_size, length, 1])
    

In [144]:
def make_chip_seq(batch_size, num_tfs):
    return np.random.randint(0, 2, [batch_size, num_tfs])


In [145]:
dna = make_dna_seq(batch_size, length)
atac = make_atac_seq(batch_size, length)
chip = make_chip_seq(batch_size, num_tfs)

In [146]:
dna.shape, atac.shape, chip.shape

((9, 100, 4), (9, 100, 1), (9, 2))

In [147]:
input_ = np.concatenate([dna, atac], axis=-1)
input_.shape

(9, 100, 5)

In [148]:
def get_glimpses(data, num_resolutions):
    glimpses = []
    for i in range(num_resolutions):
        resolution = 2**i
        glimpse = tf.nn.pool(
            input=data,
            window_shape=[resolution],
            strides=[resolution],
            pooling_type='MAX',
            padding='SAME')
        glimpses.append(glimpse)
    return glimpses


In [149]:
def index_glimpses(dna, num_resolutions, glimpses, location, glimpse_size):
    to_concatenate = []
    for i in range(num_resolutions):
        glimpse = glimpses[i]
        location_index = int(location / 2**i)
            
        left_pad = max(glimpse_size - location_index, 0)
        right_pad = max(glimpse_size + location_index - glimpse.get_shape().as_list()[1], 0)
        
        left_index = max(location_index - glimpse_size, 0)
        right_index = min(location_index + glimpse_size, glimpse.get_shape().as_list()[1])
        
        if i == 0:
            to_concatenate.append(dna[:, left_index: right_index, :])
        
        resolution = glimpse[:, left_index: right_index, :]
        resolution = tf.pad(resolution, [[0, 0], [left_pad, right_pad], [0, 0]])
        
        to_concatenate.append(resolution)
    return tf.concat(to_concatenate, axis=-1)


In [150]:
glimpses = get_glimpses(atac, num_resolutions)
index_glimpses(dna, num_resolutions, glimpses, 50, glimpse_size).eval()

array([[[ 0,  0,  0,  1,  2, 11, 19],
        [ 0,  0,  0,  1, 15, 18, 14],
        [ 0,  1,  0,  0, 13, 15, 16],
        [ 0,  0,  0,  1, 12, 13, 18],
        [ 0,  0,  1,  0,  1, 16, 16],
        [ 0,  0,  1,  0, 16, 16, 19],
        [ 0,  0,  1,  0,  5, 19, 19],
        [ 0,  0,  0,  1, 16, 19, 19]],

       [[ 0,  1,  0,  0, 16, 16,  9],
        [ 1,  0,  0,  0,  5,  9, 16],
        [ 1,  0,  0,  0,  5, 16, 16],
        [ 0,  0,  1,  0,  8,  8, 16],
        [ 1,  0,  0,  0, 11, 17, 17],
        [ 0,  0,  0,  1, 17, 17, 18],
        [ 0,  0,  0,  1,  8, 18, 14],
        [ 1,  0,  0,  0, 17, 14, 14]],

       [[ 0,  0,  1,  0, 17,  3, 12],
        [ 1,  0,  0,  0,  9, 11, 12],
        [ 0,  0,  0,  1, 14, 17, 19],
        [ 0,  0,  1,  0, 16, 16, 17],
        [ 0,  0,  0,  1, 12, 14, 16],
        [ 0,  0,  1,  0, 14, 18, 18],
        [ 0,  0,  0,  1, 13, 13, 17],
        [ 0,  1,  0,  0, 18, 17, 18]],

       [[ 0,  0,  0,  1, 13,  6, 16],
        [ 0,  0,  1,  0,  6, 13, 16],
      

In [151]:
[g.eval()[0,:,0] for g in glimpses]

[array([ 4,  4,  9, 12,  8,  1,  4,  1, 12,  4, 10,  4, 16, 19,  8,  0,  6,
         7, 15,  3, 18, 19, 14, 13,  9,  7, 19, 12, 15,  9,  6, 16, 10,  5,
         6, 19,  6, 14, 12,  2, 16, 16, 11, 11, 13, 18,  2, 15, 13, 12,  1,
        16,  5, 16, 19, 18, 18, 19,  9,  5, 19, 19, 11,  5,  6,  3,  3,  2,
        10,  4, 15, 11, 15, 10,  4,  7, 10, 14,  1, 17,  9,  0, 10, 13,  4,
         3, 17,  4,  9,  1,  4, 13,  2, 14, 10,  6, 18,  1,  1, 11]),
 array([ 4, 12,  8,  4, 12, 10, 19,  8,  7, 15, 19, 14,  9, 19, 15, 16, 10,
        19, 14, 12, 16, 11, 18, 15, 13, 16, 16, 19, 19,  9, 19, 11,  6,  3,
        10, 15, 15,  7, 14, 17,  9, 13,  4, 17,  9, 13, 14, 10, 18, 11]),
 array([12,  8, 12, 19, 15, 19, 19, 16, 19, 14, 16, 18, 16, 19, 19, 19,  6,
        15, 15, 17, 13, 17, 13, 14, 18])]

In [185]:
def index_glimpses(dna, num_resolutions, glimpses, location, glimpse_size):
    to_concatenate = []
    for i in range(num_resolutions):
        glimpse = glimpses[i]
        location_index = tf.to_int32(location / 2**i)
            
        left_pad = tf.maximum(glimpse_size - location_index, 0)
        right_pad = tf.maximum(glimpse_size + location_index - glimpse.get_shape().as_list()[1], 0)

        left_index = tf.maximum(location_index - glimpse_size, 0)
        right_index = tf.minimum(location_index + glimpse_size, glimpse.get_shape().as_list()[1])
        print(left_index.eval(), right_index.eval(), dna.shape)
        
        if i == 0:
            to_concatenate.append(tf.slice(dna, begin=tf.squeeze(left_index),
                                           size=tf.squeeze(left_index+right_index)))
        resolution = tf.slice(glimpse, begin=tf.squeeze(left_index),
                                           size=tf.squeeze(left_index+right_index))
        print(resolution.eval())
#         resolution = tf.pad(resolution, [[0, 0], [tf.squeeze(left_pad), tf.squeeze(right_pad)], [0, 0]])
#         print(resolution)
        
#         to_concatenate.append(resolution)
#     return tf.concat(to_concatenate, axis=-1)


In [186]:
glimpses = get_glimpses(atac, num_resolutions)
index_glimpses(dna, num_resolutions, glimpses, np.array([[20], [40], [60]]), glimpse_size)

[[16]
 [36]
 [56]] [[24]
 [44]
 [64]] (9, 100, 4)


InvalidArgumentError: Expected begin[0] in [0, 9], but got 16
	 [[Node: Slice_46 = Slice[Index=DT_INT32, T=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](max_pool_172/Squeeze, Squeeze_94, Squeeze_95)]]

Caused by op 'Slice_46', defined at:
  File "/Library/Frameworks/Python.framework/Versions/3.4/lib/python3.4/runpy.py", line 170, in _run_module_as_main
    "__main__", mod_spec)
  File "/Library/Frameworks/Python.framework/Versions/3.4/lib/python3.4/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/ipykernel/kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/tornado/ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/tornado/stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/ipykernel/kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/ipykernel/kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/ipykernel/zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/IPython/core/interactiveshell.py", line 2698, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/IPython/core/interactiveshell.py", line 2808, in run_ast_nodes
    if self.run_code(code, result):
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/IPython/core/interactiveshell.py", line 2862, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-186-e84f4074df01>", line 2, in <module>
    index_glimpses(dna, num_resolutions, glimpses, np.array([[20], [40], [60]]), glimpse_size)
  File "<ipython-input-185-bcb83ba0eea2>", line 18, in index_glimpses
    size=tf.squeeze(left_index+right_index))
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/tensorflow/python/ops/array_ops.py", line 561, in slice
    return gen_array_ops._slice(input_, begin, size, name=name)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3125, in _slice
    name=name)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/Gunjan/.virtualenvs/deeprl/lib/python3.4/site-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Expected begin[0] in [0, 9], but got 16
	 [[Node: Slice_46 = Slice[Index=DT_INT32, T=DT_INT64, _device="/job:localhost/replica:0/task:0/cpu:0"](max_pool_172/Squeeze, Squeeze_94, Squeeze_95)]]
