<a href="https://colab.research.google.com/github/bahrad/Covid/blob/main/Corona_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Initialization

Load dependencies

In [1]:
%tensorflow_version 2.x
import tensorflow as tf
from tensorflow import keras

import numpy as np
import os
import csv

import pandas as pd
import pickle

from datetime import datetime
from dateutil.parser import parse as dateparse
from collections import Counter
from tqdm.notebook import tqdm
import itertools
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder, QuantileTransformer, OneHotEncoder, LabelBinarizer, OrdinalEncoder
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold, RandomizedSearchCV, StratifiedShuffleSplit
from sklearn.metrics import mean_squared_error, accuracy_score,classification_report, make_scorer, balanced_accuracy_score, coverage_error, roc_auc_score, confusion_matrix, plot_confusion_matrix, multilabel_confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.cluster import KMeans, DBSCAN
from sklearn.utils import class_weight
import sklearn as sk

from imblearn.over_sampling import SMOTE, RandomOverSampler, ADASYN, BorderlineSMOTE
from imblearn.under_sampling import RandomUnderSampler, EditedNearestNeighbours
from imblearn.combine import SMOTEENN, SMOTETomek
from imblearn.pipeline import make_pipeline,Pipeline

# !python --version

If using Google colab, mount drive and set location

In [2]:
from google.colab import drive, files
# drive.mount('/content/drive')

FILELOC = "/content/drive/My Drive/COVID_Python/"

Activate TPU resources if available

In [3]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    tpu_strategy = tf.distribute.TPUStrategy(tpu)
    tpu_env=True
except ValueError:
    print('Not connected to a TPU runtime.')
    tpu_env=False

Not connected to a TPU runtime.


##Initialize Model

Function to define the model.

* "regress" sets the model to output a continuous value from 0 to 1 for regression.
* "singleclass" is for a 0/1 binary output.
* "multiclass" outputs a softmax for a number of classes. The code may be readily modified for multilabel classification as well.
* "output_multiheadatt" generates the attention values for the transformer model, which may be extracted later.
* "use_att" adds a flat sequence-wide attention layer.
* "output_two" allows for two outputs, which maybe jointly used to compute the loss and optimize the model.
* "nclasses" is the number of classes (ignored for regress or singleclass modes
* "mask" allows for masking zeros in the sequence. "numvars" 
* "numvars" if set to a non-zero value adds additional variables as input (such as age/date/gender for clinical severity calculation) which are concatenated after the transformer and/or sequence-wide flat attention layer (if "use_att" is set to True) 

In [4]:
def reset_model(regress, singleclass, multiclass, output_multiheadatt, use_att, nclasses=4,
                output_two=False, mask=True, numvars = 0):

    if output_multiheadatt:
        model_fn = AttMod_2
    elif output_two:
        model_fn = AttMod_3
    else:
        model_fn = AttModel

    model = model_fn(L=SEQLEN,
                     vocab_size=len(aa_list)+1,
                     embdim = ENCDIM,
                     numheads = NHEADS,
                     ffdim = FFDIM,
                     num_dense = NDENSE,
                     mask_zero = mask,
                     dropout_rate = DROPRATE,
                     trans_drop = TRANSDROPRATE,
                     Nt = NT,
                     W = 1, Nc = NC, Nl = NL,
                     regress=regress,
                     singleclass=singleclass,
                     multiclass=multiclass,
                     use_att=use_att,
                     nclasses=nclasses,
                     nvars=numvars
                     )
    
    optimizer = keras.optimizers.Adam(learning_rate=LEARN_RATE)
    if regress:
        loss = keras.losses.MeanSquaredError()
        metrics = [keras.metrics.MeanSquaredError(name='mse'),
            keras.metrics.MeanSquaredLogarithmicError(name='msle'),
            keras.losses.MeanAbsoluteError(name='mae')
            ]
    if singleclass:
        loss = keras.losses.BinaryCrossentropy()
        metrics = [keras.metrics.BinaryAccuracy(name='acc'),
                   keras.metrics.AUC(name='auc')]
    if multiclass:    
        loss = keras.losses.SparseCategoricalCrossentropy()
        metrics = [keras.metrics.SparseCategoricalAccuracy(name='acc')]

    model.compile(loss=loss, optimizer=optimizer, metrics=metrics,)
                #   steps_per_execution = STEPS_PER_EXECUTION,)

    if output_two:
        losses = {'outfirst':'mean_squared_error',
                  'outpeak':'mean_squared_error'}
        lossweights = {'outfirst':1.0, 'outpeak':1.0}
        metrics = [keras.metrics.MeanSquaredError(name='mse'),
                   keras.metrics.MeanSquaredLogarithmicError(name='msle'),
                   keras.losses.MeanAbsoluteError(name='mae')]
        model.compile(loss=losses, loss_weights=lossweights, optimizer=optimizer,metrics=metrics)

    return model

##Model Parameters

In [5]:
# These parameters are currently hard-coded
ENCDIM = 1500
NC = 300
NL = 1                  # set to 0 to remove CNN pre-filtering
NT = 1
NHEADS = 4  # 8
FFDIM = 64
NDENSE = 64             # set to 0 to deactivate embedding layer
TRANSDROPRATE = 0.4
DROPRATE = 0.0

LEARN_RATE = 0.0001

BATCH_SIZE = 48

STEPS_PER_EXECUTION = 50

##Model Definitions

Transformer and Token & Position Embedding definitions. Adapted from https://keras.io/examples/nlp/text_classification_with_transformer/

In [6]:
class TransformerBlock(keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [keras.layers.Dense(ff_dim, activation="relu"), keras.layers.Dense(embed_dim),]
        )
        self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = keras.layers.Dropout(rate)
        self.dropout2 = keras.layers.Dropout(rate)

    def call(self, inputs, training, mask=None):
        # attention masking currently not implemented
        attn_output = self.att(inputs, inputs, attention_mask=None)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

class TokenAndPositionEmbedding(keras.layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim, mask_zero=False):
        super(TokenAndPositionEmbedding, self).__init__()
        self.token_emb = keras.layers.Embedding(input_dim=vocab_size,
                                                output_dim=embed_dim,
                                                mask_zero=mask_zero)
        self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim,
                                              mask_zero=mask_zero)
        self.mask_zero = mask_zero

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)

        if self.mask_zero:
            mask = x._keras_mask
            return x + positions, mask
        else:
            return x + positions

###Default model definition

