In [None]:
import tensorflow as tf
import numpy as np
import glob

In [None]:
def dna_encoder(seq, bases='ACTG'):
    # one-hot-encoding for sequence data
    # enumerates base in a sequence
    indices = [
        bases.index(x) if x in bases else -1
        for x in seq
    ]
    # one extra index for unknown
    eye = np.eye(len(bases) + 1)
    return eye[indices].astype(np.float32)

def tf_dna_encoder(seq, bases='ACTG'):
    # wraps `dna_encoder` with a `py_func`
    return tf.py_func(dna_encoder, [seq, bases], [tf.float32])[0]

def dataset_input_fn(filenames,
                     buffer_size=10000,
                     batch_size=32,
                     num_epochs=20,
                     ):
    dataset = tf.contrib.data.TFRecordDataset(filenames)
    
    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-record preprocessing.
    def parser(record):
        keys_to_features = {
            "sequence": tf.FixedLenFeature((), tf.string),
            "atacCounts": tf.FixedLenFeature((1000,), tf.int64),
            "Labels": tf.FixedLenFeature((1,), tf.int64),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        seq = tf_dna_encoder(parsed["sequence"])
        seq = tf.reshape(seq, [1000, 5])
        atac = parsed["atacCounts"]
        label = parsed["Labels"]

        # add more here if needed
        return {'seq': seq, 'atac': atac}, label
    
    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    
    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features, labels = iterator.get_next()
    return features, labels

In [None]:
# data pipeline parameters
buffer_size = 10000
batch_size = 64
num_epochs = 1

# reset the graph because it might be finalized
tf.reset_default_graph()

# sharded tfrecord filenames
filenames = glob.glob('../../deleteme/CEBPB-A549-hg38.txt/part-r-*')
features, labels = dataset_input_fn(filenames=filenames,
                                    buffer_size=buffer_size,
                                    num_epochs=num_epochs)

sy_seq_n = tf.cast(features['seq'], tf.float32)
sy_atac_n = tf.cast(features['atac'], tf.float32)
sy_label_n = tf.cast(labels, tf.float32)

In [None]:
with tf.train.MonitoredSession() as sess:
    # `sess.should_stop()` will indicate if the dataflow runs out
    # loop breaks once a certain number of epochs has passed
    while not sess.should_stop():
        print(np.mean(sess.run(sy_atac_n)))