# Setup

In [None]:
!pip install wget pandas numpy matplotlib scikit-learn scipy
!pip install tensorflow==2.7
!pip install tf-models-official==2.7
!pip install transformers
# download the resources
!python -m wget https://zuchnerlab.s3.amazonaws.com/VariantPathogenicity/Maverick_resources.tar.gz
!tar -zxvf Maverick_resources.tar.gz

In [None]:
import os
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.callbacks import Callback
import numpy as np
import official.nlp
import official.nlp.keras_nlp.layers
import tensorflow_addons as tfa
from transformers import TFT5EncoderModel, T5Tokenizer,T5Config
import pandas
pandas.options.mode.chained_assignment = None
from sklearn.preprocessing import QuantileTransformer
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from sklearn.utils import resample
import scipy
from scipy.stats import rankdata
from datetime import datetime


In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, list_IDs, labels, dataFrameIn, tokenizer, T5Model, batch_size=32, padding=100, n_channels_emb=1024, n_channels_mm=51, n_classes=3, returnStyle=1, shuffle=True):
        self.padding = padding
        self.dim = self.padding + self.padding + 1
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_channels_emb = n_channels_emb
        self.n_channels_mm = n_channels_mm
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.dataFrameIn=dataFrameIn
        self.tokenizer = tokenizer
        self.T5Model = T5Model
        self.returnStyle = returnStyle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        if (len(self.list_IDs) % self.batch_size) == 0:
            return int(np.floor(len(self.list_IDs) / self.batch_size))
        else:
            return int(np.ceil(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        if (((len(self.list_IDs) % self.batch_size) != 0) & (((index+1)*self.batch_size)>len(self.list_IDs))):
            indexes = self.indexes[index*self.batch_size:]
        else:
            indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples'
        # Initialization
        thisBatchSize=len(list_IDs_temp)
        altEmbeddings=np.zeros((thisBatchSize, self.dim, self.n_channels_emb))
        mm_alt=np.zeros((thisBatchSize, self.dim, self.n_channels_mm))
        mm_orig=np.zeros((thisBatchSize, self.dim, self.n_channels_mm))
        nonSeq=np.zeros((thisBatchSize, 12))
        y = np.empty((thisBatchSize), dtype=int)
        AMINO_ACIDS = {'A':0,'C':1,'D':2,'E':3,'F':4,'G':5,'H':6,'I':7,'K':8,'L':9,'M':10,'N':11,'P':12,'Q':13,'R':14,'S':15,'T':16,'V':17,'W':18,'Y':19} 

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            if self.returnStyle==2:
                # process Alt seq with T5 model to create embeddings
                T5AltSeqTokens=[]
                transcriptID=self.dataFrameIn.loc[ID,'TranscriptID']
                changePos=self.dataFrameIn.loc[ID,'ChangePos']-1
                if changePos<0:
                    changePos=0
                AltSeq=self.dataFrameIn.loc[ID,'AltSeq']
                if AltSeq[-1]!="*":
                    AltSeq=AltSeq + "*"
                seqLenAlt=len(AltSeq)-1
                startPos=0
                if changePos>self.padding:
                    if (changePos+self.padding)<seqLenAlt:
                        startPos=changePos-self.padding
                    elif seqLenAlt>=self.dim:
                        startPos=seqLenAlt-self.dim
                endPos=changePos+self.padding
                if changePos<self.padding:
                    if self.dim<seqLenAlt:
                        endPos=self.dim
                    else:
                        endPos=seqLenAlt
                elif (changePos+self.padding)>=seqLenAlt:
                    endPos=seqLenAlt
                T5AltSeqTokens.append(" ".join(AltSeq[startPos:endPos]))
                # prep the WT seq too
                WTSeq=self.dataFrameIn.loc[ID,'WildtypeSeq']
                if WTSeq[-1]!="*":
                    WTSeq=WTSeq + "*"
                seqLen=len(WTSeq)-1
                startPos=0
                if changePos>self.padding:
                    if (changePos+self.padding)<seqLen:
                        startPos=int(changePos-self.padding)
                    elif seqLen>=self.dim:
                        startPos=int(seqLen-self.dim)
                endPos=int(changePos+self.padding)
                if changePos<self.padding:
                    if self.dim<seqLen:
                        endPos=int(self.dim)
                    else:
                        endPos=int(seqLen)
                elif (changePos+self.padding)>=seqLen:
                    endPos=int(seqLen)
                T5AltSeqTokens.append(" ".join(WTSeq[startPos:endPos]))

                # process the altSeq and wtSeq through the T5 tokenizer (for consistency with pre-computed data used for training)
                allTokens=self.tokenizer.batch_encode_plus(T5AltSeqTokens,add_special_tokens=True, padding=True, return_tensors="tf")
                input_ids=allTokens['input_ids']
                # but only process the altSeq through the T5 model
                input_ids=tf.expand_dims(input_ids[0],0)
                embeddings=self.T5Model(input_ids)
                allEmbeddings=np.asarray(embeddings.last_hidden_state)
                seq_len = (np.asarray(allTokens['attention_mask'])[0] == 1).sum()
                seq_emb = allEmbeddings[0][1:seq_len-1]
                altEmbeddings[i,:seq_emb.shape[0],:]=seq_emb

            # collect MMSeqs WT info
            tmp=np.load("HHMFiles/" + transcriptID + "_MMSeqsProfile.npz",allow_pickle=True)
            tmp=tmp['arr_0']
            seqLen=tmp.shape[0]
            startPos=changePos-self.padding
            endPos=changePos+self.padding + 1
            startOffset=0
            endOffset=self.dim
            if changePos<self.padding:
                startPos=0
                startOffset=self.padding-changePos
            if (changePos + self.padding) >= seqLen:
                endPos=seqLen
                endOffset=self.padding + seqLen - changePos
            mm_orig[i,startOffset:endOffset,:] = tmp[startPos:endPos,:]

            # collect MMSeqs Alt info
            # change the amino acid at 'ChangePos' and any after that if needed
            varType=self.dataFrameIn.loc[ID,'varType']
            WTSeq=self.dataFrameIn.loc[ID,'WildtypeSeq']
            if varType=='nonsynonymous SNV':
                if changePos==0:
                    # then this transcript is ablated
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    altEncoded[:,0:20]=0
                    altEncoded[:,50]=0
                else:
                    # change the single amino acid
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    altEncoded[changePos,AMINO_ACIDS[WTSeq[changePos]]]=0
                    altEncoded[changePos,AMINO_ACIDS[AltSeq[changePos]]]=1
            elif varType=='stopgain':
                if changePos==0:
                    # then this transcript is ablated
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    altEncoded[:,0:20]=0
                    altEncoded[:,50]=0
                elif seqLenAlt>seqLen:
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLen):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    for j in range(seqLen,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLen:,50]=1
                else:
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    altEncoded[changePos:,0:20]=0
                    altEncoded[changePos:,50]=0
            elif varType=='stoploss':
                altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                altEncoded[:seqLen,:]=tmp
                for j in range(seqLen,seqLenAlt):
                    altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                altEncoded[seqLen:,50]=1
            elif varType=='synonymous SNV':
                # no change
                altEncoded=tmp
            elif ((varType=='frameshift deletion') | (varType=='frameshift insertion') | (varType=='frameshift substitution')):
                if seqLen<seqLenAlt:
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLen):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    for j in range(seqLen,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLen:,50]=1
                elif seqLen>seqLenAlt:
                    for j in range(changePos,seqLenAlt):
                        tmp[j,AMINO_ACIDS[WTSeq[j]]]=0
                        tmp[j,AMINO_ACIDS[AltSeq[j]]]=1
                    for j in range(seqLenAlt,seqLen):
                        tmp[j,AMINO_ACIDS[WTSeq[j]]]=0
                    altEncoded=tmp
                elif seqLen==seqLenAlt:
                    for j in range(changePos,seqLen):
                        tmp[j,AMINO_ACIDS[WTSeq[j]]]=0
                        tmp[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded=tmp
                else:
                    print('Error: seqLen comparisons did not work')
                    exit()
            elif varType=='nonframeshift deletion':
                # how many amino acids deleted?
                altNucLen=0
                if self.dataFrameIn.loc[ID,'alt']!='-':
                    altNucLen=len(self.dataFrameIn.loc[ID,'alt'])
                refNucLen=len(self.dataFrameIn.loc[ID,'ref'])
                numAADel=int((refNucLen-altNucLen)/3)
                if (seqLen-numAADel)==seqLenAlt:
                    # non-frameshift deletion
                    #altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    #altEncoded[:changePos,:]=tmp[:changePos,:]
                    #altEncoded[changePos:,:]=tmp[(changePos+numAADel):,:]
                    for j in range(changePos,(changePos+numAADel)):
                        tmp[j,:20]=0
                    altEncoded=tmp
                elif seqLen>=seqLenAlt:
                    # early truncation
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    #for j in range(seqLenAlt,seqLen):
                    #    altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                    altEncoded[seqLenAlt:,0:20]=0
                    altEncoded[seqLenAlt:,50]=0
                elif seqLen<seqLenAlt:
                    # deletion causes stop-loss
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLen):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    #for j in range(seqLen,seqLenAlt):
                    #    altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLen:,0:20]=0
                    altEncoded[seqLen:,50]=0
                else:
                    print('Error: seqLen comparisons did not work for nonframeshift deletion')
                    exit()
            elif varType=='nonframeshift insertion':
                # how many amino acids inserted?
                refNucLen=0
                if self.dataFrameIn.loc[ID,'ref']!='-':
                    altNucLen=len(self.dataFrameIn.loc[ID,'ref'])
                altNucLen=len(self.dataFrameIn.loc[ID,'alt'])
                numAAIns=int((altNucLen-refNucLen)/3)
                if (seqLen+numAAIns)==seqLenAlt:
                    # non-frameshift insertion
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:changePos,:]=tmp[:changePos,:]
                    altEncoded[(changePos+numAAIns):,:]=tmp[changePos:,:]
                    for j in range(numAAIns):
                        altEncoded[(changePos+j),AMINO_ACIDS[AltSeq[(changePos+j)]]]=1
                    altEncoded[:,50]=1
                elif seqLen<seqLenAlt:
                    # stop loss
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLen):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    for j in range(seqLen,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLen:,50]=1
                elif seqLen>=seqLenAlt:
                    # stop gain
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLenAlt:,0:20]=0
                    altEncoded[seqLenAlt:,50]=0
                else:
                    print('Error: seqLen comparisons did not work for nonframeshift insertion')
                    exit()
            elif varType=='nonframeshift substitution':
                # is this an insertion or a deletion?
                # note that there will not be any '-' symbols in these ref or alt fields because it is a substitution
                refNucLen=len(self.dataFrameIn.loc[ID,'ref'])
                altNucLen=len(self.dataFrameIn.loc[ID,'alt'])
                if refNucLen>altNucLen:
                    # deletion
                    # does this cause an early truncation or non-frameshift deletion?
                    if seqLen>seqLenAlt: 
                        numAADel=int((refNucLen-altNucLen)/3)
                        if (seqLen-numAADel)==seqLenAlt:
                            # non-frameshift deletion
                            #altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                            #altEncoded[:changePos,:]=tmp[:changePos,:]
                            #altEncoded[changePos:,:]=tmp[(changePos+numAADel):,:]
                            for j in range(changePos,(changePos+numAADel)):
                                tmp[j,:20]=0
                            altEncoded=tmp
                        else:
                            # early truncation
                            altEncoded=np.zeros((seqLen,self.n_channels_mm))
                            altEncoded[:seqLen,:]=tmp
                            for j in range(changePos,seqLenAlt):
                                altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                                altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                            #for j in range(seqLenAlt,seqLen):
                            #    altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[seqLenAlt:,0:20]=0
                            altEncoded[seqLenAlt:,50]=0
                    # does this cause a stop loss?
                    elif seqLen<seqLenAlt:
                        altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLen):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        for j in range(seqLen,seqLenAlt):
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        altEncoded[seqLen:,50]=1
                    else: # not sure how this would happen
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLen):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                elif refNucLen<altNucLen:
                    # insertion
                    # does this cause a stop loss or non-frameshift insertion?
                    if seqLen<seqLenAlt: 
                        numAAIns=int((altNucLen-refNucLen)/3)
                        if (seqLen+numAAIns)==seqLenAlt:
                            # non-frameshift insertion
                            altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                            altEncoded[:changePos,:]=tmp[:changePos,:]
                            altEncoded[(changePos+numAAIns):,:]=tmp[changePos:,:]
                            for j in range(numAAIns):
                                altEncoded[(changePos+j),AMINO_ACIDS[AltSeq[(changePos+j)]]]=1
                            altEncoded[:,50]=1
                        else:
                            # stop loss
                            altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                            altEncoded[:seqLen,:]=tmp
                            for j in range(changePos,seqLen):
                                altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                                altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                            for j in range(seqLen,seqLenAlt):
                                altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                            altEncoded[:,50]=1
                    # does this cause an early truncation?
                    elif seqLen>seqLenAlt: 
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLenAlt):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        altEncoded[seqLenAlt:,0:20]=0
                        #for j in range(seqLenAlt,seqLen):
                        #    altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[seqLenAlt:,50]=0
                    else: # not sure how this would happen
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLen):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                elif refNucLen==altNucLen:
                    if seqLen==seqLenAlt:
                        # synonymous or nonsynonymous change
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        altEncoded[changePos,AMINO_ACIDS[WTSeq[changePos]]]=0
                        altEncoded[changePos,AMINO_ACIDS[AltSeq[changePos]]]=1
                    elif seqLen>seqLenAlt:
                        # early truncation
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLenAlt):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        altEncoded[seqLenAlt:,0:20]=0
                        #for j in range(seqLenAlt,seqLen):
                        #    altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[seqLenAlt:,50]=0
                    elif seqLen<seqLenAlt:
                        # stop loss
                        altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLen):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        for j in range(seqLen,seqLenAlt):
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        altEncoded[seqLen:,50]=1
                    else:
                        print('non-frameshift substitution comparisons failed')
                        exit()
                else:
                    print('Error: nonframeshift substitution nucleotide length comparison did not work')
                    exit()
            startPos=changePos-self.padding
            endPos=changePos+self.padding+1
            startOffset=0
            endOffset=self.dim
            if changePos<self.padding:
                startPos=0
                startOffset=self.padding-changePos
            if (changePos + self.padding) >= seqLenAlt:
                endPos=seqLenAlt
                endOffset=self.padding + seqLenAlt - changePos
            # exception to deal with start loss SNVs that create new frameshifted products longer than the original protein (when original was shorter than padding length)
            if ((changePos==0) & (self.padding>=seqLen) & (seqLen<seqLenAlt) & (varType=='nonsynonymous SNV')):
                endPos=seqLen
                endOffset=self.padding + seqLen - changePos
            elif ((changePos==0) & (varType=='stopgain')): # related exception for stopgains at position 0
                if (seqLen+self.padding)<=self.dim:
                    endPos=seqLen
                    endOffset=self.padding + seqLen - changePos
                else:
                    endPos=self.padding+1
                    endOffset=self.dim
            mm_alt[i,startOffset:endOffset,:] = altEncoded[startPos:endPos,:]


            # non-seq info
            nonSeq[i] = self.dataFrameIn.loc[ID,['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']]
            
            # Store class
            y[i] = self.labels[ID]

        X={'mm_orig_seq':mm_orig,'mm_alt_seq':mm_alt,'non_seq_info':nonSeq}
        if self.returnStyle==2:
            X={'alt_cons':mm_alt,'alt_emb':altEmbeddings,'non_seq_info':nonSeq}

        return X, keras.utils.to_categorical(y, num_classes=self.n_classes)


