# 0. Import dependencies 

In [21]:
# PyTorch
import torch
import torch_scatter
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.loader import DataLoader
from torch_geometric.nn.pool import global_mean_pool, global_max_pool, global_add_pool
import torch.nn.functional as F

# Pytorch geometric explainer
from torch_geometric.explain import *

import pickle
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import random
import os

# cust_functions folder
from cust_functions.training import *
from cust_functions.graph_networks import *
from cust_functions.graph_creation import *
from cust_functions.explain_helper import *

# -1: Helper functions, to be exported later on

The explanation model puts out a matrix of size (num_nodes, num_features)
The feature importance is then calculated as the sum of a feature across all nodes
The node importance is calculated as the sum of a node across all features (and also some influence of the weights, unclear!)

In [61]:
def importance_calculator(explanation: Explanation, node: bool, k_node: int, feature: bool, k_feature: int, input_data_preprocessed, pathways, translation):

    top_nodes, top_features = pd.DataFrame, pd.DataFrame
    
    # Turn node_mask into pandas 
    node_mask = explanation.node_mask 
    pd_node_mask = pd.DataFrame(node_mask.numpy())
    
    # Retrieve k_node most important nodes
    if (node == True):
        pd_node_mask['Node_score'] = pd_node_mask.sum(axis=1) # Per node sum horizontally across all features
        top_nodes = pd_node_mask.nlargest(k_node, 'Node_score')
        top_nodes = top_nodes.drop(columns=top_nodes.columns.difference(['Node_score']))
        pd_node_mask = pd_node_mask.drop('Node_score', axis=1)

    if (feature == True): 
        pd_node_mask['Feature_score'] = pd_node_mask.sum(axis=0) # Per feature sum vertically across all nodes
        top_features = pd_node_mask.nlargest(k_feature, 'Feature_score')
        top_features = top_features.drop(columns=top_features.columns.difference(['Feature_score']))

    # Connect node indices to their pathway name
    sample_graph = create_pathway_graph(pathways, translation, descendants=True)
    index_node_match = {'Pathway': list(sample_graph.nodes())}
    index_node_match = pd.DataFrame(index_node_match)
    top_nodes = pd.merge(top_nodes, index_node_match, left_index=True, right_index=True, how='inner')
        
    # Connect feature indices to their protein name
    index_protein_match = input_data_preprocessed.drop(columns=input_data_preprocessed.columns.difference(['Protein']))
    top_features = pd.merge(top_features, index_protein_match, left_index=True, right_index=True, how='inner')


    return top_nodes, top_features

# 1. Import data and preprocess it 

-We need a list of graphs where each graph represents one patient \
-Proteins are encoded using UniProt names \
-top protein: RESGCN_fold_1_rocauc_0.87, on trainings data

In [23]:
# Initialize pathway graph 
translation = pd.read_csv("aki_data/translation.tsv", sep="\t", index_col=0)
pathways = pd.read_csv("aki_data/pathways.tsv", sep="\t")
G = create_pathway_graph(pathways, translation, descendants=True)

# Load AKI disease data
input_data = pd.read_csv("aki_data/test_data.tsv", sep="\t", )
input_data_qm = pd.read_csv("aki_data/test_qm.csv")
design_matrix = pd.read_csv("aki_data/design_matrix.tsv", sep="\t")

# Preprocess input data
input_data_preprocessed = input_data_qm.fillna(0)
design_matrix = design_matrix.replace(1, 0)
design_matrix = design_matrix.replace(2, 1)

# split data into train and test
X_train = input_data_preprocessed.loc[:, input_data_preprocessed.columns.str.contains("M2012") | input_data_preprocessed.columns.str.contains("Protein")]
X_test = input_data_preprocessed.loc[:, ~input_data_preprocessed.columns.str.contains("M2012") | input_data_preprocessed.columns.str.contains("Protein")]

y_train = design_matrix[design_matrix['sample'].str.contains("M2012")]
y_test = design_matrix[~design_matrix['sample'].str.contains("M2012")]


# Load/Create graph data per patient 

load_train, save_train, load_test, save_test = False, False, False, False
if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_train_graph_data.pkl'):
    load_train = True
else: 
    save_train = True

if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_test_graph_data.pkl'):
    load_test = True
else: 
    save_train = True

