In [None]:
import logging
import time
import datetime
from pathlib import Path

import tensorflow as tf
import tensorflow_text as text

import data_io
from data_io import _parse_function

from utils import *
from model import *


In [None]:
# Setup GPU, limiting GPU memory growth
#
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
#
def prepare_batch(example,label):
    dna=example[0]
    epi=tokenizer.tokenize(example[1])
    epi=epi.merge_dims(-2,-1).to_tensor()  #
    return (dna, epi), label

def make_batches(ds, batch_size=32,buffer_size=20000,shuffle=False,):
    if shuffle:
        ds = ds.shuffle(buffer_size)
    return (
        ds
        .batch(batch_size)
        .map(prepare_batch, tf.data.AUTOTUNE)
        .prefetch(buffer_size=tf.data.AUTOTUNE))


In [None]:
#Matthews Correlation Coefficient (MCC) metric
#
from tensorflow.keras.metrics import Metric
from tensorflow.keras import backend as K

class MatthewsCorrelationCoefficient(Metric):
    def __init__(self, name='matthews_correlation_coefficient', threshold=0.5, **kwargs):
        super().__init__(name=name, **kwargs)
        self.true_positives = self.add_variable(
            shape=(),
            initializer='zeros',
            name='true_positives'
        )
        self.true_negatives = self.add_variable(
            shape=(),
            initializer='zeros',
            name='true_negatives'
        )
        self.false_positives = self.add_variable(
            shape=(),
            initializer='zeros',
            name='false_positives'
        )
        self.false_negatives = self.add_variable(
            shape=(),
            initializer='zeros',
            name='false_negatives'
        )

    def update_state(self, y_true, y_pred, sample_weight=None,threshold=0.5):
        y_true = tf.cast(y_true, 'bool')
        y_pred = tf.cast(y_pred > threshold, 'bool')  # 

        tp = tf.reduce_sum(tf.cast(tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)), self.dtype))
        tn = tf.reduce_sum(tf.cast(tf.logical_and(tf.equal(y_true, False), tf.equal(y_pred, False)), self.dtype))
        fp = tf.reduce_sum(tf.cast(tf.logical_and(tf.equal(y_true, False), tf.equal(y_pred, True)), self.dtype))
        fn = tf.reduce_sum(tf.cast(tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, False)), self.dtype))

        self.true_positives.assign_add(tp)
        self.true_negatives.assign_add(tn)
        self.false_positives.assign_add(fp)
        self.false_negatives.assign_add(fn)

    def result(self):
        numerator = self.true_positives * self.true_negatives - self.false_positives * self.false_negatives
        denominator = tf.sqrt(
            (self.true_positives + self.false_positives) *
            (self.true_positives + self.false_negatives) *
            (self.true_negatives + self.false_positives) *
            (self.true_negatives + self.false_negatives)
        )
        mcc = numerator / (denominator + K.epsilon())  # 
        return mcc

    def reset_state(self):
        self.true_positives.assign(0)
        self.true_negatives.assign(0)
        self.false_positives.assign(0)
        self.false_negatives.assign(0)

def write_metrics(history,file):
    with open(file,'w') as f:
        for k,v in history.history.items():
            v_new=','.join([str(round(x,4)) for x in v])
            line=k+','+v_new+'\n'
            f.write(line)

In [None]:
#Loading dataset
#
cell='GM12878'
tranf='CTCF'

BUFFER_SIZE = 20000
BATCH_SIZE = 32

path_peak='data/training'
path_cell=path_peak+'/'+cell
path_tf=path_cell+'/'+tranf
file_train_set=path_tf+'/'+'train_set.tfrecord'
file_test_set=path_tf+'/'+'test_set.tfrecord'

vocab_file='data/vocab.txt'
tokenizer = text.BertTokenizer(vocab_file, token_out_type=tf.int64)

train_ds = tf.data.TFRecordDataset([file_train_set]).map(_parse_function)
test_ds = tf.data.TFRecordDataset([file_test_set]).map(_parse_function)

train_batches=make_batches(train_ds,shuffle=True)
test_batches=make_batches(test_ds,shuffle=True)


In [None]:
# Initialize EIformer
seed = 1234
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

vocab_size=3**7
d_model = 32
len_motif=12
dff = 128
num_heads = 1
num_layers = 1
dropout_rate = 0.1
epiformer = EIformer(num_layers=num_layers,d_model=d_model,num_heads=num_heads,dff=dff,vocab_size=vocab_size,len_motif=len_motif,dropout_rate=dropout_rate)

dna_in=tf.keras.Input((200,4))
epi_in=tf.keras.Input((200,))

_=epiformer((dna_in,epi_in))
epiformer.summary()


In [None]:
#Training and saving model
#
LEARNING_RATE=0.001
# ITERATIONs=10
EPOCHs=10

current_time=datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

out_dir=cell+'/'+tranf+'/' + current_time
checkpoint_dir = 'data/logs/'+out_dir
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir+'/'+'model_weights_{epoch:02d}.h5',
                                                 save_weights_only=True,
                                                 verbose=1)

epiformer.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
                  loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),  #
                  metrics=[tf.keras.metrics.BinaryAccuracy(threshold=0.),
                           tf.keras.metrics.Recall(thresholds=0.),
                           tf.keras.metrics.Precision(thresholds=0.),
                           tf.keras.metrics.AUC(from_logits=True),
                           tf.keras.metrics.AUC(from_logits=True, curve='PR'),
                           MatthewsCorrelationCoefficient(threshold=0.)
                          ]  #
                 )
history=epiformer.fit(x=train_batches,
                      validation_data=test_batches,
                      epochs=EPOCHs,
                      callbacks=[cp_callback,],  #tensorboard_callback
                      verbose=2,
                     )
#
write_metrics(history,checkpoint_dir+'/model_metrics.csv')