In [None]:
def MaverickArchitecture1(input_shape=201,classes=3,classifier_activation='softmax',**kwargs):
    input0 = tf.keras.layers.Input(shape=(input_shape,51),name='mm_orig_seq')
    input1 = tf.keras.layers.Input(shape=(input_shape,51),name='mm_alt_seq')
    input2 = tf.keras.layers.Input(shape=12,name='non_seq_info')

    # project input to an embedding size that is easier to work with
    x_orig = tf.keras.layers.experimental.EinsumDense('...x,xy->...y',output_shape=64,bias_axes='y')(input0)
    x_alt = tf.keras.layers.experimental.EinsumDense('...x,xy->...y',output_shape=64,bias_axes='y')(input1)

    posEnc_wt = official.nlp.keras_nlp.layers.PositionEmbedding(max_length=input_shape)(x_orig)
    x_orig = tf.keras.layers.Masking()(x_orig)
    x_orig = tf.keras.layers.Add()([x_orig,posEnc_wt])
    x_orig = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12,dtype=tf.float32)(x_orig)
    x_orig = tf.keras.layers.Dropout(0.05)(x_orig)

    posEnc_alt = official.nlp.keras_nlp.layers.PositionEmbedding(max_length=input_shape)(x_alt)
    x_alt = tf.keras.layers.Masking()(x_alt)
    x_alt = tf.keras.layers.Add()([x_alt,posEnc_alt])
    x_alt = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12,dtype=tf.float32)(x_alt)
    x_alt = tf.keras.layers.Dropout(0.05)(x_alt)

    transformer1 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer2 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer3 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer4 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer5 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer6 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    
    x_orig = transformer1(x_orig)
    x_orig = transformer2(x_orig)
    x_orig = transformer3(x_orig)
    x_orig = transformer4(x_orig)
    x_orig = transformer5(x_orig)
    x_orig = transformer6(x_orig)
    
    x_alt = transformer1(x_alt)
    x_alt = transformer2(x_alt)
    x_alt = transformer3(x_alt)
    x_alt = transformer4(x_alt)
    x_alt = transformer5(x_alt)
    x_alt = transformer6(x_alt)

    first_token_tensor_orig = (tf.keras.layers.Lambda(lambda a: tf.squeeze(a[:, 100:101, :], axis=1))(x_orig))
    x_orig = tf.keras.layers.Dense(units=64,activation='tanh')(first_token_tensor_orig)
    x_orig = tf.keras.layers.Dropout(0.05)(x_orig)

    first_token_tensor_alt = (tf.keras.layers.Lambda(lambda a: tf.squeeze(a[:, 100:101, :], axis=1))(x_alt))
    x_alt = tf.keras.layers.Dense(units=64,activation='tanh')(first_token_tensor_alt)
    x_alt = tf.keras.layers.Dropout(0.05)(x_alt)

    diff = tf.keras.layers.Subtract()([x_alt,x_orig])
    combined = tf.keras.layers.concatenate([x_alt,diff])

    input2Dense1 = tf.keras.layers.Dense(64,activation='relu')(input2)
    input2Dense1 = tf.keras.layers.Dropout(0.05)(input2Dense1)
    x = tf.keras.layers.concatenate([combined,input2Dense1])
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(512,activation='relu')(x)
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(64,activation='relu')(x)
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(classes, activation=classifier_activation,name='output')(x)
    model = tf.keras.Model(inputs=[input0,input1,input2],outputs=x)

    optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3, momentum=0.85)
    model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])
    return model


