# 0. Import dependencies 

In [1]:
# PyTorch
import torch
import torch_scatter
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.loader import DataLoader
from torch_geometric.nn.pool import global_mean_pool, global_max_pool, global_add_pool
import torch.nn.functional as F

# Pytorch geometric explainer
from torch_geometric.explain import *

import pickle
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import random
import os

# cust_functions folder
from cust_functions.training import *
from cust_functions.graph_networks import *
from cust_functions.graph_creation import *
from cust_functions.explain_helper import *

# Set up device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


# 1. Import data and preprocess it 

-We need a list of graphs where each graph represents one patient \
-Proteins are encoded using UniProt names

## 1.1 AKI data

In [2]:
# Initialize pathway graph 
translation = pd.read_csv("aki_data/translation.tsv", sep="\t", index_col=0)
pathways = pd.read_csv("aki_data/pathways.tsv", sep="\t")
G = create_pathway_graph(pathways, translation, descendants=True)

# Load AKI disease data
aki_input_data = pd.read_csv("aki_data/test_data.tsv", sep="\t", )
aki_input_data_qm = pd.read_csv("aki_data/test_qm.csv")
aki_design_matrix = pd.read_csv("aki_data/design_matrix.tsv", sep="\t")

# Preprocess input data
aki_input_data_preprocessed = aki_input_data_qm.fillna(0)
aki_design_matrix = aki_design_matrix.replace(1, 0)
aki_design_matrix = aki_design_matrix.replace(2, 1)

# split data into train and test
aki_X_train = aki_input_data_preprocessed.loc[:, aki_input_data_preprocessed.columns.str.contains("M2012") | aki_input_data_preprocessed.columns.str.contains("Protein")]
aki_X_test = aki_input_data_preprocessed.loc[:, ~aki_input_data_preprocessed.columns.str.contains("M2012") | aki_input_data_preprocessed.columns.str.contains("Protein")]

aki_y_train = aki_design_matrix[aki_design_matrix['sample'].str.contains("M2012")]
aki_y_test = aki_design_matrix[~aki_design_matrix['sample'].str.contains("M2012")]


# Load/Create graph data per patient 

load_train, save_train, load_test, save_test = False, False, False, False
if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_train_graph_data.pkl'):
    load_train = True
else: 
    save_train = True

if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_test_graph_data.pkl'):
    load_test = True
else: 
    save_train = True

