In [6]:
import math
import random
import config
from torch_geometric.data import Data
from reward import explanation_reward, similarity, compute_fidelity
from constraint import constraint
from model import GCN_2l, GIN
import torch
import networkx as nx
import matplotlib.pyplot as plt
from MCTS_algo import MCTS
from utils import to_networkx_graph, mutag_dataset, ba2motif_dataset
from subgraph_matching import subgraph_score
from networkx.algorithms.isomorphism import GraphMatcher
from tqdm import tqdm
import torch.nn.functional as F

dataset = mutag_dataset
metric_weights = {'sparse': 1, 'interpret': 1, 'fidelity': 1}
config.metric_weights = metric_weights
fidelity_weights = {'plus': 0.3, 'minus': 0.7}
config.fidelity_weights = fidelity_weights
main_model = GIN(input_dim = dataset[0].x.shape[1], output_dim = 2, multi=True)
main_model.load_state_dict(torch.load('models/GIN_model_MUTAG.pth', map_location=torch.device('cpu'), weights_only=True))

<All keys matched successfully>

In [12]:
net_stability = 0
net_interpret = 0 
net_fidelity = 0        
num_graphs = 40

for k in tqdm(range(num_graphs)):

    config.graph_index = k  # You can change this to analyze different molecules
    graph_index = config.graph_index
    # print(f"Analyzing molecule {graph_index} from MUTAG dataset..")

    # Extract data from the selected graph
    x = dataset[graph_index].x
    edge_index = dataset[graph_index].edge_index
    edge_attr = dataset[graph_index].edge_attr
    edge_list = []

    for i in range(edge_index.size(1)):
        src, dst = edge_index[0, i].item(), edge_index[1, i].item()
        edge_list.append((src, dst))

    # Set edge_attr in config (needed by reward function)
    config.edge_attr = edge_attr


    # Initialize and run MCTS
    config.max_edges = 12
    config.allowed = range(len(edge_list))

    mcts = MCTS(main_model, x, edge_list, edge_index, explanation_reward, metric_weights, 
                constraint, C=10, num_simulations=50, rollout_depth=100)

    exec(open("interpret_norm.py").read(), globals())
    
    present_state = set()
    best_subset = set()
    best_reward = [0,0,0,0]

    for _ in range(config.max_edges):
        try:
            result = mcts.search(present_state).state
            present_state = result
            reward = explanation_reward(present_state, metric_weights)
            if(reward[-1] >= best_reward[-1]):
                best_reward = reward
                best_subset = present_state
        except:
            break

    target_edge_list = torch.zeros((2,len(best_subset)), dtype = torch.long)
    last_filled = 0 
    unique_nodes = set()

    for idx,edge in enumerate(edge_list):
        if(idx not in best_subset): continue
        target_edge_list[0][last_filled] = edge[0]
        target_edge_list[1][last_filled] = edge[1]
        unique_nodes.add(edge[0])
        unique_nodes.add(edge[1])
        last_filled+=1

    unique_nodes = sorted(list(unique_nodes))
    mapping = {}
    for idx, node in enumerate(unique_nodes):
        mapping[node] = idx

    for edge in range(target_edge_list.shape[1]):
        target_edge_list[0][edge] = mapping[target_edge_list[0][edge].item()]
        target_edge_list[1][edge] = mapping[target_edge_list[1][edge].item()]

    target_x = config.node_features[list(unique_nodes)]
    target_graph_data = Data(x=target_x, edge_index=target_edge_list, edge_attr=config.edge_attr[list(best_subset)])
    # config.alter_graphs.append(target_graph_data)
    config.alter_graphs.append((best_subset,best_reward[-1]))


    # Sample random graphs and get their explanations with the same user metrics preference
    for _ in range(10):

        k = 0.8
        sampled_indices = random.sample(range(len(edge_list)), int(k*len(edge_list)))
        config.allowed = sampled_indices

        present_state = set()
        best_subset = set()
        best_reward = [0,0,0,0]

        mcts = MCTS(main_model, x, edge_list, edge_index, explanation_reward, metric_weights, 
                constraint, C=10, num_simulations=50, rollout_depth=100)

        for _ in range(config.max_edges):
            try:
                result = mcts.search(present_state).state
                present_state = result
                reward = explanation_reward(present_state, metric_weights)
                if(reward[-1] >= best_reward[-1]):
                    best_reward = reward
                    best_subset = present_state
            except:
                break

        config.alter_graphs.append((best_subset,best_reward[-1]))

    # Run MCTS with updated reward function
    config.allowed = range(len(edge_list))
    present_state = set()
    best_subset = set()
    best_reward = [0,0,0,0]

    mcts = MCTS(main_model, x, edge_list, edge_index, similarity_score, metric_weights, 
                constraint, C=10, num_simulations=50, rollout_depth=100)

    for _ in tqdm(range(config.max_edges)):
        try:
            result = mcts.search(present_state).state
            present_state = result
            reward = explanation_reward(present_state, metric_weights)
            if(reward[-1] >= best_reward[-1]):
                best_reward = reward
                best_subset = present_state
        except:
            break

    # constraint(best_subset,log=True)
    target_edge_list = torch.zeros((2,len(best_subset)), dtype = torch.long)
    last_filled = 0 
    unique_nodes = set()

    for idx,edge in enumerate(edge_list):
        if(idx not in best_subset): continue
        target_edge_list[0][last_filled] = edge[0]
        target_edge_list[1][last_filled] = edge[1]
        unique_nodes.add(edge[0])
        unique_nodes.add(edge[1])
        last_filled+=1

    unique_nodes = sorted(list(unique_nodes))
    mapping = {}
    for idx, node in enumerate(unique_nodes):
        mapping[node] = idx

    for edge in range(target_edge_list.shape[1]):
        target_edge_list[0][edge] = mapping[target_edge_list[0][edge].item()]
        target_edge_list[1][edge] = mapping[target_edge_list[1][edge].item()]

    target_x = config.node_features[list(unique_nodes)]
    target_graph_data = Data(x=target_x, edge_index=target_edge_list, edge_attr=config.edge_attr[list(best_subset)])

    net_interpret += subgraph_score(best_subset)
    stability = 0
    for alter_graph in config.alter_graphs:
        stability += similarity(best_subset, alter_graph[0])
    net_stability += stability
    net_fidelity += compute_fidelity(best_subset, fidelity_weights)

  0%|          | 0/40 [00:00<?, ?it/s]