In [None]:
def MaverickArchitecture2(input_shape=201,embeddingSize=1024,mmSize=51,classes=3,classifier_activation='softmax',**kwargs):
    input0 = tf.keras.layers.Input(shape=(input_shape,mmSize),name='alt_cons')
    input1 = tf.keras.layers.Input(shape=(input_shape,embeddingSize),name='alt_emb')
    input2 = tf.keras.layers.Input(shape=12,name='non_seq_info')

    # project input to an embedding size that is easier to work with
    alt_cons = tf.keras.layers.experimental.EinsumDense('...x,xy->...y',output_shape=64,bias_axes='y')(input0)

    posEnc_alt = official.nlp.keras_nlp.layers.PositionEmbedding(max_length=input_shape)(alt_cons)
    alt_cons = tf.keras.layers.Masking()(alt_cons)
    alt_cons = tf.keras.layers.Add()([alt_cons,posEnc_alt])
    alt_cons = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12,dtype=tf.float32)(alt_cons)
    alt_cons = tf.keras.layers.Dropout(0.05)(alt_cons)

    transformer1 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer2 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer3 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer4 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer5 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer6 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    
    alt_cons = transformer1(alt_cons)
    alt_cons = transformer2(alt_cons)
    alt_cons = transformer3(alt_cons)
    alt_cons = transformer4(alt_cons)
    alt_cons = transformer5(alt_cons)
    alt_cons = transformer6(alt_cons)

    first_token_tensor_alt = (tf.keras.layers.Lambda(lambda a: tf.squeeze(a[:, 100:101, :], axis=1))(alt_cons))
    alt_cons = tf.keras.layers.Dense(units=64,activation='tanh')(first_token_tensor_alt)
    alt_cons = tf.keras.layers.Dropout(0.05)(alt_cons)

    sharedLSTM1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32, return_sequences=False, dropout=0.5))

    alt_emb=sharedLSTM1(input1)
    alt_emb=tf.keras.layers.Dropout(0.2)(alt_emb)

    structured = tf.keras.layers.Dense(64,activation='relu')(input2)
    structured = tf.keras.layers.Dropout(0.05)(structured)
    x = tf.keras.layers.concatenate([alt_cons,alt_emb,structured])
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(512,activation='relu')(x)
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(64,activation='relu')(x)
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(classes, activation=classifier_activation,name='output')(x)
    model = tf.keras.Model(inputs=[input0,input1,input2],outputs=x)

    optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3, momentum=0.85)
    model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])
    return model


