In [None]:
# upload the VCF file you wish to analyze
import os
from google.colab import files
uploaded = files.upload()
inputVCFFileName=list(uploaded.keys())[0]
outBase=inputVCFFileName.replace('.vcf','')
os.rename(inputVCFFileName, 'input.vcf')

In [None]:
# Is this VCF file in GRCh37 or GRCh38 coordinates? 
genome='GRCh37'
# uncomment the below line if using GRCh38
#genome='GRCh38'

In [None]:
!pip install wget pandas numpy matplotlib scikit-learn scipy biopython
!pip install tensorflow==2.7
!pip install tf-models-official==2.7
!pip install transformers
# download the resources
!python -m wget https://zuchnerlab.s3.amazonaws.com/VariantPathogenicity/Maverick_resources.tar.gz
!tar -zxvf Maverick_resources.tar.gz
!rm Maverick_resources.tar.gz

In [None]:
%%bash
# process variants with annovar
echo "Starting Step 1: Get coding changes with Annovar"
dos2unix input.vcf
grep -v '^#' input.vcf | cut -f 1,2,4,5 > input_locations.txt
annovar/convert2annovar.pl -format vcf4 input.vcf > input.avinput
if [[ {genome} == 'GRCh37' ]]; then
    annovar/annotate_variation.pl -dbtype wgEncodeGencodeBasicV33lift37 -buildver hg19 --exonicsplicing input.avinput annovar/humandb/
else
    annovar/annotate_variation.pl -dbtype wgEncodeGencodeBasicV33 -buildver hg38 --exonicsplicing input.avinput annovar/humandb/
fi
# if there are no scorable variants, end early
SCORABLEVARIANTS=$(cat input.avinput.exonic_variant_function | wc -l || true)
if [[ ${SCORABLEVARIANTS} -eq 0 ]]; then exit 0; fi
if [[ {genome} == 'GRCh37' ]]; then
    annovar/coding_change.pl input.avinput.exonic_variant_function annovar/humandb/hg19_wgEncodeGencodeBasicV33lift37.txt annovar/humandb/hg19_wgEncodeGencodeBasicV33lift37Mrna.fa --includesnp --onlyAltering --alltranscript > input.coding_changes.txt
else
    annovar/coding_change.pl input.avinput.exonic_variant_function annovar/humandb/hg38_wgEncodeGencodeBasicV33.txt annovar/humandb/hg38_wgEncodeGencodeBasicV33Mrna.fa --includesnp --onlyAltering --alltranscript > input.coding_changes.txt
fi

In [None]:
import os
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import official.nlp
import official.nlp.keras_nlp.layers
from transformers import TFT5EncoderModel, T5Tokenizer,T5Config
import pandas
pandas.options.mode.chained_assignment = None
from sklearn.preprocessing import QuantileTransformer
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report
from sklearn.utils import resample
import scipy
from scipy.stats import rankdata
from datetime import datetime

In [None]:
import pandas
import numpy as np
from Bio import SeqIO
approvedTranscripts=pandas.read_csv('gencodeBasicFullLengthTranscriptsConversionTable.txt',sep='\t',low_memory=False)
if genome=='GRCh38':
    approvedTranscripts=pandas.read_csv('gencodeBasicFullLengthTranscriptsConversionTable_GRCh38.txt',sep='\t',low_memory=False)

canonical=pandas.read_csv('gnomad211_constraint_canonical_simple.txt',sep='\t',low_memory=False)
# remove the gnomad canonical transcripts that are not approvedTranscripts
canonical=canonical.loc[canonical['transcript'].isin(approvedTranscripts['transcriptIDShort'].values),:].reset_index(drop=True)

GTEx=pandas.read_csv('GTEx.V7.tx_medians.021820.tsv',sep='\t',low_memory=False)
# remove the non-approvedTranscripts from the expression data
GTEx=GTEx.loc[GTEx['transcript_id'].isin(approvedTranscripts['transcriptIDShort'].values),:].reset_index(drop=True)
# add a overall expression column
GTEx['overallAvg']=GTEx.iloc[:,2:55].mean()

sequences={}
for record in SeqIO.parse("gencode.v33lift37.pc_translations.fa","fasta"):
    transcriptID=record.id.split('|')[1]
    if transcriptID in approvedTranscripts['transcriptID'].values:
        sequences[transcriptID]=record.seq

