In [1]:
import tkinter as tk
from tkinter import messagebox
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from networkx.drawing.nx_agraph import graphviz_layout  # Ensure Graphviz is installed
import json


In [2]:
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
q2dq = json.load(open("./question_decompositions-devset.json"))
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q2gold[question] = gold
    except Exception as e:
        # If question not found in question_decompositions, this means something went wrong in the proccess maybe in json parsing in prev steps of tree generation
        continue

In [3]:
# Load the dataset
with open('test-devset.json') as f:
    trees = json.load(f)

# Initialize global variables
selected_node = None
tree_data = None
current_tree_index = 0  # Default tree index

# Create the main GUI window
root = tk.Tk()
root.title("Interactive Tree Solver")

# Function to update the displayed tree
# def update_tree(index):
#     global tree_data, G, pos, node_circles, node_colors, selected_node

#     # Load selected tree
#     tree_data = trees[index]
    
#     # Clear previous graph
#     ax.clear()
#     G.clear()
    
#     # Build new graph
#     for node in tree_data:
#         G.add_node(node["idx"], question=node["question_text"])
#         for son in node["sons"]:
#             G.add_edge(node["idx"], son)

#     # Update layout and colors
#     pos = graphviz_layout(G, prog="dot")
#     node_colors = ["skyblue"] * len(G.nodes)
    
#     # Redraw nodes, edges, and labels
#     node_circles = nx.draw_networkx_nodes(G, pos, node_color=node_colors, ax=ax)
#     nx.draw_networkx_edges(G, pos, ax=ax)
#     nx.draw_networkx_labels(G, pos, labels={node["idx"]: node["question_text"] for node in tree_data}, ax=ax)
    
#     # Enable selection on nodes
#     node_circles.set_picker(True)
#     selected_node = None  # Reset selection
#     selected_node_label.config(text="Selected Node: None")
#     result_label.config(text="Result: ")
    
#     # Redraw canvas
#     canvas.draw()

# Label to display gold answer for root node
gold_answer_label = tk.Label(root, text="Gold Answer (Root): ", font=("Arial", 16, "bold"), fg="green")
gold_answer_label.pack(side=tk.TOP, pady=5)

def update_tree(index):
    global tree_data, G, pos, node_circles, node_colors, selected_node

    # Load selected tree
    tree_data = trees[index]

    # Clear previous graph
    ax.clear()
    G.clear()

    # Identify the root node (node without "fa" key)
    root_node = next((node for node in tree_data if "fa" not in node), None)

    # Ensure root is found before proceeding
    if root_node is None:
        messagebox.showerror("Error", "Root node not found in this tree!")
        return

    # Build new graph
    for node in tree_data:
        G.add_node(node["idx"], question=node["question_text"])
        for son in node["sons"]:
            G.add_edge(node["idx"], son)

    # Update layout and colors
    pos = graphviz_layout(G, prog="dot")
    node_colors = ["skyblue"] * len(G.nodes)

    # Redraw nodes, edges, and labels
    node_circles = nx.draw_networkx_nodes(G, pos, node_color=node_colors, ax=ax)
    nx.draw_networkx_edges(G, pos, ax=ax)
    nx.draw_networkx_labels(G, pos, labels={node["idx"]: node["question_text"] for node in tree_data}, ax=ax)

    # Enable selection on nodes
    node_circles.set_picker(True)
    selected_node = None  # Reset selection
    selected_node_label.config(text="Selected Node: None")
    result_label.config(text="Result: ")

    # Update gold answer for the identified root node
    root_question = root_node["question_text"].strip()
    gold_answer = q2gold.get(root_question, "No gold answer available")
    gold_answer_label.config(text=f"Gold Answer (Root): {gold_answer}")

    # Redraw canvas
    canvas.draw()