In [7]:
def AttModel(L, vocab_size, embdim, numheads, ffdim, num_dense=False,
             mask_zero=False, dropout_rate=False, trans_drop=0.1,
             Nt=1, W=False, Nc=False, Nl=False,
             regress=True, singleclass=False, multiclass=False, use_att=True,
             nclasses=4, nvars=0):

    inpSeq = keras.Input(shape=(L,))
    if nvars > 0:
        # additional variables besides sequence
        inpVars = keras.Input(shape=(nvars,))
    x = inpSeq
    if nvars > 0:
        v = inpVars

    # if mask_zero:
    #     x = keras.layers.Masking(mask_value=0)(x)   

    # inpSeq = keras.Input(shape=(L,))
    # inpVars = keras.Input(shape=(3,))
    # x = inpSeq
    # v = inpVars

    if mask_zero:
        x, mask = TokenAndPositionEmbedding(L, vocab_size, embdim, mask_zero)(x)
    else:
        x = TokenAndPositionEmbedding(L, vocab_size, embdim, mask_zero)(x)

    if W and Nc and Nl:
        for n in range(Nl):
            x = keras.layers.Conv1D(filters = Nc,
                                kernel_size = W,
                                activation = 'relu',
                                padding = 'same',
                                )(x)
            if n > 1 and n < Nl-1:
                x = keras.layers.BatchNormalization()(x)

    for n in range(Nt):
        x = TransformerBlock(Nc, numheads, ffdim, rate=trans_drop)(x, mask_zero)

    if use_att:
        # Attention layer
        h = keras.layers.TimeDistributed(keras.layers.Dense(Nc, activation='tanh'))(x)
        attention = keras.layers.TimeDistributed(keras.layers.Dense(1, activation='tanh'))(h)
        attention = keras.layers.Flatten()(attention)  
        attention = keras.layers.Softmax(axis=1, name='attention')(attention) # normalize attention values
        attention = keras.layers.RepeatVector(Nc)(attention)
        attention = keras.layers.Permute([2, 1])(attention)
        representation = keras.layers.multiply([h, attention])
        representation = tf.math.reduce_sum(representation, axis = 1)
        x = representation
    else:
        x = keras.layers.GlobalAveragePooling1D()(x)

    if nvars > 0:
        # concatenate additional variables with the transformer output
        h = keras.layers.concatenate([x, v])
    else:
        h = x

    if num_dense:
        x = keras.layers.Dense(num_dense, activation = 'relu')(h)
    if dropout_rate:
        x = keras.layers.Dropout(dropout_rate)(x)

    if regress:
        # finalOut = keras.layers.Dense(1, activation=linear01)(x)
        finalOut = keras.layers.Dense(1, activation='sigmoid')(x)
    if singleclass:
        finalOut = keras.layers.Dense(1, activation='sigmoid')(x)
    if multiclass:
        finalOut = keras.layers.Dense(nclasses, activation='softmax')(x)

    # define the model's start and end points    
    # model = keras.Model(inpTensor, finalOut)
    if nvars > 0:
        model = keras.Model([inpSeq,inpVars], finalOut)
    else:
        model = keras.Model(inpSeq, finalOut)

    return model

###Return multihead attention scores

if "output_multiheadatt" set to True

In [8]:
def AttMod_2(L, vocab_size, embdim, numheads, ffdim, num_dense=False,
             mask_zero=False, dropout_rate=False, trans_drop=0.1,
             Nt=1, W=False, Nc=False, Nl=False,
             regress=True, singleclass=False, multiclass=False, use_att=True,
             nclasses=4, nvars=0):

    inpSeq = keras.Input(shape=(L,))
    if nvars > 0:
        # additional variables besides sequence
        inpVars = keras.Input(shape=(nvars,))
    x = inpSeq
    if nvars > 0:
        v = inpVars

    # if mask_zero:
    #     x = keras.layers.Masking(mask_value=0)(x)   

    if mask_zero:
        x, mask = TokenAndPositionEmbedding(L, vocab_size, embdim, mask_zero)(x)
    else:
        x = TokenAndPositionEmbedding(L, vocab_size, embdim, mask_zero)(x)

    if W and Nc and Nl:
        for n in range(Nl):
            x = keras.layers.Conv1D(filters = Nc,
                                kernel_size = W,
                                activation = 'relu',
                                padding = 'same',
                                )(x)
            if n > 1 and n < Nl-1:
                x = keras.layers.BatchNormalization()(x)

    # attention masking currently not implemented
    y, attout = keras.layers.MultiHeadAttention(num_heads=numheads, key_dim=Nc,
                                                )(x, x, return_attention_scores=True,
                                                  attention_mask=None)
    y = keras.layers.Dropout(trans_drop)(y)
    z = keras.layers.LayerNormalization(epsilon=1e-6)(x + y)
    z1 = keras.Sequential( [keras.layers.Dense(ffdim, activation="relu"), keras.layers.Dense(embdim),])
    z1 = keras.layers.Dropout(trans_drop)(z)
    x = keras.layers.LayerNormalization(epsilon=1e-6)(z + z1)

    if use_att:
        # Attention layer
        h = keras.layers.TimeDistributed(keras.layers.Dense(Nc, activation='tanh'))(x)
        attention = keras.layers.TimeDistributed(keras.layers.Dense(1, activation='tanh'))(h)
        attention = keras.layers.Flatten()(attention)  
        attention = keras.layers.Softmax(axis=1, name='attention')(attention) # normalize attention values
        attention = keras.layers.RepeatVector(Nc)(attention)
        attention = keras.layers.Permute([2, 1])(attention)
        representation = keras.layers.multiply([h, attention])
        representation = tf.math.reduce_sum(representation, axis = 1)
        x = representation
    else:
        x = keras.layers.GlobalAveragePooling1D()(x)

    if nvars > 0:
        # concatenate additional variables with the transformer output
        h = keras.layers.concatenate([x, v])
    else:
        h = x

    if num_dense:
        x = keras.layers.Dense(num_dense, activation = 'relu')(h)
    if dropout_rate:
        x = keras.layers.Dropout(dropout_rate)(x)

    if regress:
        # finalOut = keras.layers.Dense(1, activation=linear01)(x)
        finalOut = keras.layers.Dense(1, activation='sigmoid')(x)
    if singleclass:
        finalOut = keras.layers.Dense(1, activation='sigmoid')(x)
    if multiclass:
        finalOut = keras.layers.Dense(nclasses, activation='softmax')(x)

    # define the model's start and end points    
    # model = keras.Model(inpTensor, finalOut)
    if nvars > 0:
        model = keras.Model([inpSeq,inpVars], finalOut)
    else:
        model = keras.Model(inpSeq, finalOut)

    return model

###Output two predictions for joint optimization



if "output_two" is set to True

