In [None]:
import tensorflow as tf
import numpy as np
import glob

## data processing

In [None]:
def dna_encoder(seq, bases='ACTG'):
    # one-hot-encoding for sequence data
    # enumerates base in a sequence
    indices = map(
        lambda x: bases.index(x) if x in bases else -1,
        seq)
    # one extra index for unknown
    eye = np.eye(len(bases) + 1)
    return eye[indices].astype(np.float32)

def tf_dna_encoder(seq, bases='ACTG'):
    # wraps `dna_encoder` with a `py_func`
    return tf.py_func(dna_encoder, [seq, bases], [tf.float32])[0]

def dataset_input_fn(filenames,
                     buffer_size=10000,
                     batch_size=32,
                     num_epochs=20,
                     ):
    dataset = tf.data.TFRecordDataset(filenames)
    
    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-record preprocessing.
    def parser(record):
        keys_to_features = {
            "sequence": tf.FixedLenFeature((), tf.string),
            "atacCounts": tf.FixedLenFeature((1000,), tf.int64),
            "Labels": tf.FixedLenFeature((1,), tf.int64),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        seq = tf_dna_encoder(parsed["sequence"])
        seq = tf.reshape(seq, [1000, 5])
        atac = parsed["atacCounts"]
        label = parsed["Labels"]

        # add more here if needed
        return {'seq': seq, 'atac': atac}, label
    
    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    
    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features, labels = iterator.get_next()
    return features, labels

## hyper-parameters

In [None]:
# training hyper-parameters
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3
# for shuffled batches
BUFFER_SIZE = 10000
# weighting positive examples
POS_WEIGHT = 1

# logging parameters
CONSOLE_LOG_STEPS = 200
SUMMARY_SAVE_SECS = 20
CHECKPOINT_SAVE_SECS = 200
CHECKPOINT_DIR = './deleteme/checkpoint_dir/'
OUTPUT_DIR = './deleteme/output_dir/'

# percent of data to use for training
TRAIN_PROPORTION = 0.9

## training setup

In [None]:
# reset default graph
tf.reset_default_graph()

# Creates a variable to hold the global_step.
global_step_tensor = tf.Variable(0, trainable=False, name='global_step')

# import data as a shuffle-batch iterator
# https://www.tensorflow.org/programmers_guide/datasets
filenames = glob.glob('../../deleteme/CEBPB-A549-hg38.txt/part-r-*')
num_train_files = int(len(filenames) * TRAIN_PROPORTION)
train_filenames = filenames[:num_train_files]
valid_filenames = filenames[num_train_files:]

# training data-flow tensors
train_features, train_labels = dataset_input_fn(
    filenames=train_filenames,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS)

# validation data-flow tensors
valid_features, valid_labels = dataset_input_fn(
    filenames=valid_filenames,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    num_epochs=None)

## define model architecture

In [None]:
def multi_layer_perceptron(x, units, activations, name='mlp',
                           reuse=None):
    # `x` is a tensor with shape [batch, features]
    # `units` is a list of ints, corresponding to hidden units
    # `activations` is a list of activation functions
    with tf.variable_scope(name, reuse=reuse):
        assert len(units) == len(activations)
        for u, a in zip(units, activations):
            x = tf.layers.dense(x, units=u, activation=a)
        return x
    
    
def preprocess_inputs(features, labels):
    # define symbolic inputs (see previous cell)
    sy_seq_n = tf.cast(features['seq'], tf.float32)
    sy_atac_n = tf.cast(features['atac'], tf.float32)
    sy_label_n = tf.cast(labels, tf.float32)
    
    # concatenate one-hot encoded seq with atac counts
    sy_input_n = tf.concat(
        [sy_seq_n, tf.expand_dims(sy_atac_n, axis=-1)],
        axis=-1)
    return sy_input_n, sy_label_n
    
    
def simple_model(features, labels, name='model', reuse=None,
                 is_train=False):
    # wrapper function for scoping
    # variable sharing across train/valid
    
    # preprocess inputs to neural-network
    sy_input_n, sy_label_n = preprocess_inputs(features,labels)
    
    # pass inputs through multi-layer perceptron
    # in this case, just a logistic-regression
    sy_logit_n = multi_layer_perceptron(
        tf.contrib.layers.flatten(sy_input_n),
        units=[256, 128, 1],
        activations=[tf.nn.relu, tf.nn.relu, None],
        name=name,
        reuse=reuse)
    
    # pass logits through sigmoid to get predictions
    sy_pred_n = tf.nn.sigmoid(sy_logit_n)
    
    # computing weighted cross-entropy loss
    sy_loss = tf.reduce_mean(
        tf.nn.weighted_cross_entropy_with_logits(
            logits=sy_logit_n,
            targets=sy_label_n,
            pos_weight=POS_WEIGHT))
    
    # computing area under ROC
    sy_auc, auc_op = tf.metrics.auc(
        labels=sy_label_n,
        predictions=sy_pred_n)
    
    # dictionary of results
    results = {}
    results['pred'] = sy_pred_n
    results['loss'] = sy_loss
    results['auc'] = (sy_auc, auc_op)
    
    if is_train:
        # optimizer configuration
        optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
        train_op = optimizer.minimize(sy_loss, global_step=global_step_tensor)
        results['train_op'] = train_op
    return results

## logistic-regression sanity check

In [None]:
# pass through simple model
sy_train_dict = simple_model(train_features, train_labels, name='model', is_train=True)
sy_valid_dict = simple_model(valid_features, valid_labels, name='model', reuse=True)

# tensorboard summaries
tf.summary.scalar('loss/train', sy_train_dict['loss'])
tf.summary.scalar('loss/valid', sy_valid_dict['loss'])
tf.summary.scalar('auc/train', sy_train_dict['auc'][0])
tf.summary.scalar('auc/valid', sy_valid_dict['auc'][0])

summary_op = tf.summary.merge_all()

## training loop

In [None]:
# maybe delete existing checkpoint file if
# the graph structure/variables have changed
!rm -r ./deleteme/

In [None]:
# saver hook periodically checkpoints model
# session creator can restore from these
saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir=CHECKPOINT_DIR,
                                          save_secs=CHECKPOINT_SAVE_SECS)

# summary hook allows you to use tensorboard
# specify the metrics you want to log above
summary_hook = tf.train.SummarySaverHook(summary_writer=tf.summary.FileWriter(OUTPUT_DIR),
                                         save_secs=SUMMARY_SAVE_SECS, summary_op=summary_op)

# begin the training loop
with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(
    checkpoint_dir=CHECKPOINT_DIR), hooks=[summary_hook, saver_hook]) as sess:
    
    while not sess.should_stop():
        # step forward everything
        _, _, summary = sess.run([sy_train_dict, sy_valid_dict, summary_op])
        global_step = tf.train.global_step(sess, global_step_tensor)
        
        # maybe log to console
        if CONSOLE_LOG_STEPS and global_step % CONSOLE_LOG_STEPS == 0:
            # parse the protocol-buffer string
            summary_proto = tf.summary.Summary()
            summary_proto.ParseFromString(summary)
            print('Iteration %s' % global_step)
            for value in summary_proto.value:
                print(value.tag + ': ' + str(value.simple_value))