def groomAnnovarOutput(base,sequences=sequences,approvedTranscripts=approvedTranscripts,canonical=canonical,GTEx=GTEx,genome=genome):
    sample=''
    if genome=='GRCh37':
        sample=pandas.read_csv(base + ".avinput.exonic_variant_function",sep='\t',low_memory=False,header=None,
                            names=['line','varType','location','hg19_chr','hg19_pos(1-based)','end','ref','alt','genotype','qual','depth'])
    else:
        sample=pandas.read_csv(base + ".avinput.exonic_variant_function",sep='\t',low_memory=False,header=None,
						    names=['line','varType','location','hg38_chr','hg38_pos(1-based)','end','ref','alt','genotype','qual','depth'])
    # convert the position, ref, and alt alleles to long form
    longForm=pandas.read_csv(base + "_locations.txt",sep='\t',low_memory=False,header=None,names=['chrom','pos_long','ref_long','alt_long'])
    sample['lineNum']=sample.loc[:,'line'].str[4:].astype(int)-1
    sample=sample.merge(longForm,how='inner',left_on='lineNum',right_on=longForm.index)
    if genome=='GRCh37':
        sample=sample.loc[:,['line','varType','location','hg19_chr','pos_long','end','ref_long','alt_long','genotype','qual','depth']].rename(columns={'pos_long':'hg19_pos(1-based)','ref_long':'ref','alt_long':'alt'}).reset_index(drop=True)
    else:
    	sample=sample.loc[:,['line','varType','location','hg38_chr','pos_long','end','ref_long','alt_long','genotype','qual','depth']].rename(columns={'pos_long':'hg38_pos(1-based)','ref_long':'ref','alt_long':'alt'}).reset_index(drop=True)
    # add new columns with placeholders to be filled in
    sample['WildtypeSeq']=""
    sample['AltSeq']=""
    sample['ChangePos']=-1
    sample['TranscriptID']=""
    sample['TranscriptIDShort']=sample['location'].str.split(':',expand=True)[1].str[:15]
    sample['geneName']=sample['location'].str.split(':',expand=True)[0]
    sample['geneID']=""
    sample['geneIDShort']=""


    for i in range(len(sample)):
        if i % 1000 == 0:
            print(str(i) + ' rows completed')
        numTranscripts=len(sample.loc[i,'location'].split(','))
        numCanonical=0
        canonicals=[]
        transcripts=[]
        transcriptLengths=[]
        canonicalTranscript=""
        correctedGeneName=""
        for j in range(numTranscripts-1):
            if sample.loc[i,'location'].split(',')[j].split(':')[1][:15] in canonical['transcript'].values:
                numCanonical=numCanonical+1
                canonicals.append(sample.loc[i,'location'].split(',')[j].split(':')[1][:15])
            if sample.loc[i,'location'].split(',')[j].split(':')[1] in approvedTranscripts['transcriptID'].values:  
                transcripts.append(sample.loc[i,'location'].split(',')[j].split(':')[1][:15])
                transcriptLengths.append(len(sequences[sample.loc[i,'location'].split(',')[j].split(':')[1]]))

        if len(transcripts)>0:
            if numCanonical==1:
                transcriptID=canonicals[0]
                sample.loc[i,'TranscriptIDShort']=transcriptID
                sample.loc[i,'TranscriptID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'transcriptID'].values[0]
                sample.loc[i,'geneName']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneName'].values[0]
                sample.loc[i,'geneID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneID'].values[0]
                sample.loc[i,'geneIDShort']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneIDShort'].values[0]
            elif numCanonical==0:
                if len(transcripts)==1:
                    transcriptID=transcripts[0]
                    sample.loc[i,'TranscriptIDShort']=transcriptID
                    sample.loc[i,'TranscriptID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'transcriptID'].values[0]
                    sample.loc[i,'geneName']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneName'].values[0]
                    sample.loc[i,'geneID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneID'].values[0]
                    sample.loc[i,'geneIDShort']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneIDShort'].values[0]
                else:
                    if len(GTEx.loc[GTEx['transcript_id'].isin(transcripts),:])>0:
                        # pick the transcript with the highest expression
                        transcriptID=GTEx.loc[GTEx['transcript_id'].isin(transcripts),:].sort_values(by=['overallAvg'],ascending=False).reset_index(drop=True).iloc[0,0]
                        sample.loc[i,'TranscriptIDShort']=transcriptID
                        sample.loc[i,'TranscriptID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'transcriptID'].values[0]
                        sample.loc[i,'geneName']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneName'].values[0]
                        sample.loc[i,'geneID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneID'].values[0]
                        sample.loc[i,'geneIDShort']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneIDShort'].values[0]
                    else:
                        # if none of the transcripts have measured expression and none of them are canonical, then pick the one with the longest amino acid sequence
                        # if multiple tie for longest, this picks the one we saw first
                        j=transcriptLengths.index(max(transcriptLengths))
                        transcriptID=transcripts[j]
                        sample.loc[i,'TranscriptIDShort']=transcriptID
                        sample.loc[i,'TranscriptID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'transcriptID'].values[0]
                        sample.loc[i,'geneName']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneName'].values[0]
                        sample.loc[i,'geneID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneID'].values[0]
                        sample.loc[i,'geneIDShort']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneIDShort'].values[0]
            elif numCanonical>1:
                if len(GTEx.loc[GTEx['transcript_id'].isin(canonicals),:])>0:
                    # pick the canonical transcript with the highest expression
                    transcriptID=GTEx.loc[GTEx['transcript_id'].isin(canonicals),:].sort_values(by=['overallAvg'],ascending=False).reset_index(drop=True).iloc[0,0]
                    sample.loc[i,'TranscriptIDShort']=transcriptID
                    sample.loc[i,'TranscriptID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'transcriptID'].values[0]
                    sample.loc[i,'geneName']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneName'].values[0]
                    sample.loc[i,'geneID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneID'].values[0]
                    sample.loc[i,'geneIDShort']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneIDShort'].values[0]
                else:
                    # if none of the canonical transcripts have measured expression, then pick the one with the longest amino acid sequence
                    # if multiple tie for longest, this picks the one we saw first
                    j=transcriptLengths.index(max(transcriptLengths))
                    transcriptID=transcripts[j]
                    sample.loc[i,'TranscriptIDShort']=transcriptID
                    sample.loc[i,'TranscriptID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'transcriptID'].values[0]
                    sample.loc[i,'geneName']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneName'].values[0]
                    sample.loc[i,'geneID']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneID'].values[0]
                    sample.loc[i,'geneIDShort']=approvedTranscripts.loc[approvedTranscripts['transcriptIDShort']==transcriptID,'geneIDShort'].values[0]

    for record in SeqIO.parse(base + ".coding_changes.txt", "fasta"):
        lineNum=record.id
        # only use the transcript that we selected above 
        if sample.loc[sample['line']==lineNum,'TranscriptID'].values==record.description.split(' ')[1]:
            if 'WILDTYPE' in record.description:
                if record.seq.__str__()[:-1] == sequences[record.description.split(' ')[1]]:
                    sample.loc[sample['line']==lineNum,'WildtypeSeq']=record.seq.__str__()
                    sample.loc[sample['line']==lineNum,'TranscriptID']=record.description.split(' ')[1]
            else:
                sample.loc[sample['line']==lineNum,'AltSeq']=record.seq.__str__()
                if 'startloss' in record.description:
                    sample.loc[sample['line']==lineNum,'ChangePos']=1
                elif 'silent' in record.description:
                    sample.loc[sample['line']==lineNum,'ChangePos']=-1
                else:
                    sample.loc[sample['line']==lineNum,'ChangePos']=record.description.split(' ')[7].split('-')[0]
    sample2=sample.loc[~((sample['WildtypeSeq']=="") | (sample['AltSeq']=="") | (sample['ChangePos']==-1)),:]
    sample2.to_csv(base + '.groomed.txt',sep='\t',index=False)
    return


In [None]:
groomAnnovarOutput('input')

In [None]:
import pandas
constraint=pandas.read_csv('gnomad211_constraint_canonical_simple.txt',sep='\t',low_memory=False)

gnomadAF=''
CCR=''
pext=''
gerp=''
if genome=='GRCh37':
    gnomadAF=pandas.read_csv('gnomad211_exomes_AFs.txt',sep='\t',low_memory=True,dtype={'hg19_chr':str,'hg19_pos(1-based)':np.int16,'ref':str,'alt':str,'AF':np.float16,'nhomalt':np.int16,'controls_AF':np.float16,'controls_nhomalt':np.int16})
    gnomadAF.loc[gnomadAF['hg19_chr']=='X','hg19_chr']=23
    gnomadAF.loc[gnomadAF['hg19_chr']=='Y','hg19_chr']=24
    gnomadAF.loc[gnomadAF['hg19_chr']=='MT','hg19_chr']=25
    gnomadAF['hg19_chr']=gnomadAF['hg19_chr'].astype(int)
else:
    gnomadAF=pandas.read_csv('gnomad211_GRCh38_exomes_AFs.txt',sep='\t',low_memory=False,dtype={'hg38_chr':str,'hg38_pos(1-based)':np.int32,'ref':str,'alt':str,'AF':np.float32,'nhomalt':np.int32,'controls_AF':np.float32,'controls_nhomalt':np.int32})
    gnomadAF=gnomadAF.loc[(~(gnomadAF['hg38_chr'].str.contains('_'))),:].reset_index(drop=True)
    gnomadAF.loc[gnomadAF['hg38_chr']=='X','hg38_chr']=23
    gnomadAF.loc[gnomadAF['hg38_chr']=='Y','hg38_chr']=24
    gnomadAF.loc[gnomadAF['hg38_chr']=='MT','hg38_chr']=25
    gnomadAF['hg38_chr']=gnomadAF['hg38_chr'].astype(int)


if genome=='GRCh37':
    CCR=pandas.read_csv('ccrs.enumerated.txt',sep='\t',low_memory=True,dtype={'chrom':str,'pos':np.int16,'ccr_pct':np.float16})
    CCR.loc[CCR['chrom']=='X','chrom']=23
    CCR['chrom']=CCR.loc[:,'chrom'].astype(int)
    CCR=CCR.sort_values(by=['chrom','pos','ccr_pct'],ascending=[True,True,False]).drop_duplicates(subset=['chrom','pos'],keep='first').reset_index(drop=True)
else:
    CCR=pandas.read_csv('ccrs_GRCh38.enumerated.txt',sep='\t',low_memory=False,dtype={'chrom':str,'pos':np.int32,'ccr_pct':np.float32})
    CCR=CCR.loc[(~(CCR['chrom'].str.contains('_'))),:]
    CCR.loc[CCR['chrom']=='X','chrom']=23
    CCR['chrom']=CCR.loc[:,'chrom'].astype(int)
    CCR=CCR.sort_values(by=['chrom','pos','ccr_pct'],ascending=[True,True,False]).drop_duplicates(subset=['chrom','pos'],keep='first').reset_index(drop=True)

if genome=='GRCh37':
    pext=pandas.read_csv('gnomAD_pext_values.txt',sep='\t',low_memory=True,dtype={'chr':str,'pos':np.int16,'pext':np.float16})
    pext.loc[pext['chr']=='X','chr']=23
    pext.loc[pext['chr']=='Y','chr']=24
    pext.loc[pext['chr']=='MT','chr']=25
    pext['chr']=pext.loc[:,'chr'].astype(int)
    pext=pext.sort_values(by=['chr','pos','pext'],ascending=[True,True,False]).drop_duplicates(subset=['chr','pos'],keep='first').reset_index(drop=True)
else:
	pext=pandas.read_csv('gnomAD_pext_values_GRCh38.txt',sep='\t',low_memory=False,dtype={'chr':str,'pos':np.int32,'pext':np.float32})
	pext=pext.loc[(~(pext['chr'].str.contains('_'))),:]
	pext.loc[pext['chr']=='X','chr']=23
	pext.loc[pext['chr']=='Y','chr']=24
	pext.loc[pext['chr']=='MT','chr']=25
	pext['chr']=pext.loc[:,'chr'].astype(int)
	pext=pext.sort_values(by=['chr','pos','pext'],ascending=[True,True,False]).drop_duplicates(subset=['chr','pos'],keep='first').reset_index(drop=True)

if genome=='GRCh37':
    gerp=pandas.read_csv('gerpOnExons.txt',sep='\t',low_memory=True,header=None,names=['chr','pos','gerp'],dtype={'chr':str,'pos':np.int16,'gerp':np.float16})
    gerp.loc[gerp['chr']=='X','chr']=23
    gerp.loc[gerp['chr']=='Y','chr']=24
    gerp.loc[gerp['chr']=='MT','chr']=25
    gerp['chr']=gerp['chr'].astype(int)
    gerp=gerp.sort_values(by=['chr','pos','gerp'],ascending=[True,True,False]).drop_duplicates(subset=['chr','pos'],keep='first').reset_index(drop=True)
else:
	gerp=pandas.read_csv('gerpOnExons_GRCh38.txt',sep='\t',low_memory=False,header=None,names=['chr','pos','gerp'],dtype={'chr':str,'pos':np.int32,'gerp':np.float32})
	gerp=gerp.loc[(~(gerp['chr'].str.contains('_'))),:]
	gerp.loc[gerp['chr']=='X','chr']=23
	gerp.loc[gerp['chr']=='Y','chr']=24
	gerp.loc[gerp['chr']=='MT','chr']=25
	gerp['chr']=gerp['chr'].astype(int)
	gerp=gerp.sort_values(by=['chr','pos','gerp'],ascending=[True,True,False]).drop_duplicates(subset=['chr','pos'],keep='first').reset_index(drop=True)

GDI=pandas.read_csv('GDI.groomed.txt',sep='\t',low_memory=False)
RVIS=pandas.read_csv('RVIS.groomed.txt',sep='\t',low_memory=False)

def annotateVariants(base,constraint=constraint,gnomadAF=gnomadAF,CCR=CCR,pext=pext,gerp=gerp,GDI=GDI,RVIS=RVIS,genome=genome):
    import pandas
    import numpy as np
    sample=pandas.read_csv(base + '.groomed.txt',sep='\t',low_memory=False)
    if genome=='GRCh37':
        sample.loc[sample['hg19_chr']=='X','hg19_chr']=23
        sample.loc[sample['hg19_chr']=='Y','hg19_chr']=24
        sample.loc[sample['hg19_chr']=='MT','hg19_chr']=25
        sample['hg19_chr']=sample['hg19_chr'].astype(int)
    else:
        sample.loc[sample['hg38_chr']=='X','hg38_chr']=23
        sample.loc[sample['hg38_chr']=='Y','hg38_chr']=24
        sample.loc[sample['hg38_chr']=='MT','hg38_chr']=25
        sample.loc[(~(sample['hg38_chr'].str.contains('_'))),:].reset_index(drop=True)
        sample['hg38_chr']=sample['hg38_chr'].astype(int)

    # merge on the allele frequency data
    if genome=='GRCh37':
        sample=sample.merge(gnomadAF,how='left',on=['hg19_chr','hg19_pos(1-based)','ref','alt'])
    else:
        sample=sample.merge(gnomadAF,how='left',on=['hg38_chr','hg38_pos(1-based)','ref','alt'])

    # merge on the constraint data (try transcript ID merge first)
    sampleTranscript=sample.merge(constraint,how='inner',left_on=['TranscriptIDShort'],right_on=['transcript'])
    notMatched=sample.loc[~(sample['TranscriptIDShort'].isin(sampleTranscript['TranscriptIDShort'])),:]
    constraint=pandas.read_csv('gnomad211_constraint_simple_geneLevel.txt',sep='\t',low_memory=False)
    sampleGeneID=notMatched.merge(constraint,how='inner',left_on=['geneIDShort'],right_on=['gene_id'])
    notMatched2=notMatched.loc[~(notMatched['geneIDShort'].isin(sampleGeneID['geneIDShort'])),:]
    sampleGeneName=notMatched2.merge(constraint,how='left',left_on=['geneName'],right_on=['gene'])
    # stack them all back together
    sample2=pandas.concat([sampleTranscript,sampleGeneID,sampleGeneName],axis=0,ignore_index=True)
    if genome=='GRCh37':
        sample2.loc[sample2['hg19_chr']=='X','hg19_chr']=23
        sample2.loc[sample2['hg19_chr']=='Y','hg19_chr']=24
        sample2.loc[sample2['hg19_chr']=='MT','hg19_chr']=25
        sample2['hg19_chr']=sample2['hg19_chr'].astype(int)
    else:
        sample2.loc[sample2['hg38_chr']=='X','hg38_chr']=23
        sample2.loc[sample2['hg38_chr']=='Y','hg38_chr']=24
        sample2.loc[sample2['hg38_chr']=='MT','hg38_chr']=25
        sample2['hg38_chr']=sample2['hg38_chr'].astype(int)

    # merge on the CCR data
    sample2['CCR']=np.nan
    sampleSNVs=sample2.loc[sample2['varType'].isin(['nonsynonymous SNV','synonymous SNV','stopgain','stoploss']),['hg19_chr','hg19_pos(1-based)']]
    sampleIndels=sample2.loc[sample2['varType'].isin(['frameshift insertion','frameshift deletion','frameshift substitution',
                                                    'nonframeshift insertion','nonframeshift deletion','nonframeshift substitution']),['hg19_chr','hg19_pos(1-based)','ref']]
    sampleIndels['length']=sampleIndels['ref'].str.len()
    sampleIndels['CCR']=np.nan
    if genome=='GRCh37':
        sampleSNVs2=sampleSNVs.merge(CCR,how='left',left_on=['hg19_chr','hg19_pos(1-based)'],right_on=['chrom','pos']).set_index(sampleSNVs.index)
    else:
        sampleSNVs2=sampleSNVs.merge(CCR,how='left',left_on=['hg38_chr','hg38_pos(1-based)'],right_on=['chrom','pos']).set_index(sampleSNVs.index)
    for i in range(len(sampleIndels)):
        if i%100==0:
            print(str(i) + ' rows complete of ' + str(len(sampleIndels)))
        startPos=sampleIndels.iloc[i,1]+1
        endPos=startPos+sampleIndels.iloc[i,3]
        sampleIndels.iloc[i,4]=CCR.loc[((CCR['chrom']==sampleIndels.iloc[i,0]) & (CCR['pos'].isin(range(startPos,endPos)))),'ccr_pct'].max()
    sample2.loc[sampleSNVs2.index,'CCR']=sampleSNVs2.loc[:,'ccr_pct'].values
    sample2.loc[sampleIndels.index,'CCR']=sampleIndels.loc[:,'CCR'].values

    # merge on the pext data
    sample2['pext']=np.nan
    sampleIndels['pext']=np.nan
    if genome=='GRCh37':
        sampleSNVs2=sampleSNVs.merge(pext,how='left',left_on=['hg19_chr','hg19_pos(1-based)'],right_on=['chr','pos']).set_index(sampleSNVs.index)
    else:
        sampleSNVs2=sampleSNVs.merge(pext,how='left',left_on=['hg38_chr','hg38_pos(1-based)'],right_on=['chr','pos']).set_index(sampleSNVs.index)
    for i in range(len(sampleIndels)):
        if i%100==0:
            print(str(i) + ' rows complete of ' + str(len(sampleIndels)))
        startPos=sampleIndels.iloc[i,1]+1
        endPos=startPos+sampleIndels.iloc[i,3]
        sampleIndels.iloc[i,5]=pext.loc[((pext['chr']==sampleIndels.iloc[i,0]) & (pext['pos'].isin(range(startPos,endPos)))),'pext'].max()
    sample2.loc[sampleSNVs2.index,'pext']=sampleSNVs2.loc[:,'pext'].values
    sample2.loc[sampleIndels.index,'pext']=sampleIndels.loc[:,'pext'].values

    # merge on the GERP data
    sample2['gerp']=np.nan
    sampleIndels['gerp']=np.nan
    if genome=='GRCh37':
        sampleSNVs2=sampleSNVs.merge(gerp,how='left',left_on=['hg19_chr','hg19_pos(1-based)'],right_on=['chr','pos']).set_index(sampleSNVs.index)
    else:
        sampleSNVs2=sampleSNVs.merge(gerp,how='left',left_on=['hg38_chr','hg38_pos(1-based)'],right_on=['chr','pos']).set_index(sampleSNVs.index)
    for i in range(len(sampleIndels)):
        if i%100==0:
            print(str(i) + ' rows complete of ' + str(len(sampleIndels)))
        startPos=sampleIndels.iloc[i,1]+1
        endPos=startPos+sampleIndels.iloc[i,3]
        sampleIndels.iloc[i,6]=gerp.loc[((gerp['chr']==sampleIndels.iloc[i,0]) & (gerp['pos'].isin(range(startPos,endPos)))),'gerp'].max()
    sample2.loc[sampleSNVs2.index,'gerp']=sampleSNVs2.loc[:,'gerp'].values
    sample2.loc[sampleIndels.index,'gerp']=sampleIndels.loc[:,'gerp'].values

    if genome=='GRCh37':
        sample2=sample2.drop_duplicates(subset=['hg19_chr','hg19_pos(1-based)','ref','alt'],keep='first')
        sample2=sample2.drop(columns=['line','location','end','qual','depth','gene','transcript', 'canonical','gene_id'])
        sample2=sample2.sort_values(by=['hg19_chr','hg19_pos(1-based)','ref','alt']).reset_index(drop=True)
    else:
        sample2=sample2.drop_duplicates(subset=['hg38_chr','hg38_pos(1-based)','ref','alt'],keep='first')
        sample2=sample2.drop(columns=['line','location','end','qual','depth','gene','transcript', 'canonical','gene_id'])
        sample2=sample2.sort_values(by=['hg38_chr','hg38_pos(1-based)','ref','alt']).reset_index(drop=True)


    # merge on GDI data
    sample2=sample2.merge(GDI,how='left',on='geneName')
    # merge on RVIS data
    sample2=sample2.merge(RVIS,how='left',on='geneName')
    
    if genome=='GRCh37':
        sample2=sample2.sort_values(by=['hg19_chr','hg19_pos(1-based)','ref','alt']).reset_index(drop=True)
        sample2=sample2.drop_duplicates(subset=['hg19_chr','hg19_pos(1-based)','ref','alt'],keep='first').reset_index(drop=True)
    else:
        sample2=sample2.sort_values(by=['hg38_chr','hg38_pos(1-based)','ref','alt']).reset_index(drop=True)
        sample2=sample2.drop_duplicates(subset=['hg38_chr','hg38_pos(1-based)','ref','alt'],keep='first').reset_index(drop=True)

    sample2.to_csv(base + '.annotated.txt',sep='\t',index=False)
    return



In [None]:
annotateVariants('input')

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, list_IDs, labels, dataFrameIn, tokenizer, T5Model, batch_size=32, padding=100, n_channels_emb=1024, n_channels_mm=51, n_classes=3, shuffle=True):
        self.padding = padding
        self.dim = self.padding + self.padding + 1
        self.batch_size = batch_size
        self.labels = labels
        self.list_IDs = list_IDs
        self.n_channels_emb = n_channels_emb
        self.n_channels_mm = n_channels_mm
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.dataFrameIn=dataFrameIn
        self.tokenizer = tokenizer
        self.T5Model = T5Model
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        if (len(self.list_IDs) % self.batch_size) == 0:
            return int(np.floor(len(self.list_IDs) / self.batch_size))
        else:
            return int(np.ceil(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        if (((len(self.list_IDs) % self.batch_size) != 0) & (((index+1)*self.batch_size)>len(self.list_IDs))):
            indexes = self.indexes[index*self.batch_size:]
        else:
            indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples'
        # Initialization
        thisBatchSize=len(list_IDs_temp)
        altEmbeddings=np.zeros((thisBatchSize, self.dim, self.n_channels_emb))
        mm_alt=np.zeros((thisBatchSize, self.dim, self.n_channels_mm))
        mm_orig=np.zeros((thisBatchSize, self.dim, self.n_channels_mm))
        nonSeq=np.zeros((thisBatchSize, 12))
        y = np.empty((thisBatchSize), dtype=int)
        AMINO_ACIDS = {'A':0,'C':1,'D':2,'E':3,'F':4,'G':5,'H':6,'I':7,'K':8,'L':9,'M':10,'N':11,'P':12,'Q':13,'R':14,'S':15,'T':16,'V':17,'W':18,'Y':19} 
        T5AltSeqTokens=[]

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # process Alt seq with T5 model to create embeddings
            transcriptID=self.dataFrameIn.loc[ID,'TranscriptID']
            changePos=self.dataFrameIn.loc[ID,'ChangePos']-1
            if changePos<0:
                changePos=0
            AltSeq=self.dataFrameIn.loc[ID,'AltSeq']
            if AltSeq[-1]!="*":
                AltSeq=AltSeq + "*"
            seqLenAlt=len(AltSeq)-1
            startPos=0
            if changePos>self.padding:
                if (changePos+self.padding)<seqLenAlt:
                    startPos=changePos-self.padding
                elif seqLenAlt>=self.dim:
                    startPos=seqLenAlt-self.dim
            endPos=changePos+self.padding
            if changePos<self.padding:
                if self.dim<seqLenAlt:
                    endPos=self.dim
                else:
                    endPos=seqLenAlt
            elif (changePos+self.padding)>=seqLenAlt:
                endPos=seqLenAlt
            T5AltSeqTokens.append(" ".join(AltSeq[startPos:endPos]))
            # prep the WT seq too
            WTSeq=self.dataFrameIn.loc[ID,'WildtypeSeq']
            if WTSeq[-1]!="*":
                WTSeq=WTSeq + "*"
            seqLen=len(WTSeq)-1
            startPos=0
            if changePos>self.padding:
                if (changePos+self.padding)<seqLen:
                    startPos=int(changePos-self.padding)
                elif seqLen>=self.dim:
                    startPos=int(seqLen-self.dim)
            endPos=int(changePos+self.padding)
            if changePos<self.padding:
                if self.dim<seqLen:
                    endPos=int(self.dim)
                else:
                    endPos=int(seqLen)
            elif (changePos+self.padding)>=seqLen:
                endPos=int(seqLen)
            T5AltSeqTokens.append(" ".join(WTSeq[startPos:endPos]))


            # collect MMSeqs WT info
            tmp=np.load("HHMFiles/" + transcriptID + "_MMSeqsProfile.npz",allow_pickle=True)
            tmp=tmp['arr_0']
            seqLen=tmp.shape[0]
            startPos=changePos-self.padding
            endPos=changePos+self.padding + 1
            startOffset=0
            endOffset=self.dim
            if changePos<self.padding:
                startPos=0
                startOffset=self.padding-changePos
            if (changePos + self.padding) >= seqLen:
                endPos=seqLen
                endOffset=self.padding + seqLen - changePos
            mm_orig[i,startOffset:endOffset,:] = tmp[startPos:endPos,:]

            # collect MMSeqs Alt info
            # change the amino acid at 'ChangePos' and any after that if needed
            varType=self.dataFrameIn.loc[ID,'varType']
            WTSeq=self.dataFrameIn.loc[ID,'WildtypeSeq']
            if varType=='nonsynonymous SNV':
                if changePos==0:
                    # then this transcript is ablated
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    altEncoded[:,0:20]=0
                    altEncoded[:,50]=0
                else:
                    # change the single amino acid
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    altEncoded[changePos,AMINO_ACIDS[WTSeq[changePos]]]=0
                    altEncoded[changePos,AMINO_ACIDS[AltSeq[changePos]]]=1
            elif varType=='stopgain':
                if changePos==0:
                    # then this transcript is ablated
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    altEncoded[:,0:20]=0
                    altEncoded[:,50]=0
                elif seqLenAlt>seqLen:
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLen):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    for j in range(seqLen,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLen:,50]=1
                else:
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    altEncoded[changePos:,0:20]=0
                    altEncoded[changePos:,50]=0
            elif varType=='stoploss':
                altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                altEncoded[:seqLen,:]=tmp
                for j in range(seqLen,seqLenAlt):
                    altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                altEncoded[seqLen:,50]=1
            elif varType=='synonymous SNV':
                # no change
                altEncoded=tmp
            elif ((varType=='frameshift deletion') | (varType=='frameshift insertion') | (varType=='frameshift substitution')):
                if seqLen<seqLenAlt:
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLen):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    for j in range(seqLen,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLen:,50]=1
                elif seqLen>seqLenAlt:
                    for j in range(changePos,seqLenAlt):
                        tmp[j,AMINO_ACIDS[WTSeq[j]]]=0
                        tmp[j,AMINO_ACIDS[AltSeq[j]]]=1
                    for j in range(seqLenAlt,seqLen):
                        tmp[j,AMINO_ACIDS[WTSeq[j]]]=0
                    altEncoded=tmp
                elif seqLen==seqLenAlt:
                    for j in range(changePos,seqLen):
                        tmp[j,AMINO_ACIDS[WTSeq[j]]]=0
                        tmp[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded=tmp
                else:
                    print('Error: seqLen comparisons did not work')
                    exit()
            elif varType=='nonframeshift deletion':
                # how many amino acids deleted?
                altNucLen=0
                if self.dataFrameIn.loc[ID,'alt']!='-':
                    altNucLen=len(self.dataFrameIn.loc[ID,'alt'])
                refNucLen=len(self.dataFrameIn.loc[ID,'ref'])
                numAADel=int((refNucLen-altNucLen)/3)
                if (seqLen-numAADel)==seqLenAlt:
                    # non-frameshift deletion
                    #altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    #altEncoded[:changePos,:]=tmp[:changePos,:]
                    #altEncoded[changePos:,:]=tmp[(changePos+numAADel):,:]
                    for j in range(changePos,(changePos+numAADel)):
                        tmp[j,:20]=0
                    altEncoded=tmp
                elif seqLen>=seqLenAlt:
                    # early truncation
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    #for j in range(seqLenAlt,seqLen):
                    #    altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                    altEncoded[seqLenAlt:,0:20]=0
                    altEncoded[seqLenAlt:,50]=0
                elif seqLen<seqLenAlt:
                    # deletion causes stop-loss
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLen):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    #for j in range(seqLen,seqLenAlt):
                    #    altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLen:,0:20]=0
                    altEncoded[seqLen:,50]=0
                else:
                    print('Error: seqLen comparisons did not work for nonframeshift deletion')
                    exit()
            elif varType=='nonframeshift insertion':
                # how many amino acids inserted?
                refNucLen=0
                if self.dataFrameIn.loc[ID,'ref']!='-':
                    altNucLen=len(self.dataFrameIn.loc[ID,'ref'])
                altNucLen=len(self.dataFrameIn.loc[ID,'alt'])
                numAAIns=int((altNucLen-refNucLen)/3)
                if (seqLen+numAAIns)==seqLenAlt:
                    # non-frameshift insertion
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:changePos,:]=tmp[:changePos,:]
                    altEncoded[(changePos+numAAIns):,:]=tmp[changePos:,:]
                    for j in range(numAAIns):
                        altEncoded[(changePos+j),AMINO_ACIDS[AltSeq[(changePos+j)]]]=1
                    altEncoded[:,50]=1
                elif seqLen<seqLenAlt:
                    # stop loss
                    altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLen):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    for j in range(seqLen,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLen:,50]=1
                elif seqLen>=seqLenAlt:
                    # stop gain
                    altEncoded=np.zeros((seqLen,self.n_channels_mm))
                    altEncoded[:seqLen,:]=tmp
                    for j in range(changePos,seqLenAlt):
                        altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                    altEncoded[seqLenAlt:,0:20]=0
                    altEncoded[seqLenAlt:,50]=0
                else:
                    print('Error: seqLen comparisons did not work for nonframeshift insertion')
                    exit()
            elif varType=='nonframeshift substitution':
                # is this an insertion or a deletion?
                # note that there will not be any '-' symbols in these ref or alt fields because it is a substitution
                refNucLen=len(self.dataFrameIn.loc[ID,'ref'])
                altNucLen=len(self.dataFrameIn.loc[ID,'alt'])
                if refNucLen>altNucLen:
                    # deletion
                    # does this cause an early truncation or non-frameshift deletion?
                    if seqLen>seqLenAlt: 
                        numAADel=int((refNucLen-altNucLen)/3)
                        if (seqLen-numAADel)==seqLenAlt:
                            # non-frameshift deletion
                            #altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                            #altEncoded[:changePos,:]=tmp[:changePos,:]
                            #altEncoded[changePos:,:]=tmp[(changePos+numAADel):,:]
                            for j in range(changePos,(changePos+numAADel)):
                                tmp[j,:20]=0
                            altEncoded=tmp
                        else:
                            # early truncation
                            altEncoded=np.zeros((seqLen,self.n_channels_mm))
                            altEncoded[:seqLen,:]=tmp
                            for j in range(changePos,seqLenAlt):
                                altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                                altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                            #for j in range(seqLenAlt,seqLen):
                            #    altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[seqLenAlt:,0:20]=0
                            altEncoded[seqLenAlt:,50]=0
                    # does this cause a stop loss?
                    elif seqLen<seqLenAlt:
                        altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLen):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        for j in range(seqLen,seqLenAlt):
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        altEncoded[seqLen:,50]=1
                    else: # not sure how this would happen
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLen):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                elif refNucLen<altNucLen:
                    # insertion
                    # does this cause a stop loss or non-frameshift insertion?
                    if seqLen<seqLenAlt: 
                        numAAIns=int((altNucLen-refNucLen)/3)
                        if (seqLen+numAAIns)==seqLenAlt:
                            # non-frameshift insertion
                            altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                            altEncoded[:changePos,:]=tmp[:changePos,:]
                            altEncoded[(changePos+numAAIns):,:]=tmp[changePos:,:]
                            for j in range(numAAIns):
                                altEncoded[(changePos+j),AMINO_ACIDS[AltSeq[(changePos+j)]]]=1
                            altEncoded[:,50]=1
                        else:
                            # stop loss
                            altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                            altEncoded[:seqLen,:]=tmp
                            for j in range(changePos,seqLen):
                                altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                                altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                            for j in range(seqLen,seqLenAlt):
                                altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                            altEncoded[:,50]=1
                    # does this cause an early truncation?
                    elif seqLen>seqLenAlt: 
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLenAlt):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        altEncoded[seqLenAlt:,0:20]=0
                        #for j in range(seqLenAlt,seqLen):
                        #    altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[seqLenAlt:,50]=0
                    else: # not sure how this would happen
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLen):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                elif refNucLen==altNucLen:
                    if seqLen==seqLenAlt:
                        # synonymous or nonsynonymous change
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        altEncoded[changePos,AMINO_ACIDS[WTSeq[changePos]]]=0
                        altEncoded[changePos,AMINO_ACIDS[AltSeq[changePos]]]=1
                    elif seqLen>seqLenAlt:
                        # early truncation
                        altEncoded=np.zeros((seqLen,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLenAlt):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        altEncoded[seqLenAlt:,0:20]=0
                        #for j in range(seqLenAlt,seqLen):
                        #    altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                        altEncoded[seqLenAlt:,50]=0
                    elif seqLen<seqLenAlt:
                        # stop loss
                        altEncoded=np.zeros((seqLenAlt,self.n_channels_mm))
                        altEncoded[:seqLen,:]=tmp
                        for j in range(changePos,seqLen):
                            altEncoded[j,AMINO_ACIDS[WTSeq[j]]]=0
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        for j in range(seqLen,seqLenAlt):
                            altEncoded[j,AMINO_ACIDS[AltSeq[j]]]=1
                        altEncoded[seqLen:,50]=1
                    else:
                        print('non-frameshift substitution comparisons failed')
                        exit()
                else:
                    print('Error: nonframeshift substitution nucleotide length comparison did not work')
                    exit()
            startPos=changePos-self.padding
            endPos=changePos+self.padding+1
            startOffset=0
            endOffset=self.dim
            if changePos<self.padding:
                startPos=0
                startOffset=self.padding-changePos
            if (changePos + self.padding) >= seqLenAlt:
                endPos=seqLenAlt
                endOffset=self.padding + seqLenAlt - changePos
            # exception to deal with start loss SNVs that create new frameshifted products longer than the original protein (when original was shorter than padding length)
            if ((changePos==0) & (self.padding>=seqLen) & (seqLen<seqLenAlt) & (varType=='nonsynonymous SNV')):
                endPos=seqLen
                endOffset=self.padding + seqLen - changePos
            elif ((changePos==0) & (varType=='stopgain')): # related exception for stopgains at position 0
                if (seqLen+self.padding)<=self.dim:
                    endPos=seqLen
                    endOffset=self.padding + seqLen - changePos
                else:
                    endPos=self.padding+1
                    endOffset=self.dim
            mm_alt[i,startOffset:endOffset,:] = altEncoded[startPos:endPos,:]


            # non-seq info
            nonSeq[i] = self.dataFrameIn.loc[ID,['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']]
            
            # Store class
            y[i] = self.labels[ID]

        # process the altSeq and wtSeq through the T5 tokenizer (for consistency with pre-computed data used for training)
        allTokens=self.tokenizer.batch_encode_plus(T5AltSeqTokens,add_special_tokens=True, padding=True, return_tensors="tf")
        input_ids=allTokens['input_ids'][::2]
        attnMask=allTokens['attention_mask'][::2]
        # but only process the altSeq through the T5 model
        #embeddings=self.T5Model(input_ids[::2],decoder_input_ids=input_ids[::2])
        embeddings=self.T5Model(input_ids,attention_mask=attnMask)
        allEmbeddings=np.asarray(embeddings.last_hidden_state)
        for i in range(thisBatchSize):
            seq_len = (np.asarray(attnMask)[i] == 1).sum()
            seq_emb = allEmbeddings[i][1:seq_len-1]
            altEmbeddings[i,:seq_emb.shape[0],:]=seq_emb


        X={'alt_cons':mm_alt,'alt_emb':altEmbeddings,'non_seq_info':nonSeq,'mm_orig_seq':mm_orig}

        return X, keras.utils.to_categorical(y, num_classes=self.n_classes)


In [None]:
def MaverickArchitecture1(input_shape=201,classes=3,classifier_activation='softmax',**kwargs):
    input0 = tf.keras.layers.Input(shape=(input_shape,51),name='mm_orig_seq')
    input1 = tf.keras.layers.Input(shape=(input_shape,51),name='mm_alt_seq')
    input2 = tf.keras.layers.Input(shape=12,name='non_seq_info')

    # project input to an embedding size that is easier to work with
    x_orig = tf.keras.layers.experimental.EinsumDense('...x,xy->...y',output_shape=64,bias_axes='y')(input0)
    x_alt = tf.keras.layers.experimental.EinsumDense('...x,xy->...y',output_shape=64,bias_axes='y')(input1)

    posEnc_wt = official.nlp.keras_nlp.layers.PositionEmbedding(max_length=input_shape)(x_orig)
    x_orig = tf.keras.layers.Masking()(x_orig)
    x_orig = tf.keras.layers.Add()([x_orig,posEnc_wt])
    x_orig = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12,dtype=tf.float32)(x_orig)
    x_orig = tf.keras.layers.Dropout(0.05)(x_orig)

    posEnc_alt = official.nlp.keras_nlp.layers.PositionEmbedding(max_length=input_shape)(x_alt)
    x_alt = tf.keras.layers.Masking()(x_alt)
    x_alt = tf.keras.layers.Add()([x_alt,posEnc_alt])
    x_alt = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12,dtype=tf.float32)(x_alt)
    x_alt = tf.keras.layers.Dropout(0.05)(x_alt)

    transformer1 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer2 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer3 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer4 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer5 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer6 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    
    x_orig = transformer1(x_orig)
    x_orig = transformer2(x_orig)
    x_orig = transformer3(x_orig)
    x_orig = transformer4(x_orig)
    x_orig = transformer5(x_orig)
    x_orig = transformer6(x_orig)
    
    x_alt = transformer1(x_alt)
    x_alt = transformer2(x_alt)
    x_alt = transformer3(x_alt)
    x_alt = transformer4(x_alt)
    x_alt = transformer5(x_alt)
    x_alt = transformer6(x_alt)

    first_token_tensor_orig = (tf.keras.layers.Lambda(lambda a: tf.squeeze(a[:, 100:101, :], axis=1))(x_orig))
    x_orig = tf.keras.layers.Dense(units=64,activation='tanh')(first_token_tensor_orig)
    x_orig = tf.keras.layers.Dropout(0.05)(x_orig)

    first_token_tensor_alt = (tf.keras.layers.Lambda(lambda a: tf.squeeze(a[:, 100:101, :], axis=1))(x_alt))
    x_alt = tf.keras.layers.Dense(units=64,activation='tanh')(first_token_tensor_alt)
    x_alt = tf.keras.layers.Dropout(0.05)(x_alt)

    diff = tf.keras.layers.Subtract()([x_alt,x_orig])
    combined = tf.keras.layers.concatenate([x_alt,diff])

    input2Dense1 = tf.keras.layers.Dense(64,activation='relu')(input2)
    input2Dense1 = tf.keras.layers.Dropout(0.05)(input2Dense1)
    x = tf.keras.layers.concatenate([combined,input2Dense1])
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(512,activation='relu')(x)
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(64,activation='relu')(x)
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(classes, activation=classifier_activation,name='output')(x)
    model = tf.keras.Model(inputs=[input0,input1,input2],outputs=x)

    optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3, momentum=0.85)
    model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])
    return model


