In [1]:
import tensorflow as tf
import numpy as np
import glob

## data processing

In [2]:
def dna_encoder(seq, bases='ACTG'):
    # one-hot-encoding for sequence data
    # enumerates base in a sequence
    indices = [
        bases.index(x) if x in bases else -1
        for x in seq
    ]
    # one extra index for unknown
    eye = np.eye(len(bases) + 1)
    return eye[indices].astype(np.float32)

def tf_dna_encoder(seq, bases='ACTG'):
    # wraps `dna_encoder` with a `py_func`
    return tf.py_func(dna_encoder, [seq, bases], [tf.float32])[0]

def dataset_input_fn(filenames,
                     buffer_size=10000,
                     batch_size=32,
                     num_epochs=20,
                     ):
    dataset = tf.contrib.data.TFRecordDataset(filenames)
    
    # Use `tf.parse_single_example()` to extract data from a `tf.Example`
    # protocol buffer, and perform any additional per-record preprocessing.
    def parser(record):
        keys_to_features = {
            "sequence": tf.FixedLenFeature((), tf.string),
            "atacCounts": tf.FixedLenFeature((1000,), tf.int64),
            "Labels": tf.FixedLenFeature((1,), tf.int64),
        }
        parsed = tf.parse_single_example(record, keys_to_features)

        # Perform additional preprocessing on the parsed data.
        seq = tf_dna_encoder(parsed["sequence"])
        seq = tf.reshape(seq, [1000, 5])
        atac = parsed["atacCounts"]
        label = parsed["Labels"]

        # add more here if needed
        return {'seq': seq, 'atac': atac}, label
    
    # Use `Dataset.map()` to build a pair of a feature dictionary and a label
    # tensor for each example.
    dataset = dataset.map(parser)
    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()
    
    # `features` is a dictionary in which each value is a batch of values for
    # that feature; `labels` is a batch of labels.
    features, labels = iterator.get_next()
    return features, labels

## hyper-parameters

In [43]:
# training hyper-parameters
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 1e-3
# for shuffled batches
BUFFER_SIZE = 10000
# weighting positive examples
POS_WEIGHT = 1

# logging parameters
CONSOLE_LOG_STEPS = 200
SUMMARY_SAVE_SECS = 20
CHECKPOINT_SAVE_SECS = 200
CHECKPOINT_DIR = './deleteme/checkpoint_dir/'
OUTPUT_DIR = './deleteme/output_dir/'

# percent of data to use for training
TRAIN_PROPORTION = 0.9

## training setup

In [37]:
# reset default graph
tf.reset_default_graph()

# Creates a variable to hold the global_step.
global_step_tensor = tf.Variable(0, trainable=False, name='global_step')

# import data as a shuffle-batch iterator
# https://www.tensorflow.org/programmers_guide/datasets
filenames = glob.glob('../../CEBPB-A549-hg38.txt/part-r-*')
num_train_files = int(len(filenames) * TRAIN_PROPORTION)
train_filenames = filenames[:num_train_files]
valid_filenames = filenames[num_train_files:]

# training data-flow tensors
train_features, train_labels = dataset_input_fn(
    filenames=train_filenames,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    num_epochs=NUM_EPOCHS)

# validation data-flow tensors
valid_features, valid_labels = dataset_input_fn(
    filenames=valid_filenames,
    buffer_size=BUFFER_SIZE,
    batch_size=BATCH_SIZE,
    num_epochs=None)

## define model architecture

In [38]:
def multi_layer_perceptron(x, units, activations, name='mlp',
                           reuse=None):
    # `x` is a tensor with shape [batch, features]
    # `units` is a list of ints, corresponding to hidden units
    # `activations` is a list of activation functions
    with tf.variable_scope(name, reuse=reuse):
        assert len(units) == len(activations)
        # flatten the inputs
        x = tf.contrib.layers.flatten(x)
        for u, a in zip(units, activations):
            x = tf.layers.dense(x, units=u, activation=a)
        return x
    