Graph has 17 nodes and 19 edges


100%|██████████| 12/12 [00:00<00:00, 41.53it/s]
  2%|▎         | 1/40 [02:20<1:31:02, 140.06s/it]

Graph has 13 nodes and 14 edges


100%|██████████| 12/12 [00:00<00:00, 50.79it/s]
  5%|▌         | 2/40 [03:50<1:10:22, 111.12s/it]

Graph has 13 nodes and 14 edges


100%|██████████| 12/12 [00:00<00:00, 49.88it/s]
  8%|▊         | 3/40 [05:08<59:01, 95.71s/it]   

Graph has 19 nodes and 22 edges


100%|██████████| 12/12 [00:00<00:00, 33.60it/s]
 10%|█         | 4/40 [07:28<1:08:02, 113.40s/it]

Graph has 11 nodes and 11 edges


 92%|█████████▏| 11/12 [00:00<00:00, 55.12it/s]
 12%|█▎        | 5/40 [08:03<49:35, 85.02s/it]   

Graph has 28 nodes and 31 edges


100%|██████████| 12/12 [00:00<00:00, 24.29it/s]
 15%|█▌        | 6/40 [10:53<1:04:35, 113.99s/it]

Graph has 16 nodes and 17 edges


100%|██████████| 12/12 [00:00<00:00, 31.67it/s]
 18%|█▊        | 7/40 [12:26<58:49, 106.97s/it]  

Graph has 20 nodes and 22 edges


100%|██████████| 12/12 [00:00<00:00, 29.32it/s]
 20%|██        | 8/40 [14:37<1:01:12, 114.78s/it]

Graph has 12 nodes and 13 edges


100%|██████████| 12/12 [00:00<00:00, 43.46it/s]
 22%|██▎       | 9/40 [15:24<48:19, 93.54s/it]   

Graph has 17 nodes and 19 edges


100%|██████████| 12/12 [00:00<00:00, 27.88it/s]
 25%|██▌       | 10/40 [17:37<52:49, 105.64s/it]

Graph has 17 nodes and 19 edges


100%|██████████| 12/12 [00:00<00:00, 26.82it/s]
 28%|██▊       | 11/40 [19:42<53:58, 111.69s/it]

Graph has 20 nodes and 23 edges