In [None]:
def MaverickArchitecture2(input_shape=201,embeddingSize=1024,mmSize=51,classes=3,classifier_activation='softmax',**kwargs):
    input0 = tf.keras.layers.Input(shape=(input_shape,mmSize),name='alt_cons')
    input1 = tf.keras.layers.Input(shape=(input_shape,embeddingSize),name='alt_emb')
    input2 = tf.keras.layers.Input(shape=12,name='non_seq_info')

    # project input to an embedding size that is easier to work with
    alt_cons = tf.keras.layers.experimental.EinsumDense('...x,xy->...y',output_shape=64,bias_axes='y')(input0)

    posEnc_alt = official.nlp.keras_nlp.layers.PositionEmbedding(max_length=input_shape)(alt_cons)
    alt_cons = tf.keras.layers.Masking()(alt_cons)
    alt_cons = tf.keras.layers.Add()([alt_cons,posEnc_alt])
    alt_cons = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-12,dtype=tf.float32)(alt_cons)
    alt_cons = tf.keras.layers.Dropout(0.05)(alt_cons)

    transformer1 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer2 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer3 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer4 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer5 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    transformer6 = official.nlp.keras_nlp.layers.TransformerEncoderBlock(16,256,tf.keras.activations.relu,output_dropout=0.1,attention_dropout=0.1)
    
    alt_cons = transformer1(alt_cons)
    alt_cons = transformer2(alt_cons)
    alt_cons = transformer3(alt_cons)
    alt_cons = transformer4(alt_cons)
    alt_cons = transformer5(alt_cons)
    alt_cons = transformer6(alt_cons)

    first_token_tensor_alt = (tf.keras.layers.Lambda(lambda a: tf.squeeze(a[:, 100:101, :], axis=1))(alt_cons))
    alt_cons = tf.keras.layers.Dense(units=64,activation='tanh')(first_token_tensor_alt)
    alt_cons = tf.keras.layers.Dropout(0.05)(alt_cons)

    sharedLSTM1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32, return_sequences=False, dropout=0.5))

    alt_emb=sharedLSTM1(input1)
    alt_emb=tf.keras.layers.Dropout(0.2)(alt_emb)

    structured = tf.keras.layers.Dense(64,activation='relu')(input2)
    structured = tf.keras.layers.Dropout(0.05)(structured)
    x = tf.keras.layers.concatenate([alt_cons,alt_emb,structured])
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(512,activation='relu')(x)
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(64,activation='relu')(x)
    x = tf.keras.layers.Dropout(0.05)(x)
    x = tf.keras.layers.Dense(classes, activation=classifier_activation,name='output')(x)
    model = tf.keras.Model(inputs=[input0,input1,input2],outputs=x)

    optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3, momentum=0.85)
    model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])
    return model


