In [1]:
import torch
import os
os.environ['CUDA_VISIBLE_DEVICES'] ='3'
import sys
sys.path.append('../model/sei_model/')
import sei
from torchinfo import summary
import random
import numpy as np
import h5py
from tqdm import tqdm
import math
import random
import glob
sys.path.append('../data_generation/')
import utils

file_dict = torch.load('../model/sei_model/sei.pth')
clean_dict = {}
for key in file_dict:
    clean_key = key[13:]
    clean_dict[clean_key] = file_dict[key]
model = sei.Sei();
model.load_state_dict(clean_dict)
model.to('cuda').eval();

In [7]:
class embed_extractor():
    def __init__(self):
        self.activation = {}
    def get_activation(self,name):
        def hook(model, input, output):
            if name not in self.activation.keys():
                self.activation[name] = []
            self.activation[name].extend(output.detach().cpu().numpy())
        return hook

## Zero shot learning for CAGI

In [3]:
datalen = '4096'
file = h5py.File("../data/CAGI/"+datalen+"/CAGI_onehot.h5", "r")
alt = file['alt']
ref = file['ref']

In [13]:
embed = embed_extractor()
model.spline_tr.register_forward_hook(embed.get_activation('s_out'))
batch_size = 128
#LentiMPRA
for i in tqdm(range(0,len(alt),batch_size)):
    alt_seq  = np.swapaxes(alt[i:i+batch_size],1,2).astype('float32')   
    with torch.no_grad():
        output_seq = model(torch.from_numpy(alt_seq).to('cuda'))
alt_out = np.array(embed.activation['s_out'])   

embed = embed_extractor()
model.spline_tr.register_forward_hook(embed.get_activation('s_out'))
for i in tqdm(range(0,len(alt),batch_size)):
    ref_seq  = np.swapaxes(ref[i:i+batch_size],1,2).astype('float32')   
    with torch.no_grad():
        output_seq = model(torch.from_numpy(ref_seq).to('cuda'))
ref_out = np.array(embed.activation['s_out']) 

100%|██████████| 145/145 [00:23<00:00,  6.14it/s]
100%|██████████| 145/145 [00:23<00:00,  6.06it/s]


In [19]:
cos = []
dot = []
l1 = []
l2 = []
for i in range(len(alt)):
    cos.append((ref_out[i] * alt_out[i]).sum()/(np.linalg.norm(ref_out[i])*np.linalg.norm(alt_out[i])))
    dot.append((ref_out[i] * alt_out[i]).sum())
    l1.append(np.absolute(ref_out[i] - alt_out[i]).sum())
    l2.append(np.square(ref_out[i] - alt_out[i]).sum())

In [22]:
output = h5py.File('../data/CAGI/'+'cagi_'+datalen+'_'+'sei.h5', 'w')
output.create_dataset('cosine', data=np.array(cos))
output.create_dataset('dot', data=np.array(dot))
output.create_dataset('l1', data=np.array(l1))
output.create_dataset('l2', data=np.array(l2))
output.close()

## Organize Performance

In [1]:
import h5py
import pandas as pd
import numpy as np
from operator import itemgetter
import seaborn as sns
import matplotlib.pyplot as plt 
import scipy.stats as stats
datalen = '4096'
cagi_df = pd.read_csv('../data/CAGI/'+datalen+'/final_cagi_metadata.csv',
                      index_col=0).reset_index()
target = cagi_df['6'].values.tolist()
exp_list = cagi_df['8'].unique()
cagi_result = h5py.File('../data/CAGI/'+'cagi_'+datalen+'_'+'sei.h5', 'r')

In [4]:
sub_df = cagi_df[cagi_df['8'] == 'LDLR']
sub_df

