In [11]:
import pickle
import numpy as np
import pandas as pd
from collections import Counter

from scipy.stats import pearsonr

from synplan.utils.loading import load_reaction_rules

import seaborn as sns
import matplotlib.pyplot as plt

In [12]:
def route_to_node(tree, node_id):
    nodes = []
    while node_id:
        nodes.append(node_id)
        node_id = tree.parents[node_id]
    return list(reversed(nodes))

In [13]:
with open("tree_list_hybrid.pickle", "rb") as f:
    tree_list = pickle.load(f)
tree_list = [v for v in tree_list.values()]

In [14]:
with open("training_hybrid/reaction_rules.pickle", "rb") as f:
    reaction_rules = pickle.load(f)

### 1. Extract rule stat

In [15]:
from collections import defaultdict

rule_predicted, rule_applied, rule_solved = set(), set(), set()
label_predicted, label_applied, label_solved = defaultdict(set), defaultdict(set), defaultdict(set)

# --- Count total rules per label ---
label_total = defaultdict(int)
for rule, _ in reaction_rules:
    label_total[rule.meta["label"]] += 1

# --- Traverse MCTS trees ---
for tree in tree_list:
    
    # 1. Get recommended (predicted) rules
    rule_predicted.update(tree.predicted_rules)
    for ri in tree.predicted_rules:
        rule = reaction_rules[ri][0]
        label_predicted[rule.meta["label"]].add(ri)
    
    # 2. Get applied rules
    applied_rules = tree.nodes_rules.values()
    rule_applied.update(applied_rules)
    for ri in applied_rules:
        rule = reaction_rules[ri][0]
        label_applied[rule.meta["label"]].add(ri)

    # 3. Get rules from solved routes
    if not tree.winning_nodes:
        continue
    for wn in tree.winning_nodes:
        solv_route = route_to_node(tree, wn)[1:]
        solv_rules = [tree.nodes_rules[i] for i in solv_route]
        rule_solved.update(solv_rules)
        for ri in solv_rules:
            rule = reaction_rules[ri][0]
            label_solved[rule.meta["label"]].add(ri)

# --- Global statistics ---
res_dict = {
    "TOTAL_TREE": len(tree_list),
    "SOLVED_TREE": sum(bool(tree.winning_nodes) for tree in tree_list),
    "TOTAL_RULES": len(reaction_rules),
    "PREDICTED_RULES": len(rule_predicted),
    "APPLIED_RULES": len(rule_applied),
    "SOLVED_RULES": len(rule_solved),
    "PREDICTED_COVERAGE": len(rule_predicted) / len(reaction_rules),
    "APPLIED_COVERAGE": len(rule_applied) / len(reaction_rules),
    "SOLVED_COVERAGE": len(rule_solved) / len(reaction_rules)
}

# --- Label-wise statistics ---
label_dict = {}
for label in sorted(set(label_predicted) | set(label_applied) | set(label_solved) | set(label_total)):
    total_label_rules = label_total[label]
    label_dict[label] = {
        "TOTAL_RULES": total_label_rules,
        "PREDICTED_RULES": len(label_predicted[label]),
        "APPLIED_RULES": len(label_applied[label]),
        "SOLVED_RULES": len(label_solved[label]),
        "PREDICTED_COVERAGE": len(label_predicted[label]) / total_label_rules if total_label_rules else 0,
        "APPLIED_COVERAGE": len(label_applied[label]) / total_label_rules if total_label_rules else 0,
        "SOLVED_COVERAGE": len(label_solved[label]) / total_label_rules if total_label_rules else 0
    }
res_dict["BY_LABEL"] = label_dict

# final
res_dict

{'TOTAL_TREE': 100,
 'SOLVED_TREE': 17,
 'TOTAL_RULES': 30194,
 'PREDICTED_RULES': 12138,
 'APPLIED_RULES': 2146,
 'SOLVED_RULES': 556,
 'PREDICTED_COVERAGE': 0.40200039742995297,
 'APPLIED_COVERAGE': 0.07107372325627608,
 'SOLVED_COVERAGE': 0.018414254487646553,
 'BY_LABEL': {'radical': {'TOTAL_RULES': 5426,
   'PREDICTED_RULES': 712,
   'APPLIED_RULES': 58,
   'SOLVED_RULES': 10,
   'PREDICTED_COVERAGE': 0.1312200516033911,
   'APPLIED_COVERAGE': 0.010689273866568375,
   'SOLVED_COVERAGE': 0.0018429782528566164},
  'uspto': {'TOTAL_RULES': 24768,
   'PREDICTED_RULES': 11426,
   'APPLIED_RULES': 2088,
   'SOLVED_RULES': 546,
   'PREDICTED_COVERAGE': 0.46132105943152457,
   'APPLIED_COVERAGE': 0.08430232558139535,
   'SOLVED_COVERAGE': 0.022044573643410854}}}