# Prioritize notebook

# Loading Modules

In [None]:
from importlib import reload
import question_answering as question_answering
question_answering = reload(question_answering)
from question_answering import aggregate_multihop_answer as aggregate_multihop_answer_hotpotqa, get_cb_answer as get_cb_answer_hotpotqa, get_singlehop_ob_answer as get_singlehop_ob_answer_hotpotqa, get_multihop_ob_answer as get_multihop_ob_answer_hotpotqa

import question_answering_2wiki as question_answering_2wiki
question_answering_2wiki = reload(question_answering_2wiki)
from question_answering_2wiki import aggregate_multihop_answer as aggregate_multihop_answer_2wiki, get_cb_answer as get_cb_answer_2wiki, get_singlehop_ob_answer as get_singlehop_ob_answer_2wiki, get_multihop_ob_answer as get_multihop_ob_answer_2wiki

import question_answering_musique as question_answering_musique
question_answering_musique = reload(question_answering_musique)
from question_answering_musique import aggregate_multihop_answer as aggregate_multihop_answer_musique, get_cb_answer as get_cb_answer_musique, get_singlehop_ob_answer as get_singlehop_ob_answer_musique, get_multihop_ob_answer as get_multihop_ob_answer_musique

from question_answering import togetherai_caller

In [None]:
def get_cb_answer(question, dataset):
    """
    Get the answer from the dataset.
    """
    if dataset == "hotpotqa":
        return get_cb_answer_hotpotqa(question)
    elif dataset == "2wiki":
        return get_cb_answer_2wiki(question)
    elif dataset == "musique":
        return get_cb_answer_musique(question)
    else:
        raise ValueError("Invalid dataset name. Choose from 'hotpotqa', '2wiki', or 'musique'.")
    
def get_singlehop_ob_answer(question, topic_entities, dataset):
    """
    Get the single-hop open-book answer from the dataset.
    """
    if dataset == "hotpotqa":
        return get_singlehop_ob_answer_hotpotqa(question, topic_entities)
    elif dataset == "2wiki":
        return get_singlehop_ob_answer_2wiki(question, topic_entities)
    elif dataset == "musique":
        return get_singlehop_ob_answer_musique(question, topic_entities)
    else:
        raise ValueError("Invalid dataset name. Choose from 'hotpotqa', '2wiki', or 'musique'.")
    
def get_multihop_ob_answer(question, tree, dataset):
    """
    Get the multi-hop open-book answer from the dataset.
    """
    if dataset == "hotpotqa":
        return get_multihop_ob_answer_hotpotqa(question, tree)
    elif dataset == "2wiki":
        return get_multihop_ob_answer_2wiki(question, tree)
    elif dataset == "musique":
        return get_multihop_ob_answer_musique(question, tree)
    else:
        raise ValueError("Invalid dataset name. Choose from 'hotpotqa', '2wiki', or 'musique'.")
    
def aggregate_multihop_answer(question, tree, dataset):
    """
    Aggregate the multi-hop answer from the dataset.
    """
    if dataset == "hotpotqa":
        return aggregate_multihop_answer_hotpotqa(question, tree)
    elif dataset == "2wiki":
        return aggregate_multihop_answer_2wiki(question, tree)
    elif dataset == "musique":
        return aggregate_multihop_answer_musique(question, tree)
    else:
        raise ValueError("Invalid dataset name. Choose from 'hotpotqa', '2wiki', or 'musique'.")
    


In [None]:
def are_answers_equivalent_using_llm(gold, candidate):
    """
    Compare two answers to determine their semantic equivalence using an LLM.
    Returns True if answers are equivalent; False otherwise.

    Args:
        gold (str): The gold answer.
        candidate (str): The candidate answer.
    """
    # Few-shot prompt
    prompt = f"""
    I want you to compare two answers and determine if they are equivalent.
    You should respond with "Yes" if the two answers have the same meaning, even if they are written differently. 
    Respond with "No" if they do not have the same meaning.

    Here are some examples:

    Gold Answer: "two"
    Candidate Answer: "2"
    Are these answers equivalent? Yes

    Gold Answer: "Paris"
    Candidate Answer: "the capital city of France"
    Are these answers equivalent? Yes

    Gold Answer: "four hundred years"
    Candidate Answer: "400 years"
    Are these answers equivalent? Yes

    Gold Answer: "climate change"
    Candidate Answer: "global warming"
    Are these answers equivalent? Yes

    Gold Answer: "The Eiffel Tower"
    Candidate Answer: "a landmark in Paris"
    Are these answers equivalent? Yes

    Gold Answer: "3 meters"
    Candidate Answer: "300 centimeters"
    Are these answers equivalent? Yes

    Gold Answer: "apple"
    Candidate Answer: "orange"
    Are these answers equivalent? No

    Gold Answer: "2023"
    Candidate Answer: "2022"
    Are these answers equivalent? No

    Gold Answer: "Canada"
    Candidate Answer: "United States"
    Are these answers equivalent? No

    Now, here is the pair for you to evaluate:

    Gold Answer: "{gold}"
    Candidate Answer: "{candidate}"
    Are these answers equivalent?
    """

    # Make API call
    try:
        # print("Prompt:", prompt)
        # Use LLM API to get the reformulated question
        response, tag = togetherai_caller.req2provider(prompt=prompt, max_tokens=None, stop= None, use_cache=True)
        response = response[0]
        answer = response['message']['content'].strip().lower()
        
        if "yes" in answer:
            return True
        elif "no" in answer:
            return False
        else:
            raise ValueError(f"Unexpected LLM response: {answer}")

    except Exception as e:
        print(f"Error during LLM API call: {e}")
        return False


In [None]:
are_answers_equivalent_using_llm("Two", "2")

In [None]:
import re
import string
import copy

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

# Some Analysis

In [None]:
import json

# Load the raw data and question decompositions
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_dev_random_100.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data.extend(raw_data_2wiki)
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_dev_random_100.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'
raw_data.extend(raw_data_musique)

q2dq = json.load(open("./question_decompositions-devset-hotpotqa.json"))
q2dq_2wiki = json.load(open("./question_decompositions-devset-2wiki.json"))
q2dq_musique = json.load(open("./question_decompositions-devset-musique.json"))
q2dq.update(q2dq_2wiki)
q2dq.update(q2dq_musique)

# Create q2gold map
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q_type = item['dataset']
        q2gold[question] = (gold, q_type)
    except Exception as e:
        # Skip if question not found in question_decompositions
        print(f"Error processing item: {item}, Error: {e}")
        continue

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)

# Initialize counters
correct_answers = 0
total_parent_nodes = 0
# Initialize lists to store logprobs
correct_logprobs = []
incorrect_logprobs = []

# Prepare the dataset
X = []  # Features (logprobs)
y = []  # Labels (1 for correct, 0 for incorrect)

# Loop through each example in the data
for example in data:
    for node in example:
        # Check if the node is a parent node (no "fa" entry)
        if "fa" not in node:
            total_parent_nodes += 1
            question_text = node.get("question_text", "").strip()
            answer = node.get("answer", [None])[0]  # Extract the best answer
            # Get the gold answer from q2gold
            if question_text in q2gold:
                gold_answer, _ = q2gold[question_text]
                print(f"Question: {question_text}, Gold: {gold_answer}, answer: {answer}")
                # Compare cb_answer_text with the gold answer
                if are_answers_equivalent_using_llm(gold_answer, answer):
                    correct_answers += 1


# Calculate the accuracy
accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0

# Print the results
print(f"Total parent nodes: {total_parent_nodes}")
print(f"Times probtree answer matches gold answer: {correct_answers}")
print(f"correct answers match rate for probtree answers: {accuracy:.2f}%")

# Closed book answers

In [None]:
import json

# Load the raw data and question decompositions
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_dev_random_100.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data.extend(raw_data_2wiki)
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_dev_random_100.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'
raw_data.extend(raw_data_musique)

q2dq = json.load(open("./question_decompositions-devset-hotpotqa.json"))
q2dq_2wiki = json.load(open("./question_decompositions-devset-2wiki.json"))
q2dq_musique = json.load(open("./question_decompositions-devset-musique.json"))
q2dq.update(q2dq_2wiki)
q2dq.update(q2dq_musique)

# Create q2gold map
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q_type = item['dataset']
        q2gold[question] = (gold, q_type)
    except Exception as e:
        # Skip if question not found in question_decompositions
        continue

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)

# Initialize counters
correct_cb_answers = 0
total_parent_nodes = 0
# Initialize lists to store logprobs
correct_logprobs = []
incorrect_logprobs = []
# Initialize lists to store lengths
correct_lengths = []
incorrect_lengths = []

# Prepare the dataset
X = []  # Features (logprobs)
y = []  # Labels (1 for correct, 0 for incorrect)

# Loop through each example in the data
for example in data:
    for node in example:
        # Check if the node is a parent node (no "fa" entry)
        if "fa" not in node:
            total_parent_nodes += 1
            question_text = node.get("question_text", "").strip()
            cb_answer = node.get("cb_answer", [None])  # Extract the cb_answer
            cb_answer_text = cb_answer[0]
            cb_logprob = cb_answer[1]  # Extract logprobs
            cb_logprobs = cb_answer[3]  # Extract logprobs, now i am taking the whole sequence

            # Get the gold answer from q2gold
            if question_text in q2gold:
                gold_answer, _ = q2gold[question_text]

                # Compare cb_answer_text with the gold answer
                # if normalize_answer(cb_answer_text) == normalize_answer(gold_answer):
                if are_answers_equivalent_using_llm(gold_answer, cb_answer_text):
                    print(f"CB Answer: {cb_answer_text}, Gold Answer: {gold_answer} : Correct")
                    if cb_logprob > -10:
                        correct_cb_answers += 1
                        correct_logprobs.append(cb_logprob)  # Save logprobs for correct cases
                        correct_lengths.append(len(cb_logprobs))  # Save lengths for correct cases
                        # X.append([cb_logprob])  # Feature
                        X.append(cb_logprobs)  # Feature
                        y.append(1)  # Label (correct)
                else:
                    print(f"CB Answer: {cb_answer_text}, Gold Answer: {gold_answer} : Incorrect")
                    if cb_logprob > -10:
                        incorrect_logprobs.append(cb_logprob)  # Save logprobs for incorrect cases
                        incorrect_lengths.append(len(cb_logprobs))
                        # X.append([cb_logprob])  # Feature
                        X.append(cb_logprobs)  # Feature
                        y.append(0)  # Label (incorrect)

# Calculate the accuracy
accuracy = (correct_cb_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0

# Print the results
print(f"Total parent nodes: {total_parent_nodes}")
print(f"Times cb_answer matches gold answer: {correct_cb_answers}")
print(f"Closed-book match rate: {accuracy:.2f}%")

In [None]:
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.preprocessing import StandardScaler
import numpy as np

# Pad/truncate sequences to the maximum log prob sequence length
max_length = max(len(seq) for seq in X)
X_cb_padded = pad_sequences(X, maxlen=max_length, padding="post", truncating="post", value=-100, dtype="float32")

In [None]:
import numpy as np

# Calculate statistics for correct cases
correct_mean = np.mean(correct_logprobs) if correct_logprobs else 0
correct_std = np.std(correct_logprobs) if correct_logprobs else 0

# Calculate statistics for incorrect cases
incorrect_mean = np.mean(incorrect_logprobs) if incorrect_logprobs else 0
incorrect_std = np.std(incorrect_logprobs) if incorrect_logprobs else 0

# Print the results
print("Correct Cases:")
print(f"  Number of cases: {len(correct_logprobs)}")
print(f"  Mean logprobs: {correct_mean:.4f}")
print(f"  Standard deviation of logprobs: {correct_std:.4f}")

print("\nIncorrect Cases:")
print(f"  Number of cases: {len(incorrect_logprobs)}")
print(f"  Mean logprobs: {incorrect_mean:.4f}")
print(f"  Standard deviation of logprobs: {incorrect_std:.4f}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
# Calculate statistics for correct cases
correct_mean = np.mean(correct_logprobs) if correct_logprobs else 0
correct_std = np.std(correct_logprobs) if correct_logprobs else 0

# Calculate statistics for incorrect cases
incorrect_mean = np.mean(incorrect_logprobs) if incorrect_logprobs else 0
incorrect_std = np.std(incorrect_logprobs) if incorrect_logprobs else 0

# Print the results
print("Correct Cases:")
print(f"  Number of cases: {len(correct_logprobs)}")
print(f"  Mean logprobs: {correct_mean:.4f}")
print(f"  Standard deviation of logprobs: {correct_std:.4f}")

print("\nIncorrect Cases:")
print(f"  Number of cases: {len(incorrect_logprobs)}")
print(f"  Mean logprobs: {incorrect_mean:.4f}")
print(f"  Standard deviation of logprobs: {incorrect_std:.4f}")

# Visualizations
plt.figure(figsize=(12, 6))

# Histogram of logprobs for correct and incorrect cases
plt.subplot(1, 2, 1)
sns.histplot(correct_logprobs, color="green", label="Correct", kde=True, bins=20)
sns.histplot(incorrect_logprobs, color="red", label="Incorrect", kde=True, bins=20)
plt.title("Distribution of Logprobs")
plt.xlabel("Logprobs")
plt.ylabel("Frequency")
plt.legend()

# Boxplot of logprobs for correct and incorrect cases
plt.subplot(1, 2, 2)
sns.boxplot(data=[correct_logprobs, incorrect_logprobs], palette=["green", "red"])
plt.xticks([0, 1], ["Correct", "Incorrect"])
plt.title("Boxplot of Logprobs")
plt.xlabel("Case Type")
plt.ylabel("Logprobs")

plt.tight_layout()
plt.show()


In [None]:
# Compute mean and standard deviation
cb_logprob_correct_mean = np.mean(correct_logprobs)
cb_logprob_correct_std = np.std(correct_logprobs)
cb_logprob_incorrect_mean = np.mean(incorrect_logprobs)
cb_logprob_incorrect_std = np.std(incorrect_logprobs)

print("Correct - Mean:", cb_logprob_correct_mean, "Std:", cb_logprob_correct_std)
print("Incorrect - Mean:", cb_logprob_incorrect_mean, "Std:", cb_logprob_incorrect_std)

In [None]:
import matplotlib.pyplot as plt

# Plot histogram of log_probs
plt.hist(correct_logprobs, bins=5, edgecolor='black')
plt.title("Distribution of log_probs")
plt.xlabel("log_prob")
plt.ylabel("Frequency")
plt.show()

# Plot histogram of log_probs
plt.hist(incorrect_logprobs, bins=5, edgecolor='black')
plt.title("Distribution of log_probs")
plt.xlabel("log_prob")
plt.ylabel("Frequency")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Visualizations
plt.figure(figsize=(12, 6))

# Histogram of logprobs for correct and incorrect cases
plt.subplot(1, 2, 1)
sns.histplot(correct_lengths, color="green", label="Correct", kde=True, bins=20)
sns.histplot(incorrect_lengths, color="red", label="Incorrect", kde=True, bins=20)
plt.title("Distribution of Logprobs")
plt.xlabel("Logprobs")
plt.ylabel("Frequency")
plt.legend()

# Boxplot of logprobs for correct and incorrect cases
plt.subplot(1, 2, 2)
sns.boxplot(data=[correct_lengths, incorrect_lengths], palette=["green", "red"])
plt.xticks([0, 1], ["Correct", "Incorrect"])
plt.title("Boxplot of Logprobs")
plt.xlabel("Case Type")
plt.ylabel("Logprobs")

plt.tight_layout()
plt.show()


In [None]:
# Compute mean and standard deviation
cb_length_correct_mean = np.mean(correct_lengths)
cb_length_correct_std = np.std(correct_lengths)
cb_length_incorrect_mean = np.mean(incorrect_lengths)
cb_length_incorrect_std = np.std(incorrect_lengths)

print("Correct - Mean:", cb_length_correct_mean, "Std:", cb_length_correct_std)
print("Incorrect - Mean:", cb_length_incorrect_mean, "Std:", cb_length_incorrect_std)

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
import matplotlib.pyplot as plt


# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_cb_padded, y, test_size=0.2, random_state=42)

# Check class distribution
print("Class Distribution:", np.bincount(y))

# Train a Logistic Regression model with class weights
model = LogisticRegression(class_weight="balanced")
model.fit(X_train, y_train)

# Make predictions
y_pred_proba = model.predict_proba(X_test)[:, 1]  # Probabilities for ROC curve

# Adjust threshold
for i in range(0, 10):
    threshold = i/10  # Experiment with different thresholds
    # threshold = 0.2  # Experiment with different thresholds
    y_pred_adj = (y_pred_proba >= threshold).astype(int)

    # Evaluate the model
    accuracy = accuracy_score(y_test, y_pred_adj)
    precision = precision_score(y_test, y_pred_adj)
    recall = recall_score(y_test, y_pred_adj)
    f1 = f1_score(y_test, y_pred_adj)
    roc_auc = roc_auc_score(y_test, y_pred_proba)

    print(f"Model Evaluation (Adjusted Threshold): {threshold}")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")
    print(f"  F1-Score: {f1:.4f}")
    print(f"  ROC AUC: {roc_auc:.4f}")

    # Plot the ROC curve
    fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc:.4f})")
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend()
    plt.show()

# Child answer

In [None]:
import json

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)

# Initialize counters
correct_child_answers = 0
total_parent_nodes = 0
# Initialize lists to store logprobs
correct_logprobs = []
incorrect_logprobs = []
# Initialize lists to store lengths
correct_lengths = []
incorrect_lengths = []

# Prepare the dataset
X = []  # Features (logprobs)
y = []  # Labels (1 for correct, 0 for incorrect)

# Loop through each example in the data
for example in data:
    for node in example:
        # Check if the node is a parent node (no "fa" entry)
        if "fa" not in node:
            total_parent_nodes += 1
            question_text = node.get("question_text", "").strip()
            child_answer = node.get("child_answer", [None])  # Extract the child_answer
            child_answer_text = child_answer[0]
            child_logprob = child_answer[1]  # Extract logprobs
            child_logprobs = child_answer[3]  # Extract logprobs

            # Get the gold answer from q2gold
            if question_text in q2gold:
                gold_answer, _ = q2gold[question_text]

                # Compare child_answer_text with the gold answer
                # if normalize_answer(child_answer_text) == normalize_answer(gold_answer):
                if are_answers_equivalent_using_llm(gold_answer, child_answer_text):
                    print(f"OB Answer: {child_answer_text}, Gold Answer: {gold_answer} : Correct")
                    correct_child_answers += 1
                    correct_logprobs.append(child_logprob)  # Save logprobs for correct cases
                    correct_lengths.append(len(child_logprobs))
                    # X.append([child_logprob])  # Feature
                    X.append(child_logprobs)  # Feature
                    y.append(1)  # Label (correct)
                else:
                    print(f"OB Answer: {child_answer_text}, Gold Answer: {gold_answer} : Incorrect")
                    if child_logprob > -10:
                        incorrect_logprobs.append(child_logprob)  # Save logprobs for incorrect cases
                        incorrect_lengths.append(len(child_logprobs))
                        # X.append([child_logprob])  # Feature
                        X.append(child_logprobs)  # Feature
                        y.append(0)  # Label (incorrect)

# Calculate the accuracy
accuracy = (correct_child_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0

# Print the results
print(f"Total parent nodes: {total_parent_nodes}")
print(f"Times child_answer matches gold answer: {correct_child_answers}")
print(f"Open-book match rate: {accuracy:.2f}%")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
# Calculate statistics for correct cases
correct_mean = np.mean(correct_logprobs) if correct_logprobs else 0
correct_std = np.std(correct_logprobs) if correct_logprobs else 0

# Calculate statistics for incorrect cases
incorrect_mean = np.mean(incorrect_logprobs) if incorrect_logprobs else 0
incorrect_std = np.std(incorrect_logprobs) if incorrect_logprobs else 0

# Print the results
print("Correct Cases:")
print(f"  Number of cases: {len(correct_logprobs)}")
print(f"  Mean logprobs: {correct_mean:.4f}")
print(f"  Standard deviation of logprobs: {correct_std:.4f}")

print("\nIncorrect Cases:")
print(f"  Number of cases: {len(incorrect_logprobs)}")
print(f"  Mean logprobs: {incorrect_mean:.4f}")
print(f"  Standard deviation of logprobs: {incorrect_std:.4f}")

# Visualizations
plt.figure(figsize=(12, 6))

# Histogram of logprobs for correct and incorrect cases
plt.subplot(1, 2, 1)
sns.histplot(correct_logprobs, color="green", label="Correct", kde=True, bins=20)
sns.histplot(incorrect_logprobs, color="red", label="Incorrect", kde=True, bins=20)
plt.title("Distribution of Logprobs")
plt.xlabel("Logprobs")
plt.ylabel("Frequency")
plt.legend()

# Boxplot of logprobs for correct and incorrect cases
plt.subplot(1, 2, 2)
sns.boxplot(data=[correct_logprobs, incorrect_logprobs], palette=["green", "red"])
plt.xticks([0, 1], ["Correct", "Incorrect"])
plt.title("Boxplot of Logprobs")
plt.xlabel("Case Type")
plt.ylabel("Logprobs")

plt.tight_layout()
plt.show()


In [None]:
# Compute mean and standard deviation
child_logprob_correct_mean = np.mean(correct_logprobs)
child_logprob_correct_std = np.std(correct_logprobs)
child_logprob_incorrect_mean = np.mean(incorrect_logprobs)
child_logprob_incorrect_std = np.std(incorrect_logprobs)

print("Correct - Mean:", child_logprob_correct_mean, "Std:", child_logprob_correct_std)
print("Incorrect - Mean:", child_logprob_incorrect_mean, "Std:", child_logprob_incorrect_std)

In [None]:
import matplotlib.pyplot as plt

# Plot histogram of log_probs
plt.hist(correct_logprobs, bins=5, edgecolor='black')
plt.title("Distribution of log_probs")
plt.xlabel("log_prob")
plt.ylabel("Frequency")
plt.show()

# Plot histogram of log_probs
plt.hist(incorrect_logprobs, bins=5, edgecolor='black')
plt.title("Distribution of log_probs")
plt.xlabel("log_prob")
plt.ylabel("Frequency")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Visualizations
plt.figure(figsize=(12, 6))

# Histogram of logprobs for correct and incorrect cases
plt.subplot(1, 2, 1)
sns.histplot(correct_lengths, color="green", label="Correct", kde=True, bins=20)
sns.histplot(incorrect_lengths, color="red", label="Incorrect", kde=True, bins=20)
plt.title("Distribution of Logprobs")
plt.xlabel("Logprobs")
plt.ylabel("Frequency")
plt.legend()

# Boxplot of logprobs for correct and incorrect cases
plt.subplot(1, 2, 2)
sns.boxplot(data=[correct_lengths, incorrect_lengths], palette=["green", "red"])
plt.xticks([0, 1], ["Correct", "Incorrect"])
plt.title("Boxplot of Logprobs")
plt.xlabel("Case Type")
plt.ylabel("Logprobs")

plt.tight_layout()
plt.show()

In [None]:
# Compute mean and standard deviation
child_length_correct_mean = np.mean(correct_lengths)
child_length_correct_std = np.std(correct_lengths)
child_length_incorrect_mean = np.mean(incorrect_lengths)
child_length_incorrect_std = np.std(incorrect_lengths)

print("Correct - Mean:", child_length_correct_mean, "Std:", child_length_correct_std)
print("Incorrect - Mean:", child_length_incorrect_mean, "Std:", child_length_incorrect_std)

# Openbook answer

In [None]:
import json

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)

# Initialize counters
correct_ob_answers = 0
total_parent_nodes = 0
# Initialize lists to store logprobs
correct_logprobs = []
incorrect_logprobs = []
# Initialize lists to store lengths
correct_lengths = []
incorrect_lengths = []

# Prepare the dataset
X = []  # Features (logprobs)
y = []  # Labels (1 for correct, 0 for incorrect)

