In [5]:
import math
import random
import config
from torch_geometric.data import Data
from reward import explanation_reward, similarity, compute_fidelity, similarity_score
from constraint import constraint
from model import GCN_2l, GIN
import torch
import networkx as nx
import matplotlib.pyplot as plt
from MCTS_algo import MCTS
from utils import to_networkx_graph, mutag_dataset, ba2motif_dataset
from subgraph_matching import subgraph_score
from networkx.algorithms.isomorphism import GraphMatcher
from tqdm import tqdm
import torch.nn.functional as F

dataset = mutag_dataset
metric_weights = {'sparse': 1, 'interpret': 1, 'fidelity': 1}
config.metric_weights = metric_weights
fidelity_weights = {'plus': 0.3, 'minus': 0.7}
config.fidelity_weights = fidelity_weights
main_model = GIN(input_dim = dataset[0].x.shape[1], output_dim = 2, multi=True)
main_model.load_state_dict(torch.load('models/GIN_model_MUTAG.pth', map_location=torch.device('cpu'), weights_only=True))

<All keys matched successfully>

In [2]:
config.query_graphs

{'nitro_group': <networkx.classes.graph.Graph at 0x168e538f0>,
 'benzene_ring': <networkx.classes.graph.Graph at 0x16864ec90>,
 'napthalene': <networkx.classes.graph.Graph at 0x168eb00b0>,
 'anthracene': <networkx.classes.graph.Graph at 0x169251eb0>,
 'pyridine': <networkx.classes.graph.Graph at 0x1682cef00>,
 'ethyl': <networkx.classes.graph.Graph at 0x1693ed8b0>,
 'fluoro': <networkx.classes.graph.Graph at 0x1696dbf50>,
 'propyl': <networkx.classes.graph.Graph at 0x168f3e360>,
 'ester_group': <networkx.classes.graph.Graph at 0x1692776b0>,
 'aromatic_oxy': <networkx.classes.graph.Graph at 0x16895e090>,
 'imidazole': <networkx.classes.graph.Graph at 0x168cada60>,
 'amino_benzene': <networkx.classes.graph.Graph at 0x16929b230>,
 'ketone': <networkx.classes.graph.Graph at 0x1697ad280>,
 'cyanide': <networkx.classes.graph.Graph at 0x1697afce0>,
 'iodo': <networkx.classes.graph.Graph at 0x1697ae3f0>,
 'ethene': <networkx.classes.graph.Graph at 0x1697ad700>,
 'chloro': <networkx.classes.gra

In [8]:
net_stability = 0
net_interpret = 0 
net_fidelity = 0        
num_graphs = 40

for k in tqdm(range(num_graphs)):

    config.graph_index = k  # You can change this to analyze different molecules
    graph_index = config.graph_index
    config.alter_graphs = []
    # print(f"Analyzing molecule {graph_index} from MUTAG dataset..")

    # Extract data from the selected graph
    x = dataset[graph_index].x
    edge_index = dataset[graph_index].edge_index
    edge_attr = dataset[graph_index].edge_attr
    edge_list = []

    for i in range(edge_index.size(1)):
        src, dst = edge_index[0, i].item(), edge_index[1, i].item()
        edge_list.append((src, dst))

    # Set edge_attr in config (needed by reward function)
    config.edge_attr = edge_attr
    
    # Initialize and run MCTS
    config.max_edges = 12
    config.allowed = range(len(edge_list))

    mcts = MCTS(main_model, x, edge_list, edge_index, explanation_reward, metric_weights, 
                constraint, C=10, num_simulations=50, rollout_depth=100)

    exec(open("interpret_norm.py").read(), globals())
    
    present_state = set()
    best_subset = set()
    best_reward = [0,0,0,0]

    for _ in range(config.max_edges):
        try:
            result = mcts.search(present_state).state
            present_state = result
            reward = explanation_reward(present_state, metric_weights)
            if(reward[-1] >= best_reward[-1]):
                best_reward = reward
                best_subset = present_state
        except:
            break

    target_edge_list = torch.zeros((2,len(best_subset)), dtype = torch.long)
    last_filled = 0 
    unique_nodes = set()

    for idx,edge in enumerate(edge_list):
        if(idx not in best_subset): continue
        target_edge_list[0][last_filled] = edge[0]
        target_edge_list[1][last_filled] = edge[1]
        unique_nodes.add(edge[0])
        unique_nodes.add(edge[1])
        last_filled+=1

    unique_nodes = sorted(list(unique_nodes))
    mapping = {}
    for idx, node in enumerate(unique_nodes):
        mapping[node] = idx

    for edge in range(target_edge_list.shape[1]):
        target_edge_list[0][edge] = mapping[target_edge_list[0][edge].item()]
        target_edge_list[1][edge] = mapping[target_edge_list[1][edge].item()]

    target_x = config.node_features[list(unique_nodes)]
    target_graph_data = Data(x=target_x, edge_index=target_edge_list)
    # config.alter_graphs.append(target_graph_data)
    config.alter_graphs.append((best_subset,best_reward[-1]))


    # Sample random graphs and get their explanations with the same user metrics preference
    for _ in range(10):

        k = 0.8
        sampled_indices = random.sample(range(len(edge_list)), int(k*len(edge_list)))
        config.allowed = sampled_indices

        present_state = set()
        best_subset = set()
        best_reward = [0,0,0,0]

        mcts = MCTS(main_model, x, edge_list, edge_index, explanation_reward, metric_weights, 
                constraint, C=10, num_simulations=50, rollout_depth=100)

        for _ in range(config.max_edges):
            try:
                result = mcts.search(present_state).state
                present_state = result
                reward = explanation_reward(present_state, metric_weights)
                if(reward[-1] >= best_reward[-1]):
                    best_reward = reward
                    best_subset = present_state
            except:
                break

        config.alter_graphs.append((best_subset,best_reward[-1]))

    # Run MCTS with updated reward function
    config.allowed = range(len(edge_list))
    present_state = set()
    best_subset = set()
    best_reward = [0,0,0,0]

    mcts = MCTS(main_model, x, edge_list, edge_index, similarity_score, metric_weights, 
                constraint, C=10, num_simulations=50, rollout_depth=100)

    for _ in range(config.max_edges):
        try:
            result = mcts.search(present_state).state
            present_state = result
            reward = similarity_score(present_state, metric_weights)
            if(reward[-1] >= best_reward[-1]):
                best_reward = reward
                best_subset = present_state
        except:
            break

    # constraint(best_subset,log=True)
    target_edge_list = torch.zeros((2,len(best_subset)), dtype = torch.long)
    last_filled = 0 
    unique_nodes = set()

    for idx,edge in enumerate(edge_list):
        if(idx not in best_subset): continue
        target_edge_list[0][last_filled] = edge[0]
        target_edge_list[1][last_filled] = edge[1]
        unique_nodes.add(edge[0])
        unique_nodes.add(edge[1])
        last_filled+=1

    unique_nodes = sorted(list(unique_nodes))
    mapping = {}
    for idx, node in enumerate(unique_nodes):
        mapping[node] = idx

    for edge in range(target_edge_list.shape[1]):
        target_edge_list[0][edge] = mapping[target_edge_list[0][edge].item()]
        target_edge_list[1][edge] = mapping[target_edge_list[1][edge].item()]

    target_x = config.node_features[list(unique_nodes)]
    target_graph_data = Data(x=target_x, edge_index=target_edge_list)

    net_interpret += subgraph_score(best_subset)
    stability = 0
    for alter_graph in config.alter_graphs:
        stability += similarity(best_subset, alter_graph[0])
    net_stability += stability
    net_fidelity += compute_fidelity(best_subset, fidelity_weights)

  0%|          | 0/40 [00:00<?, ?it/s]

Graph has 17 nodes and 19 edges


  2%|▎         | 1/40 [02:21<1:31:39, 141.00s/it]

Graph has 13 nodes and 14 edges


In [16]:
print(net_stability/num_graphs)
print(net_fidelity/num_graphs)
print(net_interpret/num_graphs)

5.77143766833143
0.5520170528970352
89.45