In [None]:
class CosineAnnealer:
    
    def __init__(self, start, end, steps):
        self.start = start
        self.end = end
        self.steps = steps
        self.n = 0
        
    def step(self):
        self.n += 1
        cos = np.cos(np.pi * (self.n / self.steps)) + 1
        return self.end + (self.start - self.end) / 2. * cos


class OneCycleScheduler(Callback):
    """ `Callback` that schedules the learning rate on a 1cycle policy as per Leslie Smith's paper(https://arxiv.org/pdf/1803.09820.pdf).
    If the model supports a momentum parameter, it will also be adapted by the schedule.
    The implementation adopts additional improvements as per the fastai library: https://docs.fast.ai/callbacks.one_cycle.html, where
    only two phases are used and the adaptation is done using cosine annealing.
    In phase 1 the LR increases from `lr_max / div_factor` to `lr_max` and momentum decreases from `mom_max` to `mom_min`.
    In the second phase the LR decreases from `lr_max` to `lr_max / (div_factor * 1e4)` and momemtum from `mom_max` to `mom_min`.
    By default the phases are not of equal length, with the phase 1 percentage controlled by the parameter `phase_1_pct`.
    """

    def __init__(self, lr_max, steps, mom_min=0.85, mom_max=0.95, phase_1_pct=0.3, div_factor=25.):
        super(OneCycleScheduler, self).__init__()
        lr_min = lr_max / div_factor
        final_lr = lr_max / (div_factor * 1e4)
        phase_1_steps = steps * phase_1_pct
        phase_2_steps = steps - phase_1_steps
        
        self.phase_1_steps = phase_1_steps
        self.phase_2_steps = phase_2_steps
        self.phase = 0
        self.step = 0
        
        self.phases = [[CosineAnnealer(lr_min, lr_max, phase_1_steps), CosineAnnealer(mom_max, mom_min, phase_1_steps)], 
                 [CosineAnnealer(lr_max, final_lr, phase_2_steps), CosineAnnealer(mom_min, mom_max, phase_2_steps)]]
        
        self.lrs = []
        self.moms = []

    def on_train_begin(self, logs=None):
        self.phase = 0
        self.step = 0

        self.set_lr(self.lr_schedule().start)
        self.set_momentum(self.mom_schedule().start)
        
    def on_train_batch_begin(self, batch, logs=None):
        self.lrs.append(self.get_lr())
        self.moms.append(self.get_momentum())

    def on_train_batch_end(self, batch, logs=None):
        self.step += 1
        if self.step >= self.phase_1_steps:
            self.phase = 1
            
        self.set_lr(self.lr_schedule().step())
        self.set_momentum(self.mom_schedule().step())
        
    def get_lr(self):
        try:
            return tf.keras.backend.get_value(self.model.optimizer.lr)
        except AttributeError:
            return None
        
    def get_momentum(self):
        try:
            return tf.keras.backend.get_value(self.model.optimizer.momentum)
        except AttributeError:
            return None
        
    def set_lr(self, lr):
        try:
            tf.keras.backend.set_value(self.model.optimizer.lr, lr)
        except AttributeError:
            pass # ignore
        
    def set_momentum(self, mom):
        try:
            tf.keras.backend.set_value(self.model.optimizer.momentum, mom)
        except AttributeError:
            pass # ignore

    def lr_schedule(self):
        return self.phases[self.phase][0]
    
    def mom_schedule(self):
        return self.phases[self.phase][1]
    
    def plot(self):
        ax = plt.subplot(1, 2, 1)
        ax.plot(self.lrs)
        ax.set_title('Learning Rate')
        ax = plt.subplot(1, 2, 2)
        ax.plot(self.moms)
        ax.set_title('Momentum')