In [9]:
def AttMod_3(L, vocab_size, embdim, numheads, ffdim, num_dense=False,
             mask_zero=False, dropout_rate=False, trans_drop=0.1,
             Nt=1, W=False, Nc=False, Nl=False,
             regress=True, singleclass=False, multiclass=False, use_att=True,
             nclasses=4, nvars=0):

    inpSeq = keras.Input(shape=(L,))
    if nvars > 0:
        # additional variables besides sequence
        inpVars = keras.Input(shape=(nvars,))
    x = inpSeq
    if nvars > 0:
        v = inpVars

    # if mask_zero:
    #     x = keras.layers.Masking(mask_value=0)(x)   

    if mask_zero:
        x, mask = TokenAndPositionEmbedding(L, vocab_size, embdim, mask_zero)(x)
    else:
        x = TokenAndPositionEmbedding(L, vocab_size, embdim, mask_zero)(x)

    if W and Nc and Nl:
        for n in range(Nl):
            x = keras.layers.Conv1D(filters = Nc,
                                kernel_size = W,
                                activation = 'relu',
                                padding = 'same',
                                )(x)
            if n > 1 and n < Nl-1:
                x = keras.layers.BatchNormalization()(x)

    for n in range(Nt):
        x = TransformerBlock(Nc, numheads, ffdim, rate=trans_drop)(x, mask_zero)

    if use_att:
        # Attention layer
        h = keras.layers.TimeDistributed(keras.layers.Dense(Nc, activation='tanh'))(x)
        attention = keras.layers.TimeDistributed(keras.layers.Dense(1, activation='tanh'))(h)
        attention = keras.layers.Flatten()(attention)  
        attention = keras.layers.Softmax(axis=1, name='attention')(attention) # normalize attention values
        attention = keras.layers.RepeatVector(Nc)(attention)
        attention = keras.layers.Permute([2, 1])(attention)
        representation = keras.layers.multiply([h, attention])
        representation = tf.math.reduce_sum(representation, axis = 1)
        x = representation
    else:
        x = keras.layers.GlobalAveragePooling1D()(x)

    if nvars > 0:
        # concatenate additional variables with the transformer output
        h = keras.layers.concatenate([x, v])
    else:
        h = x

    if num_dense:
        x = keras.layers.Dense(num_dense, activation = 'relu')(h)
    if dropout_rate:
        x = keras.layers.Dropout(dropout_rate)(x)

    if regress:
        # finalOut = keras.layers.Dense(1, activation=linear01)(x)
        finalOut = keras.layers.Dense(1, activation='sigmoid')(x)
    if singleclass:
        finalOut = keras.layers.Dense(1, activation='sigmoid')(x)
    if multiclass:
        finalOut = keras.layers.Dense(nclasses, activation='softmax')(x)

    out1 = keras.layers.Dense(1, activation='sigmoid', name='outfirst')(x)
    out2 = keras.layers.Dense(1, activation='sigmoid', name='outpeak')(x)
    # define the model's start and end points    
    if nvars > 0:
        model = keras.Model([inpSeq,inpVars], [out1,out2])
    else:
        model = keras.Model(inpSeq, [out1,out2])

    return model

##Function to tokenize sequences

In [10]:
def tokenize_sequences(data_dataframe, SeqCol='ISM', seqlen=1273):
    def f(x):
        if len(x) < seqlen:
            return x + '*'*(seqlen-len(x))
        elif len(x) > seqlen:
            return x[:seqlen]
        else:
            return x
    data = np.vstack(data_dataframe[SeqCol].apply(f).apply(lambda x: np.array(list(x))))
    aa_list = ['A', 'R', 'N', 'D', 'C', 'Q', 'E',
            'G', 'H', 'I', 'L', 'K', 'M', 'F',
            'P', 'S', 'T', 'W', 'Y', 'V', '-',
            ]
    aa_tokenizer = {aa_list[k]:k+1 for k in range(len(aa_list))}
    aa_tokenizer['*'] = 0
    aa_tokenizer['X'] = 0
    # optionally handle B, J, Z ambiguities
    # Asx	B	Aspartic acid or Asparagine (D or N)
    # Glx	Z	Glutamic acid or Glutamine (E or Q)
    # Xaa	X	Any amino acid
    # Xle	J	Leucine or Isoleucine (L or I)
    aa_tokenizer['B'] = 0
    aa_tokenizer['Z'] = 0
    aa_tokenizer['J'] = 0

    return np.vectorize(aa_tokenizer.get)(data)

aa_list = ['A', 'R', 'N', 'D', 'C', 'Q', 'E',
        'G', 'H', 'I', 'L', 'K', 'M', 'F',
        'P', 'S', 'T', 'W', 'Y', 'V', '-',
        ]

#Dataset Preprocessing

Each of the modules (or submodules in the case of SARS-CoV-2 data) processes a different kind of data for different classification tasks. The modules all start with loading a datafile (which may have already been pre-processed) and end with the creation of numpy arrays for training, testing features and labels. Additional modules may be added for additional classification tasks.

##Corona (Multi-Genus) Sequence Data

Read the csv file with coronavirus sequences (different species)

In [None]:
data = pd.read_csv(FILELOC + "coronavirus_spike/" + "coronataxonomy_dataset_notpreprocessed.csv")
print(len(data))

# rename column with sequences to "Spike" for consistency with downstream operations
data.rename(columns={'Seq':'Spike'}, inplace=True)

Remove short sequences or those whose species is not coronavirus

In [None]:
dataset = data[(data.Length>1000) & (data.Species.str.contains('coronavirus'))].copy()

Create class labels (comment / comment-out code as needed)

In [None]:
# # Create labels for host: human/non-human classification task

# NCLASSES = 2

# dataset['hostlabel'] = dataset.Host.apply(lambda x: 1 if x=='Homo sapiens' else 0)

# Create labels for genus classification task

NCLASSES = 4

def f(x):
    if x=='Alphacoronavirus':
        return 0
    if x=='Betacoronavirus':
        return 1
    if x=='Gammacoronavirus':
        return 2
    if x=='Deltacoronavirus':
        return 3

dataset['genuslabel'] = dataset.Genus.apply(f)

Display histogram of sequence lengths and set the maximum sequence length (SEQLEN)

In [None]:
dataset.Spike.apply(len).plot.hist()

Set sequence length (pad shorter sequences / truncate longer sequences)

In [None]:
SEQLEN = 1500

Sample dates may indicate that a sample was collected after SARS-CoV-2 is discovered. For some validation (i.e. determining whether SARS-CoV-2 is classified correctly as a Betacoronavirus) may want to remove post Jan 2020 samples. Additional code may be added below to remove those samples.

In [None]:
dataset['date'] = dataset['Release_Date'].apply(lambda x:dateparse(x))
dataset['date'] = dataset['date'].dt.date

In [None]:
print(len(dataset[dataset.Collection_Date.isna()])) # some dates may be nandatadf['seqlen'] = datadf.Spike.apply(len)

Tokenize sequences and create training and test data sets

In [None]:
seqtok = tokenize_sequences(dataset, 'Spike', SEQLEN)

# the code below defines the labels as genus label
# can be modified to define labels as host label, or as both
y = dataset.genus.values

In [None]:
# trainindex = np.random.choice(range(len(seqtok)), size = int(0.8*len(seqtok)), replace=False)
# testindex = np.array([k for k in range(len(seqtok)) if k not in trainindex])

