In [1]:
from run_battery import BatteryConfigs
from metric import Metric
import os
import json
from scipy import stats
import numpy as np

In [2]:
from enum import Enum
LinearGradeMetrics = Enum("LinearGradeMetrics", [ "RMSD", "MAE" ])

def eval_linear_metric(method, predictions, targets):
    if method == LinearGradeMetrics.RMSD:
        return np.sqrt(np.mean((predictions - targets) ** 2))
    elif method == LinearGradeMetrics.MAE:
        return np.mean(np.abs(predictions - targets))
    else:
        assert False, f"Cannot find metric {method}"

def predict(regression, x):
    return x * regression.slope + regression.intercept

xs = [ 0.35, 2.7, 6.1, 16.1 ]
def grade_emergence(ys, threshold, method=LinearGradeMetrics.RMSD):
    regression = stats.linregress(xs, ys)
    predictions = predict(regression, np.array(xs))
    targets = np.array(ys)
    
    metric = eval_linear_metric(method, predictions, targets)
    emergent = not (metric <= threshold)
    return {
        "metric": metric,
        "emergent": emergent,
    }

In [3]:
to_tabulate = [
    BatteryConfigs.Bugs2Fix,
    BatteryConfigs.Code2Code,
]

In [8]:
lineval_metrics = [ "RMSD", "MAE" ]

def fmt_number(number, bold=False):
    if number == 0:
        str_number = "0"
    else:
        str_number = f"{number:.4f}"
    if bold:
        str_number = f"\\textbf{{{str_number}}}"
    return str_number

for config in to_tabulate:
    table = [
        ["Metric", "Prompt", "350M", "2.7B", "6.1B", "16.1B", *lineval_metrics],
    ]

    cache_file_path = os.path.join("./output", config["task"], "metrics.json")
    with open(cache_file_path, "r", encoding="utf-8") as cache_file:
        cache = json.loads(cache_file.read())
    for metric_shortname, data in cache["results"].items():
        metric = Metric.from_shortname(metric_shortname)

        best_result = None
        for prompt_name, results in data.items():
            max_per = max(results)
            if best_result is None or max_per > best_result:
                best_result = max_per
        
        for prompt_name, results in data.items():
            
            emergence_evaluations = [
                fmt_number(result["metric"], bold=result["emergent"])
                for result in (
                    grade_emergence(
                        results,
                        threshold=0.1,
                        method=LinearGradeMetrics[metric]
                    )
                    for metric in lineval_metrics
                )
            ]
            
            str_results = [
                fmt_number(grade, bold=grade == best_result and best_result != 0)
                for grade in results
            ]
            
            table.append([
                #config["display_name"],
                metric.latex_name,
                prompt_name,
                *str_results,
                *emergence_evaluations,
            ])

    for idx, row in enumerate(table):
        row_display = " & ".join(row) + " \\\\"
        if idx == 0:
            row_display += " \\hline"
            
        print(row_display)

    print("-" * 30)

Metric & Prompt & 350M & 2.7B & 6.1B & 16.1B & RMSD & MAE \\ \hline
EM & prompt0 & 0 & 0 & 0 & 0 & 0 & 0 \\
EM & prompt1 & 0 & 0 & 0 & 0 & 0 & 0 \\
BLEU & prompt0 & 0.4882 & 0.5006 & 0.4571 & \textbf{0.5607} & 0.0252 & 0.0218 \\
BLEU & prompt1 & 0.1416 & 0.1801 & 0.2297 & 0.1747 & 0.0307 & 0.0256 \\
CodeBLEU (Java) & prompt0 & 0.6139 & 0.6458 & 0.5699 & \textbf{0.6492} & 0.0303 & 0.0247 \\
CodeBLEU (Java) & prompt1 & 0.1867 & 0.2634 & 0.3111 & 0.2642 & 0.0404 & 0.0369 \\
------------------------------
Metric & Prompt & 350M & 2.7B & 6.1B & 16.1B & RMSD & MAE \\ \hline
EM & prompt0 & 0 & 0 & 0 & 0 & 0 & 0 \\
EM & prompt1 & 0 & 0 & 0 & 0 & 0 & 0 \\
EM & prompt2 & 0 & 0 & 0 & 0 & 0 & 0 \\
EM & prompt3 & 0 & 0 & 0 & 0 & 0 & 0 \\
BLEU & prompt0 & 0.1474 & 0.1273 & 0.1394 & \textbf{0.1491} & 0.0078 & 0.0063 \\
BLEU & prompt1 & 0.1212 & 0.0719 & 0.0693 & 0.1429 & 0.0269 & 0.0254 \\
BLEU & prompt2 & 0.0886 & 0.0364 & 0.0083 & 0.0247 & 0.0245 & 0.0227 \\
BLEU & prompt3 & 0.0157 & 0.0204 & 0.020