In [1]:
from environment import root_dir
from constants import get_result_dir, supported_llms
import pandas as pd
import os
import numpy as np

In [2]:
replicate_ids = [0, 1]
model_name = "xai/grok-3-mini-beta"
# model_name = "openrouter-microsoft/phi-4-reasoning-plus"
dataset_name = "zebralogic"
n_samples = 100
# data_mode = "instructiond"
data_mode = "default"

result_dirs = []
for replicate_id in replicate_ids:
    result_dir = get_result_dir(
        dataset_name,
        model_name,
        shot = 0,
        template_type = supported_llms[model_name]["template_type"],
        response_length = 404,
        num_samples = 100,
        feature_noise = None,
        label_noise = 0.0,
        data_mode = data_mode,
        n_query = 1,
        temperature = 1.0,
        replicate_id = replicate_id,
    )
    result_dirs.append(result_dir)

In [None]:


result_dfs = []
for result_dir in result_dirs:
    result_df = pd.read_csv(os.path.join(result_dir, "tree_vis_google/gemini-2.5-pro-preview-03-25", "metric_df.csv"))
    result_df.rename(columns={result_df.columns[0]: "index"}, inplace=True)
    result_dfs.append(result_df)

corrs_list = []
filtered_ajd_list = []
forgetting_rates_list = []
average_solution_count_list = []
success_rates_list = []
overthinking_rates_list = []
average_verification_rates_list = []

for i in range(n_samples):
    selected_replicate_id = None
    selected_filtered_ajd = 100
    
    for idx, df in enumerate(result_dfs):
        filtered_ajd_value = df[df.index == i]["filtered_ajd"]
        if filtered_ajd_value.empty: continue
        if filtered_ajd_value.values[0] < selected_filtered_ajd:
            selected_replicate_id = idx
            selected_filtered_ajd = filtered_ajd_value.values[0]
    
    if selected_replicate_id is None: continue
    selected_result_df = result_dfs[selected_replicate_id]
    
    corrs_list.append(selected_result_df[selected_result_df.index == i]["corrs"].values[0])
    filtered_ajd_list.append(selected_result_df[selected_result_df.index == i]["filtered_ajd"].values[0])
    forgetting_rates_list.append(selected_result_df[selected_result_df.index == i]["forgetting_rates"].values[0])
    average_solution_count_list.append(selected_result_df[selected_result_df.index == i]["average_solution_count"].values[0])
    success_rates_list.append(selected_result_df[selected_result_df.index == i]["success_rates"].values[0])
    overthinking_rates_list.append(selected_result_df[selected_result_df.index == i]["overthinking_rates"].values[0])
    average_verification_rates_list.append(selected_result_df[selected_result_df.index == i]["average_verification_rates"].values[0])
    
print("ensemble result:")
print("accuracy: ", f"{np.mean(corrs_list):.3f}", "±", f"{np.std(corrs_list):.3f}")
print("filtered_ajd: ", f"{np.mean(filtered_ajd_list):.3f}", "±", f"{np.std(filtered_ajd_list):.3f}")
print("forgetting_rates: ", f"{np.mean(forgetting_rates_list):.3f}", "±", f"{np.std(forgetting_rates_list):.3f}")
print("average_solution_count: ", f"{np.mean(average_solution_count_list):.3f}", "±", f"{np.std(average_solution_count_list):.3f}")
print("success_rates: ", f"{np.mean(success_rates_list):.3f}", "±", f"{np.std(success_rates_list):.3f}")
print("overthinking_rates: ", f"{np.mean(overthinking_rates_list):.3f}", "±", f"{np.std(overthinking_rates_list):.3f}")
print("average_verification_rates: ", f"{np.mean(average_verification_rates_list):.3f}", "±", f"{np.std(average_verification_rates_list):.3f}")


ensemble result:
accuracy:  0.290 ± 0.454
filtered_ajd:  17.805 ± 10.242
forgetting_rates:  0.000 ± 0.000
average_solution_count:  3.800 ± 2.098
success_rates:  0.174 ± 0.244
overthinking_rates:  0.005 ± 0.050
average_verification_rates:  0.065 ± 0.042


In [4]:
result_dfs = []
for result_dir in result_dirs:
    result_df = pd.read_csv(os.path.join(result_dir, "tree_vis_google/gemini-2.5-pro-preview-03-25", "metric_df.csv"))
    result_df.rename(columns={result_df.columns[0]: "index"}, inplace=True)
    result_dfs.append(result_df)