Unnamed: 0,index,0,1,2,3,4,5,6,7,8,9
1113,1115,chr19,11197874,11201970,A,C,+,0.11,0.02,LDLR,challenge
1114,1116,chr19,11197874,11201970,A,G,+,-0.03,0.00,LDLR,challenge
1115,1117,chr19,11197874,11201970,A,T,+,-0.07,0.02,LDLR,challenge
1116,1118,chr19,11197875,11201971,C,A,+,0.02,0.00,LDLR,challenge
1117,1119,chr19,11197875,11201971,C,G,+,0.15,0.02,LDLR,challenge
...,...,...,...,...,...,...,...,...,...,...,...
4550,4554,chr19,11198128,11202224,C,G,+,-0.03,0.01,LDLR,release
4551,4555,chr19,11198128,11202224,C,T,+,-0.08,0.09,LDLR,release
4552,4556,chr19,11198129,11202225,A,C,+,-0.09,0.02,LDLR,release
4553,4557,chr19,11198129,11202225,A,G,+,-0.03,0.02,LDLR,release


In [2]:
perf = []
for key in cagi_result.keys():
    print(key)
    cagi_llr = cagi_result[key]
    for exp in ['LDLR','SORT1','F9','PKLR']:
        sub_df = cagi_df[cagi_df['8'] == exp]
        exp_target = np.array(target)[sub_df.index.to_list()]
        exp_pred = np.squeeze(cagi_llr)[sub_df.index.to_list()]
        exp_target = np.absolute(exp_target)
        print(exp)
        perf.append(stats.pearsonr(exp_pred,exp_target)[0])
        print(stats.pearsonr(exp_pred,exp_target)[0])

cosine
LDLR
-0.5520171780285357
SORT1
-0.5387206115465768
F9
-0.5432590148862912
PKLR
-0.6408864481498749
dot
LDLR
-0.27938162080021284
SORT1
-0.23028795488017276
F9
-0.11118765212780396
PKLR
-0.12963570025082222
l1
LDLR
0.5434136986989314
SORT1
0.517112310908757
F9
0.5800116641738409
PKLR
0.6820189345508577
l2
LDLR
0.5521511442917562
SORT1
0.5469994406536017
F9
0.5494278051521219
PKLR
0.6519684223759711


In [3]:

(-0.5520171780285357-0.5387206115465768-0.5432590148862912)/3

-0.5446656014871346

## Embedding for downstream Lenti model

In [3]:
datalen = '230'
file = h5py.File("../data/CAGI/"+datalen+"/CAGI_onehot.h5", "r")
alt = file['alt']
ref = file['ref']

In [4]:
embed = embed_extractor()
model.spline_tr.register_forward_hook(embed.get_activation('s_out'))
batch_size = 128
pad_size = (4096-alt.shape[-1])/2
#LentiMPRA
for i in tqdm(range(0,len(alt),batch_size)):
    alt_seq  = np.swapaxes(alt[i:i+batch_size],1,2).astype('float32')
    pad_seq = np.pad(alt_seq,((0,0),(0,0),(math.floor(pad_size),math.ceil(pad_size))))   
    with torch.no_grad():
        output_seq = model(torch.from_numpy(pad_seq).to('cuda'))
alt_out = np.array(embed.activation['s_out'])   

embed = embed_extractor()
model.spline_tr.register_forward_hook(embed.get_activation('s_out'))
for i in tqdm(range(0,len(alt),batch_size)):
    ref_seq  = np.swapaxes(ref[i:i+batch_size],1,2).astype('float32')
    pad_seq = np.pad(ref_seq,((0,0),(0,0),(math.floor(pad_size),math.ceil(pad_size))))     
    with torch.no_grad():
        output_seq = model(torch.from_numpy(pad_seq).to('cuda'))
ref_out = np.array(embed.activation['s_out']) 

  return F.conv1d(input, weight, bias, self.stride,
100%|██████████| 145/145 [00:24<00:00,  5.87it/s]
100%|██████████| 145/145 [00:24<00:00,  5.83it/s]


In [8]:
output = h5py.File('../data/CAGI/230_embed/'+'sei.h5', 'w')
output.create_dataset('ref', data=np.array(ref_out))
output.create_dataset('alt', data=np.array(alt_out))
output.close()