In [None]:
tokenizer = T5Tokenizer.from_pretrained("prot_t5_xl_bfd", do_lower_case=False,local_files_only=True)
T5Model = TFT5EncoderModel.from_pretrained("prot_t5_xl_bfd",local_files_only=True)

# calculate medians and quantiles from training data
trainingData=pandas.read_csv('trainingSet.txt',sep='\t',low_memory=False)
trainingData.loc[trainingData['GDI']>2000,'GDI']=2000
trainingDataNonSeqInfo=trainingData[['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']].copy(deep=True)
trainingDataNonSeqInfo.loc[trainingDataNonSeqInfo['controls_AF'].isna(),'controls_AF']=0
trainingDataNonSeqInfo.loc[trainingDataNonSeqInfo['controls_nhomalt'].isna(),'controls_nhomalt']=0
trainingDataNonSeqInfo.loc[trainingDataNonSeqInfo['controls_nhomalt']>10,'controls_nhomalt']=10
trainingDataNonSeqMedians=trainingDataNonSeqInfo.median()
trainingDataNonSeqInfo=trainingDataNonSeqInfo.fillna(trainingDataNonSeqMedians)
trainingDataNonSeqInfo=np.asarray(trainingDataNonSeqInfo.to_numpy()).astype(np.float32)

# scale columns by QT
qt = QuantileTransformer(subsample=1e6, random_state=0, output_distribution='uniform')
qt=qt.fit(trainingDataNonSeqInfo)
trainingDataNonSeqInfo=qt.transform(trainingDataNonSeqInfo)

# apply to validation set
validationData=pandas.read_csv('validationSet.txt',sep='\t',low_memory=False)
validationData.loc[validationData['GDI']>2000,'GDI']=2000
validationDataNonSeqInfo=validationData[['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']].copy(deep=True)
validationDataNonSeqInfo.loc[validationDataNonSeqInfo['controls_AF'].isna(),'controls_AF']=0
validationDataNonSeqInfo.loc[validationDataNonSeqInfo['controls_nhomalt'].isna(),'controls_nhomalt']=0
validationDataNonSeqInfo.loc[validationDataNonSeqInfo['controls_nhomalt']>10,'controls_nhomalt']=10
validationDataNonSeqInfo=validationDataNonSeqInfo.fillna(trainingDataNonSeqMedians)
validationDataNonSeqInfo=np.asarray(validationDataNonSeqInfo.to_numpy()).astype(np.float32)
validationDataNonSeqInfo=qt.transform(validationDataNonSeqInfo)
validationData.loc[:,['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']]=validationDataNonSeqInfo

# apply to known genes set
knownData=pandas.read_csv('knownGenes.txt',sep='\t',low_memory=False)
knownData.loc[knownData['GDI']>2000,'GDI']=2000
knownDataNonSeqInfo=knownData[['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']].copy(deep=True)
knownDataNonSeqInfo.loc[knownDataNonSeqInfo['controls_AF'].isna(),'controls_AF']=0
knownDataNonSeqInfo.loc[knownDataNonSeqInfo['controls_nhomalt'].isna(),'controls_nhomalt']=0
knownDataNonSeqInfo.loc[knownDataNonSeqInfo['controls_nhomalt']>10,'controls_nhomalt']=10
knownDataNonSeqInfo=knownDataNonSeqInfo.fillna(trainingDataNonSeqMedians)
knownDataNonSeqInfo=np.asarray(knownDataNonSeqInfo.to_numpy()).astype(np.float32)
knownDataNonSeqInfo=qt.transform(knownDataNonSeqInfo)
knownData.loc[:,['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']]=knownDataNonSeqInfo