# Initialize lists to store metrics for each sample
sample_corrs_list = []
sample_filtered_ajd_list = []
sample_forgetting_rates_list = []
sample_average_solution_count_list = []
sample_success_rates_list = []
sample_overthinking_rates_list = []
sample_average_verification_rates_list = []

# Loop through each sample index
for i in range(n_samples):
    sample_corrs = 0
    sample_filtered_ajd = 0
    sample_forgetting_rates = 0
    sample_average_solution_count = 0
    sample_success_rates = 0
    sample_overthinking_rates = 0
    sample_average_verification_rates = 0
    valid_replicates = 0
    
    # Collect data from each replicate (result_df) for the current sample
    for df in result_dfs:
        if i in df['index'].values:
            valid_replicates += 1
            sample_corrs += df[df['index'] == i]["corrs"].values[0]
            sample_filtered_ajd += df[df['index'] == i]["filtered_ajd"].values[0]
            sample_forgetting_rates += df[df['index'] == i]["forgetting_rates"].values[0]
            sample_average_solution_count += df[df['index'] == i]["average_solution_count"].values[0]
            sample_success_rates += df[df['index'] == i]["success_rates"].values[0]
            sample_overthinking_rates += df[df['index'] == i]["overthinking_rates"].values[0]
            sample_average_verification_rates += df[df['index'] == i]["average_verification_rates"].values[0]
    
    # If there are valid replicates for this sample, compute the average and add to lists
    if valid_replicates > 0:
        sample_corrs_list.append(sample_corrs / valid_replicates)
        sample_filtered_ajd_list.append(sample_filtered_ajd / valid_replicates)
        sample_forgetting_rates_list.append(sample_forgetting_rates / valid_replicates)
        sample_average_solution_count_list.append(sample_average_solution_count / valid_replicates)
        sample_success_rates_list.append(sample_success_rates / valid_replicates)
        sample_overthinking_rates_list.append(sample_overthinking_rates / valid_replicates)
        sample_average_verification_rates_list.append(sample_average_verification_rates / valid_replicates)

print("Average with std")
# Compute and print the overall averages and standard deviations across all samples
if len(sample_corrs_list) > 0:
    print("accuracy: ", f"{np.mean(sample_corrs_list):.3f}", "±", f"{np.std(sample_corrs_list):.3f}")
    print("filtered_ajd: ", f"{np.mean(sample_filtered_ajd_list):.3f}", "±", f"{np.std(sample_filtered_ajd_list):.3f}")
    print("forgetting_rates: ", f"{np.mean(sample_forgetting_rates_list):.3f}", "±", f"{np.std(sample_forgetting_rates_list):.3f}")
    print("average_solution_count: ", f"{np.mean(sample_average_solution_count_list):.3f}", "±", f"{np.std(sample_average_solution_count_list):.3f}")
    print("success_rates: ", f"{np.mean(sample_success_rates_list):.3f}", "±", f"{np.std(sample_success_rates_list):.3f}")
    print("overthinking_rates: ", f"{np.mean(sample_overthinking_rates_list):.3f}", "±", f"{np.std(sample_overthinking_rates_list):.3f}")
    print("average_verification_rates: ", f"{np.mean(sample_average_verification_rates_list):.3f}", "±", f"{np.std(sample_average_verification_rates_list):.3f}")
else:
    print("No valid samples found for averaging.")


Average with std
accuracy:  0.330 ± 0.318
filtered_ajd:  11.679 ± 7.302
forgetting_rates:  0.000 ± 0.000
average_solution_count:  3.745 ± 2.322
success_rates:  0.322 ± 0.271
overthinking_rates:  0.003 ± 0.025
average_verification_rates:  0.059 ± 0.034


In [5]:
result_dfs = []
for result_dir in result_dirs:
    result_df = pd.read_csv(os.path.join(result_dir, "tree_vis_google/gemini-2.5-pro-preview-03-25", "metric_df.csv"))
    result_df.rename(columns={result_df.columns[0]: "index"}, inplace=True)
    result_dfs.append(result_df)

# Initialize lists to store metrics for each sample
sample_corrs_list = []
sample_filtered_ajd_list = []
sample_forgetting_rates_list = []
sample_average_solution_count_list = []
sample_success_rates_list = []
sample_overthinking_rates_list = []
sample_average_verification_rates_list = []

