In [20]:
import tensorflow as tf
import numpy as np
import glob

## utility functions

In [40]:
def dna_encoder(seq, bases='ACTG'):
    # one-hot-encoding for sequence data
    # enumerates base in a sequence
    indices = map(
        lambda x: bases.index(x) if x in bases else -1,
        seq)
    # one extra index for unknown
    eye = np.eye(len(bases) + 1)
    return eye[indices].astype(np.float32)

def np_sigmoid(x):
    # logistic sigmoid function
    return 1 / (1 + np.e**-x)

def tf_dna_encoder(seq, bases='ACTG'):
    # wraps `dna_encoder` with a `py_func`
    return tf.py_func(dna_encoder, [seq, bases], [tf.float32])[0]


def dataset_input_fn(filenames,
                     buffer_size=10000,
                     batch_size=32,
                     num_epochs=20,
                     ):
    dataset = tf.data.TFRecordDataset(filenames)
    
    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-record preprocessing.
    def parser(record):
        keys_to_features = {
            "sequence": tf.FixedLenFeature((), tf.string),
            "atacCounts": tf.FixedLenFeature((1000,), tf.int64),
            "Labels": tf.FixedLenFeature((1,), tf.int64),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        seq = tf_dna_encoder(parsed["sequence"])
        seq = tf.reshape(seq, [1000, 5])

        # add more here if needed
        return {'seq': seq, 'atac': parsed["atacCounts"]}, parsed["Labels"]
    
    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    
    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features, labels = iterator.get_next()
    return features, labels

## training setup

In [126]:
# training hyper-parameters
BUFFER_SIZE = 10000
BATCH_SIZE = 64
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3
LOGGING_FREQUENCY = 100

# reset default graph
tf.reset_default_graph()

# import data as a shuffle-batch iterator
# https://www.tensorflow.org/programmers_guide/datasets
filenames = glob.glob('../../deleteme/CEBPB-A549-hg38.txt/part-r-*')
features, labels = dataset_input_fn(filenames,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS)

## logistic-regression sanity check

In [127]:
# define symbolic inputs (see previous cell)
sy_seq_n = tf.cast(features['seq'], tf.float32)
sy_atac_n = tf.cast(features['atac'], tf.float32)
sy_label_n = tf.cast(labels, tf.float32)

# concatenate one-hot encoded seq with atac counts
# sy_input_n = tf.concat([sy_seq_n, tf.expand_dims(sy_atac_n, axis=-1)], axis=-1)

# only use sequence data
# sy_input_n = sy_seq_n

# only use atac data
sy_input_n = sy_atac_n

# flatten input tensor data
sy_input_n = tf.contrib.layers.flatten(sy_input_n)

# entering neural-network
sy_net_n = sy_input_n

# multi-layer perceptron
sy_net_n = tf.layers.dense(sy_net_n, units=128)
sy_net_n = tf.nn.relu(sy_net_n)
sy_net_n = tf.layers.dense(net, units=64)
sy_net_n = tf.nn.relu(sy_net_n)

# exit neural-network to logits
sy_logit_n = tf.layers.dense(sy_net_n, units=1, activation=None)

# positively weighted loss
POS_WEIGHT = 2

# optimizer configuration
sy_loss = tf.reduce_mean(
    tf.nn.weighted_cross_entropy_with_logits(
        logits=sy_logit_n,
        targets=sy_label_n,
        pos_weight=POS_WEIGHT))

# sy_loss = tf.losses.sigmoid_cross_entropy(
#     multi_class_labels=sy_label_n,
#     logits=sy_logit_n)

optimizer = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
train_op = optimizer.minimize(sy_loss)

# define logging metrics
sy_accuracy = tf.reduce_mean(tf.cast(
    tf.equal(x=sy_label_n, y=tf.round(tf.nn.sigmoid(sy_logit_n))),
    dtype=tf.float32))

## training loop

In [128]:
# begin the training loop
with tf.train.MonitoredTrainingSession() as sess:
    iteration = 0
    while not sess.should_stop():
        fetches = [sy_logit_n, sy_label_n, sy_loss, sy_accuracy, train_op]
        logit, label, loss, accuracy, _ = sess.run(fetches)
        # log the cross-entropy loss and accuracy
        if iteration % LOGGING_FREQUENCY == 0:
            print('iteration %i: loss=%.4f, accuracy=%.4f' % (iteration, loss, accuracy))
            prob = np_sigmoid(logit).flatten()
            def binary_print(binary_seq):
                return ''.join(['x' if x == 1 else '-' for x in binary_seq])
            print('pred:  %s' % binary_print(np.round(prob).astype(np.int32)))
            print('label: %s' % binary_print(label.flatten().astype(np.int32)))
        iteration += 1

iteration 0: loss=0.9268, accuracy=0.6875
pred:  ------x-----x---x-------------------------------------------x---
label: -x----x-----xx-x-----x-----xx-xx----x----x--x----xxx-x---xxxxx--
iteration 100: loss=0.8543, accuracy=0.7656
pred:  -x---------x------------------------------------------x---------
label: ----x-x----xx---xx-x---x--------x----x----x----xx-----x-x----x--
iteration 200: loss=0.9145, accuracy=0.7031
pred:  -------------------------x--------------------------------------
label: -x--------xx--x-------xx-xx----xx-xx-x-----x--x------x---xx--x-x
iteration 300: loss=0.9463, accuracy=0.6562
pred:  --------------------------------------------------------------xx
label: --xx--x--x-x-x----xxx-xx------xx-----xx---x-xx-x-------xx--x--xx
iteration 400: loss=0.9319, accuracy=0.6562
pred:  ----------------------------------------------------------------
label: -x---x--------x-x----x-----x-x-x-x--x--x---xxxx-x-x-x--x----xxx-
iteration 500: loss=0.8741, accuracy=0.7344
pred:  -----------

iteration 4400: loss=0.8233, accuracy=0.7656
pred:  -----x---------------------------------x--------x---------------
label: -----x--xx-----x--------x----xx---xx---x----xx-----x---x-x----x-
iteration 4500: loss=0.9515, accuracy=0.6250
pred:  ------------x-----x------------------x--------------------------
label: -xx-x-x--xxxxx----x-x-x-----xx-xx---xxx-xx-x------xx----x--x---x
iteration 4600: loss=0.8345, accuracy=0.7812
pred:  ----------------------------x------------------------------x----
label: --------------------------x-x-x--x-----xx-xx--xx-x-xx-x----x---x
iteration 4700: loss=0.8497, accuracy=0.7500
pred:  -------------------------------x--x------------x----------------
label: x-------xx-x--xx---x--x-----x-xx-xxx-x----------------------x--x
iteration 4800: loss=0.8070, accuracy=0.7969
pred:  ---------------------------------------------------------------x
label: ---xxx-x--------x-----x-x--x-----x-------x--------x------x--x--x
iteration 4900: loss=0.8680, accuracy=0.7500
pred:  ---

KeyboardInterrupt: 