def hard_coded_cnn(x, name='hard_coded_cnn', reuse=None):
    # leaky-relu helper function for this cnn
    def lrelu(x, alpha=0.2):
        return tf.maximum(alpha*x, x)

    with tf.variable_scope(name, reuse=reuse):
        # specify hard-coded parameters here
        
        # number of filters per conv1d layer
        units = [64] * 4
        # width of the 1d kernel
        kernels = [8] * 4
        # stride at each layer
        strides = [1] * 4
        # pooling window at each layer
        pools = [2, 2, 2, 4]
        # dropout keep prob at each layer
        dropouts = [0.9] * 4
        # activation function at each layer
        activations = [lrelu] * 4

        # conv 1-4
        for u, k, s, p, d, a in zip(units,
                                    kernels,
                                    strides,
                                    pools,
                                    dropouts,
                                    activations):
            x = tf.layers.conv1d(x, u, k, s,
                                 padding='same',
                                 activation=a)
            x = tf.layers.max_pooling1d(x, p, p,
                                        padding="same")
            x = tf.layers.dropout(x, d)

        x = tf.contrib.layers.flatten(x)

        # fc 1
        x = tf.layers.dense(x, 925)
        x = tf.contrib.layers.layer_norm(x)
        x = lrelu(x)
        x = tf.layers.dropout(x, 0.9)

        # fc 2
        x = tf.layers.dense(x, 1)
        x = tf.contrib.layers.layer_norm(x)
        x = tf.layers.dropout(x, 0.9)

        # return logits (no activation)
        return x

In [39]:
def preprocess_inputs(features, labels):
    # define symbolic inputs (see previous cell)
    sy_seq_n = tf.cast(features['seq'], tf.float32)
    sy_atac_n = tf.cast(features['atac'], tf.float32)
    sy_label_n = tf.cast(labels, tf.float32)
    
    # concatenate one-hot encoded seq with atac counts
    sy_input_n = tf.concat(
        [sy_seq_n, tf.expand_dims(sy_atac_n, axis=-1)],
        axis=-1)
    return sy_input_n, sy_label_n
    
    
def core_pipeline(features, labels, name='model', reuse=None,
                 is_train=False):
    # wrapper function for scoping
    # variable sharing across train/valid
    
    # preprocess inputs to neural-network
    sy_input_n, sy_label_n = preprocess_inputs(features,labels)
    
    # pass inputs through a hard-coded network
    sy_logit_n = hard_coded_cnn(sy_input_n, name=name, reuse=reuse)
    
#     def lrelu(x, alpha=0.2):
#         return tf.maximum(alpha*x, x)
    
#     sy_logit_n = multi_layer_perceptron(sy_input_n, [200, 200, 1], [lrelu,lrelu,lrelu], name=name,
#                            reuse=reuse)
    
    # pass inputs through multi-layer perceptron
#     sy_logit_n = multi_layer_perceptron(
#         sy_input_n,
#         units=[256, 128, 1],
#         activations=[tf.nn.relu, tf.nn.relu, None],
#         name=name,
#         reuse=reuse)
    
    # pass logits through sigmoid to get predictions
    sy_pred_n = tf.nn.sigmoid(sy_logit_n)
    
    # computing weighted cross-entropy loss
    sy_loss = tf.reduce_mean(
        tf.nn.weighted_cross_entropy_with_logits(
            logits=sy_logit_n,
            targets=sy_label_n,
            pos_weight=POS_WEIGHT))
    
    # computing area under ROC
    sy_auc, auc_op = tf.metrics.auc(
        labels=sy_label_n,
        predictions=sy_pred_n)
    
    # dictionary of results
    results = {}
    results['pred'] = sy_pred_n
    results['loss'] = sy_loss
    results['auc'] = (sy_auc, auc_op)
    
    if is_train:
        # optimizer configuration
        optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
        train_op = optimizer.minimize(sy_loss, global_step=global_step_tensor)
        results['train_op'] = train_op
    return results

## tensorboard logging

In [40]:
# pass through simple model
sy_train_dict = core_pipeline(train_features, train_labels, name='model', is_train=True)
sy_valid_dict = core_pipeline(valid_features, valid_labels, name='model', reuse=True)

# tensorboard summaries
tf.summary.scalar('loss/train', sy_train_dict['loss'])
tf.summary.scalar('loss/valid', sy_valid_dict['loss'])
tf.summary.scalar('auc/train', sy_train_dict['auc'][0])
tf.summary.scalar('auc/valid', sy_valid_dict['auc'][0])

summary_op = tf.summary.merge_all()

## training loop

In [41]:
# maybe delete existing checkpoint file if
# the graph structure/variables have changed
!rm -r ./deleteme/

rm: cannot remove './deleteme/output_dir/.nfs0000000004c14cbb00000006': Device or resource busy
rm: cannot remove './deleteme/output_dir/.nfs0000000004c14cbc00000005': Device or resource busy
rm: cannot remove './deleteme/checkpoint_dir/.nfs0000000004c14cba00000004': Device or resource busy


In [44]:
# saver hook periodically checkpoints model
# session creator can restore from these
saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir=CHECKPOINT_DIR,
                                          save_secs=CHECKPOINT_SAVE_SECS)

# summary hook allows you to use tensorboard
# specify the metrics you want to log above
summary_hook = tf.train.SummarySaverHook(summary_writer=tf.summary.FileWriter(OUTPUT_DIR),
                                         save_secs=SUMMARY_SAVE_SECS, summary_op=summary_op)

