In [1]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from analysis_utils import get_run_data, process_run_data, add_significance_marks, METRICS, aggregate_metrics, metrics_modes, metrics_names
metric_groups = METRICS["metric_groups"]

In [2]:
data, config = get_run_data("eli-carrami/Cprt-Paper-Exp-2")
best_basis = "biochem/val_localization_f1"
out = []
for d, c in zip(data, config):
    if len(d) > 0:
        if best_basis:
            h = d.iloc[d[best_basis].idxmax()].copy()
        else:
            d = d[d.epoch == 0]
            h = d.iloc[-1].copy()
        h['esm'] = c['model']['value']['protein_model']
        h['llm'] = c['model']['value']['language_model']
        h['strategy'] = c['model']['value']['multimodal_strategy']
        h['layers'] = c['model']['value']['multimodal_layers']
        h['split'] = c['datamodule']['value']['split_ratios']
        h['subsample'] = c['datamodule']['value']['subsample_data']
        h['fields'] = c['datamodule']['value']['data_field_names']
        h['seed'] = c['seed']['value']
        out.append(h)

In [3]:
model_order = ["gpt2", "gpt2-medium", "gpt2-xl", "microsoft/phi-2"]
# model_order = ["esm2_t12_35M_UR50D", "esm2_t33_650M_UR50D"]
var = 'llm'
ordering = (var, model_order)
df = pd.DataFrame(out).reset_index(inplace=False, drop=True)
df=df[df.strategy=="soft-prompt"]
# df=df[df.split.apply(lambda x: x == [0.945, 0.005, 0.05])]
df=df[df.esm=="esm2_t33_650M_UR50D"]
# df=df[df.llm=="microsoft/phi-2"]
# df=df[df["trainer/global_step"] > 150000]
df["avg_binary_loc_f1"] = df[[col for col in df.columns if "_in_" in col]].mean(axis=1)
df = df[df[ordering[0]].isin(ordering[1])]
df[ordering[0]] = pd.Categorical(df[ordering[0]], categories=ordering[1], ordered=True)
df.sort_values([ordering[0], 'seed'], inplace=True)
df['metrics/val_perplexity'] = df['metrics/val_perplexity'].astype(float)
df

Unnamed: 0,trainer/global_step,epoch,loss/val_loss,biochem/val_cofactor,biochem/val_in_nucleus_f1,biochem/val_in_membrane_f1,biochem/val_in_mitochondria_f1,biochem/val_is_enzyme_f1,biochem/val_is_enzyme_hard_f1,biochem/val_is_fake_f1,...,metrics/val_perplexity,esm,llm,strategy,layers,split,subsample,fields,seed,avg_binary_loc_f1
5,224084.0,0.0,1.087606,0.426087,0.151601,0.227038,0.07192,0.889134,0.889193,0.011772,...,2.980547,esm2_t33_650M_UR50D,gpt2,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,7,0.150186
3,89721.0,0.0,1.204092,0.380952,0.312457,0.396514,0.325027,0.871486,0.877093,0.013587,...,3.359453,esm2_t33_650M_UR50D,gpt2,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,42,0.344666
4,179267.0,0.0,1.075419,0.36087,0.148184,0.270733,0.273378,0.890481,0.891329,0.011023,...,2.943372,esm2_t33_650M_UR50D,gpt2-medium,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,7,0.230765
2,89721.0,0.0,1.117515,0.212121,0.139503,0.216803,0.075927,0.873891,0.871283,0.035074,...,3.077381,esm2_t33_650M_UR50D,gpt2-medium,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,42,0.144077
0,134450.0,0.0,0.902861,0.547826,0.125315,0.193948,0.047826,0.885853,0.888213,0.021816,...,2.473285,esm2_t33_650M_UR50D,gpt2-xl,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,7,0.122363
1,224304.0,0.0,0.911543,0.380952,0.131052,0.206936,0.051657,0.872637,0.876828,0.015115,...,2.498868,esm2_t33_650M_UR50D,gpt2-xl,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,42,0.129882
12,89683.0,0.0,0.707736,0.494382,0.675456,0.276433,0.383909,0.904741,0.905538,0.315818,...,2.031271,esm2_t33_650M_UR50D,microsoft/phi-2,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,0,0.445266
16,156855.0,0.0,0.673837,0.565217,0.710948,0.342321,0.367271,0.886274,0.889479,0.980991,...,1.962734,esm2_t33_650M_UR50D,microsoft/phi-2,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,7,0.473513
8,179443.0,0.0,0.669927,0.571429,0.622256,0.275575,0.283083,0.882068,0.884491,0.670673,...,1.95587,esm2_t33_650M_UR50D,microsoft/phi-2,soft-prompt,[0],"[0.945, 0.005, 0.05]",1,qa,42,0.393638


In [4]:
metrics_names = {k:v for k, v in metrics_names.items() if 'rouge' not in k}
agg_df = aggregate_metrics(df, group_by=var)
agg_df = agg_df[[col for col in metrics_names]]

for col, name in metrics_names.items():
    agg_df[name] = round(agg_df[(col, 'mean')], 2).astype(str) + " (" + round(agg_df[(col, 'std')], 2).astype(str) + ")"
    agg_df.drop([(col, 'mean'), (col, 'std')], axis=1, inplace=True)

agg_df.to_clipboard()

  return df.groupby(group_by).agg(['mean', 'std'])
