In [1]:
!pip install graphlime




In [None]:
!pip uninstall scikit-learn -y
!pip install --no-binary :all: scikit-learn


Found existing installation: scikit-learn 1.1.3
Uninstalling scikit-learn-1.1.3:
  Successfully uninstalled scikit-learn-1.1.3
Collecting scikit-learn
  Downloading scikit_learn-1.6.1.tar.gz (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m31.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip install torch_geometric

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from graphlime import GraphLIME
from torch_geometric.data import Data
import numpy as np



In [None]:
!git clone https://github.com/mims-harvard/GraphXAI.git
%cd GraphXAI
!pip install -e .
%cd ..
!pip install --upgrade collections
!pip install ipdb
import os
import sys
# Assuming GraphXAI repository was cloned to /content/GraphXAI
graphxai_path = '/content/GraphXAI'

# Check if directory exists; if not, clone it
if not os.path.exists(graphxai_path):
    !git clone https://github.com/mims-harvard/GraphXAI.git $graphxai_path

# Navigate to GraphXAI directory and install
%cd $graphxai_path
!pip install -e .

# Return to original directory
%cd /content

# Add the GraphXAI directory to your PYTHONPATH
sys.path.append(graphxai_path)
!sed -i 's/from collections import Iterable/from collections.abc import Iterable/g' /content/GraphXAI/graphxai/visualization/visualizations.py

!sed -i 's/from collections import Iterable/from collections.abc import Iterable/g' /content/GraphXAI/graphxai/visualization/explanation_vis.py


In [None]:
from graphxai.datasets.shape_graph import ShapeGGen

In [None]:
import matplotlib.pyplot as plt

In [None]:
import networkx as nx
from torch_geometric.utils import from_networkx
data = ShapeGGen(
    model_layers=2,
    num_subgraphs=15,
    subgraph_size=13,
    prob_connection=0.3,
    add_sensitive_feature=False
)
G = data.G


# Visualizando o grafo completo gerado
plt.figure(figsize=(8,8))
data.visualize(show=True)

In [None]:


class GNN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(in_channels, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        return F.softmax(x, dim=1)

# Instancia o modelo
model = GNN(in_channels=data.x.size(1), out_channels=2)

# Treinamento do modelo
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out, data.y)
    loss.backward()
    optimizer.step()

In [None]:

from torch_geometric.utils import k_hop_subgraph
import networkx as nx
from torch_geometric.utils import from_networkx, to_networkx




In [None]:
data.x

In [None]:

central_node = 0
num_hops = 3

In [None]:
G.nodes


In [None]:
import numpy as np
from graphlime import GraphLIME

explainer = GraphLIME(model=model, hop=num_hops, rho=0.1, cached=True)

node_id = 0

# Gera a explicação para o nó, passando os atributos 'x' e 'edge_index'
explanation = explainer.explain_node(node_id, x=data.x, edge_index=data.edge_index)


if isinstance(explanation, np.ndarray):

    node_importances = [(i, imp) for i, imp in enumerate(explanation)]

    node_decisions = None
    print("O objeto retornado é um numpy array com as importâncias.")
else:

    subgraph_mapping = explanation.subgraph_mapping
    node_importances = []
    node_decisions = []
    for sub_idx, original_id in subgraph_mapping.items():
        imp = explanation.node_importances[sub_idx]
        dec = explanation.node_decisions[sub_idx]
        node_importances.append((original_id, imp))
        node_decisions.append((original_id, dec))

# Exibe os resultados
print("Importâncias dos nós:", node_importances)
if node_decisions is not None:
    print("Decisões dos nós:", node_decisions)


In [None]:

if hasattr(explanation, 'subgraph_mapping'):
    subgraph_mapping = explanation.subgraph_mapping
    # Lista dos nós do subgrafo (IDs originais)
    sub_nodes = list(subgraph_mapping.values())

    # Se o seu grafo original for do NetworkX, filtre os nós e as arestas:
    G_sub = G.subgraph(sub_nodes).copy()

    # Cria uma lista de importâncias de acordo com o mapeamento
    node_color = []
    for sub_idx, original_id in subgraph_mapping.items():
        imp = explanation.node_importances[sub_idx]
        node_color.append(imp)

    # Configura a visualização
    plt.figure(figsize=(8, 6))
    pos = nx.spring_layout(G_sub, seed=42)
    nodes = nx.draw_networkx_nodes(G_sub, pos, node_color=node_color, cmap=plt.cm.viridis, node_size=500)
    nx.draw_networkx_edges(G_sub, pos, alpha=0.6)
    nx.draw_networkx_labels(G_sub, pos, font_color='white')
    plt.title("Subgrafo Explicativo (com mapeamento)")
    plt.colorbar(nodes, label="Importância")
    plt.axis('off')
    plt.show()

elif isinstance(explanation, np.ndarray):

    if node_id in G:
        explained_nodes = [node_id] + list(G.neighbors(node_id))
    else:

        explained_nodes = [node_id]

    # Extrai o subgrafo do grafo original
    G_sub = G.subgraph(explained_nodes).copy()
    num_nodes = min(len(explanation), len(explained_nodes))
    node_color = explanation[:num_nodes]

    plt.figure(figsize=(8, 6))
    pos = nx.spring_layout(G_sub, seed=42)
    nodes = nx.draw_networkx_nodes(G_sub, pos, node_color=node_color, cmap=plt.cm.viridis, node_size=500)
    nx.draw_networkx_edges(G_sub, pos, alpha=0.6)
    nx.draw_networkx_labels(G_sub, pos, font_color='white')

    plt.colorbar(nodes, label="Importância")
    plt.axis('off')
    plt.show()

else:
    print("O formato de retorno da explicação não foi reconhecido.")

In [None]:
print(node_importances)

# Calculate the mean of the importances along the desired axis
print(node_importances_mean)