# save train index for future use:
# np.savetxt(FILELOC + 'corona_trainindex.csv', trainindex, fmt='%i', delimiter=',')
trainindex = np.loadtxt(FILELOC + 'corona_trainindex.csv', dtype=int, delimiter=',')

xtrain = seqtok[trainindex]
xtest = seqtok[testindex];
NVARS = 0       # there are no additional variables besides the sequence

ytrain = y[trainindex]
ytest = y[testindex]

##SARS-CoV-2 Lineage Sequence Data

###Raw Sequences (random sample)

Read a file with raw sequences (i.e. not aligned) used to demonstrate lineage prediction. These need to be generated by downloading sequences from GISAID because they cannot be separately distributed.

In [None]:
# Specific code used to generate a set of raw samples from a pre-existing
# dataframe of sequences, sequence IDs, and date of first collection

# with open(FILELOC + 'spike_reldate_0303.pkl', 'rb') as f:
#     df = pickle.load(f)
# df.reset_index(drop=False, inplace=True)
# df = df[df.Lineage!="None"].reset_index(drop=True)
# df_sample = pd.concat([df[df.Count >= 100].sample(4000), df[df.Count.between(10,99)].sample(12000),
#                        df[df.Count.between(2,3)].sample(3000), df[df.Count==1].sample(1000)], axis=0)
# with open(f'{FILELOC}coronavirus_spike_sars2cov_rawsample.pkl', 'wb') as f:
#     pickle.dump(df_sample, f)

with open(f'{FILELOC}coronavirus_spike_sars2cov_rawsample.pkl', 'rb') as f:
    datadf = pickle.load(f)
datadf.reset_index(drop=False, inplace=True)

Assign lineages to labels (this can be readily modified)

In [None]:
SEQLEN = 1500   # set sequence length (pad shorter sequences / truncate longer sequences)

labelmap = {'AY.4':0, 'B.1.617.2':0,
            'B.1':1, 'B.1.177':1, 'B.1.1':1, 'B.1.2':1,
            'BA.1':2,
            'BA.1.1':3,
            'BA.2':4,
            'P.1':5,
            'B.1.351':6,
            'B.1.427':7, 'B.1.429':7,
            }

datadf['Label'] = datadf['Lineage'].map(labelmap)
datadf = datadf[datadf.Label.notna()].reset_index(drop=True)
datadf['Label'] = datadf['Label'].astype(int)

Tokenize sequences and create training and test data sets

In [None]:
seqtok = tokenize_sequences(dataset, 'Spike', SEQLEN)

# the code below defines the labels as genus label
# can be modified to define labels as host label, or as both
y = dataset.genus.values

In [None]:
# trainindex = np.random.choice(range(len(seqtok)), size = int(0.8*len(seqtok)), replace=False)
# testindex = np.array([k for k in range(len(seqtok)) if k not in trainindex])

# save train index for future use:
# np.savetxt(FILELOC + 'corona_trainindex.csv', trainindex, fmt='%i', delimiter=',')
trainindex = np.loadtxt(FILELOC + 'corona_trainindex.csv', dtype=int, delimiter=',')

xtrain = seqtok[trainindex]
xtest = seqtok[testindex];
NVARS = 0       # there are no additional variables besides the sequence

ytrain = y[trainindex]
ytest = y[testindex]

###Aligned Sequences (covid-patient dataset)

Read a file with aligned SARS-CoV-2 sequences to predict lineage or date of first occurrence. In this case, the data were originally generated for clinical severity prediction, and are then processed to generate distinct sequences (i.e. remove repeated sequences from the database) and assign them a date of first occurence and lineage.

In [None]:
# seqs = pd.read_csv(FILELOC + 'covid_patient_seqs_20220228/covid_patient_seqs_20220228.csv')
# seqs.drop(columns="Spike", inplace=True)

# df = seqs.copy()
# df['Country'] = df['Location'].apply(lambda x: x.split('/')[1].strip())
# REFDATE = '2019-12-01'; refdt = dateparse(REFDATE)
# def f(x):
#     return (dateparse(x) - refdt).days
# df['Date'] = df['Collection date'].apply(f)
# df = df[['SequenceID', 'Date', 'Clade', 'Lineage', 'Country', 'MaskedSeq', 'Collection date', 'Location']]

# df = df[df.Date > 0] # get rid of negative date values

# df_reldate = df.groupby("MaskedSeq")["Date"].apply(list).to_frame()
# df_reldate['First_Date'] = df_reldate.Date.apply(np.amin)
# df_reldate['Last_Date'] = df_reldate.Date.apply(np.amax)
# df_reldate['Peak_Date'] = df_reldate.Date.apply(lambda x: np.median(np.argwhere(np.bincount(x)==np.amax(np.bincount(x)))))
# df_reldate['Count'] = df_reldate.Date.apply(len)
# max_first_date = max(df_reldate.First_Date)
# df_reldate['relfirstdate'] = df_reldate['First_Date'].apply(lambda x: x/max_first_date)
# df_out = df_reldate.join(df.groupby("MaskedSeq")["Location"].apply(list).to_frame())
# def f(x):
#     try:
#         return pd.Series.mode(x)[0]
#     except:
#         return np.nan

# df_out = df_out.join(df.groupby("MaskedSeq")["Clade"].agg(f).to_frame())
# df_out = df_out.join(df.groupby("MaskedSeq")["Lineage"].agg(f).to_frame())

# with open(FILELOC + 'covid_patient_seqs_20220228/covid_patient_seqs_grouped_20220228.pkl', 'wb') as f:
#     pickle.dump(df_out, f)

with open(FILELOC + 'covid_patient_seqs_20220228/covid_patient_seqs_grouped_20220228.pkl', 'rb') as f:
    datadf = pickle.load(f)
datadf.reset_index(drop=True, inplace=True)

The following is used to assign labels based on lineage.

In [None]:
# rename column with sequences to "Spike" for consistency with downstream operations
datadf.rename(columns={'MaskedSeq':'Spike'}, inplace=True)

datadf = datadf[datadf.Count >= 2].reset_index(drop=True)

labelmap = {'AY.4':0, 'B.1.617.2':0,
            'B.1':1, 'B.1.177':1, 'B.1.1':1, 'B.1.2':1,
            'BA.1':2,
            'BA.1.1':3,
            'BA.2':4,
            'P.1':5,
            'B.1.351':6,
            'B.1.427':7, 'B.1.429':7,
            }

datadf['Label'] = datadf['Lineage'].map(labelmap)
datadf = datadf[datadf.Label.notna()].reset_index(drop=True)
datadf['Label'] = datadf['Label'].astype(int)

# aligned sequence lengths are equally 1273
SEQLEN = 1273

Tokenize sequences and create training and test data sets

In [None]:
seqtok = tokenize_sequences(dataset, 'Spike', SEQLEN)