In [None]:
batchSize=32
inFile='input.annotated.txt'
tokenizer = T5Tokenizer.from_pretrained("prot_t5_xl_bfd", do_lower_case=False,local_files_only=True)
T5Model = TFT5EncoderModel.from_pretrained("prot_t5_xl_bfd",local_files_only=True)

# calculate medians and quantiles from training data
trainingData=pandas.read_csv('trainingSet_v4.groomed_withExtraInfo2_corrected.txt',sep='\t',low_memory=False)
trainingData.loc[trainingData['GDI']>2000,'GDI']=2000
trainingDataNonSeqInfo=trainingData[['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']].copy(deep=True)
trainingDataNonSeqInfo.loc[trainingDataNonSeqInfo['controls_AF'].isna(),'controls_AF']=0
trainingDataNonSeqInfo.loc[trainingDataNonSeqInfo['controls_nhomalt'].isna(),'controls_nhomalt']=0
trainingDataNonSeqInfo.loc[trainingDataNonSeqInfo['controls_nhomalt']>10,'controls_nhomalt']=10
trainingDataNonSeqMedians=trainingDataNonSeqInfo.median()
trainingDataNonSeqInfo=trainingDataNonSeqInfo.fillna(trainingDataNonSeqMedians)
trainingDataNonSeqInfo=np.asarray(trainingDataNonSeqInfo.to_numpy()).astype(np.float32)