# apply to novel genes set
novelData=pandas.read_csv('novelGenes.txt',sep='\t',low_memory=False)
novelData.loc[novelData['GDI']>2000,'GDI']=2000
novelDataNonSeqInfo=novelData[['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']].copy(deep=True)
novelDataNonSeqInfo.loc[novelDataNonSeqInfo['controls_AF'].isna(),'controls_AF']=0
novelDataNonSeqInfo.loc[novelDataNonSeqInfo['controls_nhomalt'].isna(),'controls_nhomalt']=0
novelDataNonSeqInfo.loc[novelDataNonSeqInfo['controls_nhomalt']>10,'controls_nhomalt']=10
novelDataNonSeqInfo=novelDataNonSeqInfo.fillna(trainingDataNonSeqMedians)
novelDataNonSeqInfo=np.asarray(novelDataNonSeqInfo.to_numpy()).astype(np.float32)
novelDataNonSeqInfo=qt.transform(novelDataNonSeqInfo)
novelData.loc[:,['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']]=novelDataNonSeqInfo


In [None]:
y_train=trainingData.loc[:,'classLabel'].to_numpy()
y_valid=validationData.loc[:,'classLabel'].to_numpy()
y_known=knownData.loc[:,'classLabel'].to_numpy()
y_novel=novelData.loc[:,'classLabel'].to_numpy()

# one-hot encode the class label
encoder=LabelEncoder()
encoder.fit(trainingData.loc[:,'classLabel'])
y_train_encoded=encoder.transform(trainingData.loc[:,'classLabel'])
y_train_encoded=tf.keras.utils.to_categorical(y_train_encoded).astype(int)
y_valid_encoded=encoder.transform(validationData.loc[:,'classLabel'])
y_valid_encoded=tf.keras.utils.to_categorical(y_valid_encoded).astype(int)
y_known_encoded=encoder.transform(knownData.loc[:,'classLabel'])
y_known_encoded=tf.keras.utils.to_categorical(y_known_encoded).astype(int)
y_novel_encoded=encoder.transform(novelData.loc[:,'classLabel'])
y_novel_encoded=tf.keras.utils.to_categorical(y_novel_encoded).astype(int)

# create data generators
training_generator1=DataGenerator(np.arange(len(trainingData)),y_train,dataFrameIn=trainingData,batch_size=batchSize,returnStyle=1,shuffle=True)
validation_generator1=DataGenerator(np.arange(len(validationData)),y_valid,dataFrameIn=validationData,batch_size=batchSize,returnStyle=1,shuffle=False)
known_generator1=DataGenerator(np.arange(len(knownData)),y_known,dataFrameIn=knownData,batch_size=batchSize,returnStyle=1,shuffle=False)
novel_generator1=DataGenerator(np.arange(len(novelData)),y_novel,dataFrameIn=novelData,batch_size=batchSize,returnStyle=1,shuffle=False)

training_generator2=DataGenerator(np.arange(len(trainingData)),y_train,dataFrameIn=trainingData,batch_size=batchSize,returnStyle=2,shuffle=True)
validation_generator2=DataGenerator(np.arange(len(validationData)),y_valid,dataFrameIn=validationData,batch_size=batchSize,returnStyle=2,shuffle=False)
known_generator2=DataGenerator(np.arange(len(knownData)),y_known,dataFrameIn=knownData,batch_size=batchSize,returnStyle=2,shuffle=False)
novel_generator2=DataGenerator(np.arange(len(novelData)),y_novel,dataFrameIn=novelData,batch_size=batchSize,returnStyle=,shuffle=False)


In [None]:
# Train Architecture 1, model 1
batchSize=128
numEpochs=20
lr_schedule=OneCycleScheduler(0.1,19660) # this needs to be the number of steps you will train for (# of variants in training set * numEpochs / batchSize)
modelWeightsName='weights_Architecture1_model1'
callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=modelWeightsName,save_weights_only=True,save_best_only=True,monitor='val_loss',verbose=1),lr_schedule]

Arc1Model1 = MaverickArchitecture1()
Arc1Model1.summary()
Arc1Model1.fit(training_generator1,epochs=numEpochs,callbacks=callbacks,validation_data=validation_generator1)
Arc1Model1.load_weights(modelWeightsName)