# begin the training loop
with tf.train.MonitoredSession(session_creator=tf.train.ChiefSessionCreator(
    checkpoint_dir=CHECKPOINT_DIR), hooks=[summary_hook, saver_hook]) as sess:

    while not sess.should_stop():
        # step forward everything

        _, _, summary = sess.run([sy_train_dict, sy_valid_dict, summary_op])


        global_step = tf.train.global_step(sess, global_step_tensor)
        
        # maybe log to console
        if CONSOLE_LOG_STEPS and global_step % CONSOLE_LOG_STEPS == 0:
            # parse the protocol-buffer string
            summary_proto = tf.summary.Summary()
            summary_proto.ParseFromString(summary)
            print('Iteration %s' % global_step)
            for value in summary_proto.value:
                print(value.tag + ': ' + str(value.simple_value))

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from ./deleteme/checkpoint_dir/model.ckpt-1
INFO:tensorflow:Saving checkpoints for 2 into ./deleteme/checkpoint_dir/model.ckpt.
Iteration 200
loss/train: 0.525188326836
loss/valid: 0.526072561741
auc/train: 0.540401041508
auc/valid: 0.535288274288
Iteration 400
loss/train: 0.575474500656
loss/valid: 0.659981787205
auc/train: 0.555871188641
auc/valid: 0.54803609848
Iteration 600
loss/train: 0.473432719707
loss/valid: 0.588647127151
auc/train: 0.561644136906
auc/valid: 0.562285840511
Iteration 800
loss/train: 0.603337347507
loss/valid: 0.48531422019
auc/train: 0.567743659019
auc/valid: 0.565438568592
Iteration 1000
loss/train: 0.525676369667
loss/valid: 0.564874351025
auc/train: 0.569696724415
auc/valid: 0.567727565765
Iteration 1200
loss/train: 0.572950005531
loss/valid: 0.658159554005
auc/train: 0.573168754578
auc/valid: 0.570622801781
Iteration 1400
loss/train: 0.666670084
loss/valid: 0.462632000446
auc/t

Iteration 13000
loss/train: 0.615000069141
loss/valid: 0.636752724648
auc/train: 0.661698698997
auc/valid: 0.578765273094
Iteration 13200
loss/train: 0.618879079819
loss/valid: 0.604412198067
auc/train: 0.663091540337
auc/valid: 0.578633189201
Iteration 13400
loss/train: 0.47796985507
loss/valid: 0.588779807091
auc/train: 0.664322733879
auc/valid: 0.578734338284
Iteration 13600
loss/train: 0.421558022499
loss/valid: 0.533285021782
auc/train: 0.665739238262
auc/valid: 0.578565120697
Iteration 13800
loss/train: 0.602245330811
loss/valid: 0.655339598656
auc/train: 0.66732531786
auc/valid: 0.578717291355
Iteration 14000
loss/train: 0.500169038773
loss/valid: 0.520699441433
auc/train: 0.668982923031
auc/valid: 0.578652143478
INFO:tensorflow:Saving checkpoints for 14132 into ./deleteme/checkpoint_dir/model.ckpt.
Iteration 14200
loss/train: 0.489594519138
loss/valid: 0.570780873299
auc/train: 0.670753180981
auc/valid: 0.578399121761
Iteration 14400
loss/train: 0.418091833591
loss/valid: 0.827

Iteration 25800
loss/train: 0.337287604809
loss/valid: 0.94754087925
auc/train: 0.76244610548
auc/valid: 0.575443863869
Iteration 26000
loss/train: 0.506546616554
loss/valid: 0.420996069908
auc/train: 0.764119267464
auc/valid: 0.575433492661
Iteration 26200
loss/train: 0.219534680247
loss/valid: 0.83654999733
auc/train: 0.765833199024
auc/valid: 0.575428962708
Iteration 26400
loss/train: 0.407030165195
loss/valid: 0.545385122299
auc/train: 0.767462253571
auc/valid: 0.575310707092
Iteration 26600
loss/train: 0.415561616421
loss/valid: 1.11650061607
auc/train: 0.769137322903
auc/valid: 0.575275480747
Iteration 26800
loss/train: 0.375569939613
loss/valid: 0.682141542435
auc/train: 0.770657598972
auc/valid: 0.575161337852
Iteration 27000
loss/train: 0.313041418791
loss/valid: 0.799977004528
auc/train: 0.772211194038
auc/valid: 0.575079381466
Iteration 27200
loss/train: 0.21164034307
loss/valid: 0.793759405613
auc/train: 0.773635625839
auc/valid: 0.574935257435
Iteration 27400
loss/train: 0

KeyboardInterrupt: 