In [1]:
import pandas as pd
from analysis_utils import get_run_data, process_run_data, aggregate_metrics, METRIC_NAMES

In [2]:
data, config = get_run_data("eli-carrami/Cprt-Paper-Baselines")
out  = []
models = {"linear": {0: {}, 7: {}}, "mlp": {0: {}, 7: {}}}
for d, c in zip(data, config):
    h = d.iloc[-1].copy()
    if 'classifier' in c['model']['value']:
        cls = c['model']['value']['classifier']
        task = c['datamodule']['value']['classification_task']
        if cls in models:
            models[cls][c['seed']['value']][task] = h['metrics/val_metric']
    elif "baseline" in c['model']['value']['multimodal_strategy']:
        h['baseline'] = f"{c['model']['value']['multimodal_strategy']}_{c['model']['value']['language_model']}"
        h['seed'] = c['seed']['value']
        out.append(h)
    else:
        raise ValueError("run not supported here")

In [22]:
data = []
for m, a in models.items():
    for s, v in a.items():
        data.append(v | {'seed': s, 'baseline': m})
df1 = pd.DataFrame(data, columns=['is_real', 'is_enzyme', 'kingdom', 'localization', 'mw', 'seed', 'baseline'])
df1['baseline'] = pd.Categorical(df1['baseline'])
df1

Unnamed: 0,is_real,is_enzyme,kingdom,localization,mw,seed,baseline
0,0.988,0.8056,0.939044,0.8742,0.332147,0,linear
1,0.98654,0.839272,0.947755,0.860417,0.183087,7,linear
2,0.984,0.8688,0.936573,0.876333,0.071771,0,mlp
3,0.989707,0.860649,0.951837,0.9125,0.06353,7,mlp


In [27]:
agg_df = aggregate_metrics(df1, group_by='baseline')
for col in ['is_real', 'is_enzyme', 'kingdom', 'localization', 'mw']:
    name = f"{col} F1" if col != "mw" else "MW MALE"
    agg_df[name] = round(agg_df[(col, 'mean')], 2).astype(str) + " (" + round(agg_df[(col, 'std')], 2).astype(str) + ")"
    agg_df.drop([(col, 'mean'), (col, 'std')], axis=1, inplace=True)
agg_df.to_clipboard()

  return df.groupby(group_by).agg(['mean', 'std'])


In [4]:
layer_order = ["random_baseline_gpt2-medium", "llm_only_baseline_gpt2-medium", "llm_only_baseline_microsoft/phi-2"]
var = 'baseline'
ordering = (var, layer_order)
df2 = process_run_data(out, [], ordering)
df2

Unnamed: 0,biochem/val_is_enzyme_hard_balanced_accuracy,_runtime,biochem/val_in_mitochondria_balanced_accuracy,biochem/val_kingdom_eukaryota_accuracy,biochem/val_average_semantic_localization,biochem/val_kingdom_archaea_accuracy,biochem/val_in_membrane_balanced_accuracy,_step,biochem/val_in_membrane_f1,biochem/val_localization_f1,...,biochem/val_is_fake_balanced_accuracy,biochem/val_cofactor,biochem/val_is_fake_f1,biochem/val_kingdom_bacteria_accuracy,biochem/val_localization_mitochondrion_accuracy,_timestamp,biochem/val_is_enzyme_hard_f1,baseline,seed,avg_binary_loc_f1
5,0.490918,13.239086,0.556434,0.502392,0.464433,0.017857,0.491224,0.0,0.458175,0.361975,...,0.488675,0.11985,0.488627,0.436911,0.212121,1705954000.0,0.490823,random_baseline_gpt2-medium,0,0.464433
4,0.49946,16.849513,0.512935,0.545319,0.454742,0.017241,0.475278,0.0,0.446363,0.316693,...,0.505962,0.108696,0.50593,0.390688,0.148148,1705954000.0,0.498013,random_baseline_gpt2-medium,7,0.454742
3,0.495364,11.683635,0.487012,0.536364,0.44996,0.0,0.500012,0.0,0.467871,0.326365,...,0.480424,0.155844,0.479772,0.399202,0.076923,1705954000.0,0.495264,random_baseline_gpt2-medium,42,0.44996
8,0.0,2.878841,0.5,0.0,0.310313,0.0,0.5,0.0,0.166389,0.196643,...,0.5,0.0,0.325707,0.0,0.0,1705589000.0,0.0,llm_only_baseline_gpt2-medium,0,0.310313
7,0.0,2.396601,0.5,0.0,0.305945,0.0,0.5,0.0,0.182553,0.172043,...,0.5,0.019231,0.335085,0.0,0.0,1705589000.0,0.0,llm_only_baseline_gpt2-medium,7,0.305945
6,0.0,2.380484,0.5,0.0,0.308423,0.0,0.5,0.0,0.204687,0.16092,...,0.5,0.0,0.358942,0.0,0.0,1705590000.0,0.0,llm_only_baseline_gpt2-medium,42,0.308423
2,0.5,2.572741,0.5,1.0,0.307788,0.0,0.5,0.0,0.182472,0.0,...,0.5,0.022472,0.339672,0.0,0.0,1705961000.0,0.342451,llm_only_baseline_microsoft/phi-2,0,0.307788
1,0.5,2.65724,0.5,1.0,0.308877,0.0,0.5,0.0,0.189345,0.0,...,0.5,0.030435,0.334212,0.0,0.0,1705961000.0,0.345935,llm_only_baseline_microsoft/phi-2,7,0.308877
0,0.5,2.527341,0.5,1.0,0.309638,0.0,0.5,0.0,0.201805,0.0,...,0.5,0.04329,0.323127,0.0,0.0,1705962000.0,0.344627,llm_only_baseline_microsoft/phi-2,42,0.309638


In [23]:
agg_df = aggregate_metrics(df2, group_by=var)
metrics_names = {k:v for k, v in METRIC_NAMES.items() if 'rouge' not in k and k in agg_df.columns}

agg_df = agg_df[[col for col in metrics_names]]

for col, name in metrics_names.items():
    agg_df[name] = round(agg_df[(col, 'mean')], 2).astype(str) + " (" + round(agg_df[(col, 'std')], 2).astype(str) + ")"
    agg_df.drop([(col, 'mean'), (col, 'std')], axis=1, inplace=True)

agg_df.to_clipboard()

  return df.groupby(group_by).agg(['mean', 'std'])