In [None]:
# Test Architecture 1, model 1
y_valid_pred=Arc1Model1.predict(validation_generator1)
print("Validation set performance")
print(classification_report(np.argmax(y_valid_encoded,axis=1).astype('int'), np.argmax(y_valid_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
validationData[['BenignScore','DomScore','RecScore']]=y_valid_pred
validationData.to_csv('validationSet_withArchitecture1Model1Scores.txt',sep='\t',index=False)

y_known_pred=Arc1Model1.predict(known_generator1)
print("Known genes set performance")
print(classification_report(np.argmax(y_known_encoded,axis=1).astype('int'), np.argmax(y_known_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
knownData[['BenignScore','DomScore','RecScore']]=y_known_pred
knownData.to_csv('knownGenes_withArchitecture1Model1Scores.txt',sep='\t',index=False)

y_novel_pred=Arc1Model1.predict(novel_generator1)
print("Novel genes set performance")
print(classification_report(np.argmax(y_novel_encoded,axis=1).astype('int'), np.argmax(y_novel_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
novelData[['BenignScore','DomScore','RecScore']]=y_novel_pred
novelData.to_csv('novelGenes_withArchitecture1Model1Scores.txt',sep='\t',index=False)



In [None]:
# Train Architecture 1, model 2
lr_schedule=OneCycleScheduler(0.1,19660)
modelWeightsName='weights_Architecture1_model2'
callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=modelWeightsName,save_weights_only=True,save_best_only=True,monitor='val_loss',verbose=1),lr_schedule]

Arc1Model2 = MaverickArchitecture1()
Arc1Model2.summary()
Arc1Model2.fit(training_generator1,epochs=numEpochs,callbacks=callbacks,validation_data=validation_generator1,class_weights={0:1,1:2,2:7})
Arc1Model2.load_weights(modelWeightsName)

In [None]:
# Test Architecture 1, model 2
y_valid_pred=Arc1Model2.predict(validation_generator1)
print("Validation set performance")
print(classification_report(np.argmax(y_valid_encoded,axis=1).astype('int'), np.argmax(y_valid_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
validationData[['BenignScore','DomScore','RecScore']]=y_valid_pred
validationData.to_csv('validationSet_withArchitecture1Model2Scores.txt',sep='\t',index=False)

y_known_pred=Arc1Model2.predict(known_generator1)
print("Known genes set performance")
print(classification_report(np.argmax(y_known_encoded,axis=1).astype('int'), np.argmax(y_known_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
knownData[['BenignScore','DomScore','RecScore']]=y_known_pred
knownData.to_csv('knownGenes_withArchitecture1Model2Scores.txt',sep='\t',index=False)

y_novel_pred=Arc1Model2.predict(novel_generator1)
print("Novel genes set performance")
print(classification_report(np.argmax(y_novel_encoded,axis=1).astype('int'), np.argmax(y_novel_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
novelData[['BenignScore','DomScore','RecScore']]=y_novel_pred
novelData.to_csv('novelGenes_withArchitecture1Model2Scores.txt',sep='\t',index=False)



In [None]:
# Train Architecture 1, model 3
lr_schedule=OneCycleScheduler(0.1,19660)
modelWeightsName='weights_Architecture1_model3'
callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=modelWeightsName,save_weights_only=True,save_best_only=True,monitor='val_loss',verbose=1),lr_schedule]

Arc1Model3 = MaverickArchitecture1()
Arc1Model3.summary()
Arc1Model3.fit(training_generator1,epochs=numEpochs,callbacks=callbacks,validation_data=validation_generator1,class_weights={0:1,1:2,2:7})
Arc1Model3.load_weights(modelWeightsName)

In [None]:
# Test Architecture 1, model 3
y_valid_pred=Arc1Model3.predict(validation_generator1)
print("Validation set performance")
print(classification_report(np.argmax(y_valid_encoded,axis=1).astype('int'), np.argmax(y_valid_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
validationData[['BenignScore','DomScore','RecScore']]=y_valid_pred
validationData.to_csv('validationSet_withArchitecture1Model3Scores.txt',sep='\t',index=False)

y_known_pred=Arc1Model3.predict(known_generator1)
print("Known genes set performance")
print(classification_report(np.argmax(y_known_encoded,axis=1).astype('int'), np.argmax(y_known_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
knownData[['BenignScore','DomScore','RecScore']]=y_known_pred
knownData.to_csv('knownGenes_withArchitecture1Model3Scores.txt',sep='\t',index=False)

y_novel_pred=Arc1Model3.predict(novel_generator1)
print("Novel genes set performance")
print(classification_report(np.argmax(y_novel_encoded,axis=1).astype('int'), np.argmax(y_novel_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
novelData[['BenignScore','DomScore','RecScore']]=y_novel_pred
novelData.to_csv('novelGenes_withArchitecture1Model3Scores.txt',sep='\t',index=False)



In [None]:
# Train Architecture 2, model 1
batchSize=16
lr_schedule=OneCycleScheduler(0.1,157280)
modelWeightsName='weights_Architecture2_model1'
callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=modelWeightsName,save_weights_only=True,save_best_only=True,monitor='val_loss',verbose=1),lr_schedule]

Arc2Model1 = MaverickArchitecture2()
Arc2Model1.summary()
Arc2Model1.fit(training_generator2,epochs=numEpochs,callbacks=callbacks,validation_data=validation_generator2)
Arc2Model1.load_weights(modelWeightsName)

In [None]:
# Test Architecture 2, model 1
y_valid_pred=Arc2Model1.predict(validation_generator2)
print("Validation set performance")
print(classification_report(np.argmax(y_valid_encoded,axis=1).astype('int'), np.argmax(y_valid_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
validationData[['BenignScore','DomScore','RecScore']]=y_valid_pred
validationData.to_csv('validationSet_withArchitecture2Model1Scores.txt',sep='\t',index=False)

y_known_pred=Arc2Model1.predict(known_generator2)
print("Known genes set performance")
print(classification_report(np.argmax(y_known_encoded,axis=1).astype('int'), np.argmax(y_known_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
knownData[['BenignScore','DomScore','RecScore']]=y_known_pred
knownData.to_csv('knownGenes_withArchitecture2Model1Scores.txt',sep='\t',index=False)

y_novel_pred=Arc2Model1.predict(novel_generator2)
print("Novel genes set performance")
print(classification_report(np.argmax(y_novel_encoded,axis=1).astype('int'), np.argmax(y_novel_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
novelData[['BenignScore','DomScore','RecScore']]=y_novel_pred
novelData.to_csv('novelGenes_withArchitecture2Model1Scores.txt',sep='\t',index=False)



In [None]:
# Train Architecture 2, model 2
lr_schedule=OneCycleScheduler(0.1,157280)
modelWeightsName='weights_Architecture2_model2'
callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=modelWeightsName,save_weights_only=True,save_best_only=True,monitor='val_loss',verbose=1),lr_schedule]

Arc2Model2 = MaverickArchitecture2()
Arc2Model2.summary()
Arc2Model2.fit(training_generator2,epochs=numEpochs,callbacks=callbacks,validation_data=validation_generator2)
Arc2Model2.load_weights(modelWeightsName)

In [None]:
# Test Architecture 2, model 2
y_valid_pred=Arc2Model2.predict(validation_generator2)
print("Validation set performance")
print(classification_report(np.argmax(y_valid_encoded,axis=1).astype('int'), np.argmax(y_valid_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
validationData[['BenignScore','DomScore','RecScore']]=y_valid_pred
validationData.to_csv('validationSet_withArchitecture2Model2Scores.txt',sep='\t',index=False)

y_known_pred=Arc2Model2.predict(known_generator2)
print("Known genes set performance")
print(classification_report(np.argmax(y_known_encoded,axis=1).astype('int'), np.argmax(y_known_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
knownData[['BenignScore','DomScore','RecScore']]=y_known_pred
knownData.to_csv('knownGenes_withArchitecture2Model2Scores.txt',sep='\t',index=False)

y_novel_pred=Arc2Model2.predict(novel_generator2)
print("Novel genes set performance")
print(classification_report(np.argmax(y_novel_encoded,axis=1).astype('int'), np.argmax(y_novel_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
novelData[['BenignScore','DomScore','RecScore']]=y_novel_pred
novelData.to_csv('novelGenes_withArchitecture2Model2Scores.txt',sep='\t',index=False)



In [None]:
# Train Architecture 2, model 3
lr_schedule=OneCycleScheduler(0.1,157280)
modelWeightsName='weights_Architecture2_model3'
callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=modelWeightsName,save_weights_only=True,save_best_only=True,monitor='val_loss',verbose=1),lr_schedule]

Arc2Model3 = MaverickArchitecture2()
Arc2Model3.summary()
Arc2Model3.fit(training_generator2,epochs=numEpochs,callbacks=callbacks,validation_data=validation_generator2)
Arc2Model3.load_weights(modelWeightsName)

In [None]:
# Test Architecture 2, model 3
y_valid_pred=Arc2Model3.predict(validation_generator2)
print("Validation set performance")
print(classification_report(np.argmax(y_valid_encoded,axis=1).astype('int'), np.argmax(y_valid_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
validationData[['BenignScore','DomScore','RecScore']]=y_valid_pred
validationData.to_csv('validationSet_withArchitecture2Model3Scores.txt',sep='\t',index=False)

y_known_pred=Arc2Model3.predict(known_generator2)
print("Known genes set performance")
print(classification_report(np.argmax(y_known_encoded,axis=1).astype('int'), np.argmax(y_known_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
knownData[['BenignScore','DomScore','RecScore']]=y_known_pred
knownData.to_csv('knownGenes_withArchitecture2Model3Scores.txt',sep='\t',index=False)

y_novel_pred=Arc2Model3.predict(novel_generator2)
print("Novel genes set performance")
print(classification_report(np.argmax(y_novel_encoded,axis=1).astype('int'), np.argmax(y_novel_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
novelData[['BenignScore','DomScore','RecScore']]=y_novel_pred
novelData.to_csv('novelGenes_withArchitecture2Model3Scores.txt',sep='\t',index=False)



In [None]:
# Train Architecture 2, model 4
lr_schedule=OneCycleScheduler(0.1,157280)
modelWeightsName='weights_Architecture2_model4'
callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=modelWeightsName,save_weights_only=True,save_best_only=True,monitor='val_loss',verbose=1),lr_schedule]

Arc2Model4 = MaverickArchitecture2()
Arc2Model4.summary()
Arc2Model4.fit(training_generator2,epochs=numEpochs,callbacks=callbacks,validation_data=validation_generator2,class_weights={0:1,1:2,2:3})
Arc2Model4.load_weights(modelWeightsName)

In [None]:
# Test Architecture 2, model 4
y_valid_pred=Arc2Model4.predict(validation_generator2)
print("Validation set performance")
print(classification_report(np.argmax(y_valid_encoded,axis=1).astype('int'), np.argmax(y_valid_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
validationData[['BenignScore','DomScore','RecScore']]=y_valid_pred
validationData.to_csv('validationSet_withArchitecture2Model4Scores.txt',sep='\t',index=False)

y_known_pred=Arc2Model4.predict(known_generator2)
print("Known genes set performance")
print(classification_report(np.argmax(y_known_encoded,axis=1).astype('int'), np.argmax(y_known_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
knownData[['BenignScore','DomScore','RecScore']]=y_known_pred
knownData.to_csv('knownGenes_withArchitecture2Model4Scores.txt',sep='\t',index=False)

y_novel_pred=Arc2Model4.predict(novel_generator2)
print("Novel genes set performance")
print(classification_report(np.argmax(y_novel_encoded,axis=1).astype('int'), np.argmax(y_novel_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
novelData[['BenignScore','DomScore','RecScore']]=y_novel_pred
novelData.to_csv('novelGenes_withArchitecture2Model4Scores.txt',sep='\t',index=False)



In [None]:
# Train Architecture 2, model 5
lr_schedule=OneCycleScheduler(0.1,157280)
modelWeightsName='weights_Architecture2_model5'
callbacks=[tf.keras.callbacks.ModelCheckpoint(filepath=modelWeightsName,save_weights_only=True,save_best_only=True,monitor='val_loss',verbose=1),lr_schedule]

Arc2Model5 = MaverickArchitecture2()
Arc2Model5.summary()
Arc2Model5.fit(training_generator2,epochs=numEpochs,callbacks=callbacks,validation_data=validation_generator2,class_weights={0:1,1:2,2:7})
Arc2Model5.load_weights(modelWeightsName)

In [None]:
# Test Architecture 2, model 5
y_valid_pred=Arc2Model5predict(validation_generator2)
print("Validation set performance")
print(classification_report(np.argmax(y_valid_encoded,axis=1).astype('int'), np.argmax(y_valid_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
validationData[['BenignScore','DomScore','RecScore']]=y_valid_pred
validationData.to_csv('validationSet_withArchitecture2Model5Scores.txt',sep='\t',index=False)

y_known_pred=Arc2Model5predict(known_generator2)
print("Known genes set performance")
print(classification_report(np.argmax(y_known_encoded,axis=1).astype('int'), np.argmax(y_known_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
knownData[['BenignScore','DomScore','RecScore']]=y_known_pred
knownData.to_csv('knownGenes_withArchitecture2Model5Scores.txt',sep='\t',index=False)

y_novel_pred=Arc2Model5predict(novel_generator2)
print("Novel genes set performance")
print(classification_report(np.argmax(y_novel_encoded,axis=1).astype('int'), np.argmax(y_novel_pred,axis=1).astype('int'), target_names=['Benign','Dominant','Recessive'], digits=3))
novelData[['BenignScore','DomScore','RecScore']]=y_novel_pred
novelData.to_csv('novelGenes_withArchitecture2Model5Scores.txt',sep='\t',index=False)

