In [4]:
import numpy as np
import tensorflow.contrib.distributions as ds
from sklearn.utils import shuffle
import tensorflow as tf
import argparse
import ram

In [5]:
sess = tf.InteractiveSession()

In [15]:
# side of full glimpse will actually be glimpse_size * 2 + 1
glimpse_size = 4
num_resolutions = 3
batch_size = 9
length = 100
num_tfs = 2

In [244]:
def make_dna_seq(batch_size, length):
    one_hot_bases = np.eye(4)
    sample_indices = np.random.randint(0, 4, [batch_size, length])
    return one_hot_bases[sample_indices]

def make_atac_seq(batch_size, length):
    return np.random.randint(0, 20, [batch_size, length, 1])
    
def make_chip_seq(batch_size, num_tfs):
    return np.random.randint(0, 2, [batch_size, num_tfs])

def get_glimpses(data, num_resolutions):
    glimpses = []
    for i in range(num_resolutions):
        resolution = 2**i
        glimpse = tf.nn.pool(
            input=data,
            window_shape=[resolution],
            strides=[resolution],
            pooling_type='MAX',
            padding='SAME')
        glimpses.append(glimpse)
    return glimpses

def index_glimpses(dna, location, num_resolutions, glimpses, glimpse_size, length, batch_size):
    to_concatenate = []
    for i in range(num_resolutions):
        glimpse = glimpses[i]
        # add glimpse_size to location_index
        # because glimpses will be padded with glimpse_size values on each side
        location_index = tf.to_int32(location / 2.0**i) + glimpse_size
        start_index = location_index - glimpse_size
        boolean_mask = get_boolean_mask(glimpse_size, start_index, glimpse.shape[1])
        padded_glimpse = get_padded_glimspe(glimpse, glimpse_size)
        if i == 0:
            padded_dna = get_padded_dna(dna, glimpse_size)
            dna_boolean_mask = tf.squeeze(tf.stack([boolean_mask]*4, axis=-1))
            sliced_dna = tf.boolean_mask(padded_dna, dna_boolean_mask)
            print(boolean_mask)
            print(dna_boolean_mask)
            print(sliced_dna)
            sliced_dna = tf.reshape(sliced_dna, [batch_size, glimpse_size * 2, 4])
            to_concatenate.append(tf.to_int64(sliced_dna))
        sliced_glimpse = tf.boolean_mask(padded_glimpse, boolean_mask)
        sliced_glimpse = tf.reshape(sliced_glimpse, [batch_size, glimpse_size * 2, 1])
        to_concatenate.append(sliced_glimpse)
#     print(to_concatenate)
    return tf.concat(to_concatenate, axis=-1)
        
def get_boolean_mask(glimpse_size, start_index, length):
    curr_index = start_index
    padded_size = length + 2 * glimpse_size
    index_mask = tf.one_hot(indices=curr_index, depth=padded_size, axis=1)
    for i in range(glimpse_size * 2 - 1):
        curr_index += 1
        index_mask += tf.one_hot(indices=curr_index, depth=padded_size, axis=1)
    return index_mask > 0
    
def get_padded_glimspe(glimpse, glimpse_size):
    return tf.pad(glimpse, paddings=[[0, 0], [glimpse_size, glimpse_size], [0, 0]], constant_values=-1)

def get_padded_dna(dna, glimpse_size):
    return get_padded_glimspe(dna, glimpse_size)


In [245]:
dna = make_dna_seq(batch_size, length)
atac = make_atac_seq(batch_size, length)
chip = make_chip_seq(batch_size, num_tfs)
input_ = np.concatenate([dna, atac], axis=-1)
input_.shape

(9, 100, 5)

In [246]:
location = np.array([[20], [40], [99]] * 3)

glimpses = get_glimpses(atac, num_resolutions)
# index_glimpses(dna, num_resolutions, glimpses, 50, glimpse_size).eval()
g = index_glimpses(dna, location, num_resolutions, glimpses, glimpse_size, length, batch_size)
g.eval(), g.shape

Tensor("Greater_170:0", shape=(9, 108, 1), dtype=bool)
Tensor("Squeeze_32:0", shape=(9, 108, 4), dtype=bool)
Tensor("boolean_mask_214/Gather:0", shape=(?,), dtype=float64)


(array([[[ 0,  1,  0,  0,  4, 10, 16],
         [ 1,  0,  0,  0, 18, 10, 13],
         [ 0,  0,  1,  0, 12, 18, 10],
         [ 0,  0,  0,  1, 18, 18, 18],
         [ 1,  0,  0,  0, 12, 12, 15],
         [ 1,  0,  0,  0,  1, 15, 19],
         [ 0,  0,  0,  1,  7,  7, 16],
         [ 0,  1,  0,  0, 15, 19, 19]],
 
        [[ 0,  1,  0,  0,  2, 14, 19],
         [ 0,  0,  1,  0, 13,  4, 18],
         [ 1,  0,  0,  0, 15, 13, 14],
         [ 0,  0,  1,  0, 11, 15, 15],
         [ 0,  1,  0,  0,  9, 16, 16],
         [ 0,  1,  0,  0, 16, 16, 16],
         [ 1,  0,  0,  0, 16, 12, 19],
         [ 0,  0,  1,  0,  5, 16, 15]],
 
        [[ 0,  1,  0,  0, 10,  9,  9],
         [ 0,  0,  0,  1, 17, 16, 18],
         [ 0,  1,  0,  0,  2, 13, 18],
         [ 1,  0,  0,  0,  5, 17, 16],
         [ 0,  0,  1,  0, 18, 18, 18],
         [-1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1]],
 
        [[ 0,  0,  1,  0, 13, 15, 19],
         [ 0,  0