In [1]:
import numpy as np
import pandas as pd
import pickle
import math
import sys
sys.path.append('..')

import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.functional import F

import torch_geometric
from torch_geometric.nn import Set2Set
from torch_geometric.data import Data, DataLoader, Dataset
from torch_geometric.utils import to_networkx, from_networkx

import networkx as nx
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.preprocessing import OneHotEncoder, MinMaxScaler
from tqdm.notebook import tqdm

from deepSNEM.models.graph_transformer.euclidean_graph_transformer import GraphTransformerEncoder
from models.deep_graph_infomax.infomax import SNInfomax
from utils.data_gen import load_prot_embs, load_prot_embs_go, ucsv2graph_infomax, SNDatasetInfomax

import re
import gc

from captum.attr import IntegratedGradients, DeepLift, LayerIntegratedGradients, Saliency
from captum.attr import configure_interpretable_embedding_layer,remove_interpretable_embedding_layer

dev = torch.device('cuda')
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.cuda.is_available()

True

In [None]:
prot_embs, global_dict = load_prot_embs_go(512, norm=False)

def load_infomax_model(emb_type = 'seqveq'):
    if emb_type=='GO':
        trained_model_path = 'embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_GO_uniform_un.pt'
    elif emb_type=='seqveq':
        trained_model_path = 'embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_seq_uniform_un.pt'
    elif emb_type=='random':
        trained_model_path = 'embeddings/deep_graph_infomax/unsupervised/DGI_JSD_512_random_uniform_un.pt'
    else:
        raise AttributeError

    summarizer = summarizer = Set2Set(512, 3)
    encoder = GraphTransformerEncoder(n_layers=1, n_heads=4, n_hid=512, pretrained_weights=prot_embs, 
                                    summarizer=None).to(dev)
    model = SNInfomax(hidden_channels=512, encoder=encoder,
                                        summary=summarizer, semi=True).to(dev)
    model.load_state_dict(torch.load(trained_model_path))
    
    return model

def drug_group(drug_name ,smoothing=0.1):
    full_dataset = pd.read_csv('data/graph_info_df/full_dataset.csv')
    full_dataset = full_dataset[full_dataset.pert_iname == drug_name]
    #full_dataset = full_dataset[full_dataset.moa_v1 != 'Unknown']
    
    u_path_list = full_dataset.files_combined.values
    all_moas = full_dataset.moa_v1.values
    moa_v1 = full_dataset.moa_v1.values
    labels = full_dataset.sigs_g.to_numpy().reshape(-1,1)

    oh = OneHotEncoder()
    labels = oh.fit_transform(labels).toarray()
    dataset = SNDatasetInfomax(u_path_list, global_dict, labels)
    return full_dataset, dataset

def moa_group(moa):
    df = pd.read_csv('data/graph_info_df/kris.csv')
    full_dataset = df[df.moa_v1 == moa].groupby('sig_id').apply(lambda df: df.sample(1))
    
    u_path_list = full_dataset.files_combined.values
    all_moas = full_dataset.moa_v1.values
    moa_v1 = full_dataset.moa_v1.values

    dataset = SNDatasetInfomax(u_path_list, global_dict)
    return full_dataset, dataset

def custom_forward(x, data):
    act = model.encoder.pe(data)
    x = torch.add(x, act)
    for t in model.encoder.transformers:
        x = t(x, data)
    
    summary = model.summary(x, data.batch)
    
    return summary

def cni_all_models(model, dataset, topk=5):
    sal = Saliency(custom_forward)
    
    data_loader = DataLoader(dataset, batch_size=1, num_workers=12, shuffle=False)
        
    important_nodes = []
    attribs = []
    
    for tb in tqdm(data_loader):
        tb = tb.to(dev)
        for tar in range(1024):
            model.zero_grad()
            x = model.encoder.emb_layer(tb.global_idx)
            attributions = sal.attribute(x, additional_forward_args=(tb), target=tar, abs=True)
            attributions_a = attributions.sum(1).cpu().sort(descending=True)[0].numpy().reshape(-1,1)
            
            minmax = MinMaxScaler()
            attributions_a = minmax.fit_transform(attributions_a)
            attribs.append(attributions_a.reshape(-1))
            
            att_keys = attributions.sum(1).cpu().sort(descending=True)[1]
            important_nodes.append(tb.global_idx[att_keys].cpu().numpy())
            
    return np.array(important_nodes), np.array(attribs)


# Important Features

In [65]:
full, dataset = moa_group('HSP inhibitor')

In [66]:
model = load_infomax_model('seqveq')

In [None]:
#all_nodes, all_attrs = cni_all_models(model, dataset, topk=5)

In [10]:
reverse_global_dict = {v:k for k,v in global_dict.items()}

In [11]:
all_nodes = np.concatenate(all_nodes).ravel()
all_attrs = np.concatenate(all_attrs).ravel()

NameError: name 'all_nodes' is not defined

In [72]:
nodes = np.array([reverse_global_dict[key] for key in all_nodes.reshape(-1)])
attribs = all_attrs.reshape(-1)
df = pd.DataFrame({'N':nodes,'A':attribs})

In [73]:
df_filt = df.groupby('N').sum().reset_index().sort_values(by='A', ascending=False)

In [74]:
df1 = pd.merge(df, df_filt, on='N')
order = df1.groupby('N').mean().sort_values(by='A_x', ascending=False).reset_index().N.values
df2 = df1.groupby('N').mean().sort_values(by='A_x', ascending=False).reset_index().drop('A_y', axis=1)

In [75]:
random_hek = df2.copy()

In [77]:
df3 = pd.concat([seqveq_hek, go_hek, random_hek], axis=1)
df3.columns = ['N_s', 'A_s', 'N_G', 'A_G', 'N_r', 'A_r']

In [79]:
df3_new = df3.iloc[:7]

# Vis Attention Heads

In [81]:
from networkx.drawing.nx_agraph import graphviz_layout
sns.set_style('white')

In [110]:
sample = pd.read_csv('../snac_data/' + df_atp.files_combined.values[5])
G = nx.from_pandas_edgelist(sample, source='node1', target='node2', 
                            edge_attr=['sign'], create_using=nx.DiGraph())

In [112]:
colors = np.nan_to_num(np.array(colors, dtype=np.float))