# Loop through each example in the data
for example in data:
    for node in example:
        # Check if the node is a parent node (no "fa" entry)
        if "fa" not in node:
            total_parent_nodes += 1
            question_text = node.get("question_text", "").strip()
            ob_answer = node.get("ob_answer", [None])  # Extract the ob_answer
            ob_answer_text = ob_answer[0]
            ob_logprob = ob_answer[1]  # Extract logprobs
            ob_logprobs = ob_answer[3]  # Extract logprobs

            # Get the gold answer from q2gold
            if question_text in q2gold:
                gold_answer, _ = q2gold[question_text]

                # Compare ob_answer_text with the gold answer
                # if normalize_answer(ob_answer_text) == normalize_answer(gold_answer):
                if are_answers_equivalent_using_llm(gold_answer, ob_answer_text):
                    print(f"OB Answer: {ob_answer_text}, Gold Answer: {gold_answer} : Correct")
                    correct_ob_answers += 1
                    correct_logprobs.append(ob_logprob)  # Save logprobs for correct cases
                    correct_lengths.append(len(ob_logprobs))
                    # X.append([ob_logprob])  # Feature
                    X.append(ob_logprobs)  # Feature
                    y.append(1)  # Label (correct)
                else:
                    print(f"OB Answer: {ob_answer_text}, Gold Answer: {gold_answer} : Incorrect")
                    if ob_logprob > -10:
                        incorrect_logprobs.append(ob_logprob)  # Save logprobs for incorrect cases
                        incorrect_lengths.append(len(ob_logprobs))
                        # X.append([ob_logprob])  # Feature
                        X.append(ob_logprobs)  # Feature
                        y.append(0)  # Label (incorrect)

# Calculate the accuracy
accuracy = (correct_ob_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0

# Print the results
print(f"Total parent nodes: {total_parent_nodes}")
print(f"Times ob_answer matches gold answer: {correct_ob_answers}")
print(f"Open-book match rate: {accuracy:.2f}%")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Calculate statistics for correct cases
correct_mean = np.mean(correct_logprobs) if correct_logprobs else 0
correct_std = np.std(correct_logprobs) if correct_logprobs else 0

# Calculate statistics for incorrect cases
incorrect_mean = np.mean(incorrect_logprobs) if incorrect_logprobs else 0
incorrect_std = np.std(incorrect_logprobs) if incorrect_logprobs else 0

# Print the results
print("Correct Cases:")
print(f"  Number of cases: {len(correct_logprobs)}")
print(f"  Mean logprobs: {correct_mean:.4f}")
print(f"  Standard deviation of logprobs: {correct_std:.4f}")

print("\nIncorrect Cases:")
print(f"  Number of cases: {len(incorrect_logprobs)}")
print(f"  Mean logprobs: {incorrect_mean:.4f}")
print(f"  Standard deviation of logprobs: {incorrect_std:.4f}")

# Visualizations
plt.figure(figsize=(12, 6))

# Histogram of logprobs for correct and incorrect cases
plt.subplot(1, 2, 1)
sns.histplot(correct_logprobs, color="green", label="Correct", kde=True, bins=20)
sns.histplot(incorrect_logprobs, color="red", label="Incorrect", kde=True, bins=20)
plt.title("Distribution of Logprobs")
plt.xlabel("Logprobs")
plt.ylabel("Frequency")
plt.legend()

# Boxplot of logprobs for correct and incorrect cases
plt.subplot(1, 2, 2)
sns.boxplot(data=[correct_logprobs, incorrect_logprobs], palette=["green", "red"])
plt.xticks([0, 1], ["Correct", "Incorrect"])
plt.title("Boxplot of Logprobs")
plt.xlabel("Case Type")
plt.ylabel("Logprobs")

plt.tight_layout()
plt.show()

In [None]:
# Compute mean and standard deviation
ob_logprob_correct_mean = np.mean(correct_logprobs)
ob_logprob_correct_std = np.std(correct_logprobs)
ob_logprob_incorrect_mean = np.mean(incorrect_logprobs)
ob_logprob_incorrect_std = np.std(incorrect_logprobs)

print("Correct - Mean:", ob_logprob_correct_mean, "Std:", ob_logprob_correct_std)
print("Incorrect - Mean:", ob_logprob_incorrect_mean, "Std:", ob_logprob_incorrect_std)

In [None]:
import matplotlib.pyplot as plt

# Plot histogram of log_probs
plt.hist(correct_logprobs, bins=5, edgecolor='black')
plt.title("Distribution of log_probs")
plt.xlabel("log_prob")
plt.ylabel("Frequency")
plt.show()

# Plot histogram of log_probs
plt.hist(incorrect_logprobs, bins=5, edgecolor='black')
plt.title("Distribution of log_probs")
plt.xlabel("log_prob")
plt.ylabel("Frequency")
plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Visualizations
plt.figure(figsize=(12, 6))

# Histogram of logprobs for correct and incorrect cases
plt.subplot(1, 2, 1)
sns.histplot(correct_lengths, color="green", label="Correct", kde=True, bins=20)
sns.histplot(incorrect_lengths, color="red", label="Incorrect", kde=True, bins=20)
plt.title("Distribution of Logprobs")
plt.xlabel("Logprobs")
plt.ylabel("Frequency")
plt.legend()

# Boxplot of logprobs for correct and incorrect cases
plt.subplot(1, 2, 2)
sns.boxplot(data=[correct_lengths, incorrect_lengths], palette=["green", "red"])
plt.xticks([0, 1], ["Correct", "Incorrect"])
plt.title("Boxplot of Logprobs")
plt.xlabel("Case Type")
plt.ylabel("Logprobs")

plt.tight_layout()
plt.show()

In [None]:
# Compute mean and standard deviation
ob_length_correct_mean = np.mean(correct_lengths)
ob_length_correct_std = np.std(correct_lengths)
ob_length_incorrect_mean = np.mean(incorrect_lengths)
ob_length_incorrect_std = np.std(incorrect_lengths)

print("Correct - Mean:", ob_length_correct_mean, "Std:", ob_length_correct_std)
print("Incorrect - Mean:", ob_length_incorrect_mean, "Std:", ob_length_incorrect_std)

In [None]:
from tensorflow.keras.preprocessing.sequence import pad_sequences

# Pad/truncate sequences to the maximum log prob sequence length
max_length = max(len(seq) for seq in X)
X_ob_padded = pad_sequences(X, maxlen=max_length, padding="post", truncating="post", value=-100, dtype="float32")
X_ob_padded

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
import matplotlib.pyplot as plt

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_ob_padded, y, test_size=0.2, random_state=42)

# Check class distribution
print("Class Distribution:", np.bincount(y))

# Train a Logistic Regression model with class weights
model = LogisticRegression(class_weight="balanced")
model.fit(X_train, y_train)

# Make predictions
y_pred_proba = model.predict_proba(X_test)[:, 1]  # Probabilities for ROC curve

# loop on thresholds from 0.1 to 0.9
for i in range(1, 10):
    threshold = i / 10

    # Adjust threshold
    # threshold = i  # Experiment with different thresholds
    y_pred_adj = (y_pred_proba >= threshold).astype(int)

    # Evaluate the model
    accuracy = accuracy_score(y_test, y_pred_adj)
    precision = precision_score(y_test, y_pred_adj)
    recall = recall_score(y_test, y_pred_adj)
    f1 = f1_score(y_test, y_pred_adj)
    roc_auc = roc_auc_score(y_test, y_pred_proba)

    print("Model Evaluation (Adjusted Threshold):")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Precision: {precision:.4f}")
    print(f"  Recall: {recall:.4f}")
    print(f"  F1-Score: {f1:.4f}")
    print(f"  ROC AUC: {roc_auc:.4f}")

    # Plot the ROC curve
    fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc:.4f})")
    plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend()
    plt.show()

# Training a Random forest classifier for each type of answer : 

In [None]:
import json

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)

# Initialize counters
correct_answers = 0
total_parent_nodes = 0
# Initialize lists to store logprobs
correct_logprobs = []
incorrect_logprobs = []

# Prepare the dataset
X = []  # Features (logprobs)
y = []  # Labels (1 for correct, 0 for incorrect)

# Loop through each example in the data
for example in data:
    for node in example:
        # Check if the node is a parent node (no "fa" entry)
        if "fa" not in node:
            total_parent_nodes += 1
            question_text = node.get("question_text", "").strip()
            ob_answer = node.get("ob_answer", [None])  # Extract the ob_answer
            ob_answer_text = ob_answer[0]
            cb_answer_text = node.get("cb_answer", [None])[0]


            # Get the gold answer from q2gold
            if question_text in q2gold:
                gold_answer, _ = q2gold[question_text]

                # Compare ob_answer_text with the gold answer
                # if normalize_answer(ob_answer_text) == normalize_answer(gold_answer) or normalize_answer(cb_answer_text) == normalize_answer(gold_answer):
                if are_answers_equivalent_using_llm(gold_answer, ob_answer_text) or are_answers_equivalent_using_llm(gold_answer, cb_answer_text):
                    print(f"OB Answer: {ob_answer_text}, CB Answer: {cb_answer_text}, Gold Answer: {gold_answer} : Correct")
                    correct_answers += 1
                else:
                    print(f"OB Answer: {ob_answer_text}, CB Answer: {cb_answer_text}, Gold Answer: {gold_answer} : Incorrect")

# Calculate the accuracy
accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0

# Print the results
print(f"Total parent nodes: {total_parent_nodes}")
print(f"Times cb_answer & ob_answer matches gold answer: {correct_answers}")
print(f"Closed-book & Open-book match rate: {accuracy:.2f}%")

In [None]:
import numpy as np

# Extract features and labels for parent nodes
X_cb = []  # CB logprobs
X_ob = []  # OB logprobs
X_child = []  # Child logprobs
y_cb = []  # 1 if CB answer matches final answer, else 0
y_ob = []  # 1 if OB answer matches final answer, else 0
y_child = []  # 1 if Child answer matches final answer, else 0

for example in data:
    for node in example:
        if "fa" not in node:  # Only process root nodes (original questions)
            cb_answer = node.get("cb_answer", [None, None])[0]
            ob_answer = node.get("ob_answer", [None, None])[0]
            child_answer = node.get("child_answer", [None, None])[0]
            cb_logprob = node.get("cb_answer", [None, None])[3]
            ob_logprob = node.get("ob_answer", [None, None])[3]
            child_logprob = node.get("child_answer", [None, None])[3]
            question_text = node.get("question_text", "").strip()
            final_answer, _ = q2gold[question_text]

            # Check if CB answer matches final answer
            if are_answers_equivalent_using_llm(final_answer, cb_answer):
                y_cb.append(1)
            else:
                y_cb.append(0)
            X_cb.append(cb_logprob)

            # Check if OB answer matches final answer
            if are_answers_equivalent_using_llm(final_answer, ob_answer):
                y_ob.append(1)
            else:
                y_ob.append(0)
            X_ob.append(ob_logprob)

            # Check if Child answer matches final answer
            if child_answer and are_answers_equivalent_using_llm(final_answer, child_answer):
                y_child.append(1)
            else:
                y_child.append(0)
            X_child.append(child_logprob)

# Pad/truncate sequences to the maximum log prob sequence length
cb_max_length = max(len(seq) for seq in X_cb)
X_cb_padded = pad_sequences(X_cb, maxlen=cb_max_length, padding="post", truncating="post", value=-100, dtype="float32")
ob_max_length = max(len(seq) for seq in X_ob)
X_ob_padded = pad_sequences(X_ob, maxlen=ob_max_length, padding="post", truncating="post", value=-100, dtype="float32")
child_max_length = max(len(seq) for seq in X_child)
X_child_padded = pad_sequences(X_child, maxlen=child_max_length, padding="post", truncating="post", value=-100, dtype="float32")

# Convert to numpy arrays
X_cb = np.array(X_cb_padded)
X_ob = np.array(X_ob_padded)
X_child = np.array(X_child_padded)
y_cb = np.array(y_cb)
y_ob = np.array(y_ob)
y_child = np.array(y_child)

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Split into train and test sets
X_cb_train, X_cb_test, y_cb_train, y_cb_test = train_test_split(X_cb, y_cb, test_size=0.2, random_state=42)
X_ob_train, X_ob_test, y_ob_train, y_ob_test = train_test_split(X_ob, y_ob, test_size=0.2, random_state=42)
X_child_train, X_child_test, y_child_train, y_child_test = train_test_split(X_child, y_child, test_size=0.2, random_state=42)

# Train CB model
# cb_model = LogisticRegression(class_weight="balanced")
cb_model = RandomForestClassifier(class_weight="balanced", random_state=42)
cb_model.fit(X_cb_train, y_cb_train)

# Evaluate CB model
y_cb_pred = cb_model.predict(X_cb_test)
print("CB Model Accuracy:", accuracy_score(y_cb_test, y_cb_pred))

# Train OB model
# ob_model = LogisticRegression(class_weight="balanced")
ob_model = RandomForestClassifier(class_weight="balanced", random_state=42)
ob_model.fit(X_ob_train, y_ob_train)

# Evaluate OB model
y_ob_pred = ob_model.predict(X_ob_test)
print("OB Model Accuracy:", accuracy_score(y_ob_test, y_ob_pred))

# Train Child model
# child_model = LogisticRegression(class_weight="balanced")
child_model = RandomForestClassifier(class_weight="balanced", random_state=42)
child_model.fit(X_child_train, y_child_train)

# Evaluate Child model
y_child_pred = child_model.predict(X_child_test)
print("Child Model Accuracy:", accuracy_score(y_child_test, y_child_pred))

In [None]:
# Save the order of which the models performed based on their accuracy from the best to the lowest
model_accuracies = {
    "CB": accuracy_score(y_cb_test, y_cb_pred),
    "OB": accuracy_score(y_ob_test, y_ob_pred),
    "Child": accuracy_score(y_child_test, y_child_pred)
}

model_dict = {
    "CB": cb_model,
    "OB": ob_model,
    "Child": child_model,
}

# Sort the models based on their accuracy in descending order and save the order
sorted_models = sorted(model_accuracies.items(), key=lambda x: x[1], reverse=True)
sorted_models = [sorted_model[0] for sorted_model in sorted_models]

sorted_models

# Greedy Solver

In [None]:
def find_node_by_idx(idx):
    # Function to find a node by its index in the tree
    for example in data:
        for node in example:
            if node.get("idx") == idx:
                return node
    return None

def aggregate_answers(answers):
    # Example aggregation logic: majority voting
    from collections import Counter
    if not answers:
        return None
    counter = Counter(answers)
    return counter.most_common(1)[0][0]

def greedy_solver(node, cb_model, ob_model, child_model):

    # Step 1: Check OB answer
    ob_answer, ob_logprob, _, ob_logprobs = node.get("ob_answer")
    # ob_reliable = ob_model.predict([[ob_logprob]])[0] == 1  # Predict reliability
    ob_logprobs_padded = pad_sequences([ob_logprobs], maxlen=ob_max_length, padding="post", truncating="post", value=-100, dtype="float32")
    ob_reliable = ob_model.predict(ob_logprobs_padded)[0] == 1  # Predict reliability
    if ob_reliable:
        if "unknown" not in ob_answer.lower().strip():
            print("OB answer is reliable", ob_answer)
            return ob_answer, "OB"  # Accept OB answer
        else:
            ob_logprob = -float('inf')
    
    # Step 2: Expand the Tree (Child Aggregation)
    if "sons" in node and node["sons"]:  # Check if node has children
        child_answer, child_logprob, _, child_logprobs = node.get("child_answer")
        # child_reliable = child_model.predict([[child_logprob]])[0] == 1
        child_logprobs_padded = pad_sequences([child_logprobs], maxlen=child_max_length, padding="post", truncating="post", value=-100, dtype="float32")
        child_reliable = child_model.predict(child_logprobs_padded)[0] == 1  # Predict reliability
        if child_reliable:
            if "unknown" not in child_answer.lower().strip():
                print("Child answer is reliable", child_answer)
                return child_answer, "Child"  # Accept Child answer
            else:
                child_logprob = -float('inf')

    # Step 3: Check CB answer
    cb_answer, cb_logprob, _, cb_logprobs = node.get("cb_answer")
    # cb_reliable = cb_model.predict([[cb_logprob]])[0] == 1  # Predict reliability
    cb_logprobs_padded = pad_sequences([cb_logprobs], maxlen=cb_max_length, padding="post", truncating="post", value=-100, dtype="float32")
    cb_reliable = cb_model.predict(cb_logprobs_padded)[0] == 1  # Predict reliability
    if cb_reliable:
        if "unknown" not in cb_answer.lower().strip():
            print("CB answer is reliable", cb_answer)
            return cb_answer, "CB"  # Accept CB answer
        else:
            # set it to -ve infinity not to be chosen
            cb_logprob = -float('inf')

    # If no method produces a reliable answer, return the best available
    print("No reliable answer found, returning the best available answer")
    if cb_logprob > ob_logprob and cb_logprob > child_logprob:
        print("Best Answer: CB, Best Method: CB")
        return cb_answer, "CB"
    elif ob_logprob > cb_logprob and ob_logprob > child_logprob:
        print("Best Answer: OB, Best Method: OB")
        return ob_answer, "OB"
    else:
        print("Best Answer: Child, Best Method: Child")
        return child_answer, "Child"

In [None]:
import json

# Load test data
with open('results-testset-hotpotqa.json', 'r') as file:
    test_data_hotpotqa = json.load(file)
with open('results-testset-2wiki.json', 'r') as file:
    test_data_2wiki = json.load(file)
with open('results-testset-musique.json', 'r') as file:
    test_data_musique = json.load(file)

# Load ground truth answers
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_test_random_500.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_test_random_500.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_test_random_500.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'


In [None]:
# Run the greedy solver on all parent nodes

q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        gold = item['answers_objects'][0]['spans'][0]
        q2gold[question] = (gold, item['dataset'])
    except Exception as e:
        print("ERROR CASE", e)
        
results = []
correct_answers = 0
total_parent_nodes = 0
for example in test_data_hotpotqa:
    for node in example:
        if "fa" not in node:  # Only process root nodes (original questions)
            total_parent_nodes += 1
            answer, method = greedy_solver(node, cb_model, ob_model, child_model)
            question_text = node.get("question_text", "").strip()
            final_answer, _ = q2gold[question_text]
            # if normalize_answer(answer) == normalize_answer(final_answer):
            if are_answers_equivalent_using_llm(final_answer, answer):
                correct_answers += 1
            results.append({
                "idx": node["idx"],
                "question": node["question_text"],
                "answer": answer,
                "method": method
            })

# Print the results
for result in results:
    print(f"Node {result['idx']}:")
    print(f"  Question: {result['question']}")
    print(f"  Answer: {result['answer']}")
    print(f"  Method: {result['method']}")
    print()

# Calculate the accuracy
accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
print(f"Total parent nodes: {total_parent_nodes}")
print(f"Times answer matches gold answer: {correct_answers}")
print(f"Greedy Solver match rate: {accuracy:.2f}%")

In [None]:
# Run the greedy solver on all parent nodes

q2gold = {}
for item in raw_data_2wiki:
    try:
        question = item['question_text'].strip()
        gold = item['answers_objects'][0]['spans'][0]
        q2gold[question] = (gold, item['dataset'])
    except Exception as e:
        print("ERROR CASE", e)
        
results = []
correct_answers = 0
total_parent_nodes = 0
for example in test_data_2wiki:
    for node in example:
        if "fa" not in node:  # Only process root nodes (original questions)
            total_parent_nodes += 1
            answer, method = greedy_solver(node, cb_model, ob_model, child_model)
            question_text = node.get("question_text", "").strip()
            final_answer, _ = q2gold[question_text]
            # if normalize_answer(answer) == normalize_answer(final_answer):
            if are_answers_equivalent_using_llm(final_answer, answer):
                correct_answers += 1
            results.append({
                "idx": node["idx"],
                "question": node["question_text"],
                "answer": answer,
                "method": method
            })

# Print the results
for result in results:
    print(f"Node {result['idx']}:")
    print(f"  Question: {result['question']}")
    print(f"  Answer: {result['answer']}")
    print(f"  Method: {result['method']}")
    print()

# Calculate the accuracy
accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
print(f"Total parent nodes: {total_parent_nodes}")
print(f"Times answer matches gold answer: {correct_answers}")
print(f"Greedy Solver match rate: {accuracy:.2f}%")

In [None]:
# Run the greedy solver on all parent nodes

q2gold = {}
for item in raw_data_musique:
    try:
        question = item['question_text'].strip()
        gold = item['answers_objects'][0]['spans'][0]
        q2gold[question] = (gold, item['dataset'])
    except Exception as e:
        print("ERROR CASE", e)
        
results = []
correct_answers = 0
total_parent_nodes = 0
for example in test_data_musique:
    for node in example:
        if "fa" not in node:  # Only process root nodes (original questions)
            total_parent_nodes += 1
            answer, method = greedy_solver(node, cb_model, ob_model, child_model)
            question_text = node.get("question_text", "").strip()
            final_answer, _ = q2gold[question_text]
            # if normalize_answer(answer) == normalize_answer(final_answer):
            if are_answers_equivalent_using_llm(final_answer, answer):
                correct_answers += 1
            results.append({
                "idx": node["idx"],
                "question": node["question_text"],
                "answer": answer,
                "method": method
            })

# Print the results
for result in results:
    print(f"Node {result['idx']}:")
    print(f"  Question: {result['question']}")
    print(f"  Answer: {result['answer']}")
    print(f"  Method: {result['method']}")
    print()

# Calculate the accuracy
accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
print(f"Total parent nodes: {total_parent_nodes}")
print(f"Times answer matches gold answer: {correct_answers}")
print(f"Greedy Solver match rate: {accuracy:.2f}%")

# Resampling a tree from a question

In [None]:
from Tree_Generation.tree_resampling import TreeResamplingPipeline
# tree_resampling_pipeline = TreeResamplingPipeline()
tree_resampling_pipeline = TreeResamplingPipeline(use_cache=True)
question = "what is the population size of the smallest city in the world?"
tree_resampling_pipeline.resample_tree(question=question)

# RL Approach 

In [None]:
# Load the raw data and question decompositions
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_dev_random_100.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data.extend(raw_data_2wiki)
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_dev_random_100.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'
raw_data.extend(raw_data_musique)

q2dq = json.load(open("./question_decompositions-devset-hotpotqa.json"))
q2dq_2wiki = json.load(open("./question_decompositions-devset-2wiki.json"))
q2dq_musique = json.load(open("./question_decompositions-devset-musique.json"))
q2dq.update(q2dq_2wiki)
q2dq.update(q2dq_musique)

# Create q2gold map
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q_type = item['dataset']
        q2gold[question] = (gold, q_type)
    except Exception as e:
        # Skip if question not found in question_decompositions
        continue

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)


In [None]:
import numpy as np
import random
from collections import defaultdict

# Define the state space, action space, and Q-table
action_space = ['CB', 'OB', 'Child']  # Possible actions
Q = defaultdict(lambda: np.zeros(len(action_space)))  # Q-table

# Hyperparameters
alpha = 0.1  # Learning rate
gamma = 0.9  # Discount factor
epsilon = 1.0  # Exploration rate
epsilon_decay = 0.98
epsilon_min = 0.01

# Initialize global counters for success rates
success_counts = {"cb": 0, "ob": 0, "child": 0}
attempt_counts = {"cb": 0, "ob": 0, "child": 0}

# Define the reward function
def get_reward(chosen_answer, gold_answer):
    # if normalize_answer(chosen_answer) == normalize_answer(gold_answer):
    if are_answers_equivalent_using_llm(gold_answer, chosen_answer):
        return 1
    else:
        return -1

# Function to update success rates dynamically
def update_success_rate(answer_type, is_correct):
    """
    Update the success rate for a given answer type.
    :param answer_type: "cb", "ob", or "child"
    :param is_correct: True if the answer was correct, False otherwise
    """
    global success_counts, attempt_counts
    answer_type = answer_type.lower()
    attempt_counts[answer_type] += 1
    if is_correct:
        success_counts[answer_type] += 1

def get_success_rate(answer_type):
    """
    Get the current success rate for a given answer type.
    :param answer_type: "cb", "ob", or "child"
    :return: Success rate (float)
    """
    global success_counts, attempt_counts
    if attempt_counts[answer_type] == 0:
        # Default success rates (can be adjusted based on prior knowledge)
        default_rates = {"cb": 0.32, "ob": 0.32, "child": 0.42}
        return default_rates[answer_type]
    return success_counts[answer_type] / attempt_counts[answer_type]