# the code below defines the labels as genus label
# can be modified to define labels as host label, or as both
y = dataset.genus.values

In [None]:
# trainindex = np.random.choice(range(len(seqtok)), size = int(0.8*len(seqtok)), replace=False)
# testindex = np.array([k for k in range(len(seqtok)) if k not in trainindex])

# save train index for future use:
# np.savetxt(FILELOC + 'corona_trainindex.csv', trainindex, fmt='%i', delimiter=',')
trainindex = np.loadtxt(FILELOC + 'corona_trainindex.csv', dtype=int, delimiter=',')

xtrain = seqtok[trainindex]
xtest = seqtok[testindex];
NVARS = 0       # there are no additional variables besides the sequence

ytrain = y[trainindex]
ytest = y[testindex]

##COVID-19 Disease Severity Data

###Load up-to-date processed sequence dataset

In [None]:
with open(FILELOC + 'covid_patient_seqs_20220415/covid_patient_20220415.pkl', 'rb') as f:
    pdf = pickle.load(f)

pdf.drop(columns="Unnamed: 0", inplace=True)
pdf['Lineage'] = pdf.Lineage.fillna('None')
print(len(pdf))
pdf.rename(columns={'reldate':'Date'}, inplace=True)

In [None]:
# Use aligned sequences rather than "raw" sequences:

pdf.drop(columns = 'Spike', inplace=True)

# rename column with sequences to "Spike" for consistency with downstream operations
pdf.rename(columns={'MaskedSeq':'Spike'}, inplace=True)
# aligned sequence lengths are equally 1273
SEQLEN = 1273

Use the following code to generate/regenerate labels (commented out because the csv file loaded above already has labels).

In [None]:
# labelmap = {'alive' : -1,
#             'asymptomatic' : 0,
#             'dead' : 1,
#             'hospitalized' : 1,
#             'mild' : 0,
#             'moderate' : 0,
#             'released' : 1,
#             'screening' : 0,
#             'severe' : 1,
#             'symptomatic' : -1,
#             'unknown' : -1,
#         }

# pdf['Label'] = pdf['Category'].map(labelmap).astype(int)

Remove samples with invalid labels and patient variables (which were set to -1 in preprocessing).

In [None]:
# pdfbackup = pdf.copy()

In [None]:
pdf.drop(pdf[pdf.Label==-1].index, inplace=True)
print(len(pdf))
pdf.drop(pdf[(pdf.Age==-1) | (pdf.Age > 100)].index, inplace=True)
print(len(pdf))
pdf.drop(pdf[pdf.Gender==-1].index, inplace=True)
print(len(pdf))
pdf.reset_index(drop=True, inplace=True)

Optionally clean the dataset of sequences with ambiguous or missing residues

In [None]:
# pdfclean = pdf.drop(pdf[pdf['Spike'].str.contains('\*|X')].index).reset_index(drop=True)
# print(len(pdf), len(pdfclean))

Count sequence frequency (i.e. to exclude infrequent sequences)

In [None]:
# count all sequences
countdata = pdf.groupby(["Spike"]).size().to_frame().rename(columns={0:"Count"}).reset_index(drop=False)
pdfcount = pd.merge(pdf,countdata,on='Spike')

try:
    # count only sequences without * or X
    countdata_cl = pdfclean.groupby(["Spike"]).size().to_frame().rename(columns={0:"Count"}).reset_index(drop=False)
    pdfcount_cl = pd.merge(pdfclean,countdata_cl,on='Spike')

    for m in [1000, 200, 100, 50, 10,5,4,3,2,1]:
        print(m,
            len(countdata[countdata.Count >= m]), len(pdfcount[pdfcount.Count >= m]),
            len(countdata_cl[countdata_cl.Count >= m]), len(pdfcount_cl[pdfcount_cl.Count >= m]))
except:
    for m in [1000, 200, 100, 50, 10,5,4,3,2,1]:
        print(m, len(countdata[countdata.Count >= m]), len(pdfcount[pdfcount.Count >= m]))

Optionally create a dataset of distinct sequences with averages of label and other variables

In [None]:
# pdfcount = pdf.groupby(["Spike"]).mean()
# pdfcount['Count'] = pdf.groupby(["Spike"]).size()
# def f(x):
#     try:
#         return pd.Series.mode(x)[0]
#     except:
#         return 'None'
# c = pdf.groupby("Spike")["Lineage"].agg(f)
# pdfcount = pdfcount.join(c)
# pdfcount.reset_index(drop=False, inplace=True)

# pdfcount_cl = pdfcount.drop(pdfcount[pdfcount['Spike'].str.contains('\*|X')].index).reset_index(drop=True)
# print(len(pdfcount), len(pdfcount_cl))

# for m in [1000, 200, 100, 50, 10,5,4,3,2,1]:
#     print(m, len(pdfcount[pdfcount.Count >= m]), len(pdfcount_cl[pdfcount_cl.Count >= m]))

Predefined train and test sets

In [None]:
# traindf = pdfcount[pdfcount.Count >= 10]
# testdf = pdfcount[pdfcount.Count < 10]
# print(len(traindf), len(testdf))

# # NVARS = 3       # 3 patient variables (age/gender/date)
# # NCLASSES = 2    # 2 outcomes (mild/severe)
# # trainvars = traindf[['Age', 'Gender', 'reldate']].values
# # testvars = testdf[['Age', 'Gender', 'reldate']].values

# NVARS = 2       # 3 patient variables (age/date)
# NCLASSES = 2    # 2 outcomes (mild/severe)
# trainvars = traindf[['Age', 'reldate']].values
# testvars = testdf[['Age', 'reldate']].values

# SEQLEN = 1273

# trainseq = tokenize_sequences(traindf, 'Spike', SEQLEN)
# testseq = tokenize_sequences(testdf, 'Spike', SEQLEN)

# xtrain = [trainseq, trainvars]
# xtest = [testseq, testvars]
# ytrain = traindf.Label.values
# ytest = testdf.Label.values

Train and test sets determined through random split or lineage-based split

In [None]:
# # use all sequences
# tdf = pdf
tdf = pdfcount
# # use only "clean" sequences (no * or X)
# tdf = pdfclean
# tdf = pdfcount_cl
# use only frequent sequences
# tdf = pdfcount[pdfcount.Count >= 5]
# # use only frequent sequences without * or X
# tdf = pdfcount_cl[pdfcount_cl.Count >= 10]

tdf = tdf.reset_index(drop=True)

seqtok = tokenize_sequences(tdf, 'Spike', SEQLEN)
y = tdf.Label.values

In [None]:
# RANDOM SPLIT

# trainindex = np.random.choice(range(len(seqtok)), size = int(0.7*len(seqtok)), replace=False)
# testindex = np.array([k for k in range(len(seqtok)) if k not in trainindex])