train_graph_data = pytorch_graphdata(y_train, X_train, G, gen_column = 'Protein', load_data = load_train, save_data = save_train, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_train_graph_data.pkl')
test_graph_data = pytorch_graphdata(y_test, X_test, G, gen_column = 'Protein', load_data = load_test, save_data = save_test, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_test_graph_data.pkl')

print(f"Number of training graphs: {len(train_graph_data)}")
print(f"Number of test graphs: {len(test_graph_data)}")
print(f"Number of features: {train_graph_data[0].num_features}")
print(f"Number of classes: {np.unique([graph.y.detach().numpy()[0] for graph in train_graph_data])}")
print(f"Is directed: {train_graph_data[0].is_directed()}")
print(train_graph_data[0])

Number of training graphs: 141
Number of test graphs: 56
Number of features: 554
Number of classes: [0 1]
Is directed: True
Data(x=[2585, 554], edge_index=[2, 2603], y=[1])


# 2. Explain ResGCN

In [24]:
## Define ResGCN model without batch processing as this causes problems with the Pytorch Geometric Explainer 

class ResGCN(torch.nn.Module):
    def __init__(self, num_features, layer_configs, num_classes):
        super(ResGCN, self).__init__()

        initial_layer = layer_configs[0]
        self.initial = GCNBlock(num_features, initial_layer['out_channels'], initial_layer['dropout_rate'], initial_layer['batch_norm'])

        self.hidden_layers = torch.nn.ModuleList()
        for layer_config in layer_configs[1:]:
            self.hidden_layers.append(GCNBlock(layer_config['in_channels'], layer_config['out_channels'], layer_config['dropout_rate'], layer_config['batch_norm'], residual=True))

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(layer_configs[-1]['out_channels'], 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(64, num_classes),
        )

    def forward(self, x, edge_index):
        x = self.initial(x, edge_index)
        for layer in self.hidden_layers:
            x = layer(x, edge_index)
        x = global_max_pool(x, batch = None)
        x = self.mlp(x)
        return x

In [77]:
## Load trained model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Layer configuration as used in ResGCN training
layer_configs = [
    {"in_channels": 32, "out_channels": 32, "dropout_rate": 0.6, "batch_norm": True, "residual": True},
]

exp_ResGCN_model_path = "trained_models/Full_model_ResGCN_fold_2_rocauc_0.91.pt"
exp_ResGCN_model = ResGCN(train_graph_data[0].num_features, layer_configs, 2).to(device)
exp_ResGCN_model.load_state_dict(torch.load(exp_ResGCN_model_path, map_location=torch.device(device)))
exp_ResGCN_model.eval()

Using device: cpu


ResGCN(
  (initial): GCNBlock(
    (conv): GCNConv(554, 32)
    (dropout): Dropout(p=0.6, inplace=False)
    (bn): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (hidden_layers): ModuleList()
  (mlp): Sequential(
    (0): Linear(in_features=32, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=64, out_features=2, bias=True)
  )
)

In [78]:
# Create explanation object with model and data 
exp_ResGCN = explain_function(exp_ResGCN_model, train_graph_data)

KeyboardInterrupt: 

In [74]:
top_nodes, top_features = importance_calculator(exp_ResGCN, True, 20, True, 20, input_data_preprocessed, pathways, translation)

In [75]:
top_nodes.head(20)

Unnamed: 0,Node_score,Pathway
500,24.386959,R-HSA-977606
1247,17.851124,R-HSA-381426
543,17.268383,R-HSA-6798695
2069,11.15559,R-HSA-114608
826,8.977059,R-HSA-210993
505,6.402675,R-HSA-173623
541,4.596302,R-HSA-2454202
918,3.898797,R-HSA-975634
498,3.392097,R-HSA-166665
255,3.284095,R-HSA-3000178


In [76]:
top_features

Unnamed: 0,Feature_score,Protein
73,2.522169,P00734
433,2.414155,P01764
438,2.391601,P01614
387,2.230954,P01742
444,1.829935,P04433
93,1.736243,P02679
405,1.727084,P04211
474,1.715961,P01766
121,1.691966,P01714
344,1.653005,P80748


In [None]:
# Ideen: Checken ob da jetzt etwas vernünftiges rauskommt
# Sonst: Irgendwie versuchen den Subgraphen zu verwenden und schauen welche pathways da zu den Nummern gehören 

In [28]:
# Deprecated: Do not run
top_features #Deprecated: ResGCN_fold_1_0.87ROC

Unnamed: 0,Feature_score,Protein
73,2.507029,P00734
433,2.487755,P01764
438,2.434342,P01614
387,2.282706,P01742
444,1.817743,P04433
405,1.760881,P04211
474,1.749604,P01766
93,1.703529,P02679
121,1.692959,P01714
87,1.659538,P01024


In [29]:
top_nodes

Unnamed: 0,Node_score
500,24.216917
1247,17.844151
543,17.276419
2069,11.084801
826,8.945814
505,5.669091
541,4.816395
504,4.152472
918,3.893676
498,3.321473