# scale columns by QT
qt = QuantileTransformer(subsample=1e6, random_state=0, output_distribution='uniform')
qt=qt.fit(trainingDataNonSeqInfo)
trainingDataNonSeqInfo=qt.transform(trainingDataNonSeqInfo)

# load the models
model1 = MaverickArchitecture1()
model1.load_weights('weights_TransformerNetDiff_model_1')
model2 = MaverickArchitecture1()
model2.load_weights('weights_TransformerNetDiff_classWeights_1_2_7_model_1')
model3 = MaverickArchitecture1()
model3.load_weights('weights_TransformerNetDiff_classWeights_1_2_7_model_2')
model4 = MaverickArchitecture2()
model4.load_weights('weights_T5_withBiLSTM_TransformerNet_altOnly_model_4')
model5 = MaverickArchitecture2()
model5.load_weights('weights_T5_withBiLSTM_TransformerNet_altOnly_model_5')
model6 = MaverickArchitecture2()
model6.load_weights('weights_T5_withBiLSTM_TransformerNet_altOnly_model_7')
model7 = MaverickArchitecture2()
model7.load_weights('weights_T5_withBiLSTM_TransformerNet_altOnly_classWeights_1_2_3_model_1')
model8 = MaverickArchitecture2()
model8.load_weights('weights_T5_withBiLSTM_TransformerNet_altOnly_classWeights_1_2_7_model_1')