# save train index for future use:
# np.savetxt(FILELOC + 'covidpatient_trainindex.csv', trainindex, fmt='%i', delimiter=',')

# load predefined set of training indices
# trainindex = np.loadtxt(FILELOC + 'covidpatient_trainindex.csv', dtype=int, delimiter=',')

In [None]:
# LINEAGE SPLIT

# trainindex = tdf[~tdf.Lineage.str.contains('BA')].index
# testindex = tdf[tdf.Lineage.str.contains('BA')].index
# trainindex = tdf[~tdf.Lineage.str.contains('BA.1.|BA.2.')].index
# testindex = tdf[tdf.Lineage.str.contains('BA.1.|BA.2.')].index
# trainindex = tdf[~(tdf.Lineage.str.contains('AY')|tdf.Lineage.str.contains('B.1.617.2'))].index
# testindex = tdf[tdf.Lineage.str.contains('AY')|tdf.Lineage.str.contains('B.1.617.2')].index
trainindex = tdf[~(tdf.Lineage.str.contains('AY'))].index
testindex = tdf[tdf.Lineage.str.contains('AY')].index

print(len(trainindex), len(testindex))

In [None]:
# Use normalized date and ages instead

# normalize date to 1000 max. and age to 100 max.
tdf['reldate'] = tdf['Date'].apply(lambda x: x/1000)
tdf['relage'] = tdf['Age'].apply(lambda x: x/100)

# # otherwise uncomment below to use max. date and age to normalize
# maxdate = pdf.Date.max()
# pdf['reldate'] = pdf['Date'].apply(lambda x: x/maxdate)
# maxage = pdf.Age.max()
# pdf['relage'] = pdf['Age'].apply(lambda x: x/maxage)
# print(maxdate, maxage)

In [None]:
NCLASSES = 2    # 2 outcomes (mild/severe)

# # Uncomment to do 3 variables including sample collection date

# NVARS = 3       # 3 patient variables (age/gender/date)
# vars = tdf[['relage', 'Gender', 'reldate']].values

# # Uncomment if not using patient/sample-specific features (i.e., age, gender, date, etc.)

# NVARS = 0

# # Uncomment to use age and gender and not date

NVARS = 2
vars = tdf[['relage', 'Gender']].values

Use regression to predict a continuous label between 0 to 1 vs. a binary classification of 0-1

In [None]:
# REGRESS = True
REGRESS = False

In [None]:
if NVARS == 0:
    xtrain = seqtok[trainindex]
    xtest = seqtok[testindex]
else:
    xtrain = [seqtok[trainindex], vars[trainindex]]
    xtest = [seqtok[testindex], vars[testindex]]
ytrain = y[trainindex]
ytest = y[testindex]

###Use train/test files that were previously used for earlier versions of the manuscript

Sokhansanj, BA et al., "An Interpretable Deep Learning Model for Predicting the Risk of Severe COVID-19 from Spike Protein Sequence", https://www.researchsquare.com/article/rs-1234007/v1

In [None]:
# # The file loaded below already has training and test sets defined

# with open (FILELOC + 'covid_patient_seqs_20220228/covid_patient_data_old.pkl', 'rb') as f:
#     traindf, testdf = pickle.load(f)
# print(len(traindf), len(testdf))

# # rename column with sequences to "Spike" for consistency
# traindf.rename(columns={'ISM':'Spike'}, inplace=True)
# testdf.rename(columns={'ISM':'Spike'}, inplace=True)

# # NVARS = 3       # 3 patient variables (age/gender/date)
# # NCLASSES = 2    # 2 outcomes (mild/severe)
# # trainvars = traindf[['Age', 'Gender', 'reldate']].values
# # testvars = testdf[['Age', 'Gender', 'reldate']].values

# NVARS = 2       # 3 patient variables (age/date)
# NCLASSES = 2    # 2 outcomes (mild/severe)
# trainvars = traindf[['Age', 'reldate']].values
# testvars = testdf[['Age', 'reldate']].values

# SEQLEN = 1273

Tokenize sequences and define "xtrain/xtest/ytrain/ytest" consistently with other methods for use in training and validation below.

In [None]:
# trainseq = tokenize_sequences(traindf, 'Spike', SEQLEN)
# testseq = tokenize_sequences(testdf, 'Spike', SEQLEN)

# xtrain = [trainseq, trainvars]
# xtest = [testseq, testvars]
# ytrain = traindf.Label.values
# ytest = testdf.Label.values

#Model Training and Evaluation

##Class Balancing

Class weights or sample weights can be used in the training code (it can be commented or commented out as needed)

In [None]:
if not REGRESS:
    from sklearn.utils import class_weight
    class_weights = list(class_weight.compute_class_weight(class_weight='balanced',
                                                        classes=np.arange(NCLASSES), y=ytrain))

    # optionally, each sample may be weighted individually
    sample_weights = np.array([class_weights[int(yi)] for yi in ytrain])
    print(class_weights)

##Train Model

This training routine operates in one shot with a preset number of epochs. Early stopping is optional.

Set REGRESS to False for classification and REGRESS to True for regression.

In [None]:
# define # of epochs and batch sizes for a binary classification or
# regression from 0-1 task (i.e., disease severity prediction)
if not REGRESS:
    NUM_EPOCHS = 50 # 70
    # A larger batch size more optimally uses TPU resources
    if tpu_env:
        BATCH_SIZE = 48*8
    else:
        BATCH_SIZE = 48
else:
    # Use a smaller batch size on a smaller dataset & more epochs for regression
    BATCH_SIZE = 48*8 # 48*8
    NUM_EPOCHS = 250 # 100

# # if doing multiclass prediction (i.e. for taxonomic classification)
# NUM_EPOCHS = 25
# BATCH_SIZE = 48     # there are not many training samples

# See Tensorflow documentation for how to modify the early stopping callback

# VAL_SPLIT = 0.2   # use if defining early stopping callbacks with validation data
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor = 'loss',
    verbose = 1,
    patience = 10, #5,
    mode = 'auto',
    min_delta = 0,
    restore_best_weights = True
    )

The below training code is used for a single training set and for a binary classification or regression from 0-1 (i.e., disease severity prediction tasks)

In [None]:
tf.keras.backend.clear_session()    # reset tensorflow session (clear model history)

if tpu_env:
    with tpu_strategy.scope():
        model = reset_model(regress=REGRESS, singleclass=(not REGRESS), multiclass=False,
                            output_multiheadatt=True, use_att=True, nclasses=NCLASSES,
                            output_two=False, numvars=NVARS, mask=True)
        model.summary() # Output table of deep model layers and parameters / connections
else:
    model = reset_model(regress=REGRESS, singleclass=(not REGRESS), multiclass=True,
                        output_multiheadatt=False, use_att=True, nclasses=NCLASSES,
                        numvars=NVARS, mask=False)
    model.summary() # Output table of deep model layers and parameters / connections