# Initialize lists for majority vote ensemble
majority_vote_corrs_list = []
majority_vote_filtered_ajd_list = []
majority_vote_forgetting_rates_list = []
majority_vote_average_solution_count_list = []
majority_vote_success_rates_list = []
majority_vote_overthinking_rates_list = []
majority_vote_average_verification_rates_list = []

# Loop through each sample index
for i in range(n_samples):
    sample_corrs = 0
    sample_filtered_ajd = 0
    sample_forgetting_rates = 0
    sample_average_solution_count = 0
    sample_success_rates = 0
    sample_overthinking_rates = 0
    sample_average_verification_rates = 0
    valid_replicates = 0
    
    # For majority vote ensemble
    answers_for_sample = []
    corrs_for_sample = []
    filtered_ajd_for_sample = []
    forgetting_rates_for_sample = []
    average_solution_count_for_sample = []
    success_rates_for_sample = []
    overthinking_rates_for_sample = []
    average_verification_rates_for_sample = []
    
    # Collect data from each replicate (result_df) for the current sample
    for df in result_dfs:
        if i in df['index'].values:
            valid_replicates += 1
            sample_corrs += df[df['index'] == i]["corrs"].values[0]
            sample_filtered_ajd += df[df['index'] == i]["filtered_ajd"].values[0]
            sample_forgetting_rates += df[df['index'] == i]["forgetting_rates"].values[0]
            sample_average_solution_count += df[df['index'] == i]["average_solution_count"].values[0]
            sample_success_rates += df[df['index'] == i]["success_rates"].values[0]
            sample_overthinking_rates += df[df['index'] == i]["overthinking_rates"].values[0]
            sample_average_verification_rates += df[df['index'] == i]["average_verification_rates"].values[0]
            
            # Collect answers and all metrics for majority vote
            answer = df[df['index'] == i]["answers"].values[0]
            corr = df[df['index'] == i]["corrs"].values[0]
            filtered_ajd = df[df['index'] == i]["filtered_ajd"].values[0]
            forgetting_rates = df[df['index'] == i]["forgetting_rates"].values[0]
            average_solution_count = df[df['index'] == i]["average_solution_count"].values[0]
            success_rates = df[df['index'] == i]["success_rates"].values[0]
            overthinking_rates = df[df['index'] == i]["overthinking_rates"].values[0]
            average_verification_rates = df[df['index'] == i]["average_verification_rates"].values[0]
            
            answers_for_sample.append(answer)
            corrs_for_sample.append(corr)
            filtered_ajd_for_sample.append(filtered_ajd)
            forgetting_rates_for_sample.append(forgetting_rates)
            average_solution_count_for_sample.append(average_solution_count)
            success_rates_for_sample.append(success_rates)
            overthinking_rates_for_sample.append(overthinking_rates)
            average_verification_rates_for_sample.append(average_verification_rates)
    
    # If there are valid replicates for this sample, compute the average and add to lists
    if valid_replicates > 0:
        sample_corrs_list.append(sample_corrs / valid_replicates)
        sample_filtered_ajd_list.append(sample_filtered_ajd / valid_replicates)
        sample_forgetting_rates_list.append(sample_forgetting_rates / valid_replicates)
        sample_average_solution_count_list.append(sample_average_solution_count / valid_replicates)
        sample_success_rates_list.append(sample_success_rates / valid_replicates)
        sample_overthinking_rates_list.append(sample_overthinking_rates / valid_replicates)
        sample_average_verification_rates_list.append(sample_average_verification_rates / valid_replicates)
        
        # Majority vote ensemble
        from collections import Counter
        answer_counts = Counter(answers_for_sample)
        most_common_answer, max_count = answer_counts.most_common(1)[0]
        
        # If there are at least 2 of the same answer, use that answer's metrics
        if max_count >= 2:
            # Find the metrics corresponding to the most common answer
            for j, answer in enumerate(answers_for_sample):
                if answer == most_common_answer:
                    majority_vote_corrs_list.append(corrs_for_sample[j])
                    majority_vote_filtered_ajd_list.append(filtered_ajd_for_sample[j])
                    majority_vote_forgetting_rates_list.append(forgetting_rates_for_sample[j])
                    majority_vote_average_solution_count_list.append(average_solution_count_for_sample[j])
                    majority_vote_success_rates_list.append(success_rates_for_sample[j])
                    majority_vote_overthinking_rates_list.append(overthinking_rates_for_sample[j])
                    majority_vote_average_verification_rates_list.append(average_verification_rates_for_sample[j])
                    break
        else:
            # If no majority, use the first answer's metrics
            majority_vote_corrs_list.append(corrs_for_sample[0])
            majority_vote_filtered_ajd_list.append(filtered_ajd_for_sample[0])
            majority_vote_forgetting_rates_list.append(forgetting_rates_for_sample[0])
            majority_vote_average_solution_count_list.append(average_solution_count_for_sample[0])
            majority_vote_success_rates_list.append(success_rates_for_sample[0])
            majority_vote_overthinking_rates_list.append(overthinking_rates_for_sample[0])
            majority_vote_average_verification_rates_list.append(average_verification_rates_for_sample[0])