# prep the data
inputData=pandas.read_csv(inFile,sep='\t',low_memory=False)
inputData.loc[inputData['GDI']>2000,'GDI']=2000
inputDataNonSeqInfo=inputData[['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']].copy(deep=True)
inputDataNonSeqInfo.loc[inputDataNonSeqInfo['controls_AF'].isna(),'controls_AF']=0
inputDataNonSeqInfo.loc[inputDataNonSeqInfo['controls_nhomalt'].isna(),'controls_nhomalt']=0
inputDataNonSeqInfo.loc[inputDataNonSeqInfo['controls_nhomalt']>10,'controls_nhomalt']=10
inputDataNonSeqInfo=inputDataNonSeqInfo.fillna(trainingDataNonSeqMedians)
inputDataNonSeqInfo=np.asarray(inputDataNonSeqInfo.to_numpy()).astype(np.float32)
# scale columns by QT
inputDataNonSeqInfo=qt.transform(inputDataNonSeqInfo)
inputData.loc[:,['controls_AF','controls_nhomalt','pLI','pNull','pRec','mis_z','lof_z','CCR','GDI','pext','RVIS_ExAC_0.05','gerp']]=inputDataNonSeqInfo

data_generator=DataGenerator(np.arange(len(inputData)),np.ones(len(inputData)),dataFrameIn=inputData,tokenizer=tokenizer,T5Model=T5Model,batch_size=batchSize,shuffle=False)