# Function to handle node selection
def on_node_select(event):
    global selected_node, node_colors

    if event.artist != node_circles:
        return

    # Get the index of the selected node
    selected_index = event.ind[0]
    new_selected_node = list(G.nodes)[selected_index]

    # Reset previous node color
    if selected_node is not None:
        prev_index = list(G.nodes).index(selected_node)
        node_colors[prev_index] = "skyblue"

    # Highlight newly selected node
    selected_node = new_selected_node
    node_colors[selected_index] = "red"

    # Apply changes
    node_circles.set_facecolor(node_colors)
    canvas.draw()

    # Update the label
    selected_node_label.config(text=f"Selected Node: {selected_node} : {tree_data[selected_node]['question_text']}")

# Function to solve a node using CB, OB, or Child method
def solve_node(method):
    if selected_node is None:
        messagebox.showwarning("No Node Selected", "Please select a node first.")
        return

    node_data = next(node for node in tree_data if node["idx"] == selected_node)

    if method == "CB":
        result = node_data["cb_answer"][0]
    elif method == "OB":
        result = node_data["ob_answer"][0]
    elif method == "Child":
        if not node_data["sons"]:
            result = "No child nodes available."
        else:
            result = node_data["child_answer"][0]
    elif method == "Return best answer":
        result = node_data["answer"][0]
    else:
        result = "Invalid method."

    result_label.config(text=f"Result ({method}): {result}")

# Create a matplotlib figure for visualization
fig, ax = plt.subplots(figsize=(8, 6))
G = nx.DiGraph()

# Embed matplotlib figure in tkinter window
canvas = FigureCanvasTkAgg(fig, master=root)
canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=1)
canvas.mpl_connect("pick_event", on_node_select)

# Dropdown menu for tree selection
tree_selection_frame = tk.Frame(root)
tree_selection_frame.pack(side=tk.TOP, pady=5)

tk.Label(tree_selection_frame, text="Select Tree:").pack(side=tk.LEFT)
tree_var = tk.StringVar(root)
tree_var.set("Tree 0")  # Default selection
tree_dropdown = tk.OptionMenu(tree_selection_frame, tree_var, *[f"Tree {i}" for i in range(len(trees))], 
                              command=lambda choice: update_tree(int(choice.split()[1])))
tree_dropdown.pack(side=tk.LEFT)

# Button frame
button_frame = tk.Frame(root)
button_frame.pack(side=tk.BOTTOM, pady=10)

# Create buttons for solving methods
tk.Button(button_frame, text="Solve with CB", command=lambda: solve_node("CB")).pack(side=tk.LEFT, padx=5)
tk.Button(button_frame, text="Solve with OB", command=lambda: solve_node("OB")).pack(side=tk.LEFT, padx=5)
tk.Button(button_frame, text="Solve with Child", command=lambda: solve_node("Child")).pack(side=tk.LEFT, padx=5)
tk.Button(button_frame, text="Return best answer", command=lambda: solve_node("Return best answer")).pack(side=tk.LEFT, padx=5)

# Label to display the selected node
selected_node_label = tk.Label(root, text="Selected Node: None", font=("Arial", 12))
selected_node_label.pack(side=tk.BOTTOM, pady=5)

# Label to display the result
result_label = tk.Label(root, text="Result: ", font=("Arial", 12))
result_label.pack(side=tk.BOTTOM, pady=5)

# Initialize the first tree
update_tree(current_tree_index)

# Function to handle safe exit
def on_close():
    if messagebox.askokcancel("Quit", "Do you really want to exit?"):
        plt.close(fig)  # Close the matplotlib figure to free resources
        root.destroy()  # Properly destroy the tkinter window

# Bind the close event
root.protocol("WM_DELETE_WINDOW", on_close)

# Run the GUI
root.mainloop()


2025-02-25 15:31:42.390 python[40560:3075165] +[IMKClient subclass]: chose IMKClient_Modern
2025-02-25 15:31:42.390 python[40560:3075165] +[IMKInputSession subclass]: chose IMKInputSession_Modern
