# 0. Import dependencies 

In [12]:
# PyTorch
import torch
import torch_scatter
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.loader import DataLoader
from torch_geometric.nn.pool import global_mean_pool, global_max_pool, global_add_pool
import torch.nn.functional as F

# Pytorch geometric explainer
from torch_geometric.explain import *

import pickle
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import random
import os
import json

# cust_functions folder
from cust_functions.training import *
from cust_functions.graph_networks import *
from cust_functions.graph_creation import *
from cust_functions.explain_helper import *

# Set up device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


# 1. Import data and preprocess it 

-We need a list of graphs where each graph represents one patient \
-Proteins are encoded using UniProt names

In [13]:
# Initialize pathway graph (needed for AKI and Covid data)
translation = pd.read_csv("aki_data/translation.tsv", sep="\t", index_col=0)
pathways = pd.read_csv("aki_data/pathways.tsv", sep="\t")
G = create_pathway_graph(pathways, translation, descendants=True)


In [14]:
# 
## Initialize PERTURBED pathway graph 
#

# Randomize protein assignment and pertubed graph edges
translation_pert = translation.copy()
trans = translation_pert['translation'].tolist()
random.shuffle(trans)
translation_pert['translation'] = trans

G_pert = create_pathway_graph(pathways, translation_pert, descendants=True, perturb = True, edge_removal_prob=0.9, edge_addition_prob=0.0005)


## 1.1 AKI data

In [15]:
# Load AKI disease data
aki_input_data = pd.read_csv("aki_data/test_data.tsv", sep="\t", )
aki_input_data_qm = pd.read_csv("aki_data/test_qm.csv")
aki_design_matrix = pd.read_csv("aki_data/design_matrix.tsv", sep="\t")

# Preprocess input data
aki_input_data_preprocessed = aki_input_data_qm.fillna(0)
aki_design_matrix = aki_design_matrix.replace(1, 0)
aki_design_matrix = aki_design_matrix.replace(2, 1)

# split data into train and test
aki_X_train = aki_input_data_preprocessed.loc[:, aki_input_data_preprocessed.columns.str.contains("M2012") | aki_input_data_preprocessed.columns.str.contains("Protein")]
aki_X_test = aki_input_data_preprocessed.loc[:, ~aki_input_data_preprocessed.columns.str.contains("M2012") | aki_input_data_preprocessed.columns.str.contains("Protein")]

aki_y_train = aki_design_matrix[aki_design_matrix['sample'].str.contains("M2012")]
aki_y_test = aki_design_matrix[~aki_design_matrix['sample'].str.contains("M2012")]


# Load/Create graph data per patient 

load_train, save_train, load_test, save_test = False, False, False, False
if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_train_graph_data.pkl'):
    load_train = True
else: 
    save_train = True

if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_test_graph_data.pkl'):
    load_test = True
else: 
    save_test = True