# set up the output collectors
models1Pred=''
if genome=='GRCh37':
    model1Preds=inputData.loc[:,['hg19_chr','hg19_pos(1-based)','ref','alt']]
else:
    model1Preds=inputData.loc[:,['hg38_chr','hg38_pos(1-based)','ref','alt']]
model1Preds['BenignScore']=0
model1Preds['DomScore']=0
model1Preds['RecScore']=0
model2Preds=model1Preds.copy(deep=True)
model3Preds=model1Preds.copy(deep=True)
model4Preds=model1Preds.copy(deep=True)
model5Preds=model1Preds.copy(deep=True)
model6Preds=model1Preds.copy(deep=True)
model7Preds=model1Preds.copy(deep=True)
model8Preds=model1Preds.copy(deep=True)


# score the test data
for batchNum in range(int(np.ceil(len(inputData)/batchSize))):
    print('Starting batch number ' + str(batchNum), flush=True)
    thisBatch=data_generator[batchNum]
    thisBatchT5={'alt_cons':thisBatch[0]['alt_cons'],'alt_emb':thisBatch[0]['alt_emb'],'non_seq_info':thisBatch[0]['non_seq_info']}
    thisBatchDiff={'mm_orig_seq':thisBatch[0]['mm_orig_seq'],'mm_alt_seq':thisBatch[0]['alt_cons'],'non_seq_info':thisBatch[0]['non_seq_info']}
    model1Preds.loc[(batchNum*batchSize):((batchNum*batchSize)+len(thisBatch[1])-1),['BenignScore','DomScore','RecScore']]=model1.predict(thisBatchDiff,verbose=0)
    model2Preds.loc[(batchNum*batchSize):((batchNum*batchSize)+len(thisBatch[1])-1),['BenignScore','DomScore','RecScore']]=model2.predict(thisBatchDiff,verbose=0)
    model3Preds.loc[(batchNum*batchSize):((batchNum*batchSize)+len(thisBatch[1])-1),['BenignScore','DomScore','RecScore']]=model3.predict(thisBatchDiff,verbose=0)
    model4Preds.loc[(batchNum*batchSize):((batchNum*batchSize)+len(thisBatch[1])-1),['BenignScore','DomScore','RecScore']]=model4.predict(thisBatchT5,verbose=0)
    model5Preds.loc[(batchNum*batchSize):((batchNum*batchSize)+len(thisBatch[1])-1),['BenignScore','DomScore','RecScore']]=model5.predict(thisBatchT5,verbose=0)
    model6Preds.loc[(batchNum*batchSize):((batchNum*batchSize)+len(thisBatch[1])-1),['BenignScore','DomScore','RecScore']]=model6.predict(thisBatchT5,verbose=0)
    model7Preds.loc[(batchNum*batchSize):((batchNum*batchSize)+len(thisBatch[1])-1),['BenignScore','DomScore','RecScore']]=model7.predict(thisBatchT5,verbose=0)
    model8Preds.loc[(batchNum*batchSize):((batchNum*batchSize)+len(thisBatch[1])-1),['BenignScore','DomScore','RecScore']]=model8.predict(thisBatchT5,verbose=0)