def pad_or_truncate(logprobs, max_length=50, pad_value=-100):
    if len(logprobs) < max_length:
        # Pad with pad_value
        return logprobs + [pad_value] * (max_length - len(logprobs))
    else:
        # Truncate to max_length
        return logprobs[:max_length]

def get_state(node, depth=0, max_length=50):
    # Extract raw features
    cb_logprob = node.get("cb_answer", [None, None, None, []])[3] or []
    ob_logprob = node.get("ob_answer", [None, None, None, []])[3] or []
    child_logprob = node.get("child_answer", [None, None, None, []])[3] or []

    # Pad or truncate logprobs each is of length 50
    cb_logprob = pad_or_truncate(cb_logprob, max_length)
    ob_logprob = pad_or_truncate(ob_logprob, max_length)
    child_logprob = pad_or_truncate(child_logprob, max_length)

    # Other features (7)
    has_children = 1 if len(node.get("sons", [])) else 0
    question_length = len(node.get("question_text", "").split())
    question_type = encode_question_type(node.get("question_text", ""))
    num_children = len(node.get("sons", []))
    cb_success_rate = get_success_rate("cb")
    ob_success_rate = get_success_rate("ob")
    child_success_rate = get_success_rate("child")

    # Semantic features
    # Each embedding is a 384-dimensional vector and they are total 3 = 1152
    question_text = node.get("question_text", "")
    # question_embedding = model.encode(question_text, convert_to_tensor=False)
    # cb_answer_embedding = model.encode(node.get("cb_answer", [""])[0], convert_to_tensor=False)
    # ob_answer_embedding = model.encode(node.get("ob_answer", [""])[0], convert_to_tensor=False)
    # child_answer_embedding = model.encode(node.get("child_answer", [""])[0], convert_to_tensor=False)

    # Confidence and uncertainty (total 2)
    cb_confidence = node.get("cb_answer", [None, None, None, []])[1] or 0.0
    ob_confidence = node.get("ob_answer", [None, None, None, []])[1] or 0.0
    child_confidence = node.get("child_answer", [None, None, None, []])[1] or 0.0

    # Structural features (total 2)
    tree_depth = depth
    tree_position = 0 if depth == 0 else 1  # 0 for root, 1 for intermediate/leaf

    # Temporal features (example: sliding window of last 3 actions) 
    # action_history = node.get("action_history", [0, 0, 0])  # Placeholder for action history
    # action_success_history = node.get("action_success_history", [0, 0, 0])  # Placeholder for success history

    # External knowledge features
    # num_retrieved_documents = node.get("num_retrieved_documents", 0)
    # entity_linking_confidence = node.get("entity_linking_confidence", 0.0)

    # Answer quality (total 2)
    cb_answer_length = len(node.get("cb_answer", [""])[0].split())
    ob_answer_length = len(node.get("ob_answer", [""])[0].split())
    # child_answer_length = len(node.get("child_answer", [""])[0].split())

    # Build state vector
    state = (
        cb_logprob +  # CB log probabilities
        ob_logprob +  # OB log probabilities
        child_logprob + # Child log probabilities
        [has_children, question_length, question_type, num_children, cb_success_rate, ob_success_rate, child_success_rate] +  # Basic features
        # list(question_embedding) +  # Semantic embedding of the question
        # list(cb_answer_embedding) +  # Semantic embedding of the CB answer
        # list(ob_answer_embedding) +  # Semantic embedding of the OB answer
        [cb_confidence, ob_confidence, child_confidence] +  # Confidence scores for CB and OB
        [tree_depth, tree_position] +  # Structural features
        [cb_answer_length, ob_answer_length]  # Answer quality features
    )
    return state

# Q-learning algorithm
def q_learning(node, gold_answer):
    global epsilon
    state = tuple(get_state(node))  # Convert state to tuple
    if random.uniform(0, 1) < epsilon:
        action = random.choice(action_space)  # Explore action space
    else:
        # action = action_space[np.argmax(Q[state])]  # Exploit learned values
        max_q_value = np.max(Q[state])  # Find the maximum Q-value
        best_actions = [action_space[i] for i, q in enumerate(Q[state]) if q == max_q_value]  # Get all actions with max Q-value
        action = np.random.choice(best_actions)  # Choose randomly among them


    # Simulate the action (choose answer based on action)
    if action == 'CB':
        answer = node.get("cb_answer")[0]
    elif action == 'OB':
        answer = node.get("ob_answer")[0]
    else:
        answer = node.get("child_answer")[0]

    # Compute reward
    reward = get_reward(answer, gold_answer)
    next_state = tuple(get_state(node))  # Convert next_state to tuple

    # Update Q-value
    old_value = Q[state][action_space.index(action)]
    next_max = np.max(Q[next_state])
    Q[state][action_space.index(action)] = old_value + alpha * (reward + gamma * next_max - old_value)

    # Update success rate for the chosen action
    is_correct = (reward == 1)
    update_success_rate(action.lower(), is_correct)

    # Decay epsilon
    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    return answer, action



def encode_question_type(question_text):
    # Example: Encode question type based on the first word
    first_word = question_text.strip().split()[0].lower()
    if first_word == "what":
        return 0
    elif first_word == "where":
        return 1
    elif first_word == "how":
        return 2
    else:
        return 3  # Other

# Hyperparameters for multiple iterations
num_iterations = 2000  # Number of iterations over the dataset

max_percentage, max_correct = 0, None
best_Q = None

# Run the Q-learning algorithm for multiple iterations
for iteration in range(num_iterations):
    print(f"Iteration {iteration + 1}/{num_iterations}")
    results = []
    correct_answers = 0
    total_parent_nodes = 0

    for example in data:
        for node in example:
            if "fa" not in node:  # Only process root nodes (original questions)
                total_parent_nodes += 1
                question_text = node.get("question_text", "").strip()
                final_answer, _ = q2gold[question_text]
                answer, method = q_learning(node, final_answer)
                # if normalize_answer(answer) == normalize_answer(final_answer):
                if are_answers_equivalent_using_llm(final_answer, answer):
                    correct_answers += 1
                results.append({
                    "idx": node["idx"],
                    "question": node["question_text"],
                    "answer": answer,
                    "method": method
                })

    # Print the results for this iteration
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    max_percentage = max(max_percentage, accuracy)
    if max_percentage == accuracy:
        max_correct = correct_answers
        best_Q = Q.copy()

    print(f"Total parent nodes: {total_parent_nodes}")
    print(f"Times answer matches gold answer: {correct_answers}")
    print(f"Q-learning Solver match rate: {accuracy:.2f}%")
    print()

In [None]:
attempt_counts, success_counts

In [None]:
max_percentage, max_correct

In [None]:
actions = []
for val in Q.values():
    # print action
    print(action_space[np.argmax(val)])
    actions.append(action_space[np.argmax(val)])

# Plot the distribution of actions
plt.figure(figsize=(8, 6))
sns.histplot(actions, bins=3, discrete=True)
plt.title("Distribution of Maximum Actions")
plt.xlabel("Actions")
plt.ylabel("Frequency")
plt.show()

## Try it on test set data of hotpot qa 

In [None]:
import json
import numpy as np
import random
from collections import defaultdict


# Function to select an action using the trained Q-table (no exploration)
def select_action(state):
    """ Select the best action using the trained Q-values. """
    q_values = Q[state]
    if state[3] == 0:
        # no children
        q_values[2] = -float('inf')
    max_q_value = np.max(Q[state])  # Get max Q-value
    best_actions = [action_space[i] for i, q in enumerate(Q[state]) if q == max_q_value]  # Get all actions with max Q-value
    print("best_actions", best_actions)
    return np.random.choice(best_actions)  # Choose randomly if multiple exist

# Function to make predictions on test data
def run_inference(test_data_hotpotqa):
    """ Run inference on the test data using the trained Q-table. """
    correct_answers = 0
    total_parent_nodes = 0
    results = []
    
    for example in test_data_hotpotqa:
        for node in example:
            if "fa" not in node:  # Only process root nodes (original questions)
                final_answer, dataset_used = q2gold[node["question_text"].strip()]
                print(node["question_text"])
                state = tuple(get_state(node))  # Convert to tuple for Q-table lookup
                action = select_action(state)  # Use trained Q-table

                # Choose the answer based on the selected action
                if action == 'CB':
                    answer = node.get("cb_answer")[0]
                elif action == 'OB':
                    answer = node.get("ob_answer")[0]
                else:
                    answer = node.get("child_answer")[0]

                if "unknown" in answer.lower().strip():
                    action = "CB"
                    answer = node.get("cb_answer")[0]
                    
                total_parent_nodes += 1
                
                print("final_answer", final_answer)
                print("answer", answer)
                # if normalize_answer(answer) == normalize_answer(final_answer):
                if are_answers_equivalent_using_llm(final_answer, answer):
                    correct_answers += 1

                # Store results
                results.append({
                    "idx": node["idx"],
                    "question": node["question_text"],
                    "answer": answer,
                    "gold": final_answer,
                    "method": action
                })

    return results, correct_answers, total_parent_nodes


# Utilities

In [None]:
def get_question_after_resolving_references(node, tree):
    question = node["question_text"].strip()
    ref_tokens = re.findall(r"<\d+>", question)
    topic_entities = []
    # print("question, ref_tokens", question, ref_tokens)
    # return
    for ref_token in ref_tokens:
        if "fa" in node and int(ref_token[1:-1]) <= len(tree[node["fa"]]["sons"]):
            ref_idx = tree[node["fa"]]["sons"][int(ref_token[1:-1])-1]
            # print("ref_idx", ref_idx)
            if "answer" in tree[ref_idx]:
                question = question.replace(ref_token, tree[ref_idx]["answer"][0])
                topic_entities.append(tree[ref_idx]["answer"][0])

    node["question"] = question
    return question, topic_entities

In [None]:

import json
import torch

# Load test data
with open('results-testset-hotpotqa.json', 'r') as file:
    test_data_hotpotqa = json.load(file)
with open('results-testset-2wiki.json', 'r') as file:
    test_data_2wiki = json.load(file)
with open('results-testset-musique.json', 'r') as file:
    test_data_musique = json.load(file)
raw_data_hotpotqa = [json.loads(line.strip()) for line in open('./hotpotqa__v2_test_random_500.jsonl')]

raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_test_random_500.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'

raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_test_random_500.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'

q2gold_test_hotpotqa = {}
for item in raw_data_hotpotqa:
    try:
        question = item['question_text'].strip()
        gold = item['answers_objects'][0]['spans'][0]
        q2gold_test_hotpotqa[question] = (gold, item['dataset'])
    except Exception as e:
        print("ERROR CASE", e)

q2gold_test_2wiki = {}
for item in raw_data_2wiki:
    try:
        question = item['question_text'].strip()
        gold = item['answers_objects'][0]['spans'][0]
        q2gold_test_2wiki[question] = (gold, item['dataset'])
    except Exception as e:
        print("ERROR CASE", e)

q2gold_test_musique = {}
for item in raw_data_musique:
    try:
        question = item['question_text'].strip()
        gold = item['answers_objects'][0]['spans'][0]
        q2gold_test_musique[question] = (gold, item['dataset'])
    except Exception as e:
        print("ERROR CASE", e)

print("q2gold_test_hotpotqa", len(q2gold_test_hotpotqa))
print("q2gold_test_2wiki", len(q2gold_test_2wiki))
print("q2gold_test_musique", len(q2gold_test_musique))
print("raw_data_hotpotqa", len(raw_data_hotpotqa))
print("raw_data_2wiki", len(raw_data_2wiki))
print("raw_data_musique", len(raw_data_musique))

# RL Approach with deep learning using question only in state

In [None]:
# Load the raw data and question decompositions
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_dev_random_100.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data.extend(raw_data_2wiki)
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_dev_random_100.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'
raw_data.extend(raw_data_musique)

q2dq = json.load(open("./question_decompositions-devset-hotpotqa.json"))
q2dq_2wiki = json.load(open("./question_decompositions-devset-2wiki.json"))
q2dq_musique = json.load(open("./question_decompositions-devset-musique.json"))
q2dq.update(q2dq_2wiki)
q2dq.update(q2dq_musique)

# Create q2gold map
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q_type = item['dataset']
        q2gold[question] = (gold, q_type)
    except Exception as e:
        # Skip if question not found in question_decompositions
        continue

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# def get_state(node, depth=0):
#     # Extract raw features
#     cb_logprob = node.get("cb_answer", [None, None])[1] or -100
#     ob_logprob = node.get("ob_answer", [None, None])[1] or -100
#     child_logprob = node.get("child_answer", [None, None])[1] or -100
#     has_children = 1 if len(node.get("sons", [])) else 0
#     question_length = len(node.get("question_text", "").split())
#     question_type = encode_question_type(node.get("question_text", ""))
#     num_children = len(node.get("sons", []))
#     cb_success_rate = get_success_rate("cb")
#     ob_success_rate = get_success_rate("ob")
#     child_success_rate = get_success_rate("child")

#     # Build state vector
#     state = [
#         cb_logprob,
#         ob_logprob,
#         # child_logprob,
#         has_children,
#         question_length,
#         question_type,
#         num_children,
#         cb_success_rate,
#         ob_success_rate,
#         child_success_rate
#     ]
#     return torch.FloatTensor(state)

def pad_or_truncate(logprobs, max_length=50, pad_value=-100):
    if len(logprobs) < max_length:
        # Pad with pad_value
        return logprobs + [pad_value] * (max_length - len(logprobs))
    else:
        # Truncate to max_length
        return logprobs[:max_length]

# # Define the reward function
# def get_reward(chosen_answer, gold_answer):
#     if normalize_answer(chosen_answer) == normalize_answer(gold_answer):
#         return 1
#     else:
#         return -1

from sentence_transformers import SentenceTransformer, util
# Load a pre-trained sentence embedding model
model = SentenceTransformer('all-MiniLM-L6-v2')
# def get_reward(chosen_answer, gold_answer):
#     # Compute embeddings for the chosen and gold answers
#     chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
#     gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
#     # Compute cosine similarity
#     similarity = util.cos_sim(chosen_embedding, gold_embedding).item()
#     return similarity  # Reward is the similarity score (between -1 and 1)

def get_reward(chosen_answer, gold_answer, action, num_llm_calls, alpha=1.0, beta=0.1):
    """
    Reward function with tradeoff between accuracy and efficiency.
    :param chosen_answer: Answer chosen by the agent.
    :param gold_answer: Ground truth answer.
    :param action: The action selected by the agent (0=CB, 1=OB, 2=Child).
    :param num_llm_calls: The number of LLM calls made during the current decision.
    :param alpha: Weight for accuracy reward.
    :param beta: Weight for efficiency penalty.
    :return: A reward value that balances accuracy and efficiency.
    """
    # Compute Accuracy Reward
    chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
    gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
    similarity = util.cos_sim(chosen_embedding, gold_embedding).item()  # Value between -1 (opposite) and 1 (exact match)
    accuracy_reward = max(0, similarity)  # Ensure rewards are non-negative

    # Define LLM cost for each action
    ACTION_COSTS = {
        0: 1,  # CB cost
        1: 1,  # OB cost (higher than CB because of retrieval)
        2: 1,  # Child decomposition cost (higher due to multiple child evaluations)
    }

    # Compute Efficiency Penalty
    action_cost = ACTION_COSTS.get(action, 1)  # Default to 1 if action is unrecognized
    efficiency_penalty = num_llm_calls * action_cost  # Penalize based on LLM usage

    # Combine Accuracy and Efficiency
    reward = alpha * accuracy_reward - beta * efficiency_penalty

    return reward


def get_state(node, depth=0, max_length=50):
    # Extract raw features
    cb_logprob = node.get("cb_answer", [None, None, None, []])[3] or []
    ob_logprob = node.get("ob_answer", [None, None, None, []])[3] or []
    # child_logprob = node.get("child_answer", [None, None, None, []])[3] or []

    # Pad or truncate logprobs each is of length 50
    cb_logprob = pad_or_truncate(cb_logprob, max_length)
    ob_logprob = pad_or_truncate(ob_logprob, max_length)
    # child_logprob = pad_or_truncate(child_logprob, max_length)

    # Other features (7)
    has_children = 1 if len(node.get("sons", [])) else 0
    question_length = len(node.get("question_text", "").split())
    question_type = encode_question_type(node.get("question_text", ""))
    num_children = len(node.get("sons", []))
    cb_success_rate = get_success_rate("cb")
    ob_success_rate = get_success_rate("ob")
    child_success_rate = get_success_rate("child")

    # Semantic features
    # Each embedding is a 384-dimensional vector and they are total 3 = 1152
    question_text = node.get("question_text", "")
    question_embedding = model.encode(question_text, convert_to_tensor=False)
    cb_answer_embedding = model.encode(node.get("cb_answer", [""])[0], convert_to_tensor=False)
    ob_answer_embedding = model.encode(node.get("ob_answer", [""])[0], convert_to_tensor=False)
    # child_answer_embedding = model.encode(node.get("child_answer", [""])[0], convert_to_tensor=False)

    # Confidence and uncertainty (total 2)
    cb_confidence = node.get("cb_answer", [None, None, None, []])[1] or 0.0
    ob_confidence = node.get("ob_answer", [None, None, None, []])[1] or 0.0
    child_confidence = node.get("child_answer", [None, None, None, []])[1] or 0.0

    # Structural features (total 2)
    tree_depth = depth
    tree_position = 0 if depth == 0 else 1  # 0 for root, 1 for intermediate/leaf

    # Temporal features (example: sliding window of last 3 actions) 
    # action_history = node.get("action_history", [0, 0, 0])  # Placeholder for action history
    # action_success_history = node.get("action_success_history", [0, 0, 0])  # Placeholder for success history

    # External knowledge features
    # num_retrieved_documents = node.get("num_retrieved_documents", 0)
    # entity_linking_confidence = node.get("entity_linking_confidence", 0.0)

    # Answer quality (total 2)
    cb_answer_length = len(node.get("cb_answer", [""])[0].split())
    ob_answer_length = len(node.get("ob_answer", [""])[0].split())
    # child_answer_length = len(node.get("child_answer", [""])[0].split())

    # Build state vector
    state = (
        # cb_logprob +  # CB log probabilities
        # ob_logprob +  # OB log probabilities
        [has_children, question_length, question_type, num_children, cb_success_rate, ob_success_rate, child_success_rate] +  # Basic features
        list(question_embedding) +  # Semantic embedding of the question
        # list(cb_answer_embedding) +  # Semantic embedding of the CB answer
        # list(ob_answer_embedding) +  # Semantic embedding of the OB answer
        [cb_confidence, ob_confidence] +  # Confidence scores for CB and OB
        [tree_depth, tree_position] +  # Structural features
        [cb_answer_length, ob_answer_length]  # Answer quality features
    )

    # maybe try question only by masking all and add log probs for the solved one and try adding verification step in the end.

    return torch.FloatTensor(state)

HAS_CHILDREN_INDEX = 0


class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=1e-3, gamma=0.99):
        self.q_network = QNetwork(state_dim, action_dim, hidden_dim)
        self.target_network = QNetwork(state_dim, action_dim, hidden_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.replay_buffer = []

    def select_action(self, state, epsilon):
        if random.random() < epsilon:
            if state[HAS_CHILDREN_INDEX] == 0: # has children = 0, because child is not possible
                return random.randint(0, len(action_space) - 2)
            else:
                return random.randint(0, len(action_space) - 1)
        else:
            with torch.no_grad():
                q_values = self.q_network(state)
                if state[HAS_CHILDREN_INDEX] == 0: # has children = 0
                    q_values[-1] = -float('inf')
                return torch.argmax(q_values).item()

    def train(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return
        batch = random.sample(self.replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.stack(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.stack(next_states)
        dones = torch.FloatTensor(dones)

        # Compute Q-values for current states
        q_values = self.q_network(states).gather(1, actions.unsqueeze(1))

        # Compute target Q-values
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # update the success rates
        for i, action in enumerate(actions):
            if rewards[i] == 1:
                update_success_rate(action_space[action], True)
            else:
                update_success_rate(action_space[action], False)

        # Compute loss and update the network
        # loss = nn.MSELoss()(q_values.squeeze(), target_q_values)
        loss = nn.SmoothL1Loss()(q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

In [None]:
import os
import torch
import copy

# Create directory for saved models
os.makedirs("saved_models_question_only_state_trained_all", exist_ok=True)

# Define different configurations for alpha and beta
configurations = [
    {"alpha": 2.0, "beta": 0.05, "name": "High_Accuracy"},
    {"alpha": 1.0, "beta": 0.1, "name": "Balanced"},
    {"alpha": 0.5, "beta": 0.2, "name": "Efficiency_Focused"}
]

# Dictionary to store results
accuracy_rewards_numcalls = {}

# Loop through different settings
for config in configurations:
    alpha = config["alpha"]
    beta = config["beta"]
    config_name = config["name"]
    state_dim = 7 + 384 + 2 + 2 + 2  # 100 + 7 + 1152 + 2 + 2 + 2 # Number of features in the state vector
    action_dim = len(action_space)  # Number of actions (CB, OB, Child)
    hidden_dim = 128  # Hidden layer size
    lr = 5e-3  # Increase from 1e-3  //     1e-3  # Learning rate
    gamma = 0.99  # Discount factor
    epsilon = 1.0  # Initial exploration rate
    epsilon_min = 0.01  # Minimum exploration rate
    epsilon_decay = 0.98  # Decay rate for exploration
    batch_size = 128  # Mini-batch size
    num_episodes = 20  # Number of training episodes

    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)

    print(f"\nTraining with {config_name} (alpha={alpha}, beta={beta})")

    max_accuracy = 0
    best_agent = None

    # Store rewards, accuracy and total number of LLM calls for each episode
    rewards_list = []
    accuracy_list = []
    total_num_llm_calls_list = []

    # Set model to training mode
    agent.q_network.train()

    for episode in range(num_episodes):
        print(f"Episode {episode + 1}/{num_episodes}")
        total_reward = 0
        correct_answers = 0
        total_parent_nodes = 0
        total_num_llm_calls = 0

        for example in data:
            for node in example:
                if "fa" not in node:  # Only process root nodes (original questions)
                    gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                    total_parent_nodes += 1
                    state = get_state(node)  
                    action = agent.select_action(state, epsilon)  
                    chosen_answer = None
                    fallback_used = False  
                    num_llm_calls = None

                    question, _ = get_question_after_resolving_references(node, example)

                    # Simulate the action
                    if action == 0: 
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])
                        num_llm_calls = 1
                    elif action == 1:
                        node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])
                        num_llm_calls = 1
                    elif action == 2:
                        tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                        child_experiences = []

                        for node_ in tree_with_answers_chosen_by_agent:
                            question_, topic_entities = get_question_after_resolving_references(node_, tree_with_answers_chosen_by_agent)
                        
                            if len(node_["sons"]) == 0:
                                child_state = get_state(node_, depth=1)
                                child_action = agent.select_action(child_state, epsilon)
                                if child_action == 0:  
                                    node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 1  
                                        node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                        child_answer = node_.get("ob_answer", [None])
                                elif child_action == 1:  
                                    node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 0  
                                        node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                        child_answer = node_.get("cb_answer", [None])

                                node_["answer"] = child_answer
                                child_experiences.append((child_state, child_action, 0, child_state, False))  
                            else:
                                node_["child_answer"], node_["answer"] = aggregate_multihop_answer(node_, tree_with_answers_chosen_by_agent, dataset_used)
                                chosen_answer = node_["child_answer"]
                        
                        num_llm_calls = len(child_experiences) + 1 + 1  # number of children + 1 (decomposition) + 1 (aggregation)

                    if "unknown" in chosen_answer[0].lower().strip():
                        fallback_used = True
                        if action == 0:  
                            fallback_action = 1
                            node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                            chosen_answer = node.get("ob_answer", [None])
                        elif action == 1:  
                            fallback_action = 0
                            node["cb_answer"] = get_cb_answer(question, dataset_used)
                            chosen_answer = node.get("cb_answer", [None])
                        elif action == 2:  
                            fallback_action = 0
                            node["cb_answer"] = get_cb_answer(question, dataset_used)
                            chosen_answer = node.get("cb_answer", [None])

                    # Compute reward
                    reward = get_reward(chosen_answer[0], gold_answer, action, num_llm_calls, alpha, beta)
                    total_reward += reward

                    # if normalize_answer(chosen_answer[0]) == normalize_answer(gold_answer):
                    if are_answers_equivalent_using_llm(gold_answer, chosen_answer[0]):
                        correct_answers += 1

                    next_state = get_state(node)

                    if not fallback_used:
                        agent.replay_buffer.append((state, action, reward, next_state, False))  
                    else:
                        agent.replay_buffer.append((state, fallback_action, reward, next_state, False))

                    if action == 2:
                        num_children = len(node.get("sons", []))
                        child_reward = reward / num_children  
                        for child_state, child_action, _, next_child_state, done in child_experiences:
                            agent.replay_buffer.append((child_state, child_action, reward, next_child_state, done))

                    # update total number of llm calls in this episode
                    total_num_llm_calls += num_llm_calls

                    agent.train(batch_size)

        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        if episode % 10 == 0:
            agent.update_target_network()

        accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
        rewards_list.append(total_reward)
        accuracy_list.append(accuracy)
        total_num_llm_calls_list.append(total_num_llm_calls)
        if accuracy > max_accuracy:
            max_accuracy = accuracy
            best_agent = copy.deepcopy(agent)

        print(f"Total Reward: {total_reward}")
        print(f"Accuracy: {accuracy:.2f}%")
        print(f"Total LLM Calls in this episode: {total_num_llm_calls}")
        print(f"Epsilon: {epsilon:.4f}")
        print()

    # Store the rewards and accuracy for this configuration
    accuracy_rewards_numcalls[config_name] = {
        "rewards": rewards_list,
        "accuracy": accuracy_list,
        "num_calls": total_num_llm_calls_list
    }

    # Save model after training with each configuration
    model_path = f"saved_models_question_only_state_trained_all/agent_{config_name}.pth"
    torch.save(agent.q_network.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    # Save the best model overall
    if best_agent is not None:
        best_model_path = f"saved_models_question_only_state_trained_all/best_agent_{config_name}.pth"
        torch.save(best_agent.q_network.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")


In [None]:
import matplotlib.pyplot as plt

# Get rewards, accuracy and number of LLM calls for each configuration
balanced_rewards = accuracy_rewards_numcalls["Balanced"]["rewards"]
balanced_accuracy = accuracy_rewards_numcalls["Balanced"]["accuracy"]
balanced_num_calls = accuracy_rewards_numcalls["Balanced"]["num_calls"]

efficiency_rewards = accuracy_rewards_numcalls["Efficiency_Focused"]["rewards"]
efficiency_accuracy = accuracy_rewards_numcalls["Efficiency_Focused"]["accuracy"]
efficiency_num_calls = accuracy_rewards_numcalls["Efficiency_Focused"]["num_calls"]

high_accuracy_accuracy = accuracy_rewards_numcalls["High_Accuracy"]["accuracy"]
high_accuracy_rewards = accuracy_rewards_numcalls["High_Accuracy"]["rewards"]
high_accuracy_num_calls = accuracy_rewards_numcalls["High_Accuracy"]["num_calls"]

# min max scaling for rewards
balanced_rewards = (balanced_rewards - np.min(balanced_rewards)) / (np.max(balanced_rewards) - np.min(balanced_rewards))
efficiency_rewards = (efficiency_rewards - np.min(efficiency_rewards)) / (np.max(efficiency_rewards) - np.min(efficiency_rewards))
high_accuracy_rewards = (high_accuracy_rewards - np.min(high_accuracy_rewards)) / (np.max(high_accuracy_rewards) - np.min(high_accuracy_rewards))

# Plotting
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(balanced_rewards, label="Balanced")
plt.plot(efficiency_rewards, label="Efficiency Focused")
plt.plot(high_accuracy_rewards, label="High Accuracy")
plt.title("Normalized Rewards")
plt.xlabel("Episode")
plt.ylabel("Normalized Reward")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(balanced_accuracy, label="Balanced")
plt.plot(efficiency_accuracy, label="Efficiency Focused")
plt.plot(high_accuracy_accuracy, label="High Accuracy")
plt.title("Accuracy")
plt.xlabel("Episode")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 6))
plt.plot(balanced_num_calls, label="Balanced")
plt.plot(efficiency_num_calls, label="Efficiency Focused")
plt.plot(high_accuracy_num_calls, label="High Accuracy")
plt.title("Number of LLM Calls")
plt.xlabel("Episode")
plt.ylabel("Total LLM Calls")
plt.legend()
plt.show()

