In [1]:
import h5py
import sys
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append('/home/ztang/multitask_RNA/data_generation')
import utils
import numpy as np

In [2]:
file = h5py.File("/home/ztang/multitask_RNA/data/CAGI/230/CAGI_230_onehot.h5", "r")
alt = file['alt']
ref = file['ref']

## Nucleotide Transformer zero shot test

cosine similarity between embeddings with different allele

In [3]:
import nucleotide_transformer
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model
from tqdm import tqdm
model_name = '2B5_multi_species'

if '2B5' in model_name:
    print('2B5_model')
    embed_layer = 32
else:
    print('500M model')
    embed_layer = 24

2B5_model


In [4]:
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    mixed_precision=False,
    embeddings_layers_to_save=(embed_layer,),
    attention_maps_to_save=(),
    max_positions=513,
)
forward_fn = hk.transform(forward_fn)

In [7]:
random_key = jax.random.PRNGKey(0)
N, L, A = alt.shape
mut_i = int(L/2-1)
batch_size = 128
cagi_llr=[]
for i in tqdm(range(0,N,batch_size)):
    b_size = 128
    if i + batch_size > N:
        b_size = N-batch_size
    onehot = np.concatenate((ref[i:i+b_size],alt[i:i+b_size]))
    seq = utils.onehot_to_seq(onehot)
    token_out = tokenizer.batch_tokenize(seq)
    token_id = [b[1] for b in token_out]
    seq_pair = jnp.asarray(token_id,dtype=jnp.int32)
    outs = forward_fn.apply(parameters, random_key, seq_pair)
    for i in range(b_size):
        ref_out = outs['embeddings_'+str(embed_layer)][i]
        alt_out = outs['embeddings_'+str(embed_layer)][i+b_size]
        cagi_llr.append((ref_out * alt_out).sum()/(jnp.linalg.norm(ref_out)*jnp.linalg.norm(alt_out)))

100%|██████████| 145/145 [29:30<00:00, 12.21s/it]


In [8]:
output = h5py.File('/home/ztang/multitask_RNA/data/CAGI/zero_shot/230/cagi_'+model_name+'.h5', 'w')
output.create_dataset('llr', data=np.array(cagi_llr))
output.close()

## Visulization

In [1]:
import h5py
import pandas as pd
import numpy as np
from operator import itemgetter
import seaborn as sns
import matplotlib.pyplot as plt 
import scipy.stats as stats
cagi_result = h5py.File('/home/ztang/multitask_RNA/data/CAGI/zero_shot/230/cagi_2B5_multi_species.h5', 'r')
llr = cagi_result['llr'][()]

In [2]:
cagi_df = pd.read_csv('../../data/CAGI/230/final_cagi_metadata.csv',
                      index_col=0).reset_index()
exp_list = cagi_df['8'].unique()
plot_figure=False

In [3]:
performance_dict = {}
idx = {'A':0,'C':1,'G':2,'T':3}
for exp in exp_list:
    exp_df = cagi_df[cagi_df['8']==exp]
    idx_df = exp_df[['0','1','2']].drop_duplicates().sort_values(by=['1'])
    exp_len = len(exp_df['1'].unique())
    effect_size = np.zeros((4,exp_len))
    predict_size = np.zeros((4,exp_len))
    
    for pos in range(0,exp_len):
        row = idx_df.iloc[pos]
        loci_df = exp_df[(exp_df['0']==row['0'])&(exp_df['1']==row['1'])&(exp_df['2']==row['2'])]
        loci_idx = loci_df.index
        ref_allele = loci_df['3'].drop_duplicates().values
        alt_allele = loci_df['4'].values.tolist()
        diff = loci_df['6'].values

        effect_size[itemgetter(*alt_allele)(idx),pos] =np.absolute(diff)
        #predict_size [itemgetter(*alt_allele)(idx),pos] =llr[loci_idx]
        predict_size [itemgetter(*alt_allele)(idx),pos] =1
    r_value = stats.pearsonr(effect_size.flatten(),predict_size.flatten())
    performance_dict[exp]= r_value[0]
    if plot_figure:
        fig,ax = plt.subplots(2,1,figsize = (20,7))
        #fig2=plt.figure(figsize = (20,2))
        fig1 = sns.heatmap(effect_size,cmap = 'vlag',
                            center = 0,
                            #annot = exp_annot,fmt = '',
                        cbar_kws = dict(use_gridspec=False,location="bottom"),
                        ax = ax[0]);
        ax[0].tick_params(left=True, bottom=False);
        #ax.set_yticklabels(['A','C','G','T'],size = 1);
        ax[0].set_yticklabels([])
        ax[0].set_xticklabels([]);
        ax[0].set_title(exp+' ground truth')
        #plt.tight_layout()

        #fig3=plt.figure(figsize = (20,2))
        fig2 = sns.heatmap(predict_size,cmap = 'vlag',
                            center = 0,
                            #annot = pred_annot,fmt = '',
                            cbar_kws = dict(use_gridspec=False,location="bottom"),
                            ax = ax[1]);
        ax[1].tick_params(left=True, bottom=False);
        #ax.set_yticklabels(['A','C','G','T'],size = 1);
        ax[1].set_yticklabels([])
        ax[1].set_xticklabels([])
        ax[1].set_title(exp+' mutagenesis')

