## NeuroAlign - Training



In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import random
import os
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import CSVLogger
import Model as model
import Data as data


GPUS = tf.config.experimental.list_logical_devices('GPU')
NUM_DEVICES = max(1, len(GPUS))

if len(GPUS) > 0:
    print("Using ", NUM_DEVICES, " GPU devices.")
else:
    print("Using CPU.")

Using  2  GPU devices.


In [2]:
NUM_EPOCHS = 200
NAME = "base2"
MODEL_PATH = "./models/" + NAME
CHECKPOINT_PATH = MODEL_PATH + "/model.ckpt"

os.makedirs(MODEL_PATH, exist_ok=True)

##################################################################################################
##################################################################################################
neuroalign, neuroalign_config = model.make_neuro_align_model(NAME)

KeyboardInterrupt: 

In [None]:
#Pfam protein families have identifiers of the form PF00001, PF00002, ...
#The largest id is PF19227, but the counting is not contiguous, there may be missing numbers
pfam = ["PF"+"{0:0=5d}".format(i) for i in range(1,19228)]
pfam_not_found = 0

fasta = []

for i,file in enumerate(pfam):
    try:
        f = data.Fasta("../Pfam/alignments/" + file + ".fasta", gaps = True, contains_lower_case = True)
        fasta.append(f)
        for x in range(1,10):
            if i/len(pfam) > x/10 and (i-1)/len(pfam) < x/10:
                print(x*10, "% loaded")
                gc.collect()
    except:
        pfam_not_found += 1

np.random.seed(0)
random.seed(0)

indices = np.arange(len(fasta))
np.random.shuffle(indices)
if len(fasta) > 10:
    print("Using the full dataset.")
    train, val = np.split(indices, [int(len(fasta)*(1-neuroalign_config["validation_split"]))]) 
    train_gen = data.AlignmentSampleGenerator(train, fasta, neuroalign_config, neuroalign_config["family_size"], NUM_DEVICES)
    val_gen = data.AlignmentSampleGenerator(val, fasta, neuroalign_config, 2*neuroalign_config["family_size"], NUM_DEVICES, False)
else: 
    print("Using a small test dataset.")
    train_gen = data.AlignmentSampleGenerator(np.arange(len(fasta)), fasta, neuroalign_config, neuroalign_config["family_size"], NUM_DEVICES)
    val_gen = data.AlignmentSampleGenerator(np.arange(len(fasta)), fasta, neuroalign_config, 2*neuroalign_config["family_size"], NUM_DEVICES, False) 

In [None]:
INPUT_DIM = 28

COLUMN_LOSS_WEIGHT = 0.02
ATTENTION_LOSS_WEIGHT = 0.98
SEQUENCE_LOSS_WEIGHT = 1

POS_WEIGHT = 1
NEG_WEIGHT = 1

##################################################################################################
##################################################################################################

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = tf.cast(neuroalign_config["col_dim"], tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


##################################################################################################
##################################################################################################

optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

##################################################################################################
##################################################################################################

#loss for aligned aminoacid pairs (= attention)

bce = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
     
def make_mask(y_true):
    mask = tf.math.equal(y_true, 0)
    mask = tf.math.reduce_all(mask, axis=-1)
    mask = tf.cast(tf.math.logical_not(mask), y_true.dtype)
    return mask

def make_sq(y, mask):
    y = tf.boolean_mask(y, mask)
    y_sq = tf.matmul(y, y, transpose_b=True)
    y_sq = tf.reshape(y_sq, (-1, 1))
    y_sq = tf.clip_by_value(y_sq, 0.0, 1.0)
    return y_sq

def att_loss(y_true, y_pred):
    mask = make_mask(y_true)
    y_true_sq = make_sq(y_true, mask)
    y_pred_sq = make_sq(y_pred, mask)
    l = tf.expand_dims(bce(y_true_sq, y_pred_sq), -1)
    w = POS_WEIGHT * y_true_sq + NEG_WEIGHT * (1-y_true_sq)
    l *= w
    return tf.reduce_sum(l) / tf.reduce_sum(w)

##################################################################################################
##################################################################################################

#loss for sequence reconstruction from columns (unsupervised)

ce = keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)

def seq_loss(y_true, y_pred):
    mask = make_mask(y_true)
    l = ce(y_true, y_pred) * mask
    return tf.math.reduce_sum(l) / tf.math.reduce_sum(mask)

##################################################################################################
##################################################################################################

#precision and recall metrics for aligned aminoacid pairs

threshold = 0.5