In [None]:
# Print number of calls for each configuration
balanced_num_calls = accuracy_rewards_numcalls["Balanced"]["num_calls"][-1]
efficiency_num_calls = accuracy_rewards_numcalls["Efficiency_Focused"]["num_calls"][-1]
high_accuracy_num_calls = accuracy_rewards_numcalls["High_Accuracy"]["num_calls"][-1]

print("balanced_num_calls", balanced_num_calls)
print("efficiency_num_calls", efficiency_num_calls)
print("high_accuracy_num_calls", high_accuracy_num_calls)

## Try it on test set data of hotpot qa 

In [None]:
# Evaluate the agent on the test data
def evaluate_agent(agent, data, q2gold):
    agent.q_network.eval()  # Set the model to evaluation mode
    correct_answers = 0
    total_parent_nodes = 0
    results = []
    for example in data:
        for node in example:
            if "fa" not in node:  # Only process root nodes (original questions)
                gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                total_parent_nodes += 1
                state = get_state(node)
                action = agent.select_action(state, epsilon=0)  # No exploration during evaluation
                chosen_answer = None
                fallback_used = False  # Track if a fallback was used
                fallback_action = None

                question, _ = get_question_after_resolving_references(node, example)
                # Simulate the action (choose answer based on action)
                if action == 0:  # CB
                    node["cb_answer"] = get_cb_answer(question, dataset_used)
                    chosen_answer = node.get("cb_answer", [None])[0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to OB
                        node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])[0]
                        fallback_action = 1
                        fallback_used = True
                elif action == 1:  # OB
                    node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                    chosen_answer = node.get("ob_answer", [None])[0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to CB
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])[0]
                        fallback_action = 0
                        fallback_used = True
                elif action == 2:  # Child
                    tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                    for child_idx in node.get("sons", []):
                        child = tree_with_answers_chosen_by_agent[child_idx]
                        child_state = get_state(child, depth=1)
                        child_action = agent.select_action(child_state, epsilon=0)
                        question_, topic_entities = get_question_after_resolving_references(child, tree_with_answers_chosen_by_agent)
                        if child_action == 0:  # CB
                            child["cb_answer"] = get_cb_answer(question_, dataset_used)
                            child_answer = child.get("cb_answer", [None])
                            if "unknown" in child_answer[0].lower().strip():  # Fallback to OB
                                child["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                child_answer = child.get("ob_answer", [None])
                        elif child_action == 1:  # OB
                            child["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                            child_answer = child.get("ob_answer", [None])
                            if "unknown" in child_answer[0].lower().strip():  # Fallback to CB
                                child["cb_answer"] = get_cb_answer(question_, dataset_used)
                                child_answer = child.get("cb_answer", [None])

                        tree_with_answers_chosen_by_agent[child_idx]["answer"] = child_answer

                    # Generate child_answer for the parent node
                    # print("tree_with_answers_chosen_by_agent", tree_with_answers_chosen_by_agent)
                    node["child_answer"], node["answer"] = aggregate_multihop_answer(node, tree_with_answers_chosen_by_agent, dataset_used)
                    chosen_answer = node["child_answer"][0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to CB or OB
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])[0]  # Try CB first
                        # fallback_used = True
                        if "unknown" in chosen_answer.lower().strip():
                            node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                            chosen_answer = node.get("ob_answer", [None])[0]  # Try OB next

                # Compute reward
                # if chosen_answer and normalize_answer(chosen_answer) == normalize_answer(gold_answer):
                if chosen_answer and are_answers_equivalent_using_llm(chosen_answer, gold_answer):
                    correct_answers += 1

                # Store results
                results.append({
                    "idx": node["idx"],
                    "question": node["question_text"],
                    "answer": chosen_answer,
                    "gold": gold_answer,
                    "method": action_space[action] if not fallback_used else action_space[fallback_action]
                })

    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    print(f"Evaluation Accuracy: {accuracy:.2f}%")
    return results, correct_answers, total_parent_nodes

# List of trained models to evaluate
configurations = [
    {"alpha": 2.0, "beta": 0.05, "name": "High_Accuracy"},
    {"alpha": 1.0, "beta": 0.1, "name": "Balanced"},
    {"alpha": 0.5, "beta": 0.2, "name": "Efficiency_Focused"}
]

predictions_per_config = {}
correct_answers_per_config = {}
total_parent_nodes_per_config = {}
accuracy_per_config = {}

In [None]:
# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_question_only_state_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_hotpotqa, q2gold_test_hotpotqa)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")

In [None]:
# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_question_only_state_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_2wiki, q2gold_test_2wiki)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")

In [None]:
# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_question_only_state_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_musique, q2gold_test_musique)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")

# RL Approach with deep learning (using Question + cb + ob)

In [None]:
# Load the raw data and question decompositions
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_dev_random_100.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data.extend(raw_data_2wiki)
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_dev_random_100.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'
raw_data.extend(raw_data_musique)

q2dq = json.load(open("./question_decompositions-devset-hotpotqa.json"))
q2dq_2wiki = json.load(open("./question_decompositions-devset-2wiki.json"))
q2dq_musique = json.load(open("./question_decompositions-devset-musique.json"))
q2dq.update(q2dq_2wiki)
q2dq.update(q2dq_musique)

# Create q2gold map
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q_type = item['dataset']
        q2gold[question] = (gold, q_type)
    except Exception as e:
        # Skip if question not found in question_decompositions
        continue

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


def pad_or_truncate(logprobs, max_length=50, pad_value=-100):
    if len(logprobs) < max_length:
        # Pad with pad_value
        return logprobs + [pad_value] * (max_length - len(logprobs))
    else:
        # Truncate to max_length
        return logprobs[:max_length]

# # Define the reward function
# def get_reward(chosen_answer, gold_answer):
#     if normalize_answer(chosen_answer) == normalize_answer(gold_answer):
#         return 1
#     else:
#         return -1

from sentence_transformers import SentenceTransformer, util
# Load a pre-trained sentence embedding model
model = SentenceTransformer('all-MiniLM-L6-v2')
# def get_reward(chosen_answer, gold_answer):
#     # Compute embeddings for the chosen and gold answers
#     chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
#     gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
#     # Compute cosine similarity
#     similarity = util.cos_sim(chosen_embedding, gold_embedding).item()
#     return similarity  # Reward is the similarity score (between -1 and 1)

def get_reward(chosen_answer, gold_answer, action, num_llm_calls, alpha=1.0, beta=0.1):
    """
    Reward function with tradeoff between accuracy and efficiency.
    :param chosen_answer: Answer chosen by the agent.
    :param gold_answer: Ground truth answer.
    :param action: The action selected by the agent (0=CB, 1=OB, 2=Child).
    :param num_llm_calls: The number of LLM calls made during the current decision.
    :param alpha: Weight for accuracy reward.
    :param beta: Weight for efficiency penalty.
    :return: A reward value that balances accuracy and efficiency.
    """
    # Compute Accuracy Reward
    chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
    gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
    similarity = util.cos_sim(chosen_embedding, gold_embedding).item()  # Value between -1 (opposite) and 1 (exact match)
    accuracy_reward = (similarity + 1) / 2  # Normalize to range [0, 1]

    # Define LLM cost for each action
    ACTION_COSTS = {
        0: 1,  # CB cost
        1: 1,  # OB cost (higher than CB because of retrieval)
        2: 1,  # Child decomposition cost (higher due to multiple child evaluations)
    }

    # Compute Efficiency Penalty
    action_cost = ACTION_COSTS.get(action, 1)  # Default to 1 if action is unrecognized
    efficiency_penalty = num_llm_calls * action_cost  # Penalize based on LLM usage

    # Combine Accuracy and Efficiency
    reward = alpha * accuracy_reward - beta * efficiency_penalty

    return reward


def get_state(node, depth=0, max_length=50):
    # Extract raw features
    cb_logprob = node.get("cb_answer", [None, None, None, []])[3] or []
    ob_logprob = node.get("ob_answer", [None, None, None, []])[3] or []
    # child_logprob = node.get("child_answer", [None, None, None, []])[3] or []

    # Pad or truncate logprobs each is of length 50
    cb_logprob = pad_or_truncate(cb_logprob, max_length)
    ob_logprob = pad_or_truncate(ob_logprob, max_length)
    # child_logprob = pad_or_truncate(child_logprob, max_length)

    # Other features (7)
    has_children = 1 if len(node.get("sons", [])) else 0
    question_length = len(node.get("question_text", "").split())
    question_type = encode_question_type(node.get("question_text", ""))
    num_children = len(node.get("sons", []))
    cb_success_rate = get_success_rate("cb")
    ob_success_rate = get_success_rate("ob")
    child_success_rate = get_success_rate("child")

    # Semantic features
    # Each embedding is a 384-dimensional vector and they are total 3 = 1152
    question_text = node.get("question_text", "")
    question_embedding = model.encode(question_text, convert_to_tensor=False)
    cb_answer_embedding = model.encode(node.get("cb_answer", [""])[0], convert_to_tensor=False)
    ob_answer_embedding = model.encode(node.get("ob_answer", [""])[0], convert_to_tensor=False)
    # child_answer_embedding = model.encode(node.get("child_answer", [""])[0], convert_to_tensor=False)

    # Confidence and uncertainty (total 2)
    cb_confidence = node.get("cb_answer", [None, None, None, []])[1] or 0.0
    ob_confidence = node.get("ob_answer", [None, None, None, []])[1] or 0.0
    child_confidence = node.get("child_answer", [None, None, None, []])[1] or 0.0

    # Structural features (total 2)
    tree_depth = depth
    tree_position = 0 if depth == 0 else 1  # 0 for root, 1 for intermediate/leaf

    # Temporal features (example: sliding window of last 3 actions) 
    # action_history = node.get("action_history", [0, 0, 0])  # Placeholder for action history
    # action_success_history = node.get("action_success_history", [0, 0, 0])  # Placeholder for success history

    # External knowledge features
    # num_retrieved_documents = node.get("num_retrieved_documents", 0)
    # entity_linking_confidence = node.get("entity_linking_confidence", 0.0)

    # Answer quality (total 2)
    cb_answer_length = len(node.get("cb_answer", [""])[0].split())
    ob_answer_length = len(node.get("ob_answer", [""])[0].split())
    # child_answer_length = len(node.get("child_answer", [""])[0].split())

    # Build state vector
    state = (
        cb_logprob +  # CB log probabilities
        ob_logprob +  # OB log probabilities
        [has_children, question_length, question_type, num_children, cb_success_rate, ob_success_rate, child_success_rate] +  # Basic features
        list(question_embedding) +  # Semantic embedding of the question
        list(cb_answer_embedding) +  # Semantic embedding of the CB answer
        list(ob_answer_embedding) +  # Semantic embedding of the OB answer
        [cb_confidence, ob_confidence] +  # Confidence scores for CB and OB
        [tree_depth, tree_position] +  # Structural features
        [cb_answer_length, ob_answer_length]  # Answer quality features
    )

    # maybe try question only by masking all and add log probs for the solved one and try adding verification step in the end.

    return torch.FloatTensor(state)

class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=1e-3, gamma=0.99):
        self.q_network = QNetwork(state_dim, action_dim, hidden_dim)
        self.target_network = QNetwork(state_dim, action_dim, hidden_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.replay_buffer = []

    def select_action(self, state, epsilon):
        if random.random() < epsilon:
            if state[100] == 0: # has children = 0, because child is not possible
                return random.randint(0, len(action_space) - 2)
            else:
                return random.randint(0, len(action_space) - 1)
        else:
            with torch.no_grad():
                q_values = self.q_network(state)
                if state[100] == 0: # has children = 0
                    q_values[-1] = -float('inf')
                return torch.argmax(q_values).item()

    def train(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return
        batch = random.sample(self.replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.stack(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.stack(next_states)
        dones = torch.FloatTensor(dones)

        # Compute Q-values for current states
        q_values = self.q_network(states).gather(1, actions.unsqueeze(1))

        # Compute target Q-values
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # update the success rates
        for i, action in enumerate(actions):
            if rewards[i] == 1:
                update_success_rate(action_space[action], True)
            else:
                update_success_rate(action_space[action], False)

        # Compute loss and update the network
        # loss = nn.MSELoss()(q_values.squeeze(), target_q_values)
        loss = nn.SmoothL1Loss()(q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

In [None]:
import os
import torch
import copy

# Create directory for saved models
os.makedirs("saved_models_trained_all", exist_ok=True)

# Define different configurations for alpha and beta
configurations = [
    {"alpha": 2.0, "beta": 0.05, "name": "High_Accuracy"},
    {"alpha": 1.0, "beta": 0.1, "name": "Balanced"},
    {"alpha": 0.5, "beta": 0.2, "name": "Efficiency_Focused"}
]

# Dictionary to store results
accuracy_rewards_numcalls = {}

# Loop through different settings
for config in configurations:
    alpha = config["alpha"]
    beta = config["beta"]
    config_name = config["name"]
    state_dim = 100 + 7 + 1152 + 2 + 2 + 2 # Number of features in the state vector
    action_dim = len(action_space)  # Number of actions (CB, OB, Child)
    hidden_dim = 128  # Hidden layer size
    lr = 5e-3  # Increase from 1e-3  //     1e-3  # Learning rate
    gamma = 0.99  # Discount factor
    epsilon = 1.0  # Initial exploration rate
    epsilon_min = 0.01  # Minimum exploration rate
    epsilon_decay = 0.98  # Decay rate for exploration
    batch_size = 128  # Mini-batch size
    num_episodes = 20  # Number of training episodes

    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)

    print(f"\nTraining with {config_name} (alpha={alpha}, beta={beta})")

    max_accuracy = 0
    best_agent = None

    # Store rewards, accuracy and total number of LLM calls for each episode
    rewards_list = []
    accuracy_list = []
    total_num_llm_calls_list = []

    # Set model to training mode
    agent.q_network.train()

    for episode in range(num_episodes):
        print(f"Episode {episode + 1}/{num_episodes}")
        total_reward = 0
        correct_answers = 0
        total_parent_nodes = 0
        total_num_llm_calls = 0

        for example in data:
            for node in example:
                if "fa" not in node:  # Only process root nodes (original questions)
                    gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                    total_parent_nodes += 1
                    state = get_state(node)  
                    action = agent.select_action(state, epsilon)  
                    chosen_answer = None
                    fallback_used = False  
                    num_llm_calls = None

                    question, _ = get_question_after_resolving_references(node, example)
                    # Simulate the action
                    if action == 0:
                        node['cb_answer'] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])
                        # chosen_answer = node.get("cb_answer", [None])
                        num_llm_calls = 1
                    elif action == 1:
                        # since you are the parent, so topic entities are empty because no references
                        node['ob_answer'] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])
                        num_llm_calls = 1
                    elif action == 2:
                        tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                        child_experiences = []

                        for node_ in tree_with_answers_chosen_by_agent:
                            question_, topic_entities = get_question_after_resolving_references(node_, tree_with_answers_chosen_by_agent)
                            if len(node_["sons"]) == 0:
                                child_state = get_state(node_, depth=1)
                                child_action = agent.select_action(child_state, epsilon)
                                if child_action == 0:  
                                    node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 1  
                                        node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                        child_answer = node_.get("ob_answer", [None])
                                elif child_action == 1:  
                                    node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 0  
                                        node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                        child_answer = node_.get("cb_answer", [None])

                                node_["answer"] = child_answer
                                child_experiences.append((child_state, child_action, 0, child_state, False))  
                            else:
                                node_["child_answer"], node_["answer"] = aggregate_multihop_answer(node_, tree_with_answers_chosen_by_agent, dataset_used)
                                chosen_answer = node_["child_answer"]
                        
                        num_llm_calls = len(child_experiences) + 1 + 1  # number of children + 1 (decomposition) + 1 (aggregation)

                    if "unknown" in chosen_answer[0].lower().strip():
                        fallback_used = True
                        if action == 0:  
                            fallback_action = 1
                            node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                            chosen_answer = node.get("ob_answer", [None])
                        elif action == 1:  
                            fallback_action = 0
                            node["cb_answer"] = get_cb_answer(question, dataset_used)
                            chosen_answer = node.get("cb_answer", [None])
                        elif action == 2:  
                            fallback_action = 0
                            node["cb_answer"] = get_cb_answer(question, dataset_used)
                            chosen_answer = node.get("cb_answer", [None])

                    # Compute reward
                    reward = get_reward(chosen_answer[0], gold_answer, action, num_llm_calls, alpha, beta)
                    total_reward += reward

                    # if normalize_answer(chosen_answer[0]) == normalize_answer(gold_answer):
                    if are_answers_equivalent_using_llm(gold_answer, chosen_answer[0]):
                        correct_answers += 1

                    next_state = get_state(node)

                    if not fallback_used:
                        agent.replay_buffer.append((state, action, reward, next_state, False))  
                    else:
                        agent.replay_buffer.append((state, fallback_action, reward, next_state, False))

                    if action == 2:
                        num_children = len(node.get("sons", []))
                        child_reward = reward / num_children  
                        for child_state, child_action, _, next_child_state, done in child_experiences:
                            agent.replay_buffer.append((child_state, child_action, reward, next_child_state, done))

                    # update total number of llm calls in this episode
                    total_num_llm_calls += num_llm_calls

                    agent.train(batch_size)
        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        if episode % 10 == 0:
            agent.update_target_network()

        accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
        rewards_list.append(total_reward)
        accuracy_list.append(accuracy)
        total_num_llm_calls_list.append(total_num_llm_calls)
        if accuracy > max_accuracy:
            max_accuracy = accuracy
            best_agent = copy.deepcopy(agent)

        print(f"Total Reward: {total_reward}")
        print(f"Accuracy: {accuracy:.2f}%")
        print(f"Total LLM Calls in this episode: {total_num_llm_calls}")
        print(f"Epsilon: {epsilon:.4f}")
        print()

    # Store the rewards and accuracy for this configuration
    accuracy_rewards_numcalls[config_name] = {
        "rewards": rewards_list,
        "accuracy": accuracy_list,
        "num_calls": total_num_llm_calls_list
    }

    # Save model after training with each configuration
    model_path = f"saved_models_trained_all/agent_{config_name}.pth"
    torch.save(agent.q_network.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    # Save the best model overall
    if best_agent is not None:
        best_model_path = f"saved_models_trained_all/best_agent_{config_name}.pth"
        torch.save(best_agent.q_network.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")


In [None]:
import matplotlib.pyplot as plt

# Get rewards, accuracy and number of LLM calls for each configuration
balanced_rewards = accuracy_rewards_numcalls["Balanced"]["rewards"]
balanced_accuracy = accuracy_rewards_numcalls["Balanced"]["accuracy"]
balanced_num_calls = accuracy_rewards_numcalls["Balanced"]["num_calls"]

efficiency_rewards = accuracy_rewards_numcalls["Efficiency_Focused"]["rewards"]
efficiency_accuracy = accuracy_rewards_numcalls["Efficiency_Focused"]["accuracy"]
efficiency_num_calls = accuracy_rewards_numcalls["Efficiency_Focused"]["num_calls"]

high_accuracy_accuracy = accuracy_rewards_numcalls["High_Accuracy"]["accuracy"]
high_accuracy_rewards = accuracy_rewards_numcalls["High_Accuracy"]["rewards"]
high_accuracy_num_calls = accuracy_rewards_numcalls["High_Accuracy"]["num_calls"]

# min max scaling for rewards
balanced_rewards = (balanced_rewards - np.min(balanced_rewards)) / (np.max(balanced_rewards) - np.min(balanced_rewards))
efficiency_rewards = (efficiency_rewards - np.min(efficiency_rewards)) / (np.max(efficiency_rewards) - np.min(efficiency_rewards))
high_accuracy_rewards = (high_accuracy_rewards - np.min(high_accuracy_rewards)) / (np.max(high_accuracy_rewards) - np.min(high_accuracy_rewards))

# Plotting
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(balanced_rewards, label="Balanced")
plt.plot(efficiency_rewards, label="Efficiency Focused")
plt.plot(high_accuracy_rewards, label="High Accuracy")
plt.title("Normalized Rewards")
plt.xlabel("Episode")
plt.ylabel("Normalized Reward")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(balanced_accuracy, label="Balanced")
plt.plot(efficiency_accuracy, label="Efficiency Focused")
plt.plot(high_accuracy_accuracy, label="High Accuracy")
plt.title("Accuracy")
plt.xlabel("Episode")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 6))
plt.plot(balanced_num_calls, label="Balanced")
plt.plot(efficiency_num_calls, label="Efficiency Focused")
plt.plot(high_accuracy_num_calls, label="High Accuracy")
plt.title("Number of LLM Calls")
plt.xlabel("Episode")
plt.ylabel("Total LLM Calls")
plt.legend()
plt.show()