history = model.fit(xtrain, ytrain,
                # optional sample weighting
                # sample_weight = sample_weights,
                # use the following instead of sample_weight if weighting samples individually
                # class_weight = {c:class_weights[c] for c in range(NCLASSES)},
                batch_size = BATCH_SIZE,
                epochs = NUM_EPOCHS,
                verbose = 1,
                # validation_split = VAL_SPLIT,
                callbacks = [early_stopping],
                )

# optionally save weights
# model.save_weights(f"{FILELOC}weights_sarscov2.h5", save_format='h5', overwrite=True)

Output confusion matrix and sklearn classification report (precision, recall, f1-score)

In [None]:
print("Results for Testing Data:")
test_predict = model.predict(xtest)
ClassRep = classification_report(np.round(ytest), np.round(test_predict))
ConfMatrix = confusion_matrix(np.round(ytest), np.round(test_predict))
print(ClassRep)
print(ConfMatrix)

In [None]:
# sample model save

# model.save_weights(f"{FILELOC}weights_sarscov2_test_Delta_AY_agegender_2.h5", save_format='h5', overwrite=True)
# model.save_weights(f"{FILELOC}weights_sarscov2_test_Omicron_subvar_agegender.h5", save_format='h5', overwrite=True)
# model.save_weights(f"{FILELOC}weights_sarscov2_test_Omicron.h5", save_format='h5', overwrite=True)
# model.save_weights(f"{FILELOC}weights_sarscov2_test_Omicron_subvar.h5", save_format='h5', overwrite=True)
# model.save_weights(f"{FILELOC}weights_sarscov2_test_Omicron_subvar_nodemo.h5", save_format='h5', overwrite=True)

In [None]:
# Continue training for additional epochs

EXTRA_EPOCHS = 20

history = model.fit(xtrain, ytrain,
                    # sample_weight = sample_weights,
                    # use the following instead of sample_weight if weighting samples individually
                    # class_weight = {c:class_weights[c] for c in range(NCLASSES)},
                    batch_size = BATCH_SIZE,
                    epochs = EXTRA_EPOCHS,
                    verbose = 1,
                    # validation_split = VAL_SPLIT,
                    # callbacks = [early_stopping],
                    )

# model.save_weights(f"{FILELOC}weights_sarscov2_extra_epochs.h5", save_format='h5', overwrite=True)

print("Results for Testing Data:")
test_predict = model.predict(xtest)
ClassRep = classification_report(np.round(ytest), np.round(test_predict))
ConfMatrix = confusion_matrix(np.round(ytest), np.round(test_predict))
print(ClassRep)
print(ConfMatrix)

In [None]:
# example of doing MULTIPLE REPEAT RUNS
# assume tpu_env is True

for run in range(3):
    tf.keras.backend.clear_session()    # reset tensorflow session (clear model history)

    with tpu_strategy.scope():
        model = reset_model(regress=REGRESS, singleclass=(not REGRESS), multiclass=False,
                            output_multiheadatt=True, use_att=True, nclasses=NCLASSES,
                            output_two=False, numvars=NVARS, mask=True)

        history = model.fit(xtrain, ytrain, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS, verbose = 1)
        model.save_weights(f"{FILELOC}weights_sarscov2_test_AY_{run}.h5")
        test_predict = model.predict(xtest)
        print(classification_report(np.round(ytest), np.round(test_predict)))
        print(confusion_matrix(np.round(ytest), np.round(test_predict)))

##Load Pretrained Model

The arguments for reset_model() should be the same as used for model training.

In [None]:
tf.keras.backend.clear_session()
if tpu_env:
    with tpu_strategy.scope():
        model = reset_model(regress=False, singleclass=True, multiclass=False,
                            output_multiheadatt=True, use_att=True, nclasses=NCLASSES,
                            output_two=False, numvars=NVARS, mask=True)
        model.load_weights(f"{FILELOC}paper_resubmit/weights_sarscov2_test_Delta.h5")
        model.compile()
else:
    model= reset_model(regress=False, singleclass=False, multiclass=True,
                    output_multiheadatt=False, use_att=True, nclasses=4)
    model.load_weights(f"{FILELOC}weights_sarscov2_20220408.h5")
    model.compile()

Load multiple runs and print out classification reports/confusion matrices

In [None]:
for run in range(5):
    tf.keras.backend.clear_session()
    with tpu_strategy.scope():
        model = reset_model(regress=False, singleclass=True, multiclass=False,
                            output_multiheadatt=True, use_att=True, nclasses=NCLASSES,
                            output_two=False, numvars=NVARS, mask=True)
        model.load_weights(f"{FILELOC}weights_sarscov2_test_Omicron_subvar_agegender_{run}.h5")
        model.compile()
        test_predict = model.predict(xtest, batch_size=256, verbose=False)
        print(f"RUN= {run}")
        print(classification_report(np.round(ytest), np.round(test_predict)))
        print(confusion_matrix(np.round(ytest), np.round(test_predict)))

Generate Attention and Embedding values

In [None]:
# # this cell is an example of loading two different sets of attention and embeddings

# NVARS = 2; xtrain = [seqtok[trainindex], vars[trainindex]]; xtest = [seqtok[testindex], vars[testindex]]
# tf.keras.backend.clear_session()
# with tpu_strategy.scope():
#     model = reset_model(regress=False, singleclass=True, multiclass=False,
#                             output_multiheadatt=True, use_att=True, nclasses=NCLASSES,
#                             output_two=False, numvars=NVARS, mask=True)
#     model.load_weights(f"{FILELOC}weights_sarscov2_test_Omicron_subvar_agegender_4.h5")
#     get_embedding_model = keras.Model(inputs=model.input,outputs=model.get_layer('dense_4').output)
#     get_embedding_model.compile()
#     get_attention_model = keras.Model(inputs=model.input,outputs=model.get_layer('attention').output)
#     get_attention_model.compile()
#     pred = model.predict([seqtok,vars], verbose=True)
#     emb = get_embedding_model.predict([seqtok,vars], verbose=True)
#     att = get_attention_model.predict([seqtok,vars], verbose=True)
# with open(FILELOC + 'paper_resubmit/transformer_Omicron_subvar_agegender_eval.pkl', 'wb') as f:
#     pickle.dump([pred, emb, att], f)

# NVARS = 0; xtrain = seqtok[trainindex]; xtest = seqtok[testindex]
# tf.keras.backend.clear_session()
# with tpu_strategy.scope():
#     model = reset_model(regress=False, singleclass=True, multiclass=False,
#                             output_multiheadatt=True, use_att=True, nclasses=NCLASSES,
#                             output_two=False, numvars=NVARS, mask=True)
#     model.load_weights(f"{FILELOC}weights_sarscov2_test_Omicron_subvar_nodemo_0.h5")
#     get_embedding_model = keras.Model(inputs=model.input,outputs=model.get_layer('dense_4').output)
#     get_embedding_model.compile()
#     get_attention_model = keras.Model(inputs=model.input,outputs=model.get_layer('attention').output)
#     get_attention_model.compile()
#     pred = model.predict(seqtok, verbose=True)
#     emb = get_embedding_model.predict(seqtok, verbose=True)
#     att = get_attention_model.predict(seqtok, verbose=True)
# with open(FILELOC + 'paper_resubmit/transformer_Omicron_subvar_nodemo_eval.pkl', 'wb') as f:
#     pickle.dump([pred, emb, att], f)