aki_train_graph_data = pytorch_graphdata(aki_y_train, aki_X_train, G, gen_column = 'Protein', load_data = load_train, save_data = save_train, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_train_graph_data.pkl')
aki_test_graph_data = pytorch_graphdata(aki_y_test, aki_X_test, G, gen_column = 'Protein', load_data = load_test, save_data = save_test, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/AKI_test_graph_data.pkl')

print(f"Number of training graphs: {len(aki_train_graph_data)}")
print(f"Number of test graphs: {len(aki_test_graph_data)}")
print(f"Number of features: {aki_train_graph_data[0].num_features}")
print(f"Number of classes: {np.unique([graph.y.detach().numpy()[0] for graph in aki_train_graph_data])}")
print(f"Is directed: {aki_train_graph_data[0].is_directed()}")
print(aki_train_graph_data[0])

# Later needed in this format
aki_structural_data = [aki_X_train, G]


Number of training graphs: 141
Number of test graphs: 56
Number of features: 554
Number of classes: [0 1]
Is directed: True
Data(x=[2585, 554], edge_index=[2, 2603], y=[1])


## 1.2 Covid data

In [16]:
# Preprocessed Covid data already created and saved in GNN_implementation_v3_yves.ipynb script
# --> Load 

Covid_X_train = pd.read_csv('covid_data/covid_train_qm.csv', index_col=False)
Covid_X_test = pd.read_csv('covid_data/covid_test_qm.csv', index_col=False)
Covid_y_train = pd.read_csv('covid_data/covid_train_design_qm.csv', index_col=False)
Covid_y_test = pd.read_csv('covid_data/covid_test_design_qm.csv', index_col=False)


# Load/Create graph data per patient 
load_train, save_train, load_test, save_test = False, False, False, False
if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Covid_train_graph_data.pkl'):
    load_train = True
else: 
    save_train = True

if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Covid_test_graph_data.pkl'):
    load_test = True
else: 
    save_test = True


covid_train_graph_data = pytorch_graphdata(Covid_y_train, Covid_X_train, G, gen_column = 'Protein', 
                                           load_data = load_train, save_data = save_train, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Covid_train_graph_data.pkl')
covid_test_graph_data = pytorch_graphdata(Covid_y_test, Covid_X_test, G, gen_column = 'Protein',
                                        load_data = load_test, save_data = save_test, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Covid_test_graph_data.pkl')

print(f"Number of training graphs: {len(covid_train_graph_data)}")
print(f"Number of test graphs: {len(covid_test_graph_data)}")
print(f"Number of features: {covid_train_graph_data[0].num_features}")
print(f"Number of classes: {np.unique([graph.y.detach().numpy()[0] for graph in covid_train_graph_data])}")
print(f"Is directed: {covid_train_graph_data[0].is_directed()}")
print(covid_train_graph_data[0])


# Needed later in this format 
covid_structural_data = [Covid_X_train, G]


Number of training graphs: 707
Number of test graphs: 79
Number of features: 169
Number of classes: [0 1]
Is directed: True
Data(x=[2585, 169], edge_index=[2, 2603], y=[1])


## 1.3 Perturbated AKI data

In [17]:
# Create PyTorch Geometric Data objects for train and test data

load_train, save_train, load_test, save_test = False, False, False, False
if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_AKI_train_graph_data.pkl'):
    load_train = True
else: 
    save_train = True

if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_AKI_test_graph_data.pkl'):
    load_test = True
else: 
    save_test = True

pert_aki_train_graph_data = pytorch_graphdata(aki_y_train, aki_X_train, G_pert , gen_column = 'Protein', load_data = load_train, save_data = save_train, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_AKI_train_graph_data.pkl')
pert_aki_test_graph_data = pytorch_graphdata(aki_y_test, aki_X_test, G_pert , gen_column = 'Protein', load_data = load_test, save_data = save_test, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_AKI_test_graph_data.pkl')


# Needed later in this format 
pert_aki_structural_data = [aki_X_train, G_pert]


FileNotFoundError: [Errno 2] No such file or directory: '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_AKI_train_graph_data.pkl'

## 1.4. Perturbated Covid data

In [None]:
# Create PyTorch Geometric Data objects for train and test data

load_train, save_train, load_test, save_test = False, False, False, False
if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_Covid_train_graph_data.pkl'):
    load_train = True
else: 
    save_train = True

if os.path.exists('/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_Covid_test_graph_data.pkl'):
    load_test = True
else: 
    save_test = True

pert_covid_train_graph_data = pytorch_graphdata(Covid_y_train, Covid_X_train, G_pert , gen_column = 'Protein', load_data = load_train, save_data = save_train, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_Covid_train_graph_data.pkl')
pert_covid_test_graph_data = pytorch_graphdata(Covid_y_test, Covid_X_test, G_pert , gen_column = 'Protein', load_data = load_test, save_data = save_test, path = '/Users/hendrikplett/Downloads/Deep_Learning_Project/DL_files/Pert_Covid_test_graph_data.pkl')

pert_covid_structural_data = [Covid_X_train, G_pert]


________________________________________________________________________________
________________________________________________________________________________


# 2. Explain ResGCN on AKI data

In [6]:
## 
# Define ResGCN model without batch processing as this causes problems with the Pytorch Geometric Explainer 
##

class explanain_ResGCN(torch.nn.Module):
    def __init__(self, num_features, layer_configs, mlp_config, num_classes):
        super(explanain_ResGCN, self).__init__()

        initial_layer = layer_configs[0]
        self.initial = GCNBlock(num_features, initial_layer['out_channels'], initial_layer['dropout_rate'], initial_layer['batch_norm'])

        self.hidden_layers = torch.nn.ModuleList()
        for layer_config in layer_configs[1:]:
            self.hidden_layers.append(GCNBlock(layer_config['in_channels'], layer_config['out_channels'], layer_config['dropout_rate'], layer_config['batch_norm'], residual=True))

        # Configurable MLP
        mlp_layers = []
        prev_channels = layer_configs[-1]['out_channels']
        for layer in mlp_config:
            mlp_layers.append(torch.nn.Linear(prev_channels, layer['out_channels']))
            if layer.get('batch_norm', False):
                mlp_layers.append(torch.nn.BatchNorm1d(layer['out_channels']))
            if layer.get('relu', True):
                mlp_layers.append(torch.nn.ReLU())
            if 'dropout_rate' in layer:
                mlp_layers.append(torch.nn.Dropout(layer['dropout_rate']))
            prev_channels = layer['out_channels']

        mlp_layers.append(torch.nn.Linear(prev_channels, num_classes))
        self.mlp = torch.nn.Sequential(*mlp_layers)

    def forward(self, x, edge_index):
        x = self.initial(x, edge_index)
        for layer in self.hidden_layers:
            x = layer(x, edge_index)
        x = global_max_pool(x, batch=None)
        x = self.mlp(x)
        return x


In [8]:
# Layer configuration as used in ResGCN training for AKI data 

layer_configs = [{"in_channels": 32, "out_channels": 32, "dropout_rate": 0.5, "batch_norm": True, "residual": True}]
mlp_config = [{"out_channels": 64, "relu": True, "batch_norm": False, "dropout_rate": 0.1}]


paths_to_AKI_ResGCN = ["AKI_ResGCN_fold_1.pt", "AKI_ResGCN_fold_2.pt", 
                       "AKI_ResGCN_fold_3.pt", "AKI_ResGCN_fold_4.pt", "AKI_ResGCN_fold_5.pt"]

AKI_ResGCN_top_features_nodes = {}

for path in paths_to_AKI_ResGCN:
    # Initialize the  ResGCN model
    ResGCN_model = explanain_ResGCN(aki_train_graph_data[0].num_features, layer_configs, mlp_config, 2).to(device)

    # Retrieve the most important nodes and features
    top_nodes, top_features = explain_wrapper(ResGCN_model, path, aki_train_graph_data[0:2], aki_structural_data, device)

    AKI_ResGCN_top_features_nodes[path] = [top_nodes, top_features]


# 3. Explain ResGCN on Covid data

In [None]:
# Layer configuration used for ResGCN on the covid data set

layer_configs = [{"in_channels": 16, "out_channels": 16, "dropout_rate": 0.5, "batch_norm": True, "residual": True}]
mlp_config = [{"out_channels": 128, "relu": True, "batch_norm": False, "dropout_rate": 0.5},
              {"out_channels": 64, "relu": True, "batch_norm": False, "dropout_rate": 0.5},
              {"out_channels": 32, "relu": True, "batch_norm": False, "dropout_rate": 0.5}]

paths_to_Covid_ResGCN = ["COVID_ResGCN_fold_1.pt", "COVID_ResGCN_fold_2.pt", 
                       "COVID_ResGCN_fold_3.pt", "COVID_ResGCN_fold_4.pt", "COVID_ResGCN_fold_5.pt"]


Covid_ResGCN_top_features_nodes = {}

for path in paths_to_Covid_ResGCN:
    # Initialize the  ResGCN model
    ResGCN_model = explanain_ResGCN(covid_train_graph_data[0].num_features, layer_configs, mlp_config, 2).to(device)

    # Retrieve the most important nodes and features
    top_nodes, top_features = explain_wrapper(ResGCN_model, path, covid_train_graph_data, covid_structural_data, device)

    Covid_ResGCN_top_features_nodes[path] = [top_nodes, top_features]



# 4. Explain ResGCN on perturbated AKI data

# 5. Explain ResGCN on perturbated Covid data

________________________________________________________________________________
________________________________________________________________________________


# 6. Explain ResGAT on AKI data

In [12]:
## 
# Define ResGAT model without batch processing as this causes problems with the Pytorch Geometric Explainer 
##

class explain_ResGAT(torch.nn.Module):
    def __init__(self, num_features, layer_configs, mlp_config, num_classes):
        super(explain_ResGAT, self).__init__()

        # GAT layers
        initial_layer = layer_configs[0]
        self.initial = GATBlock(num_features, initial_layer['out_channels'], initial_layer.get('heads', 1), initial_layer['dropout_rate'], initial_layer['batch_norm'])

        self.hidden_layers = torch.nn.ModuleList()
        for layer_config in layer_configs[1:]:
            self.hidden_layers.append(GATBlock(layer_config['in_channels'], layer_config['out_channels'], layer_config.get('heads', 1), layer_config['dropout_rate'], layer_config['batch_norm'], residual=True))

        # Configurable MLP
        mlp_layers = []
        prev_channels = layer_configs[-1]['out_channels'] * layer_configs[-1].get('heads', 1)
        for layer in mlp_config:
            mlp_layers.append(torch.nn.Linear(prev_channels, layer['out_channels']))
            if layer.get('batch_norm', False):
                mlp_layers.append(torch.nn.BatchNorm1d(layer['out_channels']))
            if layer.get('relu', True):
                mlp_layers.append(torch.nn.ReLU())
            if 'dropout_rate' in layer:
                mlp_layers.append(torch.nn.Dropout(layer['dropout_rate']))
            prev_channels = layer['out_channels']

        mlp_layers.append(torch.nn.Linear(prev_channels, num_classes))
        self.mlp = torch.nn.Sequential(*mlp_layers)

    def forward(self, x, edge_index):
        x = self.initial(x, edge_index)
        for layer in self.hidden_layers:
            x = layer(x, edge_index)
        x = global_max_pool(x, batch = None)
        x = self.mlp(x)
        return x


In [13]:
# Layer configuration as used in ResGAT training
layer_configs = [{"in_channels": 64, "out_channels": 64, "heads": 1, "dropout_rate": 0.5, "batch_norm": True, "residual": True}]
mlp_config = [{"out_channels": 64, "relu": True, "batch_norm": False, "dropout_rate": 0.1},
              {"out_channels": 64, "relu": True, "batch_norm": False, "dropout_rate": 0.1}]


paths_to_AKI_ResGAT = ["AKI_ResGAT_fold_1.pt", "AKI_ResGAT_fold_2.pt", 
                       "AKI_ResGAT_fold_3.pt", "AKI_ResGAT_fold_4.pt", "AKI_ResGAT_fold_5.pt"]

AKI_ResGAT_top_features_nodes = {}

for path in paths_to_AKI_ResGAT:
    # Initialize the  ResGAT model
    ResGAT_model = explain_ResGAT(aki_train_graph_data[0].num_features, layer_configs, mlp_config, 2).to(device)

    # Retrieve the most important nodes and features
    top_nodes, top_features = explain_wrapper(ResGAT_model, path, aki_train_graph_data, aki_structural_data, device)

    AKI_ResGAT_top_features_nodes[path] = [top_nodes, top_features]



# 7. Explain ResGAT on Covid data 

Question: Check in BINN paper on which data (train or test or both) to evaluate the model, currently I evaluate it on the test data 