In [None]:
# Print number of calls for each configuration
balanced_num_calls = accuracy_rewards_numcalls["Balanced"]["num_calls"][-1]
efficiency_num_calls = accuracy_rewards_numcalls["Efficiency_Focused"]["num_calls"][-1]
high_accuracy_num_calls = accuracy_rewards_numcalls["High_Accuracy"]["num_calls"][-1]

print("balanced_num_calls", balanced_num_calls)
print("efficiency_num_calls", efficiency_num_calls)
print("high_accuracy_num_calls", high_accuracy_num_calls)

## Try it on test set data of hotpot qa 

In [None]:
# Evaluate the agent on the test data
def evaluate_agent(agent, data, q2gold):
    agent.q_network.eval()  # Set the model to evaluation mode
    correct_answers = 0
    total_parent_nodes = 0
    results = []
    for example in data:
        for node in example:
            if "fa" not in node:  # Only process root nodes (original questions)
                gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                total_parent_nodes += 1
                state = get_state(node)
                action = agent.select_action(state, epsilon=0)  # No exploration during evaluation
                chosen_answer = None
                fallback_used = False  # Track if a fallback was used
                fallback_action = None

                question, _ = get_question_after_resolving_references(node, example)
                # Simulate the action (choose answer based on action)
                if action == 0:  # CB
                    node['cb_answer'] = get_cb_answer(question, dataset_used)
                    chosen_answer = node.get("cb_answer", [None])[0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to OB
                        node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])[0]
                        fallback_action = 1
                        fallback_used = True
                elif action == 1:  # OB
                    node['ob_answer'] = get_singlehop_ob_answer(question, [], dataset_used)
                    chosen_answer = node.get("ob_answer", [None])[0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to CB
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])[0]
                        fallback_action = 0
                        fallback_used = True
                elif action == 2:  # Child
                    tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                    for child_idx in node.get("sons", []):
                        child = tree_with_answers_chosen_by_agent[child_idx]
                        child_state = get_state(child, depth=1)
                        child_action = agent.select_action(child_state, epsilon=0)
                        question_, topic_entities = get_question_after_resolving_references(child, tree_with_answers_chosen_by_agent)
                        if child_action == 0:  # CB
                            child["cb_answer"] = get_cb_answer(question_, dataset_used)
                            child_answer = child.get("cb_answer", [None])
                            if "unknown" in child_answer[0].lower().strip():  # Fallback to OB
                                child["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                child_answer = child.get("ob_answer", [None])
                        elif child_action == 1:  # OB
                            child["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                            child_answer = child.get("ob_answer", [None])
                            if "unknown" in child_answer[0].lower().strip():  # Fallback to CB
                                child["cb_answer"] = get_cb_answer(question_, dataset_used)
                                child_answer = child.get("cb_answer", [None])

                        tree_with_answers_chosen_by_agent[child_idx]["answer"] = child_answer

                    # Generate child_answer for the parent node
                    # print("tree_with_answers_chosen_by_agent", tree_with_answers_chosen_by_agent)
                    node["child_answer"], node["answer"] = aggregate_multihop_answer(node, tree_with_answers_chosen_by_agent, dataset_used)
                    chosen_answer = node["child_answer"][0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to CB or OB
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])[0]  # Try CB first
                        # fallback_used = True
                        if "unknown" in chosen_answer.lower().strip():
                            node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                            chosen_answer = node.get("ob_answer", [None])[0]  # Try OB next

                # Compute reward
                # if chosen_answer and normalize_answer(chosen_answer) == normalize_answer(gold_answer):
                if chosen_answer and are_answers_equivalent_using_llm(chosen_answer, gold_answer):
                    correct_answers += 1

                # Store results
                results.append({
                    "idx": node["idx"],
                    "question": node["question_text"],
                    "answer": chosen_answer,
                    "gold": gold_answer,
                    "method": action_space[action] if not fallback_used else action_space[fallback_action]
                })

    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    print(f"Evaluation Accuracy: {accuracy:.2f}%")
    return results, correct_answers, total_parent_nodes

# List of trained models to evaluate
configurations = [
    {"alpha": 2.0, "beta": 0.05, "name": "High_Accuracy"},
    {"alpha": 1.0, "beta": 0.1, "name": "Balanced"},
    {"alpha": 0.5, "beta": 0.2, "name": "Efficiency_Focused"}
]

predictions_per_config = {}
correct_answers_per_config = {}
total_parent_nodes_per_config = {}
accuracy_per_config = {}

In [None]:
# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_hotpotqa, q2gold_test_hotpotqa)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


In [None]:
# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_2wiki, q2gold_test_2wiki)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


In [None]:
# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_musique, q2gold_test_musique)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


# RL Approach with deep learning (using Question + cb + ob + Reformulation of question)

In [None]:
# Load the raw data and question decompositions
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_dev_random_100.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data.extend(raw_data_2wiki)
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_dev_random_100.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'
raw_data.extend(raw_data_musique)

q2dq = json.load(open("./question_decompositions-devset-hotpotqa.json"))
q2dq_2wiki = json.load(open("./question_decompositions-devset-2wiki.json"))
q2dq_musique = json.load(open("./question_decompositions-devset-musique.json"))
q2dq.update(q2dq_2wiki)
q2dq.update(q2dq_musique)

# Create q2gold map
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q_type = item['dataset']
        q2gold[question] = (gold, q_type)
    except Exception as e:
        # Skip if question not found in question_decompositions
        continue

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)


In [None]:
action_space = ['CB', 'OB', 'Child', 'CB_REFORMULATE', 'OB_REFORMULATE', 'Child_REFORMULATE']
success_counts = {"cb": 0, "ob": 0, "child": 0, "cb_reformulate": 0, "ob_reformulate": 0, "child_reformulate": 0}
attempt_counts = {"cb": 0, "ob": 0, "child": 0, "cb_reformulate": 0, "ob_reformulate": 0, "child_reformulate": 0}

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


def pad_or_truncate(logprobs, max_length=50, pad_value=-100):
    if len(logprobs) < max_length:
        # Pad with pad_value
        return logprobs + [pad_value] * (max_length - len(logprobs))
    else:
        # Truncate to max_length
        return logprobs[:max_length]

# # Define the reward function
# def get_reward(chosen_answer, gold_answer):
#     if normalize_answer(chosen_answer) == normalize_answer(gold_answer):
#         return 1
#     else:
#         return -1

from sentence_transformers import SentenceTransformer, util
# Load a pre-trained sentence embedding model
model = SentenceTransformer('all-MiniLM-L6-v2')
# def get_reward(chosen_answer, gold_answer):
#     # Compute embeddings for the chosen and gold answers
#     chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
#     gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
#     # Compute cosine similarity
#     similarity = util.cos_sim(chosen_embedding, gold_embedding).item()
#     return similarity  # Reward is the similarity score (between -1 and 1)

def get_reward(chosen_answer, gold_answer, action, num_llm_calls, alpha=1.0, beta=0.1):
    """
    Reward function with tradeoff between accuracy and efficiency.
    :param chosen_answer: Answer chosen by the agent.
    :param gold_answer: Ground truth answer.
    :param action: The action selected by the agent (0=CB, 1=OB, 2=Child).
    :param num_llm_calls: The number of LLM calls made during the current decision.
    :param alpha: Weight for accuracy reward.
    :param beta: Weight for efficiency penalty.
    :return: A reward value that balances accuracy and efficiency.
    """
    # Compute Accuracy Reward
    chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
    gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
    similarity = util.cos_sim(chosen_embedding, gold_embedding).item()  # Value between -1 (opposite) and 1 (exact match)
    accuracy_reward = max(0, similarity)  # Ensure rewards are non-negative

    # Define LLM cost for each action
    ACTION_COSTS = {
        0: 1,  # CB cost
        1: 1,  # OB cost (higher than CB because of retrieval)
        2: 1,  # Child decomposition cost (higher due to multiple child evaluations)
        3: 1,  # CB Reformulation cost
        4: 1,  # OB Reformulation cost
        5: 1   # Child Reformulation cost
    }

    # Compute Efficiency Penalty
    action_cost = ACTION_COSTS.get(action, 1)  # Default to 1 if action is unrecognized
    efficiency_penalty = num_llm_calls * action_cost  # Penalize based on LLM usage

    # Combine Accuracy and Efficiency
    reward = alpha * accuracy_reward - beta * efficiency_penalty

    return reward


def get_state(node, depth=0, max_length=50):
    # Extract raw features
    cb_logprob = node.get("cb_answer", [None, None, None, []])[3] or []
    ob_logprob = node.get("ob_answer", [None, None, None, []])[3] or []
    # child_logprob = node.get("child_answer", [None, None, None, []])[3] or []

    # Pad or truncate logprobs each is of length 50
    cb_logprob = pad_or_truncate(cb_logprob, max_length)
    ob_logprob = pad_or_truncate(ob_logprob, max_length)
    # child_logprob = pad_or_truncate(child_logprob, max_length)

    # Other features (7)
    has_children = 1 if len(node.get("sons", [])) else 0
    question_length = len(node.get("question_text", "").split())
    question_type = encode_question_type(node.get("question_text", ""))
    num_children = len(node.get("sons", []))
    cb_success_rate = get_success_rate("cb")
    ob_success_rate = get_success_rate("ob")
    child_success_rate = get_success_rate("child")

    # Semantic features
    # Each embedding is a 384-dimensional vector and they are total 3 = 1152
    question_text = node.get("question_text", "")
    question_embedding = model.encode(question_text, convert_to_tensor=False)
    cb_answer_embedding = model.encode(node.get("cb_answer", [""])[0], convert_to_tensor=False)
    ob_answer_embedding = model.encode(node.get("ob_answer", [""])[0], convert_to_tensor=False)
    # child_answer_embedding = model.encode(node.get("child_answer", [""])[0], convert_to_tensor=False)

    # Confidence and uncertainty (total 2)
    cb_confidence = node.get("cb_answer", [None, None, None, []])[1] or 0.0
    ob_confidence = node.get("ob_answer", [None, None, None, []])[1] or 0.0
    child_confidence = node.get("child_answer", [None, None, None, []])[1] or 0.0

    # Structural features (total 2)
    tree_depth = depth
    tree_position = 0 if depth == 0 else 1  # 0 for root, 1 for intermediate/leaf

    # Temporal features (example: sliding window of last 3 actions) 
    # action_history = node.get("action_history", [0, 0, 0])  # Placeholder for action history
    # action_success_history = node.get("action_success_history", [0, 0, 0])  # Placeholder for success history

    # External knowledge features
    # num_retrieved_documents = node.get("num_retrieved_documents", 0)
    # entity_linking_confidence = node.get("entity_linking_confidence", 0.0)

    # Answer quality (total 2)
    cb_answer_length = len(node.get("cb_answer", [""])[0].split())
    ob_answer_length = len(node.get("ob_answer", [""])[0].split())
    # child_answer_length = len(node.get("child_answer", [""])[0].split())

    # Build state vector
    state = (
        cb_logprob +  # CB log probabilities
        ob_logprob +  # OB log probabilities
        [has_children, question_length, question_type, num_children, cb_success_rate, ob_success_rate, child_success_rate] +  # Basic features
        list(question_embedding) +  # Semantic embedding of the question
        list(cb_answer_embedding) +  # Semantic embedding of the CB answer
        list(ob_answer_embedding) +  # Semantic embedding of the OB answer
        [cb_confidence, ob_confidence] +  # Confidence scores for CB and OB
        [tree_depth, tree_position] +  # Structural features
        [cb_answer_length, ob_answer_length]  # Answer quality features
    )

    # maybe try question only by masking all and add log probs for the solved one and try adding verification step in the end.

    return torch.FloatTensor(state)