100%|██████████| 12/12 [00:00<00:00, 21.69it/s]
 30%|███       | 12/40 [21:51<54:35, 116.98s/it]

Graph has 22 nodes and 25 edges


100%|██████████| 12/12 [00:00<00:00, 21.53it/s]
 32%|███▎      | 13/40 [25:14<1:04:22, 143.05s/it]

Graph has 13 nodes and 14 edges


100%|██████████| 12/12 [00:00<00:00, 28.29it/s]
 35%|███▌      | 14/40 [26:22<52:07, 120.28s/it]  

Graph has 19 nodes and 22 edges


100%|██████████| 12/12 [00:00<00:00, 21.41it/s]
 38%|███▊      | 15/40 [28:43<52:40, 126.41s/it]

Graph has 22 nodes and 25 edges


100%|██████████| 12/12 [00:00<00:00, 20.34it/s]
 40%|████      | 16/40 [31:38<56:27, 141.16s/it]

Graph has 11 nodes and 11 edges


 92%|█████████▏| 11/12 [00:00<00:00, 31.25it/s]
 42%|████▎     | 17/40 [32:08<41:17, 107.73s/it]

Graph has 17 nodes and 19 edges


100%|██████████| 12/12 [00:00<00:00, 20.82it/s]
 45%|████▌     | 18/40 [33:51<38:58, 106.31s/it]

Graph has 13 nodes and 14 edges


100%|██████████| 12/12 [00:00<00:00, 25.99it/s]
 48%|████▊     | 19/40 [35:07<34:00, 97.17s/it] 

Graph has 18 nodes and 20 edges


100%|██████████| 12/12 [00:00<00:00, 18.89it/s]
 50%|█████     | 20/40 [38:10<40:58, 122.91s/it]

Graph has 18 nodes and 19 edges


100%|██████████| 12/12 [00:00<00:00, 20.21it/s]
 52%|█████▎    | 21/40 [39:52<36:56, 116.63s/it]

Graph has 17 nodes and 19 edges


100%|██████████| 12/12 [00:00<00:00, 19.02it/s]
 55%|█████▌    | 22/40 [48:59<1:13:47, 245.97s/it]

Graph has 23 nodes and 27 edges


100%|██████████| 12/12 [00:00<00:00, 17.82it/s]
 57%|█████▊    | 23/40 [51:19<1:00:40, 214.15s/it]

Graph has 27 nodes and 33 edges


100%|██████████| 12/12 [00:00<00:00, 16.01it/s]
 60%|██████    | 24/40 [54:29<55:10, 206.89s/it]  

Graph has 17 nodes and 19 edges


100%|██████████| 12/12 [00:00<00:00, 16.77it/s]
 62%|██████▎   | 25/40 [56:46<46:27, 185.84s/it]

Graph has 13 nodes and 13 edges


100%|██████████| 12/12 [00:00<00:00, 20.41it/s]
 65%|██████▌   | 26/40 [57:41<34:13, 146.67s/it]

Graph has 23 nodes and 27 edges


100%|██████████| 12/12 [00:00<00:00, 16.75it/s]
 68%|██████▊   | 27/40 [1:00:45<34:11, 157.81s/it]

Graph has 17 nodes and 19 edges


100%|██████████| 12/12 [00:00<00:00, 16.52it/s]
 70%|███████   | 28/40 [1:02:49<29:30, 147.57s/it]

Graph has 23 nodes and 25 edges


100%|██████████| 12/12 [00:00<00:00, 15.01it/s]
 72%|███████▎  | 29/40 [1:05:41<28:23, 154.90s/it]

Graph has 23 nodes and 27 edges


100%|██████████| 12/12 [00:00<00:00, 14.30it/s]
 75%|███████▌  | 30/40 [1:08:36<26:49, 160.94s/it]

Graph has 22 nodes and 25 edges


100%|██████████| 12/12 [00:00<00:00, 13.52it/s]
 78%|███████▊  | 31/40 [1:11:21<24:19, 162.22s/it]

Graph has 24 nodes and 25 edges


100%|██████████| 12/12 [00:00<00:00, 14.23it/s]
 80%|████████  | 32/40 [1:13:26<20:08, 151.12s/it]

Graph has 23 nodes and 25 edges


In [10]:
print(net_stability/num_graphs)
print(net_fidelity/num_graphs)
print(net_interpret/num_graphs)

14.220652876535228
0.5320530414103282
77.0