# save individual model results to file
model1Preds.to_csv(outBase + '_model1Predictions.txt',sep='\t',index=False)
model2Preds.to_csv(outBase + '_model2Predictions.txt',sep='\t',index=False)
model3Preds.to_csv(outBase + '_model3Predictions.txt',sep='\t',index=False)
model4Preds.to_csv(outBase + '_model4Predictions.txt',sep='\t',index=False)
model5Preds.to_csv(outBase + '_model5Predictions.txt',sep='\t',index=False)
model6Preds.to_csv(outBase + '_model6Predictions.txt',sep='\t',index=False)
model7Preds.to_csv(outBase + '_model7Predictions.txt',sep='\t',index=False)
model8Preds.to_csv(outBase + '_model8Predictions.txt',sep='\t',index=False)

# ensemble results together
y_pred1=model1Preds.loc[:,['BenignScore','DomScore','RecScore']].to_numpy()
y_pred2=model2Preds.loc[:,['BenignScore','DomScore','RecScore']].to_numpy()
y_pred3=model3Preds.loc[:,['BenignScore','DomScore','RecScore']].to_numpy()
y_pred4=model4Preds.loc[:,['BenignScore','DomScore','RecScore']].to_numpy()
y_pred5=model5Preds.loc[:,['BenignScore','DomScore','RecScore']].to_numpy()
y_pred6=model6Preds.loc[:,['BenignScore','DomScore','RecScore']].to_numpy()
y_pred7=model7Preds.loc[:,['BenignScore','DomScore','RecScore']].to_numpy()
y_pred8=model8Preds.loc[:,['BenignScore','DomScore','RecScore']].to_numpy()
y_pred=np.mean([y_pred1,y_pred2,y_pred3,y_pred4,y_pred5,y_pred6,y_pred7,y_pred8],axis=0)
model1Preds.loc[:,['BenignScore','DomScore','RecScore']]=y_pred
model1Preds.to_csv(outBase + '_ensemblePredictions.txt',sep='\t',index=False)


In [None]:
sample=model1Preds.copy(deep=True)
if genome=='GRCh37':
    sample['varID']=sample.loc[:,['hg19_chr','hg19_pos(1-based)','ref','alt']].apply(lambda row: '_'.join(row.values.astype(str)),axis=1)
else:
    sample['varID']=sample.loc[:,['hg38_chr','hg38_pos(1-based)','ref','alt']].apply(lambda row: '_'.join(row.values.astype(str)),axis=1)
sample['TotalScore']=sample.loc[:,'Maverick_DomScore']
sample.loc[sample['genotype']=='hom','TotalScore']=sample.loc[sample['genotype']=='hom','Maverick_RecScore']
compHetPairs=pandas.DataFrame(columns=['site1_varID','site2_varID','geneID','geneName','site1_RecScore','site2_RecScore','TotalScore'])
hets=sample.loc[sample['genotype']=='het',:].reset_index(drop=True)
hetCallsOnSharedGenes=hets.loc[hets.duplicated(subset='geneID',keep=False),:]
genesWithMultipleHets=hets.loc[hets.duplicated(subset='geneID',keep='first'),'geneID'].unique()
for i in range(0,len(genesWithMultipleHets)):
    thisGeneGroup=hetCallsOnSharedGenes.loc[hetCallsOnSharedGenes['geneID']==genesWithMultipleHets[i],:]
    for j in range(0,len(thisGeneGroup)-1):
        for k in range(j+1,len(thisGeneGroup)):
            harmonicMean=scipy.stats.hmean([thisGeneGroup.loc[thisGeneGroup.index[j],'Maverick_RecScore'],thisGeneGroup.loc[thisGeneGroup.index[k],'Maverick_RecScore']])
            compHetPairs=pandas.concat([compHetPairs,pandas.DataFrame({'site1_varID':thisGeneGroup.loc[thisGeneGroup.index[j],'varID'],
                'site2_varID':thisGeneGroup.loc[thisGeneGroup.index[k],'varID'],
                'geneID':thisGeneGroup.loc[thisGeneGroup.index[k],'geneID'],
                'geneName':thisGeneGroup.loc[thisGeneGroup.index[k],'geneName'],
                'site1_RecScore':thisGeneGroup.loc[thisGeneGroup.index[j],'Maverick_RecScore'],
                'site2_RecScore':thisGeneGroup.loc[thisGeneGroup.index[k],'Maverick_RecScore'],
                'TotalScore':harmonicMean},index=[0])],ignore_index=True)
thisSampleFinalScores=pandas.concat([sample,compHetPairs],axis=0,sort=False,ignore_index=True)
thisSampleFinalScores=thisSampleFinalScores.sort_values(by="TotalScore",ascending=False)
# tidy up
if genome=='GRCh37':
    thisSampleFinalScores=thisSampleFinalScores.loc[:,['varType','hg19_chr','hg19_pos(1-based)','ref','alt','genotype','geneName','geneID','Maverick_BenignScore','Maverick_DomScore','Maverick_RecScore','varID','site1_varID','site2_varID','site1_RecScore','site2_RecScore','TotalScore']]
else:
    thisSampleFinalScores=thisSampleFinalScores.loc[:,['varType','hg38_chr','hg38_pos(1-based)','ref','alt','genotype','geneName','geneID','Maverick_BenignScore','Maverick_DomScore','Maverick_RecScore','varID','site1_varID','site2_varID','site1_RecScore','site2_RecScore','TotalScore']]
thisSampleFinalScores.to_csv(outBase + '.finalScores.txt',sep='\t',header=True,index=False)