class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=1e-3, gamma=0.99):
        self.q_network = QNetwork(state_dim, action_dim, hidden_dim)
        self.target_network = QNetwork(state_dim, action_dim, hidden_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.replay_buffer = []

    def select_action(self, state, epsilon):
        if random.random() < epsilon:
            if state[100] == 0:  # has children = 0, meaning child actions are not possible
                return random.choice([0, 1, 3, 4])  # Exclude actions 2 and 5
            else:
                return random.randint(0, 5)  # Choose from all 6 actions
        else:
            with torch.no_grad():
                q_values = self.q_network(state).clone()  # Clone to avoid in-place modification issues
                
                if state[100] == 0:  # has children = 0
                    q_values[2] = -float('inf')  # Mask action 2
                    q_values[5] = -float('inf')  # Mask action 5

                return torch.argmax(q_values).item()


    def train(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return
        batch = random.sample(self.replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.stack(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.stack(next_states)
        dones = torch.FloatTensor(dones)

        # Compute Q-values for current states
        q_values = self.q_network(states).gather(1, actions.unsqueeze(1))

        # Compute target Q-values
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # update the success rates
        for i, action in enumerate(actions):
            if rewards[i] == 1:
                update_success_rate(action_space[action], True)
            else:
                update_success_rate(action_space[action], False)

        # Compute loss and update the network
        # loss = nn.MSELoss()(q_values.squeeze(), target_q_values)
        loss = nn.SmoothL1Loss()(q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

In [None]:
def reformulate_question(question):
   """
   Reformulates a question using a language model based on pre-defined reformulation rules.
   :param question: The original question to reformulate.
   :return: Reformulated question text.
   """

   REFORMULATION_PROMPT = f"""
   Reformulate the question to improve readability, grammar, and clarity. Simplify complex phrasing and structure the question in a way that is easier to understand, without changing the core intent or meaning. Avoid adding or removing any information. Follow these steps:
   1. Simplify the question by removing unnecessary or overly complex wording.
   2. Fix grammar and punctuation issues to improve readability.
   3. Rewrite the question in a way that keeps the meaning identical but makes it easier to interpret.
   4. Return the result directly without any additional information.

   Examples:

   1. Original: "What are the possible reasons for the decline in monarch butterfly populations, and how does urbanization contribute to this issue?"
      Reformulated: "What causes the decline in monarch butterfly populations, and how does urbanization play a role?"

   2. Original: "How is social media influencing the mental health of teenagers, specifically with regard to anxiety and depression levels?"
      Reformulated: "How does social media affect teenagers' mental health, particularly with respect to anxiety and depression?"

   3. Original: "What is known about the origin of black holes and how do they affect the galaxies they exist in?"
      Reformulated: "What do we know about the origin of black holes, and how do they affect their galaxies?"

   4. Original: "Describe the ways in which renewable energy sources like solar and wind power are replacing fossil fuels in energy production."
      Reformulated: "How are renewable energy sources like solar and wind replacing fossil fuels in energy production?"

   Now, reformulate the question below:

   Original Question:
   {question}

   Reformulated Question:
   """
   
   # Format the prompt
   prompt = REFORMULATION_PROMPT.format(question=question)

   # Use LLM API to get the reformulated question
   
   # TODO:: this should be with temperature 0.7, and dont use cache
   response, tag = togetherai_caller.req2provider(prompt=prompt, max_tokens=None, stop= None, use_cache=True)
   response = response[0]
   return response['message']['content'].strip()

In [None]:
import os
import torch
import copy

# Create directory for saved models
os.makedirs("saved_models_with_reformulation_trained_all", exist_ok=True)

# Define different configurations for alpha and beta
configurations = [
    {
        "hidden_dim": 128,
        "lr": 1e-3,
        "gamma": 0.99,
        "epsilon": 1.0,
        "epsilon_min": 0.01,
        "epsilon_decay": 0.996,
        "batch_size": 32,
        "num_episodes": 50,
        "alpha": 5.0,
        "beta": 0.05,
        "name": "High_Accuracy_Priority"
    },
    {
        "hidden_dim": 128,
        "lr": 1e-3,
        "gamma": 0.99,
        "epsilon": 1.0,
        "epsilon_min": 0.01,
        "epsilon_decay": 0.996,
        "batch_size": 32,
        "num_episodes": 50,
        "alpha": 1.0,
        "beta": 0.1,
        "name": "Balanced_Reward_Tradeoff"
    },
    {
        "hidden_dim": 128,
        "lr": 5e-4,
        "gamma": 0.99,
        "epsilon": 1.0,
        "epsilon_min": 0.1,
        "epsilon_decay": 0.996,
        "batch_size": 32,
        "num_episodes": 50,
        "alpha": 0.5,
        "beta": 2.0,
        "name": "Efficiency_Focused"
    }
]

# Dictionary to store results
accuracy_rewards_numcalls = {}

# Loop through different settings
for config in configurations:
    alpha = config["alpha"]
    beta = config["beta"]
    
    epsilon = config["epsilon"]
    epsilon_min = config["epsilon_min"]
    epsilon_decay = config["epsilon_decay"]
    batch_size = config["batch_size"]
    num_episodes = config["num_episodes"]
    lr = config["lr"]
    gamma = config["gamma"]
    hidden_dim = config["hidden_dim"]
    action_dim = len(action_space)

    config_name = config["name"]
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)

    print(f"\nTraining with {config_name} (alpha={alpha}, beta={beta})")

    max_accuracy = 0
    best_agent = None

    # Store rewards, accuracy and total number of LLM calls for each episode
    rewards_list = []
    accuracy_list = []
    total_num_llm_calls_list = []

    # Set model to training mode
    agent.q_network.train()

    for episode in range(num_episodes):
        print(f"Episode {episode + 1}/{num_episodes}")
        total_reward = 0
        correct_answers = 0
        total_parent_nodes = 0
        total_num_llm_calls = 0

        for example in data:
            for node in example:
                if "fa" not in node:  # Only process root nodes (original questions)
                    gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                    total_parent_nodes += 1
                    state = get_state(node)  
                    action = agent.select_action(state, epsilon)  
                    chosen_answer = None
                    fallback_used = False  
                    num_llm_calls = None

                    question, _ = get_question_after_resolving_references(node, example)
                    # Simulate the action
                    if action == 0:
                        node['cb_answer'] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])
                        num_llm_calls = 1
                    elif action == 1:  
                        node['ob_answer'] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])
                        num_llm_calls = 1
                    elif action == 2:
                        tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                        child_experiences = []

                        for node_ in tree_with_answers_chosen_by_agent:
                            question_, topic_entities = get_question_after_resolving_references(node_, tree_with_answers_chosen_by_agent)
                        
                            if len(node_["sons"]) == 0:
                                child_state = get_state(node_, depth=1)
                                child_action = agent.select_action(child_state, epsilon)
                                if child_action == 0:  
                                    node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 1  
                                        node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                        child_answer = node_.get("ob_answer", [None])
                                elif child_action == 1:  
                                    node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 0  
                                        node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                        child_answer = node_.get("cb_answer", [None])
                                elif child_action == 3:
                                    # CB with reformulation
                                    reformulated_question = reformulate_question(question_)
                                    node_["cb_answer"] = get_cb_answer(reformulated_question, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 4
                                        node_["ob_answer"] = get_singlehop_ob_answer(reformulated_question, topic_entities, dataset_used)
                                        child_answer = node_.get("ob_answer", [None])
                                        # num_llm_calls += 1 # for fallback
                                elif child_action == 4:
                                    # OB with reformulation
                                    reformulated_question = reformulate_question(question_)
                                    node_["ob_answer"] = get_singlehop_ob_answer(reformulated_question, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 3
                                        node_["cb_answer"] = get_cb_answer(reformulated_question, dataset_used)
                                        child_answer = node_.get("cb_answer", [None])
                                        # num_llm_calls += 1 # for fallback
                                else:
                                    # Throw exception
                                    raise ValueError("Invalid action selected for child node : " + action_space[child_action])

                                node_["answer"] = child_answer
                                child_experiences.append((child_state, child_action, 0, child_state, False))  
                            else:
                                # If the node has children, aggregate the answers
                                node_["child_answer"], node_["answer"] = aggregate_multihop_answer(node_, tree_with_answers_chosen_by_agent, dataset_used)
                                chosen_answer = node_["child_answer"]
                        
                        num_llm_calls = len(child_experiences) + 1 + 1  # number of children + 1 (decomposition) + 1 (aggregation) 
                        # TODO i am not counting fallbacks here in the children
                    elif action == 3:
                        # CB with reformulation
                        reformulated_question = reformulate_question(question)
                        chosen_answer = get_cb_answer(reformulated_question, dataset_used)
                        num_llm_calls = 2  # 1 for reformulation, 1 for CB
                    
                    elif action == 4:
                        # OB with reformulation
                        reformulated_question = reformulate_question(question)
                        chosen_answer = get_singlehop_ob_answer(reformulated_question, [], dataset_used)
                        num_llm_calls = 2  # 1 for reformulation, 1 for OB
                    
                    elif action == 5:
                        # Child with reformulation
                        tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                        child_experiences = []
                        num_llm_calls = 0

                        for node_ in tree_with_answers_chosen_by_agent:
                            question_, topic_entities = get_question_after_resolving_references(node_, tree_with_answers_chosen_by_agent)

                            if len(node_["sons"]) == 0:
                                child_state = get_state(node_, depth=1)
                                child_action = agent.select_action(child_state, epsilon)
                                # num_llm_calls += 1 # for solving the child
                                if child_action == 0:  
                                    node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 1  
                                        node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                        child_answer = node_.get("ob_answer", [None])
                                        # num_llm_calls += 1 # for fallback
                                elif child_action == 1:  
                                    node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 0  
                                        node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                        child_answer = node_.get("cb_answer", [None])
                                        # num_llm_calls += 1 # for fallback
                                elif child_action == 3:
                                    # CB with reformulation
                                    reformulated_question = reformulate_question(question_)
                                    node_["cb_answer"] = get_cb_answer(reformulated_question, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 4
                                        node_["ob_answer"] = get_singlehop_ob_answer(reformulated_question, topic_entities, dataset_used)
                                        child_answer = node_.get("ob_answer", [None])
                                        # num_llm_calls += 1 # for fallback
                                elif child_action == 4:
                                    # OB with reformulation
                                    reformulated_question = reformulate_question(question_)
                                    node_["ob_answer"] = get_singlehop_ob_answer(reformulated_question, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 3
                                        node_["cb_answer"] = get_cb_answer(reformulated_question, dataset_used)
                                        child_answer = node_.get("cb_answer", [None])
                                        # num_llm_calls += 1 # for fallback
                                else:
                                    # throw error
                                    raise Exception('Wrong action decided for the children ?? ') # Don't! If you catch, likely to hide bugs.


                                node_["answer"] = child_answer
                                child_experiences.append((child_state, child_action, 0, child_state, False))  
                            else:
                                node_["child_answer"], node_["answer"] = aggregate_multihop_answer(node_, tree_with_answers_chosen_by_agent, dataset_used)
                                chosen_answer = node_["child_answer"]
                        
                        num_llm_calls = len(child_experiences) + 1 + 1  # 1 (decomposition) + 1 (aggregation)                        

                    if "unknown" in chosen_answer[0].lower().strip():
                        # num_llm_calls += 1  # Fallback to CB or OB
                        fallback_used = True
                        if action == 0 or action == 3:  
                            fallback_action = 1
                            node['ob_answer'] = get_singlehop_ob_answer(question, [], dataset_used)
                            chosen_answer = node.get("ob_answer", [None])
                        elif action == 1 or action == 4:  
                            fallback_action = 0
                            node['cb_answer'] = get_cb_answer(question, dataset_used)
                            chosen_answer = node.get("cb_answer", [None])
                        elif action == 2 or action == 5:  
                            fallback_action = 0
                            node['cb_answer'] = get_cb_answer(question, dataset_used)
                            chosen_answer = node.get("cb_answer", [None])

                    # Compute reward
                    reward = get_reward(chosen_answer[0], gold_answer, action, num_llm_calls, alpha, beta)
                    total_reward += reward

                    # if normalize_answer(chosen_answer[0]) == normalize_answer(gold_answer):
                    if are_answers_equivalent_using_llm(gold_answer, chosen_answer[0]):
                        correct_answers += 1

                    next_state = get_state(node)

                    if not fallback_used:
                        agent.replay_buffer.append((state, action, reward, next_state, False))  
                    else:
                        agent.replay_buffer.append((state, fallback_action, reward, next_state, False))

                    if action == 2 or action == 5:
                        num_children = len(node.get("sons", []))
                        child_reward = reward / num_children  
                        for child_state, child_action, _, next_child_state, done in child_experiences:
                            agent.replay_buffer.append((child_state, child_action, reward, next_child_state, done))

                    # update total number of llm calls in this episode
                    total_num_llm_calls += num_llm_calls

                    agent.train(batch_size)

        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        if episode % 10 == 0:
            agent.update_target_network()

        accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
        rewards_list.append(total_reward)
        accuracy_list.append(accuracy)
        total_num_llm_calls_list.append(total_num_llm_calls)
        if accuracy > max_accuracy:
            max_accuracy = accuracy
            best_agent = copy.deepcopy(agent)

        print(f"Total Reward: {total_reward}")
        print(f"Accuracy: {accuracy:.2f}%")
        print(f"Total LLM Calls in this episode: {total_num_llm_calls}")
        print(f"Epsilon: {epsilon:.4f}")
        print()

    # Store the rewards and accuracy for this configuration
    accuracy_rewards_numcalls[config_name] = {
        "rewards": rewards_list,
        "accuracy": accuracy_list,
        "num_calls": total_num_llm_calls_list
    }

    # Save model after training with each configuration
    model_path = f"saved_models_with_reformulation_trained_all/agent_{config_name}.pth"
    torch.save(agent.q_network.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    # Save the best model overallx
    if best_agent is not None:
        best_model_path = f"saved_models_with_reformulation_trained_all/best_agent_{config_name}.pth"
        torch.save(best_agent.q_network.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")


In [None]:
import numpy as np

import matplotlib.pyplot as plt

# Dynamically extract data for each configuration
config_names = list(accuracy_rewards_numcalls.keys())
normalized_rewards = {}
accuracies = {}
num_calls = {}

for config_name in config_names:
    rewards = np.array(accuracy_rewards_numcalls[config_name]["rewards"])
    accuracy = accuracy_rewards_numcalls[config_name]["accuracy"]
    calls = accuracy_rewards_numcalls[config_name]["num_calls"]

    # Normalize rewards
    normalized_rewards[config_name] = (rewards - np.min(rewards)) / (np.max(rewards) - np.min(rewards))
    accuracies[config_name] = accuracy
    num_calls[config_name] = calls

# Plotting normalized rewards
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
for config_name in config_names:
    plt.plot(normalized_rewards[config_name], label=config_name)
plt.title("Normalized Rewards")
plt.xlabel("Episode")
plt.ylabel("Normalized Reward")
plt.legend()

# Plotting accuracy
plt.subplot(1, 2, 2)
for config_name in config_names:
    plt.plot(accuracies[config_name], label=config_name)
plt.title("Accuracy")
plt.xlabel("Episode")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.tight_layout()
plt.show()

# Plotting number of LLM calls
plt.figure(figsize=(8, 6))
for config_name in config_names:
    plt.plot(num_calls[config_name], label=config_name)
plt.title("Number of LLM Calls")
plt.xlabel("Episode")
plt.ylabel("Total LLM Calls")
plt.legend()
plt.show()


In [None]:
# Print number of calls for each configuration
config_names = list(accuracy_rewards_numcalls.keys())
for config in config_names:
    num_calls = accuracy_rewards_numcalls[config]["num_calls"][-1]
    print("config", config,  "num calls", num_calls)

## Try it on test set data of hotpot qa 

In [None]:
# Evaluate the agent on the test data
def evaluate_agent(agent, data, q2gold):
    agent.q_network.eval()  # Set the model to evaluation mode
    correct_answers = 0
    total_parent_nodes = 0
    results = []
    for example in data:
        for node in example:
            if "fa" not in node:  # Only process root nodes (original questions)
                gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                total_parent_nodes += 1
                state = get_state(node)
                action = agent.select_action(state, epsilon=0)  # No exploration during evaluation
                chosen_answer = None
                fallback_used = False  # Track if a fallback was used
                fallback_action = None

                question, _ = get_question_after_resolving_references(node, example)
                # Simulate the action
                if action == 0:
                    node['cb_answer'] = get_cb_answer(question, dataset_used)
                    chosen_answer = node.get("cb_answer", [None])
                elif action == 1:  
                    node['ob_answer'] = get_singlehop_ob_answer(question, [], dataset_used)
                    chosen_answer = node.get("ob_answer", [None])
                elif action == 2:  
                    tree_with_answers_chosen_by_agent = copy.deepcopy(example)

                    for node_ in tree_with_answers_chosen_by_agent:
                        question_, topic_entities = get_question_after_resolving_references(node_, tree_with_answers_chosen_by_agent)
                    
                        if len(node_["sons"]) == 0:
                            child_state = get_state(node_, depth=1)
                            child_action = agent.select_action(child_state, epsilon)
                            if child_action == 0:  
                                node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                child_answer = node_.get("cb_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():
                                    child_action = 1  
                                    node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                            elif child_action == 1:  
                                node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                child_answer = node_.get("ob_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():
                                    child_action = 0  
                                    node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                            elif child_action == 3:
                                # CB with reformulation
                                reformulated_question = reformulate_question(question_)
                                node_["cb_answer"] = get_cb_answer(reformulated_question, dataset_used)
                                child_answer = node_.get("cb_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():
                                    child_action = 4
                                    node_["ob_answer"] = get_singlehop_ob_answer(reformulated_question, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                            elif child_action == 4:
                                # OB with reformulation
                                reformulated_question = reformulate_question(question_)
                                node_["ob_answer"] = get_singlehop_ob_answer(reformulated_question, topic_entities, dataset_used)
                                child_answer = node_.get("ob_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():
                                    child_action = 3
                                    node_["cb_answer"] = get_cb_answer(reformulated_question, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                            else:
                                # Throw exception
                                raise ValueError("Invalid action selected for child node : " + action_space[child_action])

                            node_["answer"] = child_answer
                        else:
                            # If the node has children, aggregate the answers
                            node_["child_answer"], node_["answer"] = aggregate_multihop_answer(node_, tree_with_answers_chosen_by_agent, dataset_used)
                            chosen_answer = node_["child_answer"]
                    
                    # TODO i am not counting fallbacks here in the children
                elif action == 3:
                    # CB with reformulation
                    reformulated_question = reformulate_question(question)
                    chosen_answer = get_cb_answer(reformulated_question, dataset_used)
                
                elif action == 4:
                    # OB with reformulation
                    reformulated_question = reformulate_question(question)
                    chosen_answer = get_singlehop_ob_answer(reformulated_question, [], dataset_used)
                elif action == 5:
                    # Child with reformulation
                    tree_with_answers_chosen_by_agent = copy.deepcopy(example)

                    for node_ in tree_with_answers_chosen_by_agent:
                        question_, topic_entities = get_question_after_resolving_references(node_, tree_with_answers_chosen_by_agent)

                        if len(node_["sons"]) == 0:
                            child_state = get_state(node_, depth=1)
                            child_action = agent.select_action(child_state, epsilon)
                            if child_action == 0:  
                                node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                child_answer = node_.get("cb_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():
                                    child_action = 1  
                                    node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                            elif child_action == 1:  
                                node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                child_answer = node_.get("ob_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():
                                    child_action = 0  
                                    node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                            elif child_action == 3:
                                # CB with reformulation
                                reformulated_question = reformulate_question(question_)
                                node_["cb_answer"] = get_cb_answer(reformulated_question, dataset_used)
                                child_answer = node_.get("cb_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():
                                    child_action = 4
                                    node_["ob_answer"] = get_singlehop_ob_answer(reformulated_question, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                            elif child_action == 4:
                                # OB with reformulation
                                reformulated_question = reformulate_question(question_)
                                node_["ob_answer"] = get_singlehop_ob_answer(reformulated_question, topic_entities, dataset_used)
                                child_answer = node_.get("ob_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():
                                    child_action = 3
                                    node_["cb_answer"] = get_cb_answer(reformulated_question, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                            else:
                                # throw error
                                raise Exception('Wrong action decided for the children ?? ') # Don't! If you catch, likely to hide bugs.

                            node_["answer"] = child_answer
                        else:
                            node_["child_answer"], node_["answer"] = aggregate_multihop_answer(node_, tree_with_answers_chosen_by_agent, dataset_used)
                            chosen_answer = node_["child_answer"]
                    

                if "unknown" in chosen_answer[0].lower().strip():
                    fallback_used = True
                    if action == 0 or action == 3:
                        fallback_action = 1
                        node['ob_answer'] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])
                    elif action == 1 or action == 4:
                        fallback_action = 0
                        node['cb_answer'] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])
                    elif action == 2 or action == 5:  
                        fallback_action = 0
                        node['cb_answer'] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])

                # Compute reward
                # if chosen_answer and normalize_answer(chosen_answer[0]) == normalize_answer(gold_answer):
                if chosen_answer and are_answers_equivalent_using_llm(gold_answer, chosen_answer[0]):
                    correct_answers += 1

                # Store results
                results.append({
                    "idx": node["idx"],
                    "question": node["question_text"],
                    "answer": chosen_answer[0],
                    "gold": gold_answer,
                    "method": action_space[action] if not fallback_used else action_space[fallback_action]
                })

    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    print(f"Evaluation Accuracy: {accuracy:.2f}%")
    return results, correct_answers, total_parent_nodes


predictions_per_config = {}
correct_answers_per_config = {}
total_parent_nodes_per_config = {}
accuracy_per_config = {}

In [None]:
# Evaluate each model
for config in configurations:
    
    epsilon = config["epsilon"]
    epsilon_min = config["epsilon_min"]
    epsilon_decay = config["epsilon_decay"]
    batch_size = config["batch_size"]
    num_episodes = config["num_episodes"]
    lr = config["lr"]
    gamma = config["gamma"]
    hidden_dim = config["hidden_dim"]

    config_name = config["name"]
    # model_path = f"saved_models_with_reformulation_trained_all/best_agent_{config_name}.pth"

    model_path = f"saved_models_with_reformulation_trained_all/agent_{config_name}.pth"


    print(f"\nEvaluating model: {config_name}, hidden_dim={hidden_dim}, lr={lr}, gamma={gamma}")
    # Load the trained model
    best_agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    best_agent.q_network.load_state_dict(torch.load(model_path))
    best_agent.q_network.eval()  # Set model to evaluation mode

    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(best_agent, test_data_hotpotqa, q2gold_test_hotpotqa)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


In [None]:
# Evaluate each model
for config in configurations:
    
    epsilon = config["epsilon"]
    epsilon_min = config["epsilon_min"]
    epsilon_decay = config["epsilon_decay"]
    batch_size = config["batch_size"]
    num_episodes = config["num_episodes"]
    lr = config["lr"]
    gamma = config["gamma"]
    hidden_dim = config["hidden_dim"]

    config_name = config["name"]
    # model_path = f"saved_models_with_reformulation_trained_all/best_agent_{config_name}.pth"

    model_path = f"saved_models_with_reformulation_trained_all/agent_{config_name}.pth"


    print(f"\nEvaluating model: {config_name}, hidden_dim={hidden_dim}, lr={lr}, gamma={gamma}")
    # Load the trained model
    best_agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    best_agent.q_network.load_state_dict(torch.load(model_path))
    best_agent.q_network.eval()  # Set model to evaluation mode

    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(best_agent, test_data_2wiki, q2gold_test_2wiki)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


In [None]:
# Evaluate each model
for config in configurations:
    
    epsilon = config["epsilon"]
    epsilon_min = config["epsilon_min"]
    epsilon_decay = config["epsilon_decay"]
    batch_size = config["batch_size"]
    num_episodes = config["num_episodes"]
    lr = config["lr"]
    gamma = config["gamma"]
    hidden_dim = config["hidden_dim"]

    config_name = config["name"]
    # model_path = f"saved_models_with_reformulation_trained_all/best_agent_{config_name}.pth"

    model_path = f"saved_models_with_reformulation_trained_all/agent_{config_name}.pth"


    print(f"\nEvaluating model: {config_name}, hidden_dim={hidden_dim}, lr={lr}, gamma={gamma}")
    # Load the trained model
    best_agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    best_agent.q_network.load_state_dict(torch.load(model_path))
    best_agent.q_network.eval()  # Set model to evaluation mode

    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(best_agent, test_data_musique, q2gold_test_musique)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


# RL Approach with deep learning (Transformers networks) using question only in state

In [None]:
# Load the raw data and question decompositions
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_dev_random_100.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data.extend(raw_data_2wiki)
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_dev_random_100.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'
raw_data.extend(raw_data_musique)

q2dq = json.load(open("./question_decompositions-devset-hotpotqa.json"))
q2dq_2wiki = json.load(open("./question_decompositions-devset-2wiki.json"))
q2dq_musique = json.load(open("./question_decompositions-devset-musique.json"))
q2dq.update(q2dq_2wiki)
q2dq.update(q2dq_musique)

# Create q2gold map
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q_type = item['dataset']
        q2gold[question] = (gold, q_type)
    except Exception as e:
        # Skip if question not found in question_decompositions
        continue

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class TransformerQNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128, nhead=4, num_transformer_layers=2):
        super(TransformerQNetwork, self).__init__()
        
        # Input embedding (maps state_dim to hidden_dim)
        self.embedding_layer = nn.Linear(state_dim, hidden_dim)
        
        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dim_feedforward=hidden_dim * 4)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layers)
        
        # Output layer (maps Transformer output to action_dim)
        self.output_layer = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, state):
        """
        Input: 
          state -> Shape: [sequence_length, batch_size, state_dim]
        Output:
          Q-values for each action -> Shape: [batch_size, action_dim]
        """
        # Embed input state
        embedded_state = self.embedding_layer(state)  # Shape: [sequence_length, batch_size, hidden_dim]
        
        # Pass through Transformer Encoder
        transformer_output = self.transformer_encoder(embedded_state)  # Shape: [sequence_length, batch_size, hidden_dim]
        
        # Use only the first sequence element (sequence_length=1) and map to action_dim
        q_values = self.output_layer(transformer_output.squeeze(0))  # Shape: [batch_size, action_dim]
        return q_values




def pad_or_truncate(logprobs, max_length=50, pad_value=-100):
    if len(logprobs) < max_length:
        # Pad with pad_value
        return logprobs + [pad_value] * (max_length - len(logprobs))
    else:
        # Truncate to max_length
        return logprobs[:max_length]

# # Define the reward function
# def get_reward(chosen_answer, gold_answer):
#     if normalize_answer(chosen_answer) == normalize_answer(gold_answer):
#         return 1
#     else:
#         return -1

from sentence_transformers import SentenceTransformer, util
# Load a pre-trained sentence embedding model
model = SentenceTransformer('all-MiniLM-L6-v2')
# def get_reward(chosen_answer, gold_answer):
#     # Compute embeddings for the chosen and gold answers
#     chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
#     gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
#     # Compute cosine similarity
#     similarity = util.cos_sim(chosen_embedding, gold_embedding).item()
#     return similarity  # Reward is the similarity score (between -1 and 1)

def get_reward(chosen_answer, gold_answer, action, num_llm_calls, alpha=1.0, beta=0.1):
    """
    Reward function with tradeoff between accuracy and efficiency.
    :param chosen_answer: Answer chosen by the agent.
    :param gold_answer: Ground truth answer.
    :param action: The action selected by the agent (0=CB, 1=OB, 2=Child).
    :param num_llm_calls: The number of LLM calls made during the current decision.
    :param alpha: Weight for accuracy reward.
    :param beta: Weight for efficiency penalty.
    :return: A reward value that balances accuracy and efficiency.
    """
    # Compute Accuracy Reward
    chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
    gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
    similarity = util.cos_sim(chosen_embedding, gold_embedding).item()  # Value between -1 (opposite) and 1 (exact match)
    accuracy_reward = max(0, similarity)  # Ensure rewards are non-negative

    # Define LLM cost for each action
    ACTION_COSTS = {
        0: 1,  # CB cost
        1: 1,  # OB cost (higher than CB because of retrieval)
        2: 1,  # Child decomposition cost (higher due to multiple child evaluations)
    }

    # Compute Efficiency Penalty
    action_cost = ACTION_COSTS.get(action, 1)  # Default to 1 if action is unrecognized
    efficiency_penalty = num_llm_calls * action_cost  # Penalize based on LLM usage

    # Combine Accuracy and Efficiency
    reward = alpha * accuracy_reward - beta * efficiency_penalty

    return reward


def get_state(node, depth=0, max_length=50):
    # Extract raw features
    cb_logprob = node.get("cb_answer", [None, None, None, []])[3] or []
    ob_logprob = node.get("ob_answer", [None, None, None, []])[3] or []
    # child_logprob = node.get("child_answer", [None, None, None, []])[3] or []

    # Pad or truncate logprobs each is of length 50
    cb_logprob = pad_or_truncate(cb_logprob, max_length)
    ob_logprob = pad_or_truncate(ob_logprob, max_length)
    # child_logprob = pad_or_truncate(child_logprob, max_length)

    # Other features (7)
    has_children = 1 if len(node.get("sons", [])) else 0
    question_length = len(node.get("question_text", "").split())
    question_type = encode_question_type(node.get("question_text", ""))
    num_children = len(node.get("sons", []))
    cb_success_rate = get_success_rate("cb")
    ob_success_rate = get_success_rate("ob")
    child_success_rate = get_success_rate("child")

    # Semantic features
    # Each embedding is a 384-dimensional vector and they are total 3 = 1152
    question_text = node.get("question_text", "")
    question_embedding = model.encode(question_text, convert_to_tensor=False)
    cb_answer_embedding = model.encode(node.get("cb_answer", [""])[0], convert_to_tensor=False)
    ob_answer_embedding = model.encode(node.get("ob_answer", [""])[0], convert_to_tensor=False)
    # child_answer_embedding = model.encode(node.get("child_answer", [""])[0], convert_to_tensor=False)

    # Confidence and uncertainty (total 2)
    cb_confidence = node.get("cb_answer", [None, None, None, []])[1] or 0.0
    ob_confidence = node.get("ob_answer", [None, None, None, []])[1] or 0.0
    child_confidence = node.get("child_answer", [None, None, None, []])[1] or 0.0

    # Structural features (total 2)
    tree_depth = depth
    tree_position = 0 if depth == 0 else 1  # 0 for root, 1 for intermediate/leaf

    # Temporal features (example: sliding window of last 3 actions) 
    # action_history = node.get("action_history", [0, 0, 0])  # Placeholder for action history
    # action_success_history = node.get("action_success_history", [0, 0, 0])  # Placeholder for success history

    # External knowledge features
    # num_retrieved_documents = node.get("num_retrieved_documents", 0)
    # entity_linking_confidence = node.get("entity_linking_confidence", 0.0)

    # Answer quality (total 2)
    cb_answer_length = len(node.get("cb_answer", [""])[0].split())
    ob_answer_length = len(node.get("ob_answer", [""])[0].split())
    # child_answer_length = len(node.get("child_answer", [""])[0].split())

    # Build state vector
    state = (
        # cb_logprob +  # CB log probabilities
        # ob_logprob +  # OB log probabilities
        [has_children, question_length, question_type, num_children, cb_success_rate, ob_success_rate, child_success_rate] +  # Basic features
        list(question_embedding) +  # Semantic embedding of the question
        # list(cb_answer_embedding) +  # Semantic embedding of the CB answer
        # list(ob_answer_embedding) +  # Semantic embedding of the OB answer
        [cb_confidence, ob_confidence] +  # Confidence scores for CB and OB
        [tree_depth, tree_position] +  # Structural features
        [cb_answer_length, ob_answer_length]  # Answer quality features
    )

    # maybe try question only by masking all and add log probs for the solved one and try adding verification step in the end.

    return torch.FloatTensor(state)

HAS_CHILDREN_INDEX = 0


class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=1e-3, gamma=0.99):
        self.q_network = TransformerQNetwork(state_dim, action_dim, hidden_dim)
        self.target_network = TransformerQNetwork(state_dim, action_dim, hidden_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.replay_buffer = []

    def select_action(self, state, epsilon):
        # Add sequence and batch dimensions for a single state, required by Transformer
        state = state.unsqueeze(0).unsqueeze(0)  # Shape becomes [1 (sequence_length), 1 (batch_size), state_dim]

        if random.random() < epsilon:  # Exploration (random action)
            # Check if the current node has children; restrict actions accordingly
            if state[0, 0, HAS_CHILDREN_INDEX] == 0:  # No children, Child action is invalid
                return random.randint(0, len(action_space) - 2)  # Choose CB or OB
            else:
                return random.randint(0, len(action_space) - 1)  # Choose CB, OB, or Child
        else:  # Exploitation (choose best action using Q-network)
            with torch.no_grad():
                # Forward pass through Q-network to get Q-values for available actions
                q_values = self.q_network(state).squeeze(0).squeeze(0)  # Output shape: [action_dim]
                
                # Mask invalid actions (disable "Child" action if node has no children)
                if state[0, 0, HAS_CHILDREN_INDEX] == 0:  # No children
                    q_values[-1] = -float('inf')  # Mask out "Child" action

                # Select action with the highest Q-value
                return torch.argmax(q_values).item()


    def train(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return
        
        # Sample a mini-batch of transitions
        batch = random.sample(self.replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        # Convert lists to PyTorch Tensors
        states = torch.stack(states)  # Shape: [batch_size, state_dim]
        actions = torch.LongTensor(actions)  # Shape: [batch_size]
        rewards = torch.FloatTensor(rewards)  # Shape: [batch_size]
        next_states = torch.stack(next_states)  # Shape: [batch_size, state_dim]
        dones = torch.FloatTensor(dones)  # Shape: [batch_size]

        # Add sequence length dimension (sequence_length=1) for Transformer input
        states = states.unsqueeze(0)  # Shape: [sequence_length=1, batch_size, state_dim]
        next_states = next_states.unsqueeze(0)  # Shape: [sequence_length=1, batch_size, state_dim]

        # Compute Q-values for current states using the QNetwork
        q_values = self.q_network(states)  # Shape: [sequence_length=1, batch_size, action_dim]
        q_values = q_values.squeeze(0)  # Remove sequence dimension: [batch_size, action_dim]
        q_values = q_values.gather(1, actions.unsqueeze(1))  # Q-values of taken actions: [batch_size, 1]

        # Compute target Q-values using the target network
        with torch.no_grad():
            next_q_values = self.target_network(next_states)  # Shape: [sequence_length=1, batch_size, action_dim]
            next_q_values = next_q_values.squeeze(0)  # Remove sequence dimension: [batch_size, action_dim]
            max_next_q_values = next_q_values.max(1)[0]  # Max Q-values for each batch: [batch_size]

            # Bellman equation for target Q-values
            target_q_values = rewards + self.gamma * max_next_q_values * (1 - dones)  # Shape: [batch_size]

        # Compute loss and backpropagate
        # loss = nn.MSELoss()(q_values.squeeze(), target_q_values)  # q_values.squeeze(): [batch_size], target_q_values: [batch_size]
        loss = nn.SmoothL1Loss()(q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()



    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

In [None]:
import os
import torch
import copy

# Create directory for saved models
os.makedirs("saved_models_transformers_question_only_state_trained_all", exist_ok=True)

# Define different configurations for alpha and beta
configurations = [
    {"alpha": 2.0, "beta": 0.05, "name": "High_Accuracy"},
    {"alpha": 1.0, "beta": 0.1, "name": "Balanced"},
    {"alpha": 0.5, "beta": 0.2, "name": "Efficiency_Focused"}
]

# Dictionary to store results
accuracy_rewards_numcalls = {}

# Loop through different settings
for config in configurations:
    alpha = config["alpha"]
    beta = config["beta"]
    config_name = config["name"]
    
    # Define hyperparameters
    action_space = ['CB', 'OB', 'Child']  # Possible actions
    state_dim = 7 + 384 + 2 + 2 + 2  # 100 + 7 + 1152 + 2 + 2 + 2 # Number of features in the state vector
    action_dim = len(action_space)  # Number of actions (CB, OB, Child)
    hidden_dim = 128  # Hidden layer size
    lr = 5e-3  # Increase from 1e-3  //     1e-3  # Learning rate
    gamma = 0.99  # Discount factor
    epsilon = 1.0  # Initial exploration rate
    epsilon_min = 0.01  # Minimum exploration rate
    epsilon_decay = 0.98  # Decay rate for exploration
    batch_size = 32  # Mini-batch size
    num_episodes = 50  # Number of training episodes
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)

    print(f"\nTraining with {config_name} (alpha={alpha}, beta={beta})")

    max_accuracy = 0
    best_agent = None

    # Store rewards, accuracy and total number of LLM calls for each episode
    rewards_list = []
    accuracy_list = []
    total_num_llm_calls_list = []

    # Set model to training mode
    agent.q_network.train()

    for episode in range(num_episodes):
        print(f"Episode {episode + 1}/{num_episodes}")
        total_reward = 0
        correct_answers = 0
        total_parent_nodes = 0
        total_num_llm_calls = 0

        for example in data:
            for node in example:
                if "fa" not in node:  # Only process root nodes (original questions)
                    gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                    total_parent_nodes += 1
                    state = get_state(node)  
                    action = agent.select_action(state, epsilon)  
                    chosen_answer = None
                    fallback_used = False  
                    num_llm_calls = None

                    question, _ = get_question_after_resolving_references(node, example)
                    # Simulate the action
                    if action == 0:  
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])
                        num_llm_calls = 1
                    elif action == 1:  
                        node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])
                        num_llm_calls = 1
                    elif action == 2:  
                        tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                        child_experiences = []

                        for node_ in tree_with_answers_chosen_by_agent:
                            question_, topic_entities = get_question_after_resolving_references(node_, tree_with_answers_chosen_by_agent)
                        
                            if len(node_["sons"]) == 0:
                                child_state = get_state(node_, depth=1)
                                child_action = agent.select_action(child_state, epsilon)
                                if child_action == 0:  
                                    node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 1  
                                        node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                        child_answer = node_.get("ob_answer", [None])
                                elif child_action == 1:  
                                    node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                                    if "unknown" in child_answer[0].lower().strip():
                                        child_action = 0  
                                        node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                        child_answer = node_.get("cb_answer", [None])

                                node_["answer"] = child_answer
                                child_experiences.append((child_state, child_action, 0, child_state, False))  
                            else:
                                node_["child_answer"], node_["answer"] = aggregate_multihop_answer(node_, tree_with_answers_chosen_by_agent, dataset_used)
                                chosen_answer = node_["child_answer"]
                        
                        num_llm_calls = len(child_experiences) + 1 + 1  # number of children + 1 (decomposition) + 1 (aggregation)

                    if "unknown" in chosen_answer[0].lower().strip():
                        fallback_used = True
                        if action == 0:  
                            fallback_action = 1
                            node['ob_answer'] = get_singlehop_ob_answer(question, [], dataset_used)
                            chosen_answer = node.get("ob_answer", [None])
                        elif action == 1:  
                            fallback_action = 0
                            node['cb_answer'] = get_cb_answer(question, dataset_used)
                            chosen_answer = node.get("cb_answer", [None])
                        elif action == 2:  
                            fallback_action = 0
                            node['cb_answer'] = get_cb_answer(question, dataset_used)
                            chosen_answer = node.get("cb_answer", [None])

                    # Compute reward
                    reward = get_reward(chosen_answer[0], gold_answer, action, num_llm_calls, alpha, beta)
                    total_reward += reward

                    # if normalize_answer(chosen_answer[0]) == normalize_answer(gold_answer):
                    if are_answers_equivalent_using_llm(gold_answer, chosen_answer[0]):
                        correct_answers += 1

                    next_state = get_state(node)

                    if not fallback_used:
                        agent.replay_buffer.append((state, action, reward, next_state, False))  
                    else:
                        agent.replay_buffer.append((state, fallback_action, reward, next_state, False))

                    if action == 2:
                        num_children = len(node.get("sons", []))
                        child_reward = reward / num_children  
                        for child_state, child_action, _, next_child_state, done in child_experiences:
                            agent.replay_buffer.append((child_state, child_action, reward, next_child_state, done))

                    # update total number of llm calls in this episode
                    total_num_llm_calls += num_llm_calls

                    agent.train(batch_size)

        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        if episode % 10 == 0:
            agent.update_target_network()

        accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
        rewards_list.append(total_reward)
        accuracy_list.append(accuracy)
        total_num_llm_calls_list.append(total_num_llm_calls)
        if accuracy > max_accuracy:
            max_accuracy = accuracy
            best_agent = copy.deepcopy(agent)

        print(f"Total Reward: {total_reward}")
        print(f"Accuracy: {accuracy:.2f}%")
        print(f"Total LLM Calls in this episode: {total_num_llm_calls}")
        print(f"Epsilon: {epsilon:.4f}")
        print()

    # Store the rewards and accuracy for this configuration
    accuracy_rewards_numcalls[config_name] = {
        "rewards": rewards_list,
        "accuracy": accuracy_list,
        "num_calls": total_num_llm_calls_list
    }

    # Save model after training with each configuration
    model_path = f"saved_models_transformers_question_only_state_trained_all/agent_{config_name}.pth"
    torch.save(agent.q_network.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    # Save the best model overall
    if best_agent is not None:
        best_model_path = f"saved_models_transformers_question_only_state_trained_all/best_agent_{config_name}.pth"
        torch.save(best_agent.q_network.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")


In [None]:
import matplotlib.pyplot as plt

# Get rewards, accuracy and number of LLM calls for each configuration
balanced_rewards = accuracy_rewards_numcalls["Balanced"]["rewards"]
balanced_accuracy = accuracy_rewards_numcalls["Balanced"]["accuracy"]
balanced_num_calls = accuracy_rewards_numcalls["Balanced"]["num_calls"]

efficiency_rewards = accuracy_rewards_numcalls["Efficiency_Focused"]["rewards"]
efficiency_accuracy = accuracy_rewards_numcalls["Efficiency_Focused"]["accuracy"]
efficiency_num_calls = accuracy_rewards_numcalls["Efficiency_Focused"]["num_calls"]

high_accuracy_accuracy = accuracy_rewards_numcalls["High_Accuracy"]["accuracy"]
high_accuracy_rewards = accuracy_rewards_numcalls["High_Accuracy"]["rewards"]
high_accuracy_num_calls = accuracy_rewards_numcalls["High_Accuracy"]["num_calls"]

# min max scaling for rewards
balanced_rewards = (balanced_rewards - np.min(balanced_rewards)) / (np.max(balanced_rewards) - np.min(balanced_rewards))
efficiency_rewards = (efficiency_rewards - np.min(efficiency_rewards)) / (np.max(efficiency_rewards) - np.min(efficiency_rewards))
high_accuracy_rewards = (high_accuracy_rewards - np.min(high_accuracy_rewards)) / (np.max(high_accuracy_rewards) - np.min(high_accuracy_rewards))

# Plotting
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(balanced_rewards, label="Balanced")
plt.plot(efficiency_rewards, label="Efficiency Focused")
plt.plot(high_accuracy_rewards, label="High Accuracy")
plt.title("Normalized Rewards")
plt.xlabel("Episode")
plt.ylabel("Normalized Reward")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(balanced_accuracy, label="Balanced")
plt.plot(efficiency_accuracy, label="Efficiency Focused")
plt.plot(high_accuracy_accuracy, label="High Accuracy")
plt.title("Accuracy")
plt.xlabel("Episode")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 6))
plt.plot(balanced_num_calls, label="Balanced")
plt.plot(efficiency_num_calls, label="Efficiency Focused")
plt.plot(high_accuracy_num_calls, label="High Accuracy")
plt.title("Number of LLM Calls")
plt.xlabel("Episode")
plt.ylabel("Total LLM Calls")
plt.legend()
plt.show()

In [None]:
# Print number of calls for each configuration
balanced_num_calls = accuracy_rewards_numcalls["Balanced"]["num_calls"][-1]
efficiency_num_calls = accuracy_rewards_numcalls["Efficiency_Focused"]["num_calls"][-1]
high_accuracy_num_calls = accuracy_rewards_numcalls["High_Accuracy"]["num_calls"][-1]

print("balanced_num_calls", balanced_num_calls)
print("efficiency_num_calls", efficiency_num_calls)
print("high_accuracy_num_calls", high_accuracy_num_calls)

## Try it on test set data of hotpot qa 

In [None]:
# Evaluate the agent on the test data
def evaluate_agent(agent, data, q2gold):
    agent.q_network.eval()  # Set the model to evaluation mode
    correct_answers = 0
    total_parent_nodes = 0
    results = []
    for example in data:
        for node in example:
            if "fa" not in node:  # Only process root nodes (original questions)
                gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                total_parent_nodes += 1
                state = get_state(node)
                action = agent.select_action(state, epsilon=0)  # No exploration during evaluation
                chosen_answer = None
                fallback_used = False  # Track if a fallback was used
                fallback_action = None

                question, _ = get_question_after_resolving_references(node, example)
                # Simulate the action (choose answer based on action)
                if action == 0:  # CB
                    node["cb_answer"] = get_cb_answer(question, dataset_used)
                    chosen_answer = node.get("cb_answer", [None])[0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to OB
                        node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])[0]
                        fallback_action = 1
                        fallback_used = True
                elif action == 1:  # OB
                    node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                    chosen_answer = node.get("ob_answer", [None])[0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to CB
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])[0]
                        fallback_action = 0
                        fallback_used = True
                elif action == 2:  # Child
                    tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                    for child_idx in node.get("sons", []):
                        child = tree_with_answers_chosen_by_agent[child_idx]
                        child_state = get_state(child, depth=1)
                        child_action = agent.select_action(child_state, epsilon=0)
                        question_, topic_entities = get_question_after_resolving_references(child, tree_with_answers_chosen_by_agent)
                        if child_action == 0:  # CB
                            child["cb_answer"] = get_cb_answer(question_, dataset_used)
                            child_answer = child.get("cb_answer", [None])
                            if "unknown" in child_answer[0].lower().strip():  # Fallback to OB
                                child["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                child_answer = child.get("ob_answer", [None])
                        elif child_action == 1:  # OB
                            child["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                            child_answer = child.get("ob_answer", [None])
                            if "unknown" in child_answer[0].lower().strip():  # Fallback to CB
                                child["cb_answer"] = get_cb_answer(question_, dataset_used)
                                child_answer = child.get("cb_answer", [None])

                        tree_with_answers_chosen_by_agent[child_idx]["answer"] = child_answer

                    # Generate child_answer for the parent node
                    # print("tree_with_answers_chosen_by_agent", tree_with_answers_chosen_by_agent)
                    node["child_answer"], node["answer"] = aggregate_multihop_answer(node, tree_with_answers_chosen_by_agent, dataset_used)
                    chosen_answer = node["child_answer"][0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to CB or OB
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])[0]  # Try CB first
                        # fallback_used = True
                        if "unknown" in chosen_answer.lower().strip():
                            node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                            chosen_answer = node.get("ob_answer", [None])[0]  # Try OB next

                # Compute reward
                # if chosen_answer and normalize_answer(chosen_answer) == normalize_answer(gold_answer):
                if chosen_answer and are_answers_equivalent_using_llm(chosen_answer, gold_answer):
                    correct_answers += 1

                # Store results
                results.append({
                    "idx": node["idx"],
                    "question": node["question_text"],
                    "answer": chosen_answer,
                    "gold": gold_answer,
                    "method": action_space[action] if not fallback_used else action_space[fallback_action]
                })

    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    print(f"Evaluation Accuracy: {accuracy:.2f}%")
    return results, correct_answers, total_parent_nodes


# List of trained models to evaluate
configurations = [
    {"alpha": 2.0, "beta": 0.05, "name": "High_Accuracy"},
    {"alpha": 1.0, "beta": 0.1, "name": "Balanced"},
    {"alpha": 0.5, "beta": 0.2, "name": "Efficiency_Focused"}
]

predictions_per_config = {}
correct_answers_per_config = {}
total_parent_nodes_per_config = {}
accuracy_per_config = {}

In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_transformers_question_only_state_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_hotpotqa, q2gold_test_hotpotqa)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_transformers_question_only_state_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_2wiki, q2gold_test_2wiki)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_transformers_question_only_state_trained_all/best_agent_{config_name}.pth"

    # Load the trained model
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(agent, test_data_musique, q2gold_test_musique)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


# Random Solver

## Try it on test set data of hotpot qa 

In [None]:
def select_random_action(is_child):
    if is_child == True:
        return random.randint(0, len(action_space) - 2)
    else:
        return random.randint(0, len(action_space) - 1)
            
# Evaluate the agent on the test data
def evaluate_random_agent(data, q2gold):
    correct_answers = 0
    total_parent_nodes = 0
    results = []
    for example in data:
        for node in example:
            if "fa" not in node:  # Only process root nodes (original questions)
                gold_answer, dataset_used = q2gold[node["question_text"].strip()]
                total_parent_nodes += 1
                if len(node.get("sons", [])) != 0:  # Leaf node
                    action = select_random_action(is_child=False)
                else:
                    action = select_random_action(is_child=True)
                chosen_answer = None
                fallback_used = False  # Track if a fallback was used
                fallback_action = None

                question, _ = get_question_after_resolving_references(node, example)
                # Simulate the action (choose answer based on action)
                if action == 0:  # CB
                    node["cb_answer"] = get_cb_answer(question, dataset_used)
                    chosen_answer = node.get("cb_answer", [None])[0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to OB
                        node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])[0]
                        fallback_action = 1
                        fallback_used = True
                elif action == 1:  # OB
                    node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                    chosen_answer = node.get("ob_answer", [None])[0]
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to CB
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])[0]
                        fallback_action = 0
                        fallback_used = True
                elif action == 2:  # Child
                    tree_with_answers_chosen_by_agent = copy.deepcopy(example)
                    for node_ in tree_with_answers_chosen_by_agent:  # Process each node in the copied tree
                        question_, topic_entities = get_question_after_resolving_references(node_, tree_with_answers_chosen_by_agent)

                        # Leaf Node -> Execute Child-Specific Actions
                        if len(node_["sons"]) == 0:
                            child_action = select_random_action(is_child=True)
                            if child_action == 0:  # CB
                                node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                child_answer = node_.get("cb_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():  # Fallback to OB
                                    child_action = 1
                                    node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                    child_answer = node_.get("ob_answer", [None])
                            elif child_action == 1:  # OB
                                node_["ob_answer"] = get_singlehop_ob_answer(question_, topic_entities, dataset_used)
                                child_answer = node_.get("ob_answer", [None])
                                if "unknown" in child_answer[0].lower().strip():  # Fallback to CB
                                    child_action = 0
                                    node_["cb_answer"] = get_cb_answer(question_, dataset_used)
                                    child_answer = node_.get("cb_answer", [None])
                        else:
                            node["child_answer"], node["answer"] = aggregate_multihop_answer(node, tree_with_answers_chosen_by_agent, dataset_used)
                            chosen_answer = node["child_answer"][0]
                    
                    if "unknown" in chosen_answer.lower().strip():  # Fallback to CB or OB
                        node["cb_answer"] = get_cb_answer(question, dataset_used)
                        chosen_answer = node.get("cb_answer", [None])[0]  # Try CB first
                        if "unknown" in chosen_answer.lower().strip():
                            fallback_action = 1
                            fallback_used = True
                            node["ob_answer"] = get_singlehop_ob_answer(question, [], dataset_used)
                            chosen_answer = node.get("ob_answer", [None])[0]  # Try OB next

                # Compute reward
                # if chosen_answer and normalize_answer(chosen_answer) == normalize_answer(gold_answer):
                if chosen_answer and are_answers_equivalent_using_llm(chosen_answer, gold_answer):
                    correct_answers += 1

                # Store results
                results.append({
                    "idx": node["idx"],
                    "question": node["question_text"],
                    "answer": chosen_answer,
                    "gold": gold_answer,
                    "method": action_space[action] if not fallback_used else action_space[fallback_action]
                })

    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    print(f"Evaluation Accuracy: {accuracy:.2f}%")
    return results, correct_answers, total_parent_nodes

# List of trained models to evaluate
configurations = [
    {"name": "Random"},
]

predictions_per_config = {}
correct_answers_per_config = {}
total_parent_nodes_per_config = {}
accuracy_per_config = {}


In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_random_trained_all/best_agent_{config_name}.pth"

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_random_agent(test_data_hotpotqa, q2gold_test_hotpotqa)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")

In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_random_trained_all/best_agent_{config_name}.pth"

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_random_agent(test_data_2wiki, q2gold_test_2wiki)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")

In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    model_path = f"saved_models_random_trained_all/best_agent_{config_name}.pth"

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_random_agent(test_data_musique, q2gold_test_musique)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")

# RL Approach with deep learning (using Question + cb + ob + RESAMPLING ACTION)

In [None]:
# Load the raw data and question decompositions
raw_data = [json.loads(line.strip()) for line in open('./hotpotqa__v2_dev_random_100.jsonl')]
raw_data_2wiki = [json.loads(line.strip()) for line in open('./2wikimultihopqa__v2_dev_random_100.jsonl')]
for item in raw_data_2wiki:
    item['dataset'] = '2wiki'
raw_data.extend(raw_data_2wiki)
raw_data_musique = [json.loads(line.strip()) for line in open('./musique_ans__v2_dev_random_100.jsonl')]
for item in raw_data_musique:
    item['dataset'] = 'musique'
raw_data.extend(raw_data_musique)

q2dq = json.load(open("./question_decompositions-devset-hotpotqa.json"))
q2dq_2wiki = json.load(open("./question_decompositions-devset-2wiki.json"))
q2dq_musique = json.load(open("./question_decompositions-devset-musique.json"))
q2dq.update(q2dq_2wiki)
q2dq.update(q2dq_musique)

# Create q2gold map
q2gold = {}
for item in raw_data:
    try:
        question = item['question_text'].strip()
        question = list(q2dq[question].keys())[0]
        gold = item['answers_objects'][0]['spans'][0]
        q_type = item['dataset']
        q2gold[question] = (gold, q_type)
    except Exception as e:
        # Skip if question not found in question_decompositions
        continue

# Load the data to analyze
with open('results-devset-hotpotqa.json', 'r') as file:
    data = json.load(file)

with open('results-devset-2wiki.json', 'r') as file:
    data_2wiki = json.load(file)
with open('results-devset-musique.json', 'r') as file:
    data_musique = json.load(file)

data.extend(data_2wiki)
data.extend(data_musique)

In [None]:
action_space = ['CB', 'OB', 'Child', 'ResampleTreeDecompositionThenSolve']  # Possible actions
success_counts = {"cb": 0, "ob": 0, "child": 0, "resampletreedecompositionthensolve": 0}
attempt_counts = {"cb": 0, "ob": 0, "child": 0, "resampletreedecompositionthensolve": 0}

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)
        self.init_weights()


    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)
    
    def init_weights(self):
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
        torch.nn.init.zeros_(self.fc1.bias)
        torch.nn.init.zeros_(self.fc2.bias)