In [4]:
performance_dict

{'ZFAND3': 0.37958882258110616,
 'HBG1': 0.3648951366719255,
 'MSMB': 0.3939608936620037,
 'LDLR': 0.2867844760793666,
 'MYCrs6983267': 0.3637777785331868,
 'SORT1': 0.3457491448088308,
 'PKLR': 0.3535938034657003,
 'F9': 0.40615243352173375,
 'TERT-HEK293T': 0.40826779723051343,
 'IRF6': 0.296117921725497,
 'HBB': 0.4576562503479471,
 'TERT-GBM': 0.3852918724452555,
 'IRF4': 0.31538845447575753,
 'GP1BB': 0.3248181791422176,
 'HNF4A': 0.33819304988396887}

In [5]:
np.array(list(performance_dict.values())).mean()

0.36128208964245495

## Compare how different are the models?

In [5]:
import h5py
import numpy as np
import pandas as pd
from operator import itemgetter
from scipy import stats

a_out = h5py.File('/home/ztang/multitask_RNA/data/CAGI/zero_shot/cagi_2B5_multi_species.h5','r')
b_out = h5py.File('/home/ztang/multitask_RNA/data/CAGI/zero_shot/cagi_500M_human_ref.h5','r')

In [6]:
cagi_df = pd.read_csv('../../data/CAGI/final_cagi_metadata.csv',
                      index_col=0).reset_index()
exp_list = cagi_df['8'].unique()
plot_figure=False
a_out = a_out['llr'][()]
b_out = b_out['llr'][()]

In [6]:
performance_dict = {}
alt_performance_dict = {}
idx = {'A':0,'C':1,'G':2,'T':3}
for exp in exp_list:
    exp_df = cagi_df[cagi_df['8']==exp]
    idx_df = exp_df[['0','1','2']].drop_duplicates().sort_values(by=['1'])
    exp_len = len(exp_df['1'].unique())
    effect_size = np.zeros((4,exp_len))
    predict_size = np.zeros((4,exp_len))
    alt_predict_size = np.zeros((4,exp_len))
    
    for pos in range(0,exp_len):
        row = idx_df.iloc[pos]
        loci_df = exp_df[(exp_df['0']==row['0'])&(exp_df['1']==row['1'])&(exp_df['2']==row['2'])]
        loci_idx = loci_df.index
        ref_allele = loci_df['3'].drop_duplicates().values
        alt_allele = loci_df['4'].values.tolist()
        diff = loci_df['6'].values

        effect_size[itemgetter(*alt_allele)(idx),pos] =np.absolute(diff)
        predict_size [itemgetter(*alt_allele)(idx),pos] =a_out[loci_idx]
        alt_predict_size [itemgetter(*alt_allele)(idx),pos] =b_out[loci_idx]

    r_value = stats.pearsonr(effect_size.flatten(),predict_size.flatten())
    alt_r_value = stats.pearsonr(effect_size.flatten(),alt_predict_size.flatten())
    performance_dict[exp]= r_value[0]
    alt_performance_dict[exp]= alt_r_value[0]

In [13]:
predict_size

array([[0.99991727, 0.        , 0.99992973, ..., 0.        , 0.9999243 ,
        0.99911875],
       [0.        , 0.99928528, 0.99990278, ..., 0.99988973, 0.99988055,
        0.        ],
       [0.        , 0.99911839, 0.        , ..., 0.99970317, 0.        ,
        0.99935865],
       [0.        , 0.9999314 , 0.99993867, ..., 0.9999159 , 0.99987179,
        0.9997077 ]])

In [14]:
stats.pearsonr(effect_size.flatten(),predict_size.flatten())

PearsonRResult(statistic=0.2867984790769957, pvalue=1.3880290840794402e-25)

In [12]:
alt_predict_size

array([[0.99982399, 0.        , 0.99986553, ..., 0.        , 0.99976522,
        0.9996832 ],
       [0.        , 0.99982858, 0.99980557, ..., 0.99964333, 0.99968797,
        0.        ],
       [0.        , 0.99987394, 0.        , ..., 0.99973935, 0.        ,
        0.99973369],
       [0.        , 0.99981195, 0.99978489, ..., 0.99969804, 0.99955088,
        0.99979353]])

In [15]:
stats.pearsonr(effect_size.flatten(),alt_predict_size.flatten())

PearsonRResult(statistic=0.286789759083672, pvalue=1.3928977404317928e-25)