In [1]:
import h5py
import sys
import os
import math
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
sys.path.append('../data_generation')
import utils 
import numpy as np
datalen = '5994'

In [2]:
file = h5py.File("../data/CAGI/"+datalen+"/CAGI_onehot.h5", "r")
alt = file['alt']
ref = file['ref']

## NT zero-shot

Cosine similarity between embeddings with different allele

In [3]:
import nucleotide_transformer
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model
from tqdm import tqdm
model_name = '500M_1000G'

if '2B5' in model_name:
    print('2B5_model')
    embed_layer = 32
elif 'huamn_ref' in model_name: 
    print('500M model')
    embed_layer = 24
elif '_v2' in model_name:
    print('V2 model')
    embed_layer = 29

V2 model


In [4]:
max_len = math.ceil(len(alt[0])/6)+1
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(embed_layer,),
    attention_maps_to_save=(),
    max_positions=max_len,
)
forward_fn = hk.transform(forward_fn)

In [5]:
# CLS = 3
# PAD = 2
random_key = jax.random.PRNGKey(0)
N, L, A = alt.shape
batch_size = 50
cagi_llr=[]
for i in tqdm(range(0,N,batch_size)):
    b_size = batch_size
    if i + batch_size > N:
        b_size = N-i
    onehot = np.concatenate((ref[i:i+b_size],alt[i:i+b_size]))
    seq = utils.onehot_to_seq(onehot)
    token_out = tokenizer.batch_tokenize(seq)
    token_id = [b[1] for b in token_out]
    seq_pair = jnp.asarray(token_id,dtype=jnp.int32)
    outs = forward_fn.apply(parameters, random_key, seq_pair)
    for a in range(b_size):
        ref_out = outs['embeddings_'+str(embed_layer)][a, 1:, :]
        alt_out = outs['embeddings_'+str(embed_layer)][a+b_size, 1:, :]
        cagi_llr.append((ref_out * alt_out).sum()/(jnp.linalg.norm(ref_out)*jnp.linalg.norm(alt_out)))

100%|██████████| 369/369 [24:20<00:00,  3.96s/it]


In [6]:
output = h5py.File('../data/CAGI/NT_zeroshot/'+'cagi_'+datalen+'_'+model_name+'.h5', 'w')
output.create_dataset('llr', data=np.array(cagi_llr))
output.close()

## Print performance

In [3]:
import h5py
import pandas as pd
import numpy as np
from operator import itemgetter
import seaborn as sns
import matplotlib.pyplot as plt 
import scipy.stats as stats
input_length = '5994'
model_name = '2B5_1000G'
cagi_df = pd.read_csv('../data/CAGI/'+input_length+'/final_cagi_metadata.csv',
                      index_col=0).reset_index()
exp_list = cagi_df['8'].unique()
target = cagi_df['6'].values.tolist()

In [4]:
cagi_result = h5py.File('/home/ztang/LLM_eval/data/CAGI/NT_zeroshot'+'/cagi_'+input_length+'_'+model_name+'.h5', 'r')
cagi_llr = cagi_result['llr'][()]

In [5]:
perf = []
for exp in ['LDLR','SORT1','F9','PKLR']:
    sub_df = cagi_df[cagi_df['8'] == exp]
    exp_target = np.array(target)[sub_df.index.to_list()]
    exp_pred = np.squeeze(cagi_llr)[sub_df.index.to_list()]
    exp_target = np.absolute(exp_target)
    exp_pred = exp_pred
    print(exp)
    perf.append(stats.pearsonr(exp_pred,exp_target)[0])
    print(stats.pearsonr(exp_pred,exp_target)[0])

LDLR
-0.16390417169264007
SORT1
-0.10878043469983589
F9
-0.10292111437360503
PKLR
-0.0068942530809115435


In [7]:
(-0.164 -0.109 -0.102)/3

-0.125

## CAGI embedding generation for lenti model

In [2]:
import nucleotide_transformer
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_model
from tqdm import tqdm
model_name = '2B5_1000G'
datalen='230'

file = h5py.File("../data/CAGI/"+datalen+"/CAGI_onehot.h5", "r")
alt = file['alt']
ref = file['ref']

if '2B5' in model_name:
    print('2B5_model')
    embed_layer = 32

max_len = len(alt[0])//6+len(alt[0])%6+1
parameters, forward_fn, tokenizer, config = get_pretrained_model(
    model_name=model_name,
    embeddings_layers_to_save=(embed_layer,),
    attention_maps_to_save=(),
    max_positions=max_len,
)
forward_fn = hk.transform(forward_fn)

2B5_model


In [3]:
# CLS = 3
# PAD = 2
random_key = jax.random.PRNGKey(0)
N, L, A = alt.shape
batch_size = 50
ref_out = []
alt_out = []
for i in tqdm(range(0,N,batch_size)):
    ref_seq = ref[i:i+batch_size]
    alt_seq = alt[i:i+batch_size]
    ref_seq = utils.onehot_to_seq(ref_seq)
    alt_seq = utils.onehot_to_seq(alt_seq)

    ref_token = tokenizer.batch_tokenize(ref_seq)
    alt_token = tokenizer.batch_tokenize(alt_seq)
    ref_token = [b[1] for b in ref_token]
    alt_token = [b[1] for b in alt_token]

    ref_pair = jnp.asarray(ref_token,dtype=jnp.int32)
    alt_pair = jnp.asarray(alt_token,dtype=jnp.int32)

    ref_output = forward_fn.apply(parameters, random_key, ref_pair)['embeddings_'+str(embed_layer)]
    alt_output = forward_fn.apply(parameters, random_key, alt_pair)['embeddings_'+str(embed_layer)]

    ref_out.extend(np.asarray(ref_output))
    alt_out.extend(np.asarray(alt_output))

100%|██████████| 369/369 [27:14<00:00,  4.43s/it]


In [4]:
output = h5py.File('../data/CAGI/230_embed/'+'NT.h5', 'w')
output.create_dataset('ref', data=np.array(ref_out))
output.create_dataset('alt', data=np.array(alt_out))
output.close()