# Attention weights

In [1]:
%load_ext autoreload
%reload_ext autoreload

In [None]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [2]:
import sys
import os
sys.path.append('/workspace/src')
from src.pair_graphs import InferenceDataset, collate_batch
from src.evaluation import Evaluation, overlaps, overlap_3, disease_dicts, disease_subtype_dict
import pandas as pd
import umap
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sns
import numpy as np
from rdkit import Chem, DataStructs
from scipy import stats
from scipy.spatial.distance import pdist,squareform
from src.train_drp import MultimodalAttentionNet, Conf

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


* Load all models

In [3]:
from src.create_graphs_ppi import PPIGraphsDRP

In [4]:
data_dir = None
project_dir = Path('/workspace')

In [5]:
conf = Conf(
    lr = 1e-3,
    batch_size = 32,
    epochs = 300,
    reduce_lr = True,
    ppi_depth = 3,
    mat_depth = 4,
    mat_heads = 4,
).to_hparams()

model = MultimodalAttentionNet(conf, data_dir=None, 
    mat_depth=4, mat_heads=4, ppi_depth=3)

Global seed set to 42


In [6]:
drp = Evaluation('/workspace', dataset='NCI60DRP', split='random', 
                 ppi_depth=3, seed=42, mat_depth=4, mat_heads=4,
                ckpt_path='workspace/models/NCI60DRP_random_42/'+
                '1680458624/checkpoint/epoch=212-step=1957470.ckpt',
                model = model)

Lightning automatically upgraded your loaded checkpoint from v1.9.4 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file workspace/models/NCI60DRP_random_42/1680458624/checkpoint/epoch=212-step=1957470.ckpt`
Global seed set to 42


In [7]:
attention = drp.attention_links()

100%|██████████| 11/11 [00:16<00:00,  1.51s/it]


In [8]:
attention[attention['cell'] == 'CVCL_0292'] 

Unnamed: 0,attention,protein_1,protein_2,cell,p1_name,p2_name
0,0.145695,2960,205,CVCL_0292,GBP6,AEN
1,0.574381,2960,4455,CVCL_0292,GBP6,MED15
2,0.180436,2960,5346,CVCL_0292,GBP6,PARP11
3,0.018591,2960,2819,CVCL_0292,GBP6,FOXO1
4,0.012227,2960,2493,CVCL_0292,GBP6,EVI5L
...,...,...,...,...,...,...
102208,0.041306,5426,5426,CVCL_0292,PCDHGB2,PCDHGB2
102209,0.402354,5427,5427,CVCL_0292,PCDHGB3,PCDHGB3
102210,0.087990,5428,5428,CVCL_0292,PCDHGB4,PCDHGB4
102211,0.018921,5429,5429,CVCL_0292,PCDHGC3,PCDHGC3


In [9]:
# extract rows that have cell 'CVCL_0292' and p1_name 'NDUFAF7'
attention[attention['cell'] == 'CVCL_0292'] 
attention[attention['p2_name'] == 'NDUFAF7'].sum()


attention                                                  1.0
protein_1                                                40118
protein_2                                                48960
cell         CVCL_0292CVCL_0292CVCL_0292CVCL_0292CVCL_0292C...
p1_name      NDUFB10M1APNDUFB4MTSS1DVL3DYNC1H1NDUFAF4DYMNDU...
p2_name      NDUFAF7NDUFAF7NDUFAF7NDUFAF7NDUFAF7NDUFAF7NDUF...
dtype: object

In [None]:
attention.to_csv('/workspace/data/processed/NCI60DRP_random_self_att/attentionDRP.csv')

In [None]:
def get_cell_df(cell):
    expression = pd.read_pickle(project_dir / 'data/processed/NCI60DRP_random/cell_features_drp.pkl')
    ppi_links_cell = pd.read_pickle(project_dir / 'data/processed/NCI60DRP_random/ppi_links_drp.pkl')
    print("Getting {} ", cell)
    data_list = []
    ppi=pd.DataFrame([])
    idx = expression['RRID'] == cell
    cell_expression = expression.loc[idx]
    oe_gene = OrdinalEncoder()
    if self.self_att == '_self_att':               
        idx_2 = (ppi_links_cell['protein_1'].isin(
            cell_expression['gene'])) & (
            ppi_links_cell['protein_2'].isin(
                cell_expression['gene']))
                
        pi1 = ppi_links_cell[['protein_1', 'protein_2']].loc[idx_2]
        cell_expression = cell_expression.loc[
            cell_expression['gene'].isin(pi1['protein_1'].unique())]
                
        oe_gene.fit(cell_expression['gene'].unique().reshape(-1,1))
        oe_length = cell_expression['gene'].unique().shape[0]
        oe_l = oe_gene.transform(
            cell_expression['gene'].unique().reshape(-1,1))
                
        gene_l = cell_expression['gene'].unique().reshape(-1,1)
        cell_expression['cell_gene_ordinal'] = oe_gene.transform(
            cell_expression['gene'].values.reshape(-1,1))

Creating PPI graphs ['CVCL_0292' 'CVCL_1331' 'CVCL_1779' 'CVCL_1304' 'CVCL_1690' 'CVCL_1092'
 'CVCL_0021' 'CVCL_0062' 'CVCL_1195' 'CVCL_1051' 'CVCL_0004']
