In [1]:
import tensorflow as tf
import numpy as np
import glob

## utility functions

In [2]:
def dna_encoder(seq, bases='ACTG'):
    # one-hot-encoding for sequence data
    # enumerates base in a sequence
    indices = map(
        lambda x: bases.index(x) if x in bases else -1,
        seq)
    # one extra index for unknown
    eye = np.eye(len(bases) + 1)
    return eye[indices].astype(np.float32)


def np_sigmoid(x):
    # logistic sigmoid function
    return 1 / (1 + np.e**-x)


def tf_dna_encoder(seq, bases='ACTG'):
    # wraps `dna_encoder` with a `py_func`
    return tf.py_func(dna_encoder, [seq, bases], [tf.float32])[0]


def dataset_input_fn(filenames,
                     buffer_size=10000,
                     batch_size=32,
                     num_epochs=20,
                     ):
    dataset = tf.data.TFRecordDataset(filenames)
    
    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-record preprocessing.
    def parser(record):
        keys_to_features = {
            "sequence": tf.FixedLenFeature((), tf.string),
            "atacCounts": tf.FixedLenFeature((1000,), tf.int64),
            "Labels": tf.FixedLenFeature((1,), tf.int64),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        seq = tf_dna_encoder(parsed["sequence"])
        seq = tf.reshape(seq, [1000, 5])

        # add more here if needed
        return {'seq': seq, 'atac': parsed["atacCounts"]}, parsed["Labels"]
    
    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    
    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features, labels = iterator.get_next()
    return features, labels

## training setup

In [10]:
# training hyper-parameters
BUFFER_SIZE = 10000
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3
LOGGING_FREQUENCY = 100

# reset default graph
tf.reset_default_graph()

# import data as a shuffle-batch iterator
# https://www.tensorflow.org/programmers_guide/datasets
filenames = glob.glob('../../deleteme/CEBPB-A549-hg38.txt/part-r-*')
features, labels = dataset_input_fn(filenames,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS)

## logistic-regression sanity check

In [11]:
# define symbolic inputs (see previous cell)
sy_seq_n = tf.cast(features['seq'], tf.float32)
sy_atac_n = tf.cast(features['atac'], tf.float32)
sy_label_n = tf.cast(labels, tf.float32)

# concatenate one-hot encoded seq with atac counts
sy_input_n = tf.concat(
    [sy_seq_n, tf.expand_dims(sy_atac_n, axis=-1)],
    axis=-1)

# only use sequence data
# sy_input_n = sy_seq_n

# only use atac data
# sy_input_n = sy_atac_n

# flatten input tensor data
sy_input_n = tf.contrib.layers.flatten(sy_input_n)

# entering neural-network
sy_net_n = sy_input_n

# multi-layer perceptron
# for units in [256, 128, 64]:
#     sy_net_n = tf.layers.dense(sy_net_n, units=units)
#     sy_net_n = tf.nn.relu(sy_net_n)

# exit neural-network to logits
sy_logit_n = tf.layers.dense(sy_net_n, units=1, activation=None)
sy_prediction_n = tf.nn.sigmoid(sy_logit_n)

# positively weighted loss
# pos weight is 1 by default
POS_WEIGHT = 1

# optimizer configuration
sy_loss = tf.reduce_mean(
    tf.nn.weighted_cross_entropy_with_logits(
        logits=sy_logit_n,
        targets=sy_label_n,
        pos_weight=POS_WEIGHT))

optimizer = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
train_op = optimizer.minimize(sy_loss)

# define logging metrics
sy_auc, auc_op = tf.metrics.auc(
    labels=sy_label_n,
    predictions=sy_prediction_n)

## training loop

In [None]:
# begin the training loop
with tf.train.MonitoredTrainingSession() as sess:
    iteration = 0
    while not sess.should_stop():
        update_ops = [train_op, auc_op]
        fetches = [sy_loss, sy_auc, update_ops]
        loss, auc, _ = sess.run(fetches)
        
        # log the cross-entropy loss, AUC, precision, recall
        if iteration % LOGGING_FREQUENCY == 0:
            print('iteration %i: loss=%.4f, auc=%.4f' % (iteration, loss, auc))
        iteration += 1

iteration 0: loss=0.6899, auc=0.0000
iteration 100: loss=0.6568, auc=0.5085
iteration 200: loss=0.6501, auc=0.5137
iteration 300: loss=0.5393, auc=0.5125
iteration 400: loss=0.6896, auc=0.5159
iteration 500: loss=0.5945, auc=0.5198
iteration 600: loss=0.7205, auc=0.5193
iteration 700: loss=0.7294, auc=0.5206
iteration 800: loss=0.7368, auc=0.5222
iteration 900: loss=0.4711, auc=0.5226
iteration 1000: loss=0.5740, auc=0.5225
iteration 1100: loss=0.6925, auc=0.5221
iteration 1200: loss=0.6163, auc=0.5228
iteration 1300: loss=0.6791, auc=0.5236
iteration 1400: loss=0.6539, auc=0.5274
iteration 1500: loss=0.6932, auc=0.5271
iteration 1600: loss=0.5450, auc=0.5275
iteration 1700: loss=0.6548, auc=0.5284
iteration 1800: loss=0.4839, auc=0.5280
iteration 1900: loss=0.6187, auc=0.5282
iteration 2000: loss=0.5966, auc=0.5296
iteration 2100: loss=0.5987, auc=0.5305
iteration 2200: loss=0.6647, auc=0.5308
iteration 2300: loss=0.5106, auc=0.5321
iteration 2400: loss=0.7596, auc=0.5328
iteration 25