In [1]:
from run_battery import BatteryConfigs
from metric import Metric
import os
import json
from scipy import stats
import numpy as np

In [2]:
from enum import Enum
LinearGradeMetrics = Enum("LinearGradeMetrics", [ "RMSD", "MAE" ])

def eval_linear_metric(method, predictions, targets):
    if method == LinearGradeMetrics.RMSD:
        return np.sqrt(np.mean((predictions - targets) ** 2))
    elif method == LinearGradeMetrics.MAE:
        return np.mean(np.abs(predictions - targets))
    else:
        assert False, f"Cannot find metric {method}"

def predict(regression, x):
    return x * regression.slope + regression.intercept

xs = [ 0.35, 2.7, 6.1, 16.1 ]
def grade_emergence(ys, threshold, method=LinearGradeMetrics.RMSD):
    regression = stats.linregress(xs, ys)
    predictions = predict(regression, np.array(xs))
    targets = np.array(ys)
    
    metric = eval_linear_metric(method, predictions, targets)
    emergent = not (metric <= threshold)
    return {
        "metric": metric,
        "emergent": emergent,
    }

In [4]:
to_tabulate = [
    BatteryConfigs.Bugs2Fix,
    BatteryConfigs.Bugs2FixChecklist,
    BatteryConfigs.Code2Code,
]

In [10]:
lineval_metrics = [ "RMSD", "MAE" ]

def fmt_number(number, bold=False):
    if number == 0:
        str_number = "0"
    else:
        str_number = f"{number:.4f}"
    if bold:
        str_number = f"\\textbf{{{str_number}}}"
    return str_number

for config in to_tabulate:
    table = [
        ["Metric", "Prompt", "350M", "2.7B", "6.1B", "16.1B", *lineval_metrics],
    ]

    cache_file_path = os.path.join("./output", config["task"], "metrics.json")
    with open(cache_file_path, "r", encoding="utf-8") as cache_file:
        cache = json.loads(cache_file.read())
    for metric_shortname, data in cache["results"].items():
        metric = Metric.from_shortname(metric_shortname)

        best_result = None
        prompt_count = 0
        for prompt_name, results in data.items():
            max_per = max(results)
            if best_result is None or max_per > best_result:
                best_result = max_per
            prompt_count += 1
        
        for prompt_name, results in data.items():
            
            emergence_evaluations = [
                fmt_number(result["metric"], bold=result["emergent"])
                for result in (
                    grade_emergence(
                        results,
                        threshold=0.1,
                        method=LinearGradeMetrics[metric]
                    )
                    for metric in lineval_metrics
                )
            ]
            
            str_results = [
                fmt_number(grade, bold=grade == best_result and best_result != 0)
                for grade in results
            ]
            
            table.append([
                #config["display_name"],
                metric.latex_name,
                prompt_name,
                *str_results,
                *emergence_evaluations,
            ])

    print("\\hline ", end="")
    for idx, row in enumerate(table):
        row_display = " & ".join(row) + " \\\\"
        if idx % prompt_count == 0:
            row_display += " \\hline"
        if idx == 0:
            row_display += " \\hline"
            
        print(row_display)

    print("-" * 30)

\hline Metric & Prompt & 350M & 2.7B & 6.1B & 16.1B & RMSD & MAE \\ \hline \hline
EM & prompt0 & 0 & 0 & 0 & 0 & 0 & 0 \\
EM & prompt1 & 0 & 0 & 0 & 0 & 0 & 0 \\
EM & prompt2 & 0 & 0 & 0 & \textbf{0.0100} & 0.0015 & 0.0012 \\ \hline
BLEU & prompt0 & 0.4882 & 0.5006 & 0.4571 & 0.5607 & 0.0252 & 0.0218 \\
BLEU & prompt1 & 0.1416 & 0.1801 & 0.2297 & 0.1747 & 0.0307 & 0.0256 \\
BLEU & prompt2 & 0.4920 & \textbf{0.5989} & 0.4642 & 0.5689 & 0.0526 & 0.0449 \\ \hline
CodeBLEU (Java) & prompt0 & 0.6134 & 0.6464 & 0.5699 & 0.6492 & 0.0304 & 0.0247 \\
CodeBLEU (Java) & prompt1 & 0.1862 & 0.2634 & 0.3111 & 0.2642 & 0.0406 & 0.0371 \\
CodeBLEU (Java) & prompt2 & 0.5918 & \textbf{0.6909} & 0.5499 & 0.6595 & 0.0538 & 0.0455 \\ \hline
------------------------------
\hline Metric & Prompt & 350M & 2.7B & 6.1B & 16.1B & RMSD & MAE \\ \hline \hline
EM & prompt0 & 0.0100 & 0 & 0.0100 & 0.0100 & 0.0041 & 0.0033 \\
EM & prompt1 & 0 & 0 & 0 & 0 & 0 & 0 \\
EM & prompt2 & 0 & 0 & \textbf{0.0200} & 0.0100 & 0.