In [1]:
import json
from prettytable import PrettyTable
import json
import os
import random
import numpy as np
import pandas as pd

# Read the jsonl file and convert it to a JSON list
def jsonl_to_json_list(jsonl_file_path):
    json_list = []
    with open(jsonl_file_path, 'r', encoding='utf-8') as file:
        for line in file:
            json_obj = json.loads(line.strip())  # Parse each line as JSON
            json_list.append(json_obj)
    
    return json_list

# Save the JSON list to a file
def save_as_json(json_list, output_file_path):
    with open(output_file_path, 'w', encoding='utf-8') as outfile:
        json.dump(json_list, outfile, indent=4)

def save_as_jsonl(json_list, output_file_path):
    with open(output_file_path, 'w', encoding='utf-8') as outfile:
        for json_obj in json_list:
            json.dump(json_obj, outfile)
            outfile.write('\n')

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
    return data

def load_jsonl(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = [json.loads(line.strip()) for line in file]
    return data

def deduplicate_data(data):
    seen = set()
    deduplicated_data = []
    for item in data:
        idx = item['realidx']
        if idx not in seen:
            deduplicated_data.append(item)
            seen.add(idx)
    return deduplicated_data

def calculate_accuracy(data):
    correct_predictions = 0
    total_predictions = len(data)
    for item in data:
        if 'predicted_answer' not in item:
            print(item['realidx'])
        if item['answer_idx'] == item['predicted_answer']:
            correct_predictions += 1
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    return accuracy

def calculate_cost_from_token_usage(data, model):
    total_cost = 0
    for item in data:
        if 'cost' in item:
            total_cost += item['cost']
        elif model == 'gpt-4o-mini':
            total_cost += item['token_usage']['prompt_tokens'] * 0.15 / 1000000 + item['token_usage']['completion_tokens'] * 0.6 / 1000000
        elif model == 'gpt-4o':
            total_cost += item['token_usage']['prompt_tokens'] * 2.5 / 1000000 + item['token_usage']['completion_tokens'] * 10 / 1000000
        elif model == 'o3-mini' or model == 'o1-mini':
            total_cost += item['token_usage']['prompt_tokens'] * 1.1 / 1000000 + item['token_usage']['completion_tokens'] * 4.4 / 1000000
        elif model == 'claude-3-5-sonnet':
            total_cost += item['token_usage']['prompt_tokens'] * 3.0 / 1000000 + item['token_usage']['completion_tokens'] * 15.0 / 1000000
        elif model == 'claude-3-5-haiku':
            total_cost += item['token_usage']['prompt_tokens'] * 0.8 / 1000000 + item['token_usage']['completion_tokens'] * 4.0 / 1000000
        elif model == 'QwQ-32B-Preview':
            total_cost += item['token_usage']['prompt_tokens'] * 1.2 / 1000000 + item['token_usage']['completion_tokens'] * 1.2 / 1000000
        elif model == 'DeepSeek-R1':
            total_cost += item['token_usage']['prompt_tokens'] * 7 / 1000000 + item['token_usage']['completion_tokens'] * 7 / 1000000
        elif model == 'DeepSeek-V3':
            total_cost += item['token_usage']['prompt_tokens'] * 1.25 / 1000000 + item['token_usage']['completion_tokens'] * 1.25 / 1000000
        elif model == 'Llama-3.3-70B-Instruct-Turbo':
            total_cost += item['token_usage']['prompt_tokens'] * 0.88 / 1000000 + item['token_usage']['completion_tokens'] * 0.88 / 1000000
    return total_cost / len(data)

def calculate_time_from_data(data):
    total_time = 0
    for item in data:
        total_time += item['time_elapsed']
    return total_time / len(data)

In [5]:
tasks = ['medqa', 'pubmedqa', 'medmcqa', 'medbullets', 'mmlu', 'mmlu-pro','medexqa', 'medxpertqa-r', 'medxpertqa-u']
task_map = {
    'medqa': 'MedQA',
    'pubmedqa': 'PubMedQA',
    'medmcqa': 'MedMCQA',
    'medbullets': 'MedBullets',
    'mmlu': 'MMLU',
    'mmlu-pro': 'MMLU-Pro',
    'afrimedqa': 'AfriMedQA',
    'medexqa': 'MedExQA',
    'medxpertqa-r': 'MedXpertQA-R',
    'medxpertqa-u': 'MedXpertQA-U',
}
models = [
    'gpt-4o-mini',
    'gpt-4o',
    'o3-mini',
]
model_map = {
    'gpt-4o-mini': 'GPT-4o-mini',
    'gpt-4o': 'GPT-4o',
    'o3-mini': 'o3-mini',
}
method_map = {
    'zero_shot': 'Zero-shot',
    'cot': 'CoT',
    'cot_sc-5': 'CoT-SC',
    'multipersona-2': 'MultiPersona',
    'self_refine-3': 'Self-Refine',
    'medprompt-3': 'MedPrompt',
    'medagents': 'MedAgents',
    'mdagents': 'MDAgents',
    'spo': 'SPO',
    'aflow': 'AFlow'
}
methods = ['zero_shot', 'cot', 'cot_sc-5', 'multipersona-2', 'self_refine-3', 'medprompt-3', 'medagents', 'mdagents', 'spo', 'aflow']

# Step 1: Build a table of results (each cell: (mean, std, formatted_string))
table_data = []  # Each row corresponds to a method; each row is a list of tuples for each (task, model) cell.
all_accs_by_method_model = {}  # For overall std calculation

for method in methods:
    row = []
    for task in tasks:
        for model in models:
            accs = []
            for run in range(3):
                try:
                    file_path = f"../output/run-{run}/{task}/{model}-{task}-test_hard-{method}.json"
                    data = load_json(file_path)
                    dedup_data = deduplicate_data(data)
                    acc = calculate_accuracy(dedup_data) * 100  # Convert to percentage.
                    accs.append(acc)
                except Exception as e:
                    pass
            if len(accs) > 0:
                mean_acc = np.mean(accs)
                std_acc = np.std(accs, ddof=1) if len(accs) > 1 else 0.0
                cell_value = (mean_acc, std_acc)
                cell_str = f"{mean_acc:.1f} $\\pm$ {std_acc:.1f}"
            else:
                cell_value = (None, None)
                cell_str = "N/A"
            row.append((cell_value, cell_str))
    table_data.append(row)

num_tasks = len(tasks)
num_models = len(models)
num_cols = num_tasks * num_models

# Step 2: Determine the best and second best values and the min/max for heatmap normalization per dataset group.
group_best = [None] * num_tasks
group_second = [None] * num_tasks
group_min = [None] * num_tasks
group_max = [None] * num_tasks

for i in range(num_tasks):
    group_values = []
    for row in table_data:
        for j in range(i * num_models, (i + 1) * num_models):
            val = row[j][0][0]  # mean
            if val is not None:
                group_values.append(val)
    if group_values:
        best_val = max(group_values)
        group_best[i] = best_val
        lower_values = [v for v in group_values if v < best_val]
        group_second[i] = max(lower_values) if lower_values else None
        group_min[i] = min(group_values)
        group_max[i] = max(group_values)
    else:
        group_best[i] = None
        group_second[i] = None
        group_min[i] = None
        group_max[i] = None

# Calculate average performance for each method and model (mean and std across 3 runs)
# For each method and model, collect the per-run averages across all tasks, then compute mean and std across 3 runs

for row_idx, method in enumerate(methods):
    vals_by_model = {}
    per_run_avgs_by_model = {}
    for model in models:
        per_run_avgs_by_model[model] = []

    # For each run, collect the average accuracy across all tasks for this method and model
    for model_idx, model in enumerate(models):
        for run in range(3):
            run_accs = []
            for task_idx, task in enumerate(tasks):
                cell_idx = task_idx * num_models + model_idx
                try:
                    file_path = f"../output/run-{run}/{task}/{model}-{task}-test_hard-{method}.json"
                    data = load_json(file_path)
                    dedup_data = deduplicate_data(data)
                    acc = calculate_accuracy(dedup_data) * 100
                    run_accs.append(acc)
                except Exception as e:
                    pass
            if len(run_accs) > 0:
                per_run_avgs_by_model[model].append(np.mean(run_accs))
        # Also collect all means for the model (for legacy code, not used for std)
        vals_by_model[model] = []
        for task_idx, task in enumerate(tasks):
            cell_idx = task_idx * num_models + model_idx
            mean_val, std_val = table_data[row_idx][cell_idx][0]
            if mean_val is not None:
                vals_by_model[model].append(mean_val)

    # Calculate and append averages (mean and std across 3 runs of average accuracy)
    for model in models:
        if len(per_run_avgs_by_model[model]) > 0:
            avg_val = np.mean(per_run_avgs_by_model[model])
            std_val = np.std(per_run_avgs_by_model[model], ddof=1) if len(per_run_avgs_by_model[model]) > 1 else 0.0
            table_data[row_idx].append(((avg_val, std_val), f"{avg_val:.1f} $\\pm$ {std_val:.1f}"))
        else:
            table_data[row_idx].append(((None, None), "N/A"))

# Calculate best and second best for average column
avg_values = []
for row in table_data:
    for j in range(num_cols, num_cols + num_models):
        val = row[j][0][0]
        if val is not None:
            avg_values.append(val)

avg_best = max(avg_values) if avg_values else None
lower_avg_values = [v for v in avg_values if v < avg_best]
avg_second = max(lower_avg_values) if lower_avg_values else None
avg_min = min(avg_values) if avg_values else None
avg_max = max(avg_values) if avg_values else None

# Step 3: Print the LaTeX table with heatmap coloring and best/second-best formatting based on dataset group.
print("\\begin{table*}[t]")
print("  \\centering")
print("  \\small")
print("  \\setlength{\\tabcolsep}{2.2mm}{")
print("  \\scalebox{0.47}{")
print(f"  \\begin{{tabular}}{{l|{'c|' * (num_tasks * num_models)}{'c|' * num_models}}}")
print("    \\toprule")
header1 = ["\\textbf{Method}"]
for task in tasks:
    header1.append(f"\\multicolumn{{{num_models}}}{{c|}}{{\\textbf{{{task_map.get(task, task)}}}}}")
header1.append(f"\\multicolumn{{{num_models}}}{{c|}}{{\\textbf{{Average}}}}")
print("    " + " & ".join(header1) + " \\\\")
print("    \\midrule")
header2 = [""] + [" \\footnotesize{\\textsc{4o-m}} & \\textsc{4o} & \\textsc{o3-m}"] * (len(tasks) + 1)
print("    " + " & ".join(header2) + " \\\\")
print("    \\midrule")

for row_idx, method in enumerate(methods):
    formatted_method = "\\textsc{" + method_map.get(method, method) + "}"
    row_cells = [formatted_method]
    for j, (cell_val, cell_str) in enumerate(table_data[row_idx]):
        mean_val, std_val = cell_val
        # Determine which task group this column belongs to.
        if j < num_cols:
            task_idx = j // num_models
            if mean_val is not None and group_min[task_idx] is not None and group_max[task_idx] is not None:
                # Normalize the value for heatmap coloring within the dataset group.
                if (group_max[task_idx] - group_min[task_idx]) != 0:
                    norm = (mean_val - group_min[task_idx]) / (group_max[task_idx] - group_min[task_idx])
                else:
                    norm = 1
                r = int(round(255 + norm * (238 - 255)))
                g = int(round(255 + norm * (156 - 255)))
                b = int(round(255 + norm * (167 - 255)))
                color_hex = f"{r:02X}{g:02X}{b:02X}"
                # Base cell text.
                cell_text = f"{mean_val:.1f} $\\pm$ {std_val:.1f}"
                # Apply best/second-best formatting based on the dataset group.
                if mean_val == group_best[task_idx]:
                    cell_text = "\\textbf{" + cell_text + "}"
                elif group_second[task_idx] is not None and mean_val == group_second[task_idx]:
                    cell_text = "\\underline{" + cell_text + "}"
                cell_str = f"\\cellcolor[HTML]{{{color_hex}}}" + cell_text
            else:
                cell_str = "N/A"
        else:
            # Handle average columns
            if mean_val is not None and avg_min is not None and avg_max is not None:
                if (avg_max - avg_min) != 0:
                    norm = (mean_val - avg_min) / (avg_max - avg_min)
                else:
                    norm = 1
                r = int(round(255 + norm * (238 - 255)))
                g = int(round(255 + norm * (156 - 255)))
                b = int(round(255 + norm * (167 - 255)))
                color_hex = f"{r:02X}{g:02X}{b:02X}"
                cell_text = f"{mean_val:.1f} $\\pm$ {std_val:.1f}"
                if mean_val == avg_best:
                    cell_text = "\\textbf{" + cell_text + "}"
                elif avg_second is not None and mean_val == avg_second:
                    cell_text = "\\underline{" + cell_text + "}"
                cell_str = f"\\cellcolor[HTML]{{{color_hex}}}" + cell_text
            else:
                cell_str = "N/A"
        row_cells.append(cell_str)
    print("    " + " & ".join(row_cells) + " \\\\")
    print("    \\midrule")
print("    \\bottomrule")
print("  \\end{tabular}")
print("  }")
print("}")
print("  \\caption{\\textbf{Performance heatmap by task and method.} All the tasks are evaluated on the \\textsc{Hard} set. For each task, three models are evaluated in order: \\textsc{GPT-4o-mini}, \\textsc{GPT-4o}, and \\textsc{o3-mini}. Accuracy values are in percentages, reported as mean $\\pm$ std over 3 runs. The best values are highlighted in \\textbf{bold}, and the second-best values are highlighted in \\underline{underlined} format.}")
print("  \\label{tab:method_task_heatmap}")
print("\\end{table*}")

\begin{table*}[t]
  \centering
  \small
  \setlength{\tabcolsep}{2.2mm}{
  \scalebox{0.47}{
  \begin{tabular}{l|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|}
    \toprule
    \textbf{Method} & \multicolumn{3}{c|}{\textbf{MedQA}} & \multicolumn{3}{c|}{\textbf{PubMedQA}} & \multicolumn{3}{c|}{\textbf{MedMCQA}} & \multicolumn{3}{c|}{\textbf{MedBullets}} & \multicolumn{3}{c|}{\textbf{MMLU}} & \multicolumn{3}{c|}{\textbf{MMLU-Pro}} & \multicolumn{3}{c|}{\textbf{MedExQA}} & \multicolumn{3}{c|}{\textbf{MedXpertQA-R}} & \multicolumn{3}{c|}{\textbf{MedXpertQA-U}} & \multicolumn{3}{c|}{\textbf{Average}} \\
    \midrule
     &  \footnotesize{\textsc{4o-m}} & \textsc{4o} & \textsc{o3-m} &  \footnotesize{\textsc{4o-m}} & \textsc{4o} & \textsc{o3-m} &  \footnotesize{\textsc{4o-m}} & \textsc{4o} & \textsc{o3-m} &  \footnotesize{\textsc{4o-m}} & \textsc{4o} & \textsc{o3-m} &  \footnotesize{\textsc{4o-m}} & \textsc{4o} & \textsc{o3-m} &  \footnotesize{\textsc{4o-m}} & \textsc{4o} & \tex

In [3]:
tasks = ['medqa', 'pubmedqa', 'medmcqa', 'medbullets', 'mmlu', 'mmlu-pro','medexqa', 'medxpertqa-r', 'medxpertqa-u']
task_map = {
    'medqa': 'MedQA',
    'pubmedqa': 'PubMedQA',
    'medmcqa': 'MedMCQA',
    'medbullets': 'MedBullets',
    'mmlu': 'MMLU',
    'mmlu-pro': 'MMLU-Pro',
    'afrimedqa': 'AfriMedQA',
    'medexqa': 'MedExQA',
    'medxpertqa-r': 'MedXpertQA-R',
    'medxpertqa-u': 'MedXpertQA-U',
}
models = [
    'gpt-4o-mini',
    'gpt-4o',
    'DeepSeek-V3',
]
model_map = {
    'gpt-4o-mini': 'GPT-4o-mini',
    'gpt-4o': 'GPT-4o',
    'DeepSeek-V3': 'DeepSeek-V3',
}
method_map = {
    'zero_shot': 'Zero-shot',
    'few_shot': 'Few-shot',
    'cot': 'CoT',
    'cot_sc-5': 'CoT-SC',
    'multipersona-2': 'MultiPersona',
    'self_refine-3': 'Self-Refine',
    'medprompt-3': 'MedPrompt',
    'medagents': 'MedAgents',
    'mdagents': 'MDAgents',
    'spo': 'SPO',
    'aflow': 'AFlow'
}
methods = ['zero_shot', 'few_shot', 'cot', 'cot_sc-5', 'multipersona-2', 'self_refine-3', 'medprompt-3', 'medagents', 'mdagents', 'spo', 'aflow']

# Construct and print a nicely formatted LaTeX heatmap table with tasks as rows and methods as columns,
# where each task row is subdivided into three subrows for the different models.
# Each accuracy value is colored with a gradient from #EE9CA7 to #FFDDE1 based on its relative value within its dataset group.
# Additionally, the best accuracy in each dataset group is formatted with both bold and underline,
# while the second best is formatted with bold only.

# Step 1: Build a table of results (each cell: (numeric_value, formatted_string))
# Transposed: Each row corresponds to a (task, model); each row is a list of tuples for each method.
table_data = []
for task_idx, task in enumerate(tasks):
    for model_idx, model in enumerate(models):
        row = []
        for method in methods:
            try:
                file_path = f"../output/{task}/{model}-{task}-test_hard-{method}.json"
                data = load_json(file_path)
                dedup_data = deduplicate_data(data)
                acc = calculate_accuracy(dedup_data) * 100  # Convert to percentage.
                cell_value = acc
                cell_str = f"{acc:.1f}"
            except Exception as e:
                cell_value = None
                cell_str = "N/A"
            row.append((cell_value, cell_str))
        table_data.append(row)

num_tasks = len(tasks)
num_models = len(models)
num_methods = len(methods)

# Step 2: Determine the best and second best values and the min/max for heatmap normalization per dataset group.
group_best = [None] * num_tasks
group_second = [None] * num_tasks
group_min = [None] * num_tasks
group_max = [None] * num_tasks

for i in range(num_tasks):
    group_values = []
    # For each task, collect all values across all models and methods
    for model_idx in range(num_models):
        row_idx = i * num_models + model_idx
        for method_idx in range(num_methods):
            val = table_data[row_idx][method_idx][0]
            if val is not None:
                group_values.append(val)
    
    if group_values:
        best_val = max(group_values)
        group_best[i] = best_val
        lower_values = [v for v in group_values if v < best_val]
        group_second[i] = max(lower_values) if lower_values else None
        group_min[i] = min(group_values)
        group_max[i] = max(group_values)
    else:
        group_best[i] = None
        group_second[i] = None
        group_min[i] = None
        group_max[i] = None

# Step 3: Print the LaTeX table with heatmap coloring and best/second-best formatting based on dataset group.
print("\\begin{table*}[t]")
print("  \\centering")
print("  \\small")
print("  \\setlength{\\tabcolsep}{2.2mm}{")
print("  \\scalebox{0.47}{")
# Define the tabular: one left-aligned column for Task/Model and one column per method
print("  \\begin{tabular}{lc|*{11}{c}|}")
print("    \\toprule")
# First header row: Task and Model columns, plus each method
print("    \\textbf{Task} & \\textbf{Model} & " + " & ".join([f"\\textbf{{\\textsc{{" + method_map.get(method, method) + "}}" for method in methods]) + " \\\\")
print("    \\midrule")

# Print each task with its three model subrows
for task_idx, task in enumerate(tasks):
    # Print task name spanning across the first column for all models
    print(f"    \\multirow{{{num_models}}}{{*}}{{\\textbf{{" + task_map.get(task, task) + "}} ", end="")
    # Print first model row
    first_model = models[0]
    row_idx = task_idx * num_models
    row_cells = []
    for method_idx, method in enumerate(methods):
        num_val, _ = table_data[row_idx][method_idx]
        if num_val is not None and group_min[task_idx] is not None and group_max[task_idx] is not None:
            # Normalize the value for heatmap coloring within the dataset group.
            if (group_max[task_idx] - group_min[task_idx]) != 0:
                norm = (num_val - group_min[task_idx]) / (group_max[task_idx] - group_min[task_idx])
            else:
                norm = 1
            # Compute interpolated color between start (#EE9CA7) and end (#FFDDE1) colors.
            r = int(round(255 + norm * (238 - 255)))  # 238 corresponds to EE in hex.
            g = int(round(255 + norm * (156 - 255)))
            b = int(round(255 + norm * (167 - 255)))
            color_hex = f"{r:02X}{g:02X}{b:02X}"
            # Base cell text.
            cell_text = f"{num_val:.1f}"
            # Apply best/second-best formatting based on the dataset group.
            if num_val == group_best[task_idx]:
                cell_text = "\\textbf{" + cell_text + "}"
            elif group_second[task_idx] is not None and num_val == group_second[task_idx]:
                cell_text = "\\underline{" + cell_text + "}"
            cell_str = f"\\cellcolor[HTML]{{{color_hex}}}" + cell_text
        else:
            cell_str = "N/A"
        row_cells.append(cell_str)
    print("& \\footnotesize{\\textsc{4o-m}} & " + " & ".join(row_cells) + " \\\\")
    
    # Print remaining model rows
    for model_idx in range(1, num_models):
        model = models[model_idx]
        model_display = "\\textsc{4o}" if model_idx == 1 else "\\footnotesize{\\textsc{ds-v3}}"
        row_idx = task_idx * num_models + model_idx
        row_cells = []
        for method_idx, method in enumerate(methods):
            num_val, _ = table_data[row_idx][method_idx]
            if num_val is not None and group_min[task_idx] is not None and group_max[task_idx] is not None:
                # Normalize the value for heatmap coloring within the dataset group.
                if (group_max[task_idx] - group_min[task_idx]) != 0:
                    norm = (num_val - group_min[task_idx]) / (group_max[task_idx] - group_min[task_idx])
                else:
                    norm = 1
                # Compute interpolated color between start (#EE9CA7) and end (#FFDDE1) colors.
                r = int(round(255 + norm * (238 - 255)))  # 238 corresponds to EE in hex.
                g = int(round(255 + norm * (156 - 255)))
                b = int(round(255 + norm * (167 - 255)))
                color_hex = f"{r:02X}{g:02X}{b:02X}"
                # Base cell text.
                cell_text = f"{num_val:.1f}"
                # Apply best/second-best formatting based on the dataset group.
                if num_val == group_best[task_idx]:
                    cell_text = "\\textbf{" + cell_text + "}"
                elif group_second[task_idx] is not None and num_val == group_second[task_idx]:
                    cell_text = "\\underline{" + cell_text + "}"
                cell_str = f"\\cellcolor[HTML]{{{color_hex}}}" + cell_text
            else:
                cell_str = "N/A"
            row_cells.append(cell_str)
        print(f"    & {model_display} & " + " & ".join(row_cells) + " \\\\")
    
    # Add midrule after each task (except the last one)
    if task_idx < num_tasks - 1:
        print("    \\midrule")

print("    \\bottomrule")
print("  \\end{tabular}")
print("  }")
print("}")
print("  \\caption{\\textbf{Performance heatmap by task and method.} All the tasks are evaluated on the \\textsc{Hard} set. For each task, three models are evaluated: \\textsc{GPT-4o-mini}, \\textsc{GPT-4o}, and \\textsc{DeepSeek-V3}. Accuracy values are in percentages. The best values are highlighted in \\textbf{bold}, and the second-best values are highlighted in \\underline{underlined} format.}")
print("  \\label{tab:method_task_heatmap}")
print("\\end{table*}")

\begin{table*}[t]
  \centering
  \small
  \setlength{\tabcolsep}{2.2mm}{
  \scalebox{0.47}{
  \begin{tabular}{lc|*{11}{c}|}
    \toprule
    \textbf{Task} & \textbf{Model} & \textbf{\textsc{Zero-shot}} & \textbf{\textsc{Few-shot}} & \textbf{\textsc{CoT}} & \textbf{\textsc{CoT-SC}} & \textbf{\textsc{MultiPersona}} & \textbf{\textsc{Self-Refine}} & \textbf{\textsc{MedPrompt}} & \textbf{\textsc{MedAgents}} & \textbf{\textsc{MDAgents}} & \textbf{\textsc{SPO}} & \textbf{\textsc{AFlow}} \\
    \midrule
    \multirow{3}{*}{\textbf{MedQA}} & \footnotesize{\textsc{4o-m}} & \cellcolor[HTML]{FCECEE}22.0 & \cellcolor[HTML]{F8D4D8}30.0 & \cellcolor[HTML]{FCF0F1}21.0 & \cellcolor[HTML]{FDF3F4}20.0 & \cellcolor[HTML]{F8D7DB}29.0 & \cellcolor[HTML]{F6CED3}32.0 & \cellcolor[HTML]{FBE9EC}23.0 & \cellcolor[HTML]{FBE6E9}24.0 & \cellcolor[HTML]{FBE6E9}24.0 & \cellcolor[HTML]{FDF6F7}19.0 & \cellcolor[HTML]{F8D4D8}30.0 \\
    & \textsc{4o} & \cellcolor[HTML]{F6CED3}32.0 & \cellcolor[HTML]{F9DADE}28.0 & \cell

In [6]:
tasks = ['medqa', 'pubmedqa', 'medmcqa', 'medbullets', 'mmlu', 'mmlu-pro', 'medexqa', 'medxpertqa-r', 'medxpertqa-u']
task_map = {
    'medqa': 'MedQA',
    'pubmedqa': 'PubMedQA',
    'medmcqa': 'MedMCQA',
    'medbullets': 'MedBullets',
    'mmlu': 'MMLU',
    'mmlu-pro': 'MMLU-Pro',
    'afrimedqa': 'AfriMedQA',
    'medexqa': 'MedExQA',
    'medxpertqa-r': 'MedXpertQA-R',
    'medxpertqa-u': 'MedXpertQA-U',
}
models = [
    'gpt-4o-mini',
    'gpt-4o',
    'DeepSeek-V3',
    'o1-mini',
    'o3-mini',
    'QwQ-32B',
    'DeepSeek-R1',
    'Llama-3.3-70B-Instruct-Turbo',
    'claude-3-5-sonnet',
    'claude-3-5-haiku'
]
model_map = {
    'gpt-4o-mini': 'GPT-4o-mini',
    'gpt-4o': 'GPT-4o',
    'o1-mini': 'o1-mini',
    'o3-mini': 'o3-mini',
    'DeepSeek-V3': 'DeepSeek-V3',
    'DeepSeek-R1': 'DeepSeek-R1',
    'QwQ-32B-Preview': 'QwQ-32B',
    'Llama-3.3-70B-Instruct-Turbo': 'Llama-3.3-70B',
    'claude-3-5-sonnet': 'Claude-3.5-Sonnet',
    'claude-3-5-haiku': 'Claude-3.5-Haiku'
}

# Construct and print a nicely formatted LaTeX heatmap figure with models as rows and tasks as columns.
# Now, for each dataset task, we report two columns: one for test and one for test_hard.
table_data = []  # Each row corresponds to a model; each row is a list of tuples for each dataset subset.
for model in models:
    row = []
    for task in tasks:
        for subset in ["test", "test_hard"]:
            try:
                file_path = f"../output/{task}/{model}-{task}-{subset}-zero_shot.json"
                data = load_json(file_path)
                dedup_data = deduplicate_data(data)
                acc = calculate_accuracy(dedup_data) * 100  # Convert to percentage.
                cell_value = acc
                cell_str = f"{acc:.1f}"
            except Exception as e:
                cell_value = None
                cell_str = "N/A"
            row.append((cell_value, cell_str))
    table_data.append(row)

num_tasks = len(tasks)
num_models = len(models)

# Determine the best and second-best values and the min/max for heatmap normalization for each (task, subset) pair.
# Each task now has two subsets: index 0 for "test" and index 1 for "test_hard".
group_best = [[None, None] for _ in range(num_tasks)]
group_second = [[None, None] for _ in range(num_tasks)]
group_min = [[None, None] for _ in range(num_tasks)]
group_max = [[None, None] for _ in range(num_tasks)]

for task_idx in range(num_tasks):
    for subset_idx in range(2):
        col_index = task_idx * 2 + subset_idx
        group_values = []
        for i in range(num_models):
            val = table_data[i][col_index][0]
            if val is not None:
                group_values.append(val)
        if group_values:
            best_val = max(group_values)
            group_best[task_idx][subset_idx] = best_val
            lower_values = [v for v in group_values if v < best_val]
            group_second[task_idx][subset_idx] = max(lower_values) if lower_values else None
            group_min[task_idx][subset_idx] = min(group_values)
            group_max[task_idx][subset_idx] = max(group_values)
        else:
            group_best[task_idx][subset_idx] = None
            group_second[task_idx][subset_idx] = None
            group_min[task_idx][subset_idx] = None
            group_max[task_idx][subset_idx] = None

# Print the LaTeX figure with heatmap coloring and best/second-best formatting for each (task, subset) pair.
print("\\begin{table*}[t]")
print("  \\centering")
print("  \\small")
print("  \\setlength{\\tabcolsep}{2.2mm}{")
print("  \\scalebox{0.62}{")
# Define the tabular: one left-aligned column for Model and two columns per task.
print(f"  \\begin{{tabular}}{{l|{'c|c|' * num_tasks}}}")
print("    \\toprule")
# First header row: "Model" plus each task spanning two columns.
header1 = ["\\multirow{2}{*}{\\textbf{Model}}"]
for task in tasks:
    header1.append(f"\\multicolumn{{2}}{{c|}}{{\\textbf{{{task_map.get(task, task)}}}}}")
print("    " + " & ".join(header1) + " \\\\")
# Second header row: empty cell for Model, then subheaders for Test and Test_hard.
header2 = [""]
for _ in tasks:
    header2.extend(["\\textsc{Full}", "\\textsc{Hard}"])
print("    " + " & ".join(header2) + " \\\\")
print("    \\midrule")
for idx, model in enumerate(models):
    formatted_model = "\\textsc{" + model_map.get(model, model) + "}"
    row_cells = [formatted_model]
    for j, (num_val, _) in enumerate(table_data[idx]):
        task_index = j // 2
        subset_index = j % 2
        if num_val is not None and group_min[task_index][subset_index] is not None and group_max[task_index][subset_index] is not None:
            if (group_max[task_index][subset_index] - group_min[task_index][subset_index]) != 0:
                norm = (num_val - group_min[task_index][subset_index]) / (group_max[task_index][subset_index] - group_min[task_index][subset_index])
            else:
                norm = 1
            # Compute interpolated color between start (#EE9CA7) and end (#FFDDE1).
            r = int(round(255 + norm * (238 - 255)))  # 238 corresponds to EE in hex.
            g = int(round(255 + norm * (156 - 255)))
            b = int(round(255 + norm * (167 - 255)))
            color_hex = f"{r:02X}{g:02X}{b:02X}"
            cell_text = f"{num_val:.1f}"
            if num_val == group_best[task_index][subset_index]:
                cell_text = "\\textbf{" + cell_text + "}"
            elif group_second[task_index][subset_index] is not None and num_val == group_second[task_index][subset_index]:
                cell_text = "\\underline{" + cell_text + "}"
            cell_str = f"\\cellcolor[HTML]{{{color_hex}}}" + cell_text
        else:
            cell_str = "N/A"
        row_cells.append(cell_str)
    print("    " + " & ".join(row_cells) + " \\\\")
    print("    \\midrule")
print("    \\bottomrule")
print("  \\end{tabular}")
print("  }")
print("}")
print("  \\caption{\\textbf{Performance heatmap by model and task.} For each task, accuracy values are in percentages, with separate columns for \\textsc{Full} and \\textsc{Hard}. The best values are highlighted in \\textbf{bold}, and the second-best values are highlighted in \\underline{underlined} format.}")
print("  \\label{fig:model_task_heatmap}")
print("\\end{table*}")

\begin{table*}[t]
  \centering
  \small
  \setlength{\tabcolsep}{2.2mm}{
  \scalebox{0.62}{
  \begin{tabular}{l|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|c|}
    \toprule
    \multirow{2}{*}{\textbf{Model}} & \multicolumn{2}{c|}{\textbf{MedQA}} & \multicolumn{2}{c|}{\textbf{PubMedQA}} & \multicolumn{2}{c|}{\textbf{MedMCQA}} & \multicolumn{2}{c|}{\textbf{MedBullets}} & \multicolumn{2}{c|}{\textbf{MMLU}} & \multicolumn{2}{c|}{\textbf{MMLU-Pro}} & \multicolumn{2}{c|}{\textbf{MedExQA}} & \multicolumn{2}{c|}{\textbf{MedXpertQA-R}} & \multicolumn{2}{c|}{\textbf{MedXpertQA-U}} \\
     & \textsc{Full} & \textsc{Hard} & \textsc{Full} & \textsc{Hard} & \textsc{Full} & \textsc{Hard} & \textsc{Full} & \textsc{Hard} & \textsc{Full} & \textsc{Hard} & \textsc{Full} & \textsc{Hard} & \textsc{Full} & \textsc{Hard} & \textsc{Full} & \textsc{Hard} & \textsc{Full} & \textsc{Hard} \\
    \midrule
    \textsc{GPT-4o-mini} & \cellcolor[HTML]{F9DDE1}73.4 & \cellcolor[HTML]{FBE9EB}22.0 & \cellcolor[HTML]{F8D4D9}76.2 &