print("Average with std")
# Compute and print the overall averages and standard deviations across all samples
if len(sample_corrs_list) > 0:
    print("accuracy: ", f"{np.mean(sample_corrs_list):.3f}", "±", f"{np.std(sample_corrs_list):.3f}")
    print("filtered_ajd: ", f"{np.mean(sample_filtered_ajd_list):.3f}", "±", f"{np.std(sample_filtered_ajd_list):.3f}")
    print("forgetting_rates: ", f"{np.mean(sample_forgetting_rates_list):.3f}", "±", f"{np.std(sample_forgetting_rates_list):.3f}")
    print("average_solution_count: ", f"{np.mean(sample_average_solution_count_list):.3f}", "±", f"{np.std(sample_average_solution_count_list):.3f}")
    print("success_rates: ", f"{np.mean(sample_success_rates_list):.3f}", "±", f"{np.std(sample_success_rates_list):.3f}")
    print("overthinking_rates: ", f"{np.mean(sample_overthinking_rates_list):.3f}", "±", f"{np.std(sample_overthinking_rates_list):.3f}")
    print("average_verification_rates: ", f"{np.mean(sample_average_verification_rates_list):.3f}", "±", f"{np.std(sample_average_verification_rates_list):.3f}")
else:
    print("No valid samples found for averaging.")

print("\nMajority Vote Ensemble")
if len(majority_vote_corrs_list) > 0:
    print("accuracy: ", f"{np.mean(majority_vote_corrs_list):.3f}", "±", f"{np.std(majority_vote_corrs_list):.3f}")
    print("filtered_ajd: ", f"{np.mean(majority_vote_filtered_ajd_list):.3f}", "±", f"{np.std(majority_vote_filtered_ajd_list):.3f}")
    print("forgetting_rates: ", f"{np.mean(majority_vote_forgetting_rates_list):.3f}", "±", f"{np.std(majority_vote_forgetting_rates_list):.3f}")
    print("average_solution_count: ", f"{np.mean(majority_vote_average_solution_count_list):.3f}", "±", f"{np.std(majority_vote_average_solution_count_list):.3f}")
    print("success_rates: ", f"{np.mean(majority_vote_success_rates_list):.3f}", "±", f"{np.std(majority_vote_success_rates_list):.3f}")
    print("overthinking_rates: ", f"{np.mean(majority_vote_overthinking_rates_list):.3f}", "±", f"{np.std(majority_vote_overthinking_rates_list):.3f}")
    print("average_verification_rates: ", f"{np.mean(majority_vote_average_verification_rates_list):.3f}", "±", f"{np.std(majority_vote_average_verification_rates_list):.3f}")
else:
    print("No valid samples found for majority vote ensemble.")


Average with std
accuracy:  0.330 ± 0.318
filtered_ajd:  11.679 ± 7.302
forgetting_rates:  0.000 ± 0.000
average_solution_count:  3.745 ± 2.322
success_rates:  0.322 ± 0.271
overthinking_rates:  0.003 ± 0.025
average_verification_rates:  0.059 ± 0.034

Majority Vote Ensemble
accuracy:  0.320 ± 0.466
filtered_ajd:  12.727 ± 10.949
forgetting_rates:  0.000 ± 0.000
average_solution_count:  3.510 ± 2.820
success_rates:  0.308 ± 0.380
overthinking_rates:  0.005 ± 0.050
average_verification_rates:  0.056 ± 0.044