def pad_or_truncate(logprobs, max_length=50, pad_value=-100):
    if len(logprobs) < max_length:
        # Pad with pad_value
        return logprobs + [pad_value] * (max_length - len(logprobs))
    else:
        # Truncate to max_length
        return logprobs[:max_length]


from sentence_transformers import SentenceTransformer, util
# Load a pre-trained sentence embedding model
model = SentenceTransformer('all-MiniLM-L6-v2')

# def get_reward(chosen_answer, gold_answer, action, num_llm_calls, alpha=1.0, beta=0.1):
#     """
#     Reward function with tradeoff between accuracy and efficiency.
#     :param chosen_answer: Answer chosen by the agent.
#     :param gold_answer: Ground truth answer.
#     :param action: The action selected by the agent (0=CB, 1=OB, 2=Child).
#     :param num_llm_calls: The number of LLM calls made during the current decision.
#     :param alpha: Weight for accuracy reward.
#     :param beta: Weight for efficiency penalty.
#     :return: A reward value that balances accuracy and efficiency.
#     """
#     # Compute Accuracy Reward
#     chosen_embedding = model.encode(chosen_answer, convert_to_tensor=True)
#     gold_embedding = model.encode(gold_answer, convert_to_tensor=True)
#     similarity = util.cos_sim(chosen_embedding, gold_embedding).item()  # Value between -1 (opposite) and 1 (exact match)
#     accuracy_reward = (similarity + 1) / 2  # Normalize to range [0, 1]

#     # Define LLM cost for each action
#     ACTION_COSTS = {
#         0: 1,  # CB cost
#         1: 1,  # OB cost (higher than CB because of retrieval)
#         2: 1,  # Child decomposition cost (higher due to multiple child evaluations)
#         3: 1,  # Tree Resampling cost
#     }

#     # Compute Efficiency Penalty
#     action_cost = ACTION_COSTS.get(action, 1)  # Default to 1 if action is unrecognized
#     efficiency_penalty = num_llm_calls * action_cost  # Penalize based on LLM usage

#     print(f"Action: {action}, Chosen Answer: {chosen_answer}, Gold Answer: {gold_answer}, accuracy_reward: {accuracy_reward:.2f}, Efficiency Penalty: {efficiency_penalty}")
#     # Combine Accuracy and Efficiency
#     reward = alpha * accuracy_reward - beta * efficiency_penalty

#     return reward

# use llm to get the reward
def get_reward(action, accuracy_reward, num_llm_calls, alpha=1.0, beta=0.1):

    # Define LLM cost for each action
    ACTION_COSTS = {
        0: 1,  # CB cost
        1: 1,  # OB cost (higher than CB because of retrieval)
        2: 1,  # Child decomposition cost (higher due to multiple child evaluations)
        3: 1,  # Tree Resampling cost
    }

    # Compute Efficiency Penalty
    action_cost = ACTION_COSTS.get(action, 1)  # Default to 1 if action is unrecognized
    efficiency_penalty = num_llm_calls * action_cost  # Penalize based on LLM usage

    # print(f"Action: {action}, accuracy_reward: {accuracy_reward:.2f}, Efficiency Penalty: {efficiency_penalty}")
    # Combine Accuracy and Efficiency
    reward = alpha * accuracy_reward - beta * efficiency_penalty

    return reward

def get_state(node, depth=0, max_length=50):
    # Extract raw features
    cb_logprob = node.get("cb_answer", [None, None, None, []])[3] or []
    ob_logprob = node.get("ob_answer", [None, None, None, []])[3] or []
    child_logprob = node.get("child_answer", [None, None, None, []])[3] or []

    # Pad or truncate logprobs each is of length 50
    cb_logprob = pad_or_truncate(cb_logprob, max_length)
    ob_logprob = pad_or_truncate(ob_logprob, max_length)
    child_logprob = pad_or_truncate(child_logprob, max_length)

    # Other features (7)
    has_children = 1 if len(node.get("sons", [])) else 0
    question_length = len(node.get("question_text", "").split())
    question_type = encode_question_type(node.get("question_text", ""))
    num_children = len(node.get("sons", []))
    cb_success_rate = get_success_rate("cb")
    ob_success_rate = get_success_rate("ob")
    child_success_rate = get_success_rate("child")

    # Semantic features
    # Each embedding is a 384-dimensional vector and they are total 4 = 384*4 = 1536
    question_text = node.get("question_text", "")
    question_embedding = model.encode(question_text, convert_to_tensor=False)
    cb_answer_embedding = model.encode(node.get("cb_answer", [""])[0], convert_to_tensor=False)
    ob_answer_embedding = model.encode(node.get("ob_answer", [""])[0], convert_to_tensor=False)
    child_answer_embedding = model.encode(node.get("child_answer", [""])[0], convert_to_tensor=False)

    # Confidence and uncertainty (total 3)
    cb_confidence = node.get("cb_answer", [None, None, None, []])[1] or 0.0
    ob_confidence = node.get("ob_answer", [None, None, None, []])[1] or 0.0
    child_confidence = node.get("child_answer", [None, None, None, []])[1] or 0.0

    # Structural features (total 2)
    tree_depth = depth
    tree_position = 0 if depth == 0 else 1  # 0 for root, 1 for intermediate/leaf

    # Temporal features (example: sliding window of last 3 actions) 
    # action_history = node.get("action_history", [0, 0, 0])  # Placeholder for action history
    # action_success_history = node.get("action_success_history", [0, 0, 0])  # Placeholder for success history

    # External knowledge features
    # num_retrieved_documents = node.get("num_retrieved_documents", 0)
    # entity_linking_confidence = node.get("entity_linking_confidence", 0.0)

    # Answer quality (total 3)
    cb_answer_length = len(node.get("cb_answer", [""])[0].split())
    ob_answer_length = len(node.get("ob_answer", [""])[0].split())
    child_answer_length = len(node.get("child_answer", [""])[0].split())

    # Build state vector
    state = (
        cb_logprob +  # CB log probabilities
        ob_logprob +  # OB log probabilities
        child_logprob +  # Child log probabilities
        # Discretized log probabilities
        [has_children, question_length, question_type, num_children, cb_success_rate, ob_success_rate, child_success_rate] +  # Basic features
        list(question_embedding) +  # Semantic embedding of the question
        list(cb_answer_embedding) +  # Semantic embedding of the CB answer
        list(ob_answer_embedding) +  # Semantic embedding of the OB answer
        list(child_answer_embedding) +  # Semantic embedding of the Child answer
        [cb_confidence, ob_confidence, child_confidence] +  # Confidence scores
        [tree_depth, tree_position] +  # Structural features
        [cb_answer_length, ob_answer_length, child_answer_length]  # Answer lengths
    )

    # maybe try question only by masking all and add log probs for the solved one and try adding verification step in the end.
    return torch.FloatTensor(state)

class DQNAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=1e-3, gamma=0.99):
        self.q_network = QNetwork(state_dim, action_dim, hidden_dim)
        self.target_network = QNetwork(state_dim, action_dim, hidden_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.replay_buffer = []

    def select_action(self, state, epsilon):
        if random.random() < epsilon:
            if state[150] == 0: # has children = 0, because child is not possible
                return random.choice([0, 1]) # CB or OB
            else:
                return random.choice([0, 1, 2, 3]) # CB, OB or Child or resample
        else:
            with torch.no_grad():
                q_values = self.q_network(state)
                # print(q_values)
                if state[150] == 0: # has children = 0
                    q_values[-1] = -float('inf')
                    q_values[-2] = -float('inf')
                return torch.argmax(q_values).item()

    def train(self, batch_size):
        if len(self.replay_buffer) < batch_size:
            return
        batch = random.sample(self.replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.stack(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.stack(next_states)
        dones = torch.FloatTensor(dones)

        # Compute Q-values for current states
        q_values = self.q_network(states).gather(1, actions.unsqueeze(1))

        # Compute target Q-values
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # update the success rates
        for i, action in enumerate(actions):
            if rewards[i] == 1:
                update_success_rate(action_space[action], True)
            else:
                update_success_rate(action_space[action], False)

        # Compute loss and update the network
        
        # loss = nn.MSELoss()(q_values.squeeze(), target_q_values)
        loss = nn.SmoothL1Loss()(q_values.squeeze(), target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

In [None]:
# state_dim = 150 + 7 + 1536 + 3 + 2 + 3 # Number of features in the state vector
# action_dim = len(action_space)  # Number of actions (CB, OB, Child)
# hidden_dim = 128  # Hidden layer size
# lr = 5e-3  # Increase from 1e-3  //     1e-3  # Learning rate
# gamma = 0.99  # Discount factor
# epsilon = 1.0  # Initial exploration rate
# epsilon_min = 0.01  # Minimum exploration rate
# epsilon_decay = 0.95  # Decay rate for exploration
# batch_size = 128  # Mini-batch size
# num_episodes = 30  # Number of training episodes

In [None]:
def child_aggregation(example, dataset_used):
    tree_with_answers_chosen_by_agent = copy.deepcopy(example)
    child_experiences = []
    chosen_answer = None
    for node_ in tree_with_answers_chosen_by_agent:
        question = node_["question_text"].strip()
        ref_tokens = re.findall(r"<\d+>", question)
        topic_entities = []
        for ref_token in ref_tokens:
            if "fa" in node_ and int(ref_token[1:-1]) <= len(tree_with_answers_chosen_by_agent[node_["fa"]]["sons"]):
                ref_idx = tree_with_answers_chosen_by_agent[node_["fa"]]["sons"][int(ref_token[1:-1])-1]
                if "answer" in tree_with_answers_chosen_by_agent[ref_idx]:
                    question = question.replace(ref_token, tree_with_answers_chosen_by_agent[ref_idx]["answer"][0])
                    topic_entities.append(tree_with_answers_chosen_by_agent[ref_idx]["answer"][0])
        node_["question"] = question

        if len(node_["sons"]) == 0:
            child_state = get_state(node_, depth=1)
            child_action = agent.select_action(child_state, epsilon)
            if child_action == 0:  
                node_["cb_answer"] = get_cb_answer(question, dataset_used)
                child_answer = node_.get("cb_answer", [None])
                if "unknown" in child_answer[0].lower().strip():
                    child_action = 1  
                    node_["ob_answer"] = get_singlehop_ob_answer(question, topic_entities, dataset_used)
                    child_answer = node_.get("ob_answer", [None])
            elif child_action == 1:  
                node_["ob_answer"] = get_singlehop_ob_answer(question, topic_entities, dataset_used)
                child_answer = node_.get("ob_answer", [None])
                if "unknown" in child_answer[0].lower().strip():
                    child_action = 0  
                    node_["cb_answer"] = get_cb_answer(question, dataset_used)
                    child_answer = node_.get("cb_answer", [None])
            else:
                print("Invalid action for child node", child_action)
                raise ValueError("Invalid action for child node")

            node_["answer"] = child_answer
            child_experiences.append((child_state, child_action, 0, child_state, False))  
        else:
            node_["child_answer"], node_["answer"] = aggregate_multihop_answer(node_, tree_with_answers_chosen_by_agent, dataset_used)
            chosen_answer = node_["child_answer"]

    return child_experiences, chosen_answer

In [None]:
import os
import torch
import copy

# Create directory for saved models
os.makedirs("saved_models_with_resampling_trained_all", exist_ok=True)

# Define different configurations for alpha and beta
configurations = [
    {"alpha": 2.0, "beta": 0.05, "name": "High_Accuracy"},
    {"alpha": 1.0, "beta": 0.1, "name": "Balanced"},
    {"alpha": 0.5, "beta": 0.2, "name": "Efficiency_Focused"},
]

# Dictionary to store results
accuracy_rewards_numcalls = {}

# Loop through different settings
for config in configurations:
    alpha = config["alpha"]
    beta = config["beta"]
    config_name = config["name"]
    # initialize the parameters
    state_dim = 150 + 7 + 1536 + 3 + 2 + 3 # Number of features in the state vector
    action_dim = len(action_space)  # Number of actions (CB, OB, Child)
    hidden_dim = 128  # Hidden layer size
    lr = 5e-3  # Increase from 1e-3  //     1e-3  # Learning rate
    gamma = 0.99  # Discount factor
    epsilon = 1.0  # Initial exploration rate
    epsilon_min = 0.01  # Minimum exploration rate
    epsilon_decay = 0.95  # Decay rate for exploration
    batch_size = 128  # Mini-batch size
    num_episodes = 50  # Number of training episodes
    
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)

    print(f"\nTraining with {config_name} (alpha={alpha}, beta={beta})")

    max_accuracy = 0
    best_agent = None

    # Store rewards, accuracy and total number of LLM calls for each episode
    rewards_list = []
    accuracy_list = []
    total_num_llm_calls_list = []

    # Set model to training mode
    agent.q_network.train()

    for episode in range(num_episodes):
        print(f"Episode {episode + 1}/{num_episodes}")
        total_reward = 0
        correct_answers = 0
        total_parent_nodes = 0
        total_num_llm_calls = 0

        for example in data:
            for node in example:
                if "fa" not in node:  # Only process root nodes (original questions)
                    gold_answer, dataset_used = q2gold[node["question_text"]]
                    total_parent_nodes += 1
                    state = get_state(node)  
                    action = agent.select_action(state, epsilon)
                    # print("action selected ", action)
                    done_resampling = False

                    # saving original node question because maybe after resampling the question is changed
                    original_node_question = node["question_text"]
                    if action == 3: # ResampleTreeDecompositionThenSolve
                        # root_question = node["question_text"].strip()
                        root_question = node["question_text"] # dont strip it, because we will search the original question
                        # print("example before resampling", example)
                        example = tree_resampling_pipeline.resample_tree(root_question)
                        # print("example after resampling", example)
                        # get the parent node again because maybe the tree is changed in term of the order
                        for node_ in example:
                            if "fa" not in node_:
                                node = node_
                                break
                        done_resampling = True
                        action = 2 # Set action now to Child to execute child aggregation
                    chosen_answer = None
                    fallback_used = False  
                    num_llm_calls = None

                    # Simulate the action
                    if action == 0:
                        node['cb_answer'] = get_cb_answer(node["question_text"].strip(), dataset_used)
                        chosen_answer = node.get("cb_answer", [None])
                        # chosen_answer = node.get("cb_answer", [None])
                        num_llm_calls = 1
                    elif action == 1:
                        # since you are the parent, so topic entities are empty because no references
                        node['ob_answer'] = get_singlehop_ob_answer(node["question_text"].strip(), [], dataset_used)
                        chosen_answer = node.get("ob_answer", [None])
                        num_llm_calls = 1
                    elif action == 2:
                        # Call child aggregation function
                        child_experiences, chosen_answer = child_aggregation(example, dataset_used)
                        num_llm_calls = len(child_experiences) + 1 + 1  # number of children + 1 (decomposition) + 1 (aggregation)

                    if done_resampling:
                        # if the original action was resampling, then reset the action to 3
                        # print("original action was resampling, so set action to 3")
                        action = 3
                        num_llm_calls += 1 # because of the resampling

                    # Get Q-values for the current state deep copy
                    q_values = agent.q_network(state).clone()
                    while (chosen_answer is None or "unknown" in chosen_answer[0].lower().strip()):
                        fallback_used = True    

                        # Mask the Q-value of the originally chosen action
                        q_values[action] = -float('inf')  # Ignore the originally chosen action

                        # if all q_values are -inf, then break
                        if torch.all(q_values == -float('inf')):
                            print("All actions are invalid, skipping fallback")
                            break

                        # Mask invalid actions if applicable
                        # For example, if state[150] == 0 (has_children == 0), disable actions 2 (Child) and 3 (Resample)
                        if state[150] == 0:
                            q_values[2] = -float('inf')  # Child aggregation
                            q_values[3] = -float('inf')  # ResampleTreeDecomposition

                        # Select the next best action
                        fallback_action = torch.argmax(q_values).item()
                        
                        # set q_values of the fallback action to -inf in order to not select it again
                        q_values[fallback_action] = -float('inf')  # Ignore the originally chosen action

                        # Execute the fallback action
                        if fallback_action == 0:  # CB
                            chosen_answer = get_cb_answer(node["question_text"].strip(), dataset_used)
                            fallback_llm_calls = 1
                        elif fallback_action == 1:  # OB
                            chosen_answer = get_singlehop_ob_answer(node["question_text"].strip(), [], dataset_used)
                            fallback_llm_calls = 1
                        elif fallback_action == 2:  # Child
                            # Call your tree decomposition/resampling logic if required
                            child_experiences, chosen_answer = child_aggregation(example, dataset_used)
                            fallback_llm_calls = len(child_experiences) + 1 + 1  # number of children + 1 (decomposition) + 1 (aggregation)
                        elif fallback_action == 3:  # Resample
                            # Call your tree resampling logic if required
                            example = tree_resampling_pipeline.resample_tree(node["question_text"].strip())
                            child_experiences, chosen_answer = child_aggregation(example, dataset_used)
                            fallback_llm_calls = len(child_experiences) + 1 + 1 + 1 # number of children + 1 (decomposition) + 1 (aggregation) + 1 (resampling)
                        
                        # print("Fallback used, ", "original action:", action , "fallback action:", fallback_action, "chosen answer:", chosen_answer)

                    # Compute reward
                    # gold_answer, _ = q2gold[original_node_question]
                    
                    equivalent = are_answers_equivalent_using_llm(gold_answer, chosen_answer[0])
                    if equivalent:
                        correct_answers += 1

                    reward = get_reward(action, equivalent, num_llm_calls, alpha, beta)
                    total_reward += reward

                    next_state = get_state(node)

                    if not fallback_used:
                        agent.replay_buffer.append((state, action, reward, next_state, False))  
                    else:
                        # if we used fallback, we need to add the fallback action to the replay buffer
                        fallback_reward = get_reward(fallback_action, equivalent, num_llm_calls + fallback_llm_calls, alpha, beta)
                        total_reward += reward
                        agent.replay_buffer.append((state, fallback_action, fallback_reward, next_state, False))

                    if action == 2 or action == 3:  # If the action was Child or ResampleTreeDecompositionThenSolve
                        num_children = len(node.get("sons", []))
                        child_reward = reward / num_children  
                        for child_state, child_action, _, next_child_state, done in child_experiences:
                            agent.replay_buffer.append((child_state, child_action, child_reward, next_child_state, done))

                    # update total number of llm calls in this episode
                    total_num_llm_calls += num_llm_calls

                    agent.train(batch_size)
        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        if episode % 10 == 0:
            agent.update_target_network()

        accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
        rewards_list.append(total_reward)
        accuracy_list.append(accuracy)
        total_num_llm_calls_list.append(total_num_llm_calls)
        if accuracy > max_accuracy:
            max_accuracy = accuracy
            best_agent = copy.deepcopy(agent)

        print(f"Total Reward: {total_reward}")
        print(f"Accuracy: {accuracy:.2f}%")
        print(f"Total LLM Calls in this episode: {total_num_llm_calls}")
        print(f"Epsilon: {epsilon:.4f}")
        print()

    # Store the rewards and accuracy for this configuration
    accuracy_rewards_numcalls[config_name] = {
        "rewards": rewards_list,
        "accuracy": accuracy_list,
        "num_calls": total_num_llm_calls_list
    }

    # Save model after training with each configuration
    model_path = f"saved_models_with_resampling_trained_all/agent_{config_name}.pth"
    torch.save(agent.q_network.state_dict(), model_path)
    print(f"Model saved to {model_path}")

    # Save the best model overall
    if best_agent is not None:
        best_model_path = f"saved_models_with_resampling_trained_all/best_agent_{config_name}.pth"
        torch.save(best_agent.q_network.state_dict(), best_model_path)
        print(f"Best model saved to {best_model_path}")


In [None]:
import matplotlib.pyplot as plt

# Get rewards, accuracy and number of LLM calls for each configuration
balanced_rewards = accuracy_rewards_numcalls["Balanced"]["rewards"]
balanced_accuracy = accuracy_rewards_numcalls["Balanced"]["accuracy"]
balanced_num_calls = accuracy_rewards_numcalls["Balanced"]["num_calls"]

efficiency_rewards = accuracy_rewards_numcalls["Efficiency_Focused"]["rewards"]
efficiency_accuracy = accuracy_rewards_numcalls["Efficiency_Focused"]["accuracy"]
efficiency_num_calls = accuracy_rewards_numcalls["Efficiency_Focused"]["num_calls"]

high_accuracy_accuracy = accuracy_rewards_numcalls["High_Accuracy"]["accuracy"]
high_accuracy_rewards = accuracy_rewards_numcalls["High_Accuracy"]["rewards"]
high_accuracy_num_calls = accuracy_rewards_numcalls["High_Accuracy"]["num_calls"]

# min max scaling for rewards
balanced_rewards = (balanced_rewards - np.min(balanced_rewards)) / (np.max(balanced_rewards) - np.min(balanced_rewards))
efficiency_rewards = (efficiency_rewards - np.min(efficiency_rewards)) / (np.max(efficiency_rewards) - np.min(efficiency_rewards))
high_accuracy_rewards = (high_accuracy_rewards - np.min(high_accuracy_rewards)) / (np.max(high_accuracy_rewards) - np.min(high_accuracy_rewards))

# Plotting
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(balanced_rewards, label="Balanced")
plt.plot(efficiency_rewards, label="Efficiency Focused")
plt.plot(high_accuracy_rewards, label="High Accuracy")
plt.title("Normalized Rewards")
plt.xlabel("Episode")
plt.ylabel("Normalized Reward")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(balanced_accuracy, label="Balanced")
plt.plot(efficiency_accuracy, label="Efficiency Focused")
plt.plot(high_accuracy_accuracy, label="High Accuracy")
plt.title("Accuracy")
plt.xlabel("Episode")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 6))
plt.plot(balanced_num_calls, label="Balanced")
plt.plot(efficiency_num_calls, label="Efficiency Focused")
plt.plot(high_accuracy_num_calls, label="High Accuracy")
plt.title("Number of LLM Calls")
plt.xlabel("Episode")
plt.ylabel("Total LLM Calls")
plt.legend()
plt.show()

In [None]:
# Print number of calls for each configuration
balanced_num_calls = accuracy_rewards_numcalls["Balanced"]["num_calls"][-1]
efficiency_num_calls = accuracy_rewards_numcalls["Efficiency_Focused"]["num_calls"][-1]
high_accuracy_num_calls = accuracy_rewards_numcalls["High_Accuracy"]["num_calls"][-1]

print("balanced_num_calls", balanced_num_calls)
print("efficiency_num_calls", efficiency_num_calls)
print("high_accuracy_num_calls", high_accuracy_num_calls)

## Try it on test set data of hotpot qa 

In [None]:
# Evaluate the agent on the test data
def evaluate_agent(agent, data, q2gold):
    agent.q_network.eval()  # Set the model to evaluation mode
    correct_answers = 0
    total_parent_nodes = 0
    results = []

    for example in data:
        for node in example:
            if "fa" not in node:  # Only process root nodes (original questions)
                gold_answer, dataset_used = q2gold[node["question_text"]]
                total_parent_nodes += 1
                state = get_state(node)
                action = agent.select_action(state, epsilon=0)
                # print("action selected ", action)
                done_resampling = False

                # saving original node question because maybe after resampling the question is changed
                original_node_question = node["question_text"]
                if action == 3: # ResampleTreeDecompositionThenSolve
                    root_question = node["question_text"]
                    # print("example before resampling", example)
                    example = tree_resampling_pipeline.resample_tree(root_question)
                    # print("example after resampling", example)
                    # get the parent node again because maybe the tree is changed in term of the order
                    for node_ in example:
                        if "fa" not in node_:
                            node = node_
                            break
                    done_resampling = True
                    action = 2 # Set action now to Child to execute child aggregation
                chosen_answer = None
                fallback_used = False
                fallback_action = None

                # Simulate the action
                if action == 0:
                    node['cb_answer'] = get_cb_answer(node["question_text"].strip(), dataset_used)
                    chosen_answer = node.get("cb_answer", [None])
                    # chosen_answer = node.get("cb_answer", [None])
                elif action == 1:
                    # since you are the parent, so topic entities are empty because no references
                    node['ob_answer'] = get_singlehop_ob_answer(node["question_text"].strip(), [], dataset_used)
                    chosen_answer = node.get("ob_answer", [None])
                elif action == 2:
                    # Call child aggregation function
                    _, chosen_answer = child_aggregation(example, dataset_used)
                    
                if done_resampling:
                    # if the original action was resampling, then reset the action to 3
                    # print("original action was resampling, so set action to 3")
                    action = 3 
                
                # Get Q-values for the current state deep copy
                q_values = agent.q_network(state).clone()
                while (chosen_answer is None or "unknown" in chosen_answer[0].lower().strip()):
                    fallback_used = True    

                    # Mask the Q-value of the originally chosen action
                    q_values[action] = -float('inf')  # Ignore the originally chosen action

                    # if all q_values are -inf, then break
                    if torch.all(q_values == -float('inf')):
                        print("All actions are invalid, skipping fallback")
                        break

                    # Mask invalid actions if applicable
                    # For example, if state[150] == 0 (has_children == 0), disable actions 2 (Child) and 3 (Resample)
                    if state[150] == 0:
                        q_values[2] = -float('inf')  # Child aggregation
                        q_values[3] = -float('inf')  # ResampleTreeDecomposition

                    # Select the next best action
                    fallback_action = torch.argmax(q_values).item()
                    
                    # set q_values of the fallback action to -inf in order to not select it again
                    q_values[fallback_action] = -float('inf')  # Ignore the originally chosen action

                    # Execute the fallback action
                    if fallback_action == 0:  # CB
                        chosen_answer = get_cb_answer(node["question_text"].strip(), dataset_used)
                    elif fallback_action == 1:  # OB
                        chosen_answer = get_singlehop_ob_answer(node["question_text"].strip(), [], dataset_used)
                    elif fallback_action == 2:  # Child
                        # Call your tree decomposition/resampling logic if required
                        _, chosen_answer = child_aggregation(example, dataset_used)
                    elif fallback_action == 3:  # Resample
                        # Call your tree resampling logic if required
                        example = tree_resampling_pipeline.resample_tree(node["question_text"].strip())
                        _, chosen_answer = child_aggregation(example, dataset_used)
                    
                    # print("Fallback used, ", "original action:", action , "fallback action:", fallback_action, "chosen answer:", chosen_answer)

                # Compute reward
                # gold_answer, _ = q2gold[original_node_question]
                # if normalize_answer(chosen_answer[0]) == normalize_answer(gold_answer):
                if chosen_answer and are_answers_equivalent_using_llm(gold_answer, chosen_answer[0]):
                    correct_answers += 1

                # Store results
                results.append({
                    "idx": node["idx"],
                    "question": original_node_question,
                    "answer": chosen_answer[0],
                    "gold": gold_answer,
                    "tree-decomposition": example,
                    "method": action_space[action] if not fallback_used else action_space[fallback_action]
                })

    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    print(f"Evaluation Accuracy: {accuracy:.2f}%")
    return results, correct_answers, total_parent_nodes


# List of trained models to evaluate
configurations = [
    {"alpha": 2.0, "beta": 0.05, "name": "High_Accuracy"},
    {"alpha": 1.0, "beta": 0.1, "name": "Balanced"},
    {"alpha": 0.5, "beta": 0.2, "name": "Efficiency_Focused"},
]

predictions_per_config = {}
correct_answers_per_config = {}
total_parent_nodes_per_config = {}
accuracy_per_config = {}

In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    agent_model_path = f"saved_models_with_resampling_trained_all/agent_{config_name}.pth"
    best_model_path = f"saved_models_with_resampling_trained_all/best_agent_{config_name}.pth"

    # Load the trained model agent and best agent
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(agent_model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    best_agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    best_agent.q_network.load_state_dict(torch.load(best_model_path))
    best_agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(best_agent, test_data_hotpotqa, q2gold_test_hotpotqa)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    agent_model_path = f"saved_models_with_resampling_trained_all/agent_{config_name}.pth"
    best_model_path = f"saved_models_with_resampling_trained_all/best_agent_{config_name}.pth"

    # Load the trained model agent and best agent
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(agent_model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    best_agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    best_agent.q_network.load_state_dict(torch.load(best_model_path))
    best_agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(best_agent, test_data_2wiki, q2gold_test_2wiki)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")


In [None]:

# Evaluate each model
for config in configurations:
    config_name = config["name"]
    agent_model_path = f"saved_models_with_resampling_trained_all/agent_{config_name}.pth"
    best_model_path = f"saved_models_with_resampling_trained_all/best_agent_{config_name}.pth"

    # Load the trained model agent and best agent
    agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    agent.q_network.load_state_dict(torch.load(agent_model_path))
    agent.q_network.eval()  # Set model to evaluation mode

    best_agent = DQNAgent(state_dim, action_dim, hidden_dim, lr, gamma)
    best_agent.q_network.load_state_dict(torch.load(best_model_path))
    best_agent.q_network.eval()  # Set model to evaluation mode

    print(f"\nEvaluating model: {config_name}")
    
    # Run evaluation
    rl_predictions, correct_answers, total_parent_nodes = evaluate_agent(best_agent, test_data_musique, q2gold_test_musique)
    predictions_per_config[config_name] = rl_predictions
    correct_answers_per_config[config_name] = correct_answers
    total_parent_nodes_per_config[config_name] = total_parent_nodes

    # Print final accuracy
    accuracy = (correct_answers / total_parent_nodes) * 100 if total_parent_nodes > 0 else 0
    accuracy_per_config[config_name] = accuracy
    print(f"Final Accuracy for {config_name}: {accuracy:.2f}%\n")