def precision(y_true, y_pred):
    mask = make_mask(y_true)
    y_true_sq = make_sq(y_true, mask)
    y_pred_sq = make_sq(y_pred, mask)
    positives = tf.cast(y_pred_sq >= threshold, tf.float32) 
    true_positives = positives * y_true_sq
    precision = tf.reduce_sum(true_positives) / tf.math.maximum(tf.reduce_sum(positives), 1.0)
    return precision

def recall(y_true, y_pred):
    mask = make_mask(y_true)
    y_true_sq = make_sq(y_true, mask)
    y_pred_sq = make_sq(y_pred, mask)
    positives = tf.cast(y_pred_sq >= threshold, tf.float32)
    true_positives = positives * y_true_sq
    recall = tf.reduce_sum(true_positives) / tf.math.maximum(tf.reduce_sum(y_true_sq), 1.0)
    return recall

#categorical accuracy for reconstructed sequences

def categorical_accuracy(y_true, y_pred):
    mask = make_mask(y_true)
    acc = tf.equal(tf.argmax(y_true, axis=-1), tf.argmax(y_pred, axis=-1)) 
    acc = tf.cast(acc, dtype=y_true.dtype)
    return tf.math.reduce_sum(acc * mask) / tf.math.reduce_sum(mask)
    

##################################################################################################
##################################################################################################


def losses_prefixed(losses, metrics, weights, prefix=""):
    if neuroalign_config["use_column_loss"]:
        losses.update({prefix+"out_columns" : tf.keras.losses.KLDivergence()})
        weights.update({prefix+"out_columns" : COLUMN_LOSS_WEIGHT})
    if neuroalign_config["use_attention_loss"]:
        losses.update({prefix+"out_attention" : att_loss})
        metrics.update({prefix+"out_attention" : [precision, recall]})
        weights.update({prefix+"out_attention" : ATTENTION_LOSS_WEIGHT})
        

losses, metrics, weights = {}, {}, {}
if NUM_DEVICES == 1:
    model = neuroalign
    losses_prefixed(losses, metrics, weights)
else:
    inputs, outputs = [], []
    for i, gpu in enumerate(GPUS):
        with tf.device(gpu.name):
            sequences = keras.Input(shape=(None,INPUT_DIM), name="GPU_"+str(i)+"_sequences")
            columns = keras.Input(shape=(INPUT_DIM), name="GPU_"+str(i)+"_in_columns")
            input_dict = {  "sequences" : sequences,
                            "in_columns" : columns }
            out_cols, A = neuroalign(input_dict)
            outputs.append(layers.Lambda(lambda x: x, name="GPU_"+str(i)+"_out_columns")(out_cols))
            outputs.append(layers.Lambda(lambda x: x, name="GPU_"+str(i)+"_out_attention")(A))
            inputs.extend([sequences, columns])

    model = keras.Model(inputs=inputs, outputs=outputs)
    for i, gpu in enumerate(GPUS):
        losses_prefixed(losses, metrics, weights, "GPU_"+str(i)+"_")

model.compile(loss=losses, optimizer=optimizer, metrics=metrics, loss_weights=weights)
    
class ModelCheckpoint(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        neuroalign.save_weights(CHECKPOINT_PATH)
        print("Saved model to " + CHECKPOINT_PATH, flush=True)

csv_logger = CSVLogger(MODEL_PATH + "/log.csv", append=True, separator=',')

history = model.fit(train_gen,
                    validation_data=val_gen,
                    epochs = NUM_EPOCHS,
                    verbose = 1,
                    callbacks=[ModelCheckpoint(), csv_logger])

Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
 2720/17346 [===>..........................] - ETA: 28:07 - loss: 0.0475 - GPU_0_out_columns_loss: 0.6914 - GPU_0_out_attention_loss: 0.0102 - GPU_1_out_columns_loss: 0.5772 - GPU_1_out_attention_loss: 0.0123

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



 3347/17346 [====>.........................] - ETA: 27:24 - loss: 0.0450 - GPU_0_out_columns_loss: 0.6625 - GPU_0_out_attention_loss: 0.0095 - GPU_1_out_columns_loss: 0.5550 - GPU_1_out_attention_loss: 0.0116

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)





IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 18/200
  653/17346 [>.............................] - ETA: 32:43 - loss: 0.0444 - GPU_0_out_columns_loss: 0.6544 - GPU_0_out_attention_loss: 0.0094 - GPU_1_out_columns_loss: 0.5377 - GPU_1_out_attention_loss: 0.0115