# Explainability of GNN

In [None]:
%load_ext autoreload
%autoreload 2

In [3]:
%matplotlib inline

## Imports

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer
import os
import networkx as nx
import torch
import pickle
import pandas as pd
from gnn_model import StationFlowGCN
from train_gnn import(
    train_node_gnn_model,
    MAPE_loss
)
from utils.station_network import StationNetworkSimul
from utils.data import get_degraded_network_loader, create_degraded_networks

## Loading data & generating network dataset

In [7]:
df_stations = pd.read_csv('plan du métro.csv')
df_stations = df_stations[~df_stations['vers Ligne'].isin(['\xa01', '\xa07', '\xa02', '\xa08', '\xa06'])]

df_pos = pd.read_csv("position gps des stations de métro.csv")

#Removing Malsesherbes RER Station
df_pos = df_pos.drop([151])

df_flow = pd.read_csv('passagers.csv')
df_flow['nombre'] = df_flow['nombre'].astype(float)
path_flows = df_flow[['de', 'vers', 'nombre']].to_dict('records')

network_simul = StationNetworkSimul(df_stations=df_stations, df_pos=df_pos)

In [8]:
network_simul.set_edges_weights()
network_simul.set_nodes_traffic(path_flows=path_flows)

In [None]:
data_dir = "graph_dataset/"

if not os.path.isdir(data_dir):
    os.mkdir(data_dir)

num_delete = 1
num_degraded=50

create_degraded_networks(network_simul, df_flow, num_delete=num_delete, num_degraded=num_degraded, data_dir=data_dir)

In [None]:
train_degraded_graphs = []
dev_degraded_graphs = []
test_degraded_graphs = []

train_test_ratio = 0.9
dev_train_ratio = 0.1


folder_path = os.path.join(data_dir, f'delete_{num_delete}')
all_files = [file_path for file_path in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, file_path))]
degraded_graphs = []
for file_path in all_files:
    with open(os.path.join(folder_path, file_path), 'rb') as f:
        new_net = pickle.load(f)
    degraded_graphs.append(new_net)

train_split_idx = int(train_test_ratio*len(all_files))
dev_split_idx = int(dev_train_ratio*train_split_idx)

dev_degraded_graphs.extend(degraded_graphs[:dev_split_idx])
train_degraded_graphs.extend(degraded_graphs[dev_split_idx:train_split_idx])
test_degraded_graphs.extend(degraded_graphs[train_split_idx:])

## Training GNN

In [11]:
config = dict(
    epochs = 5,
    lr = 0.001,
    criterion = torch.nn.L1Loss(),
    metrics = dict(
        MAE=torch.nn.L1Loss(),
        MAPE=MAPE_loss()
    )
)

In [14]:
node_target_name = 'traffic'
node_feature_names=['x', 'y']

train_loader = get_degraded_network_loader(train_degraded_graphs, node_target_name=node_target_name, node_feature_names=node_feature_names, shuffle=True)
dev_loader = get_degraded_network_loader(dev_degraded_graphs, node_target_name=node_target_name, node_feature_names=node_feature_names, shuffle=True)
test_loader = get_degraded_network_loader(test_degraded_graphs, node_target_name=node_target_name, node_feature_names=node_feature_names, shuffle=True)

In [None]:
input_dim = train_loader.dataset[0].x.shape[1]
output_dim = 1

nodes_gnn_model = StationFlowGCN(
    input_dim=input_dim,
    output_dim=output_dim,
    num_nodes=train_loader.dataset[0].x.shape[0],
)

train_node_gnn_model(nodes_gnn_model, config, train_loader, dev_loader)


## Explainability

In [16]:
explainer = Explainer(
    model=nodes_gnn_model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='regression',
        task_level='node',
        return_type='raw',
    ),
)

data = test_loader.dataset[0]
idx_chtl_4 = network_simul.network_stations['Châtelet']['4']

explanation = explainer(data.x, data.edge_index, index=idx_chtl_4)

In [None]:
node_labels = nx.get_node_attributes(network_simul.network_graph, 'title')
explanation.visualize_graph(node_labels=node_labels)