In [1]:
import h5py
import sys
import os
import math
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append('/home/ztang/multitask_RNA/data_generation')
import utils 
import numpy as np
datalen = '230'

In [2]:
file = h5py.File("/home/ztang/multitask_RNA/data/CAGI/"+datalen+"/CAGI_onehot.h5", "r")
alt = file['alt']
ref = file['ref']

## Nucleotide Transformer zero shot test

cosine similarity between embeddings with different allele

In [3]:
import nucleotide_transformer
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model
from tqdm import tqdm
model_name = '2B5_1000G'

if '2B5' in model_name:
    print('2B5_model')
    embed_layer = 32
else:
    print('500M model')
    embed_layer = 24

2B5_model


In [4]:
max_len = math.ceil(len(alt[0])/6)+2
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(embed_layer,),
    attention_maps_to_save=(),
    max_positions=max_len,
)
forward_fn = hk.transform(forward_fn)

In [5]:
# CLS = 3
# PAD = 2
random_key = jax.random.PRNGKey(0)
N, L, A = alt.shape
batch_size = 200
cagi_llr=[]
for i in tqdm(range(0,N,batch_size)):
    b_size = batch_size
    if i + batch_size > N:
        b_size = N-i
    onehot = np.concatenate((ref[i:i+b_size],alt[i:i+b_size]))
    seq = utils.onehot_to_seq(onehot)
    token_out = tokenizer.batch_tokenize(seq)
    token_id = [b[1] for b in token_out]
    seq_pair = jnp.asarray(token_id,dtype=jnp.int32)
    outs = forward_fn.apply(parameters, random_key, seq_pair)
    for a in range(b_size):
        ref_out = outs['embeddings_'+str(embed_layer)][a, 1:, :]
        alt_out = outs['embeddings_'+str(embed_layer)][a+b_size, 1:, :]
        # mean_alt = jnp.sum(alt_out, axis=1) 
        # mean_ref = jnp.sum(ref_out, axis=1) 
        # cagi_llr.append((mean_ref * mean_alt).sum()/(jnp.linalg.norm(mean_ref)*jnp.linalg.norm(mean_alt)))
        cagi_llr.append((ref_out * alt_out).sum()/(jnp.linalg.norm(ref_out)*jnp.linalg.norm(alt_out)))

100%|██████████| 93/93 [03:09<00:00,  2.03s/it]


In [6]:
output = h5py.File('/home/ztang/multitask_RNA/data/CAGI/zero_shot/'+datalen+'/cagi_'+model_name+'.h5', 'w')
output.create_dataset('llr', data=np.array(cagi_llr))
output.close()

## Visulization

In [10]:
import h5py
import pandas as pd
import numpy as np
from operator import itemgetter
import seaborn as sns
import matplotlib.pyplot as plt 
import scipy.stats as stats
model_name = '500M_1000G'
input_length = '230'
cagi_df = pd.read_csv('../../data/CAGI/'+input_length+'/final_cagi_metadata.csv',
                      index_col=0).reset_index()
exp_list = cagi_df['8'].unique()
target = cagi_df['6'].values.tolist()
plot_figure=False

In [11]:
cagi_result = h5py.File('/home/ztang/multitask_RNA/data/CAGI/zero_shot/'+input_length+'/cagi_'+model_name+'.h5', 'r')
cagi_llr = cagi_result['llr'][()]

In [12]:
perf = []
sanity_check = 0
for exp in cagi_df['8'].unique():
    sub_df = cagi_df[cagi_df['8'] == exp]
    sanity_check += len(sub_df)
    exp_target = np.array(target)[sub_df.index.to_list()]
    exp_pred = np.squeeze(cagi_llr)[sub_df.index.to_list()]
    exp_target = np.absolute(exp_target)
    exp_pred = exp_pred
    print(exp)
    perf.append(stats.pearsonr(exp_pred,exp_target)[0])
    print(stats.pearsonr(exp_pred,exp_target)[0])

ZFAND3
-0.0032445862816593684
HBG1
0.012777808439967672
MSMB
-0.003442195277169142
LDLR
0.019409858282437277
MYCrs6983267
-0.010262244619401712
SORT1
0.0181933016993205
PKLR
0.0538593500221866
F9
0.043339159835804185
TERT-HEK293T
-0.007765231618602955
IRF6
0.007473439548794598
HBB
0.018034245722466917
TERT-GBM
0.014877822453121033
IRF4
0.02228980885475156
GP1BB
0.018055375665325012
HNF4A
-0.029653319901501378


In [10]:
#0.01021866551181953
np.mean(perf)

0.01159617285505605

