## Explainability for GNNs

In this notebook, I want to try out the interpretation package for GNNS from PyTorch.

In [9]:
import torch
from torch_geometric.explain import Explainer, GNNExplainer, PGExplainer
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [10]:
dataset = 'data-oneil'  # or whichever dataset you want to explain
fold_n = 1  # The fold you used for training

form_data_path = './' + dataset + '/form_data'
xTe = np.load(form_data_path + '/xTe' + str(fold_n) + '.npy')
yTe = np.load(form_data_path + '/yTe' + str(fold_n) + '.npy')
drugTe = np.load(form_data_path + '/drugTe' + str(fold_n) + '.npy')
edge_index = torch.from_numpy(np.load(form_data_path + '/edge_index.npy')).long()

final_annotation_gene_df = pd.read_csv('./' + dataset + '/filtered_data/kegg_gene_annotation.csv')
gene_name_list = list(final_annotation_gene_df['kegg_gene'])
num_gene = len(gene_name_list)
dict_drug_num = pd.read_csv('./' + dataset + '/filtered_data/drug_num_dict.csv')
num_drug = dict_drug_num.shape[0]
node_num = num_gene + num_drug

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_dim = 4  # As per your original configuration
hidden_dim = 4
output_dim = 36
decoder_dim = 150

gene_num_df = pd.read_csv('./' + dataset + '/filtered_data/kegg_gene_num_interaction.csv')
num_gene_edge = gene_num_df.shape[0]
drugbank_num_df = pd.read_csv('./' + dataset + '/filtered_data/final_drugbank_num_sym.csv')
num_drug_edge = drugbank_num_df.shape[0]
num_edge = num_gene_edge + num_drug_edge


In [11]:
from enc_dec.geo_webgnn_decoder import WeBGNNDecoder

model = WeBGNNDecoder(input_dim=input_dim, hidden_dim=hidden_dim, embedding_dim=output_dim, 
                      decoder_dim=decoder_dim, node_num=node_num, num_edge=num_edge, 
                      num_gene_edge=num_gene_edge, device=device)
model = model.to(device)

# Load the trained model weights
model_path = '/Users/olha/Study/Software Project ML for Cancer/epoch_500_fold_1/best_train_model.pt'
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

  checkpoint = torch.load(model_path, map_location=device)


WeBGNNDecoder(
  (conv_first): WeBGNNConv()
  (conv_block): WeBGNNConv()
  (conv_last): WeBGNNConv()
  (act): ReLU()
  (act2): LeakyReLU(negative_slope=0.1)
  (x_norm): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [19]:
x = torch.FloatTensor(xTe[0:1]).to(device)  # Take first sample for explanation
drug_index = torch.LongTensor(drugTe[0:1]).to(device)

class ModelWrapper(torch.nn.Module):
    def __init__(self, model, drug_index):
        super().__init__()
        self.model = model
        self.drug_index = drug_index
    
    def forward(self, x, edge_index):
        return self.model(x, edge_index, self.drug_index)

# Create a wrapped model
wrapped_model = ModelWrapper(model, drug_index)

# Use PGExplainer instead of GNNExplainer
explainer = Explainer(
    model=wrapped_model,
    algorithm=PGExplainer(epochs=100, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    ),
)

# Create the explanation
true_value = torch.FloatTensor(yTe[0:1]).to(device)

# We'll use this as our target instead of running the model
target = true_value

explainer.algorithm.train(
    model = wrapped_model,
    epoch=100,  # Number of epochs to train
    x=x,
    edge_index=edge_index,
    target=target
)

# Then generate the explanation
explanation = explainer(x, edge_index, target=target)

RuntimeError: running_mean should contain 8108 elements not 4

In [None]:
# Visualize feature importance
path = './explanations/feature_importance.png'
os.makedirs(os.path.dirname(path), exist_ok=True)
explanation.visualize_feature_importance(path, top_k=10)
print(f"Feature importance plot has been saved to '{path}'")

# Visualize subgraph
path = './explanations/subgraph.pdf'
explanation.visualize_graph(path)
print(f"Subgraph visualization plot has been saved to '{path}'")

# To better understand the node indices, create a mapping
gene_names = list(final_annotation_gene_df['kegg_gene'])
drug_names = list(dict_drug_num['Drug'])
node_names = gene_names + drug_names

# Print the important nodes and their names
important_nodes = explanation.node_mask.topk(5).indices.cpu().numpy()
print("Most important nodes in the explanation:")
for node_idx in important_nodes:
    if node_idx < len(gene_names):
        print(f"Gene: {gene_names[node_idx]}")
    else:
        drug_idx = node_idx - len(gene_names)
        if drug_idx < len(drug_names):
            print(f"Drug: {drug_names[drug_idx]}")