In [1]:
import tensorflow as tf
import numpy as np
import glob

## utility functions

In [2]:
def dna_encoder(seq, bases='ACTG'):
    # one-hot-encoding for sequence data
    # enumerates base in a sequence
    indices = map(
        lambda x: bases.index(x) if x in bases else -1,
        seq)
    # one extra index for unknown
    eye = np.eye(len(bases) + 1)
    return eye[indices].astype(np.float32)


def np_sigmoid(x):
    # logistic sigmoid function
    return 1 / (1 + np.e**-x)


def tf_dna_encoder(seq, bases='ACTG'):
    # wraps `dna_encoder` with a `py_func`
    return tf.py_func(dna_encoder, [seq, bases], [tf.float32])[0]


def dataset_input_fn(filenames,
                     buffer_size=10000,
                     batch_size=32,
                     num_epochs=20,
                     ):
    dataset = tf.data.TFRecordDataset(filenames)
    
    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-record preprocessing.
    def parser(record):
        keys_to_features = {
            "sequence": tf.FixedLenFeature((), tf.string),
            "atacCounts": tf.FixedLenFeature((1000,), tf.int64),
            "Labels": tf.FixedLenFeature((1,), tf.int64),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        seq = tf_dna_encoder(parsed["sequence"])
        seq = tf.reshape(seq, [1000, 5])

        # add more here if needed
        return {'seq': seq, 'atac': parsed["atacCounts"]}, parsed["Labels"]
    
    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    
    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features, labels = iterator.get_next()
    return features, labels

## training setup

In [63]:
# training hyper-parameters
BUFFER_SIZE = 10000
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3
PRINT_FREQUENCY = 100

# what type of data to use?
# options: 'SEQ', 'ATAC', or 'BOTH'
DATA_SOURCE = 'BOTH'

# number of hidden units per layer
# empty-list implies logistic regression
# example: [256, 128, 64]
PERCEPTRON_UNITS = []

# positively weighted loss
# pos weight is 1 by default
POS_WEIGHT = 1

# reset default graph
tf.reset_default_graph()

# Creates a variable to hold the global_step.
global_step_tensor = tf.Variable(0, trainable=False, name='global_step')

# import data as a shuffle-batch iterator
# https://www.tensorflow.org/programmers_guide/datasets
filenames = glob.glob('../../deleteme/CEBPB-A549-hg38.txt/part-r-*')
features, labels = dataset_input_fn(filenames,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS)

## logistic-regression sanity check

In [64]:
# define symbolic inputs (see previous cell)
sy_seq_n = tf.cast(features['seq'], tf.float32)
sy_atac_n = tf.cast(features['atac'], tf.float32)
sy_label_n = tf.cast(labels, tf.float32)

if DATA_SOURCE == 'BOTH':
    # concatenate one-hot encoded seq with atac counts
    sy_input_n = tf.concat(
        [sy_seq_n, tf.expand_dims(sy_atac_n, axis=-1)],
        axis=-1)
elif DATA_SOURCE == 'SEQ':
    # only use sequence data
    sy_input_n = sy_seq_n
elif DATA_SOURCE == 'ATAC':
    # only use atac data
    sy_input_n = sy_atac_n
else:
    raise NotImplementedError()

# flatten input tensor data
sy_net_n = tf.contrib.layers.flatten(sy_input_n)

# multi-layer perceptron
for units in PERCEPTRON_UNITS:
    sy_net_n = tf.layers.dense(sy_net_n, units=units)
    sy_net_n = tf.nn.relu(sy_net_n)

# exit neural-network to logits
sy_logit_n = tf.layers.dense(sy_net_n, units=1, activation=None)
sy_prediction_n = tf.nn.sigmoid(sy_logit_n)

# optimizer configuration
sy_loss = tf.reduce_mean(
    tf.nn.weighted_cross_entropy_with_logits(
        logits=sy_logit_n,
        targets=sy_label_n,
        pos_weight=POS_WEIGHT))

optimizer = tf.train.GradientDescentOptimizer(learning_rate=LEARNING_RATE)
train_op = optimizer.minimize(sy_loss, global_step=global_step_tensor)

# define logging metrics
sy_auc, auc_op = tf.metrics.auc(
    labels=sy_label_n,
    predictions=sy_prediction_n)

# tensorboard summaries
tf.summary.scalar('loss', sy_loss)
tf.summary.scalar('auc', sy_auc)
summary_op = tf.summary.merge_all()

## training loop

In [68]:
# maybe delete existing checkpoint file
# if the graph structure/variables have
# changed
!rm -rf ./deleteme/

In [69]:
# session config
saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir='./deleteme/checkpoint_dir',
                                          save_secs=100)
summary_hook = tf.train.SummarySaverHook(output_dir='./deleteme/output_dir',
                                         save_secs=100,
                                         summary_op=summary_op)

# begin the training loop
with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(),
                               hooks=[saver_hook, summary_hook]) as sess:
    while not sess.should_stop():
        update_ops = [train_op, auc_op]
        fetches = [sy_loss, sy_auc, update_ops]
        loss, auc, _ = sess.run(fetches)
        global_step = tf.train.global_step(sess, global_step_tensor)
        
        # log the cross-entropy loss, AUC, precision, recall
        if global_step % PRINT_FREQUENCY == 0:
            print('global_step %i: loss=%.4f, auc=%.4f' % (global_step, loss, auc))

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into ./deleteme/checkpoint_dir/model.ckpt.
global_step 100: loss=0.6203, auc=0.5043
global_step 200: loss=0.5454, auc=0.5078
global_step 300: loss=0.6789, auc=0.5106
global_step 400: loss=0.6866, auc=0.5160
global_step 500: loss=0.5871, auc=0.5219


KeyboardInterrupt: 