In [3]:
performance_dict = {}
idx = {'A':0,'C':1,'G':2,'T':3}
for exp in exp_list:
    exp_df = cagi_df[cagi_df['8']==exp]
    idx_df = exp_df[['0','1','2']].drop_duplicates().sort_values(by=['1'])
    exp_len = len(exp_df['1'].unique())
    effect_size = np.zeros((4,exp_len))
    predict_size = np.zeros((4,exp_len))
    
    for pos in range(0,exp_len):
        row = idx_df.iloc[pos]
        loci_df = exp_df[(exp_df['0']==row['0'])&(exp_df['1']==row['1'])&(exp_df['2']==row['2'])]
        loci_idx = loci_df.index
        ref_allele = loci_df['3'].drop_duplicates().values
        alt_allele = loci_df['4'].values.tolist()
        diff = loci_df['6'].values

        effect_size[itemgetter(*alt_allele)(idx),pos] =np.absolute(diff)
        #predict_size [itemgetter(*alt_allele)(idx),pos] =llr[loci_idx]
        predict_size [itemgetter(*alt_allele)(idx),pos] =1
    r_value = stats.pearsonr(effect_size.flatten(),predict_size.flatten())
    performance_dict[exp]= r_value[0]
    if plot_figure:
        fig,ax = plt.subplots(2,1,figsize = (20,7))
        #fig2=plt.figure(figsize = (20,2))
        fig1 = sns.heatmap(effect_size,cmap = 'vlag',
                            center = 0,
                            #annot = exp_annot,fmt = '',
                        cbar_kws = dict(use_gridspec=False,location="bottom"),
                        ax = ax[0]);
        ax[0].tick_params(left=True, bottom=False);
        #ax.set_yticklabels(['A','C','G','T'],size = 1);
        ax[0].set_yticklabels([])
        ax[0].set_xticklabels([]);
        ax[0].set_title(exp+' ground truth')
        #plt.tight_layout()

        #fig3=plt.figure(figsize = (20,2))
        fig2 = sns.heatmap(predict_size,cmap = 'vlag',
                            center = 0,
                            #annot = pred_annot,fmt = '',
                            cbar_kws = dict(use_gridspec=False,location="bottom"),
                            ax = ax[1]);
        ax[1].tick_params(left=True, bottom=False);
        #ax.set_yticklabels(['A','C','G','T'],size = 1);
        ax[1].set_yticklabels([])
        ax[1].set_xticklabels([])
        ax[1].set_title(exp+' mutagenesis')

In [4]:
performance_dict

{'ZFAND3': 0.37958882258110616,
 'HBG1': 0.3648951366719255,
 'MSMB': 0.3939608936620037,
 'LDLR': 0.2867844760793666,
 'MYCrs6983267': 0.3637777785331868,
 'SORT1': 0.3457491448088308,
 'PKLR': 0.3535938034657003,
 'F9': 0.40615243352173375,
 'TERT-HEK293T': 0.40826779723051343,
 'IRF6': 0.296117921725497,
 'HBB': 0.4576562503479471,
 'TERT-GBM': 0.3852918724452555,
 'IRF4': 0.31538845447575753,
 'GP1BB': 0.3248181791422176,
 'HNF4A': 0.33819304988396887}

In [5]:
np.array(list(performance_dict.values())).mean()

0.36128208964245495

## Compare how different are the models?

In [5]:
import h5py
import numpy as np
import pandas as pd
from operator import itemgetter
from scipy import stats

a_out = h5py.File('/home/ztang/multitask_RNA/data/CAGI/zero_shot/cagi_2B5_multi_species.h5','r')
b_out = h5py.File('/home/ztang/multitask_RNA/data/CAGI/zero_shot/cagi_500M_human_ref.h5','r')

In [6]:
cagi_df = pd.read_csv('../../data/CAGI/final_cagi_metadata.csv',
                      index_col=0).reset_index()
exp_list = cagi_df['8'].unique()
plot_figure=False
a_out = a_out['llr'][()]
b_out = b_out['llr'][()]

In [6]:
performance_dict = {}
alt_performance_dict = {}
idx = {'A':0,'C':1,'G':2,'T':3}
for exp in exp_list:
    exp_df = cagi_df[cagi_df['8']==exp]
    idx_df = exp_df[['0','1','2']].drop_duplicates().sort_values(by=['1'])
    exp_len = len(exp_df['1'].unique())
    effect_size = np.zeros((4,exp_len))
    predict_size = np.zeros((4,exp_len))
    alt_predict_size = np.zeros((4,exp_len))
    
    for pos in range(0,exp_len):
        row = idx_df.iloc[pos]
        loci_df = exp_df[(exp_df['0']==row['0'])&(exp_df['1']==row['1'])&(exp_df['2']==row['2'])]
        loci_idx = loci_df.index
        ref_allele = loci_df['3'].drop_duplicates().values
        alt_allele = loci_df['4'].values.tolist()
        diff = loci_df['6'].values

        effect_size[itemgetter(*alt_allele)(idx),pos] =np.absolute(diff)
        predict_size [itemgetter(*alt_allele)(idx),pos] =a_out[loci_idx]
        alt_predict_size [itemgetter(*alt_allele)(idx),pos] =b_out[loci_idx]

    r_value = stats.pearsonr(effect_size.flatten(),predict_size.flatten())
    alt_r_value = stats.pearsonr(effect_size.flatten(),alt_predict_size.flatten())
    performance_dict[exp]= r_value[0]
    alt_performance_dict[exp]= alt_r_value[0]

In [13]:
predict_size

array([[0.99991727, 0.        , 0.99992973, ..., 0.        , 0.9999243 ,
        0.99911875],
       [0.        , 0.99928528, 0.99990278, ..., 0.99988973, 0.99988055,
        0.        ],
       [0.        , 0.99911839, 0.        , ..., 0.99970317, 0.        ,
        0.99935865],
       [0.        , 0.9999314 , 0.99993867, ..., 0.9999159 , 0.99987179,
        0.9997077 ]])

In [14]:
stats.pearsonr(effect_size.flatten(),predict_size.flatten())

PearsonRResult(statistic=0.2867984790769957, pvalue=1.3880290840794402e-25)

In [12]:
alt_predict_size

array([[0.99982399, 0.        , 0.99986553, ..., 0.        , 0.99976522,
        0.9996832 ],
       [0.        , 0.99982858, 0.99980557, ..., 0.99964333, 0.99968797,
        0.        ],
       [0.        , 0.99987394, 0.        , ..., 0.99973935, 0.        ,
        0.99973369],
       [0.        , 0.99981195, 0.99978489, ..., 0.99969804, 0.99955088,
        0.99979353]])

In [15]:
stats.pearsonr(effect_size.flatten(),alt_predict_size.flatten())

PearsonRResult(statistic=0.286789759083672, pvalue=1.3928977404317928e-25)