aki_train_graph_data = pytorch_graphdata(aki_y_train, aki_X_train, G, gen_column = 'Protein', load_data = load_train, save_data = save_train, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_train_graph_data.pkl')
aki_test_graph_data = pytorch_graphdata(aki_y_test, aki_X_test, G, gen_column = 'Protein', load_data = load_test, save_data = save_test, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_test_graph_data.pkl')

print(f"Number of training graphs: {len(aki_train_graph_data)}")
print(f"Number of test graphs: {len(aki_test_graph_data)}")
print(f"Number of features: {aki_train_graph_data[0].num_features}")
print(f"Number of classes: {np.unique([graph.y.detach().numpy()[0] for graph in aki_train_graph_data])}")
print(f"Is directed: {aki_train_graph_data[0].is_directed()}")
print(aki_train_graph_data[0])

# Later needed in this format
aki_structural_data = [aki_input_data_preprocessed, pathways, translation]


Number of training graphs: 141
Number of test graphs: 56
Number of features: 554
Number of classes: [0 1]
Is directed: True
Data(x=[2585, 554], edge_index=[2, 2603], y=[1])


## 1.2 Covid data

In [3]:
#tbd

________________________________________________________________________________
________________________________________________________________________________


# 2. Explain ResGCN on AKI data

In [4]:
## 
# Define ResGCN model without batch processing as this causes problems with the Pytorch Geometric Explainer 
##

class explanain_ResGCN(torch.nn.Module):
    def __init__(self, num_features, layer_configs, mlp_config, num_classes):
        super(explanain_ResGCN, self).__init__()

        initial_layer = layer_configs[0]
        self.initial = GCNBlock(num_features, initial_layer['out_channels'], initial_layer['dropout_rate'], initial_layer['batch_norm'])

        self.hidden_layers = torch.nn.ModuleList()
        for layer_config in layer_configs[1:]:
            self.hidden_layers.append(GCNBlock(layer_config['in_channels'], layer_config['out_channels'], layer_config['dropout_rate'], layer_config['batch_norm'], residual=True))

        # Configurable MLP
        mlp_layers = []
        prev_channels = layer_configs[-1]['out_channels']
        for layer in mlp_config:
            mlp_layers.append(torch.nn.Linear(prev_channels, layer['out_channels']))
            if layer.get('batch_norm', False):
                mlp_layers.append(torch.nn.BatchNorm1d(layer['out_channels']))
            if layer.get('relu', True):
                mlp_layers.append(torch.nn.ReLU())
            if 'dropout_rate' in layer:
                mlp_layers.append(torch.nn.Dropout(layer['dropout_rate']))
            prev_channels = layer['out_channels']

        mlp_layers.append(torch.nn.Linear(prev_channels, num_classes))
        self.mlp = torch.nn.Sequential(*mlp_layers)

    def forward(self, x, edge_index):
        x = self.initial(x, edge_index)
        for layer in self.hidden_layers:
            x = layer(x, edge_index)
        x = global_max_pool(x, batch=None)
        x = self.mlp(x)
        return x


In [4]:
# Layer configuration as used in ResGCN training for AKI data 

layer_configs = [{"in_channels": 32, "out_channels": 32, "dropout_rate": 0.5, "batch_norm": True, "residual": True}]
mlp_config = [{"out_channels": 64, "relu": True, "batch_norm": False, "dropout_rate": 0.1}]



path2 = "AKI_ResGCN_fold_1.pt"
ResGCN_model = explanain_ResGCN(aki_train_graph_data[0].num_features, layer_configs, mlp_config, 2).to(device)

In [5]:
top_nodes2, top_features2 = explain_wrapper(ResGCN_model, 20, path2, aki_train_graph_data, aki_structural_data, device)

In [6]:
top_nodes2.head(20)

Unnamed: 0,Node_score,Pathway
543,21.121393,R-HSA-6798695
1247,14.868736,R-HSA-381426
2069,13.079555,R-HSA-114608
500,12.263504,R-HSA-977606
1796,7.82482,R-HSA-8957275
906,6.902922,R-HSA-2168880
502,6.82643,R-HSA-2855086
958,6.062621,R-HSA-2871796
252,5.440669,R-HSA-216083
826,5.169696,R-HSA-210993


In [7]:
top_features2.head(20)

Unnamed: 0,Feature_score,Protein
1,2.663656,P02671
93,2.618227,P02679
62,2.002209,P02751
172,1.777125,P02675
48,1.722862,P01009
76,1.626122,P01700
433,1.604076,P01764
153,1.580175,P16070
444,1.562226,P04433
405,1.557812,P04211


In [8]:
path3 = 'AKI_ResGCN_fold_5.pt'
top_nodes, top_features = explain_wrapper(ResGCN_model, 20, path3, aki_train_graph_data, aki_structural_data, device)

In [9]:
top_nodes.head(20)

Unnamed: 0,Node_score,Pathway
543,19.647421,R-HSA-6798695
906,12.841311,R-HSA-2168880
826,9.727979,R-HSA-210993
1247,9.691864,R-HSA-381426
252,7.927331,R-HSA-216083
505,7.085743,R-HSA-173623
825,5.431609,R-HSA-210991
1796,5.234465,R-HSA-8957275
500,4.890698,R-HSA-977606
541,3.925312,R-HSA-2454202


In [10]:
top_features.head(20)

Unnamed: 0,Feature_score,Protein
62,2.261529,P02751
224,2.198992,P06312
52,2.146339,P01834
344,2.058982,P80748
76,1.970525,P01700
487,1.965304,P04430
373,1.962967,P01763
433,1.96217,P01764
20,1.942912,P02647
485,1.940809,P01717


# 3. Explain ResGCN on Covid data

# 4. Explain ResGCN on perturbated AKI data

________________________________________________________________________________
________________________________________________________________________________


# 5. Explain ResGAT on AKI data

In [None]:
## 
# Define ResGAT model without batch processing as this causes problems with the Pytorch Geometric Explainer 
##

class ResGAT(torch.nn.Module):
    def __init__(self, num_features, layer_configs, mlp_config, num_classes):
        super(ResGAT, self).__init__()

        # GAT layers
        initial_layer = layer_configs[0]
        self.initial = GATBlock(num_features, initial_layer['out_channels'], initial_layer.get('heads', 1), initial_layer['dropout_rate'], initial_layer['batch_norm'])

        self.hidden_layers = torch.nn.ModuleList()
        for layer_config in layer_configs[1:]:
            self.hidden_layers.append(GATBlock(layer_config['in_channels'], layer_config['out_channels'], layer_config.get('heads', 1), layer_config['dropout_rate'], layer_config['batch_norm'], residual=True))

        # Configurable MLP
        mlp_layers = []
        prev_channels = layer_configs[-1]['out_channels'] * layer_configs[-1].get('heads', 1)
        for layer in mlp_config:
            mlp_layers.append(torch.nn.Linear(prev_channels, layer['out_channels']))
            if layer.get('batch_norm', False):
                mlp_layers.append(torch.nn.BatchNorm1d(layer['out_channels']))
            if layer.get('relu', True):
                mlp_layers.append(torch.nn.ReLU())
            if 'dropout_rate' in layer:
                mlp_layers.append(torch.nn.Dropout(layer['dropout_rate']))
            prev_channels = layer['out_channels']

        mlp_layers.append(torch.nn.Linear(prev_channels, num_classes))
        self.mlp = torch.nn.Sequential(*mlp_layers)

    def forward(self, x, edge_index):
        x = self.initial(x, edge_index)
        for layer in self.hidden_layers:
            x = layer(x, edge_index)
        x = global_max_pool(x)
        x = self.mlp(x)
        return x


In [None]:
# Layer configuration as used in ResGAT training
layer_configs = [
    {"in_channels": 32, "out_channels": 32, "dropout_rate": 0.6, "batch_norm": True, "residual": True},
]

path1 = "AKI_ResGCN_fold_1.pt"
ResGCN_model = ResGCN(aki_train_graph_data[0].num_features, layer_configs, 2).to(device)


# 6. Explain ResGAT von Covid data 