###Plot Attention

In [None]:
with open(FILELOC + 'paper_resubmit/transformer_Omicron_subvar_agegender_eval.pkl', 'rb') as f:
    _, _, att_var = pickle.load(f)
with open(FILELOC + 'paper_resubmit/transformer_Omicron_subvar_nodemo_eval.pkl', 'rb') as f:
    _, _, att_novar = pickle.load(f)

In [None]:
fig,ax = plt.subplots()
fig.set_size_inches(16,6)

ax.plot(range(1,SEQLEN+1), np.median(att_var,axis=0), linewidth=1, label="Age/Gender", color='red')
ax.plot(range(1,SEQLEN+1), np.median(att_novar,axis=0), linewidth=1, label="Sequence Only", color='blue')

ax.set_xlabel('Sequence Position', fontsize=18, fontweight='bold')
ax.set_ylabel('Attention', fontsize=18, fontweight='bold')

plt.legend()
plt.show()

In [None]:
# output the specific locations of high attention

np.where(np.median(att_var,axis=0) > 0.001)[0]+1

###Plot Multi-Head Attention

Sample code to plot and extract attention from transformer heads

In [None]:
trainseq

In [None]:
output,wts = get_mha_model.predict([seqtok[topdelta],vars[topdelta]], verbose=True)

In [None]:
tdf.loc[topdelta].index.get_loc(13333)

In [None]:
# output,wts = get_mha_model.predict([trainseq[0:2],trainvars[0:2]])

In [None]:
from matplotlib.colors import LogNorm

In [None]:
k = 22

norm = {}
for m in range(4):
    vmin = np.min(np.min(wts[k,m]))
    vmax=np.max(np.max(wts[k,m]))
    norm[m] = LogNorm(vmin=vmin, vmax=vmax)
    print(vmin, vmax)

fig,ax = plt.subplots(2,2)
fig = fig.set_size_inches(18,18)
a = [ax[0,0], ax[1,0], ax[0,1], ax[1,1]]
for m in range(4):
    try:
        sns.heatmap(wts[k,m], ax = a[m], norm = norm[m])
    except:
        pass
# sns.heatmap(wts[k,0]/np.median(np.median(wts[k,0])), ax=ax[0,0])
# sns.heatmap(wts[k,1]/np.median(np.median(wts[k,1])), ax=ax[1,0])
# sns.heatmap(wts[k,2]/np.median(np.median(wts[k,2])), ax=ax[0,1])
# sns.heatmap(wts[k,3]/np.median(np.median(wts[k,3])), ax=ax[1,0])

In [None]:
# k = 0

# norm = {}
# for m in range(8):
#     norm[m] = LogNorm(vmin=np.min(np.min(wts[k,m])), vmax=np.max(np.max(wts[k,m])))

# fig,ax = plt.subplots(4,2)
# fig = fig.set_size_inches(18,18)
# a = [ax[0,0], ax[1,0], ax[2,0], ax[3,0],
#      ax[0,1], ax[1,1], ax[2,1], ax[3,1]]
# for m in range(8):
#     sns.heatmap(wts[k,m], ax = a[m], norm = norm[m])

# # sns.heatmap(wts[k,0]/np.median(np.median(wts[k,0])), ax=ax[0,0])
# # sns.heatmap(wts[k,1]/np.median(np.median(wts[k,1])), ax=ax[1,0])
# # sns.heatmap(wts[k,2]/np.median(np.median(wts[k,2])), ax=ax[2,0])
# # sns.heatmap(wts[k,3]/np.median(np.median(wts[k,3])), ax=ax[3,0])
# # sns.heatmap(wts[k,4]/np.median(np.median(wts[k,4])), ax=ax[0,1])
# # sns.heatmap(wts[k,5]/np.median(np.median(wts[k,5])), ax=ax[1,1])
# # sns.heatmap(wts[k,6]/np.median(np.median(wts[k,6])), ax=ax[2,1])
# # sns.heatmap(wts[k,7]/np.median(np.median(wts[k,7])), ax=ax[3,1])

In [None]:
np.where(np.median(att, axis=0) > 1/1273)[0]+1

In [None]:
mha = {}
for h in range(NHEADS):
    mha[h] = np.sum(wts[0,h], axis=0)
mhadf = pd.DataFrame.from_dict({f'Head {h}':mha[h] for h in range(NHEADS)})
topdf = pd.DataFrame.from_dict({f'Head {h}':mhadf.sort_values(by=f'Head {h}', ascending=False).head(50).index + 1 for h in range(NHEADS)})
display(topdf)

In [None]:
# mha = {}
# for h in range(8):
#     mha[h] = np.sum(wts[0,h], axis=0)
# mhadf = pd.DataFrame.from_dict({f'Head {h}':mha[h] for h in range(8)})
# topdf = pd.DataFrame.from_dict({f'Head {h}':mhadf.sort_values(by=f'Head {h}', ascending=False).head(20).index for h in range(8)})
# display(topdf)

##Plot Embeddings

Sample code to plot TSNE of embeddings

In [None]:
from sklearn.manifold import TSNE
t = TSNE(n_components=2).fit_transform(emb, perplexity=50, verbose=True)

In [None]:
Lineages = ['A', 'B.1', 'B.1.1.7', 'B.1.351', 'P.1', 'B.1.617.2', 'AY.4', 'BA.1', 'BA.2']

In [None]:
fig, ax = plt.subplots()
fig.set_size_inches(12,8)

for lin in Lineages:
    selind = tdf[tdf.Lineage == lin].index
    ax.scatter(t[selind,0], t[selind,1], marker='x', label=lin)

ax.legend(bbox_to_anchor=(1.0, 0.9),  framealpha=1.0)
plt.show()

In [None]:
# fig, ax = plt.subplots()
# fig.set_size_inches(12,8)

# genus = ['Alphacoronavirus', 'Betacoronavirus', 'Gammacoronavirus', 'Deltacoronavirus']
# for genus_ind in np.unique(ytest):
#     selind = np.where(ytest==genus_ind)[0]
#     ax.scatter(t[selind,0], t[selind,1], marker='x', label=genus[genus_ind])
# ax.scatter(t[len(ytest):,0], t[len(ytest):,1], marker='x', label='Omicron')

# ax.legend(bbox_to_anchor=(1.0, 0.9),  framealpha=1.0)
# plt.show()