In [8]:
import numpy as np
import pandas as pd
import pickle
import math
import sys
sys.path.append('..')

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

import torch
import torch.nn as nn
from torch.functional import F

import torch_geometric
from torch_geometric.data import Data, DataLoader, Dataset

import networkx as nx
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm.notebook import tqdm

from models.graph_transformer.euclidean_graph_transformer import GraphTransformerEncoder
from models.graph_transformer.autoencoder_base import DeepSNEM, LinearDecoder, FermiDiracDecoder
from utils.data_gen import SNLDataset, load_prot_embs, wcsv2graph, load_prot_embs_go

import re
import gc

from captum.attr import IntegratedGradients,InterpretableEmbeddingBase
from captum.attr import visualization, configure_interpretable_embedding_layer,remove_interpretable_embedding_layer

dev = torch.device('cuda')
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.cuda.is_available()

True

In [9]:
unique_prots = 'data/prot_embeddings/new_features/proteins.csv'
unique_df = pd.read_csv(unique_prots)
global_dict = {}

for idx, prot in enumerate(unique_df.proteins.to_numpy()):
    global_dict[prot] = idx

In [10]:
weighted_fnames = '../snac_data/file_info_weighted.csv'
w_fnames = pd.read_csv(weighted_fnames)
w_path_list = w_fnames.files_weighted.to_numpy()
wsample = w_path_list[2]

In [11]:
data = wcsv2graph(wsample, global_dict, [0,0,1])
data

Data(acts=[68, 2], edge_index=[2, 83], global_idx=[68], label=[1], neg_childs=[68], pos_childs=[68], seq_mat=[68, 68], sign=[83, 2], weight=[83], y=[3])

In [12]:
SIZE = 512
EMB_DIM = 512
prot_embs = load_prot_embs_go(SIZE, norm=False)
summarizer = lambda z, *args, **kwargs: z.mean(dim=0)
encoder = GraphTransformerEncoder(n_layers=1, n_heads=4, n_hid=EMB_DIM, pretrained_weights=prot_embs[0], 
                                  summarizer=summarizer).to(dev)
decoder = FermiDiracDecoder(1.0).to(dev)
autoenc = DeepSNEM(encoder, decoder).to(dev)
autoenc.load_state_dict(torch.load('embeddings/autoencoder_graph/gt_512_tl_1_lp.pt'))

<All keys matched successfully>

In [19]:
autoenc.encoder.emb_layer.weight

Parameter containing:
tensor([[-0.5638,  1.7662, -1.2303,  ...,  0.3325, -0.2374,  0.0486],
        [-0.9060,  1.8102,  1.4526,  ..., -0.1316,  0.1973,  0.2973],
        [-0.9025, -0.6047, -1.1277,  ..., -0.1601,  0.2491,  0.1278],
        ...,
        [-0.7644, -0.5591, -1.1038,  ...,  0.2015,  0.2613,  0.2287],
        [-0.7576, -0.5490,  0.8497,  ..., -0.3825, -0.0021,  0.1408],
        [-0.7599, -0.5567, -1.3666,  ..., -0.0693,  0.3519,  0.0655]],
       device='cuda:0', requires_grad=True)

# Important Features

In [18]:
interpretable_embedding = configure_interpretable_embedding_layer(autoenc, 'encoder.emb_layer')

In [24]:
ig = IntegratedGradients(autoenc)
input_embedding = interpretable_embedding.indices_to_embeddings(data.global_idx.cuda())