# Check the status of uploaded experiments

## Prerequisites

In [1]:
import datetime
import importlib
import re
from tqdm import tqdm

import plotnine as p9

import pandas as pd

from utils import add_src_to_sys_path

add_src_to_sys_path()

from common import wandb_utils, nest

wandb_utils = importlib.reload(wandb_utils)

KeyboardInterrupt: 

## Load the experiments

In [None]:
wandb_api = wandb_utils.get_wandb_api()

In [None]:
runs = wandb_api.runs(
    f"{wandb_utils.get_entity_name()}/{wandb_utils.get_project_name()}",
    filters={
        "tags": {
            "$in": ["attention_analysis"]
        }
    }
)

runs = list(runs)
len(runs)

In [5]:
manually_checked_to_be_complete = {}

In [12]:
def get_model_name(group_str: str) -> str:
    pe_names = ["pe_none", "pe_t5", "pe_abs_sin", "pe_alibi", "pe_rotary", "pe_newRot"]
    for pe in pe_names:
        if pe in group_str:
            return pe
    raise ValueError("Invalid pe")

def get_launcher(run) -> str:
    if run.job_type == "agent":
        return run.id

    launcher_tag = [t for t in run.tags if t.startswith("launched_by_")][0]
    return launcher_tag.split("launched_by_")[1]


def check_runs_generated_by_launcher_id(launcher_id: str) -> None:
    l_runs = wandb_api.runs(
        f"{wandb_utils.get_entity_name()}/{wandb_utils.get_project_name()}",
        filters={
            "tags": {
                "$in": ["launched_by_" + launcher_id]
            }
        }
    )

    l_runs = list(l_runs)

    for run in l_runs:
        if run.job_type == "agent":
            continue
        print(run.id, run.state, run.job_type, run.group, run.tags)
        is_complete = is_run_complete(run)
        print(is_complete)
        print()

dataset_to_num_metadata_files = {
    ("s2s_addition", "len_tr8_ts16"): 256,
    ('s2s_copy', 'cmc2x_tr20_ts40'): 39,
    ('s2s_copy', 'rsc2x_tr20_ts40'): 39,
    ('s2s_copy', 'cmc_tr20_ts40'): 39,
    ('s2s_copy', 'rdc_tr20_ts40'): 39,
    ('s2s_copy', 'rsc_tr20_ts40'): 39,
    ('s2s_reverse', 'mc_tr20_ts40'): 39,
    ('s2s_reverse', 'mc2x_tr20_ts40'): 39,
    ('s2s_reverse', 'mcrv_tr20_ts40'): 39,
    ('pcfg', 'md_productivity'): 32,
    ('scan', 'len_tr25_ts48'): 32,
}

def is_run_complete(run) -> bool:
    if run.id in manually_checked_to_be_complete:
        return True

    # Assert training is done
    try:
        ds = run.config["dataset"]["name"]
        split = run.config["dataset"]["split"]
    except Exception as e:
        print("Could not get max_steps")
        return False

    # Check if the best checkpoint was loaded:
    analyze_all_test_ckpt_path = run.summary["analyze_all_test_ckpt_path"]
    if analyze_all_test_ckpt_path is None:
        print("No analyze_all_test_ckpt_path")
        return False

    if "experiments/" not in analyze_all_test_ckpt_path.lower():
        print("analyze_all_test_ckpt_path is none")
        return False

    files = list(run.files())
    metadata_files = [
        f for f in files if f.name.startswith("attn_metadata_") and f.name.endswith(".json")
    ]
    # print(len(metadata_files))
    if len(metadata_files) != dataset_to_num_metadata_files[(ds, split)]:
        print(f"Not enough metadata files: {len(metadata_files)}")
        return False

    return True

df_data = []

scratchpad_config_pattern = re.compile(r"(.)*_scratchpad(.)+_ufs__(i._c._o._v._r.)_.*___.*")
dataset_name_pattern = re.compile(f".*___data-(.+)-(.+)")

for run in tqdm(runs):
    group = run.group

    scratchpad_config = "no_scratchpad"
    result = scratchpad_config_pattern.search(group)
    if result:
        scratchpad_config = result.group(3)

    result = dataset_name_pattern.search(group)
    dataset_name = result.group(1)
    dataset_split = result.group(2)
    dataset_comb = f"{dataset_name}--{dataset_split}"

    model_name = get_model_name(group)

    if run.job_type == "agent":
        is_complete = True
    elif run.job_type != "attn_analysis2":
        continue
    else:
        is_complete = is_run_complete(run)

    df_data.append({
        "run_group": group,
        "job_type": run.job_type,
        "launcher_id": get_launcher(run),
        "ds_name": dataset_name,
        "ds_split": dataset_split,
        "ds": dataset_comb,
        "scratchpad_config": scratchpad_config,
        "model": model_name,
        "is_complete": is_complete,
        "state": run.state,
        "id": run.id,
        "gr_url": f"https://wandb.ai/kzmnjd/len_gen/groups/{group}",
        "run_url": run.url,
        "created_at": run.created_at,
        "host": run.host,
    })

df = pd.DataFrame.from_records(df_data)
# df

 86%|████████▌ | 334/388 [01:43<01:12,  1.34s/it]

analyze_all_test_ckpt_path is none
analyze_all_test_ckpt_path is none


 90%|████████▉ | 348/388 [01:58<00:52,  1.31s/it]

analyze_all_test_ckpt_path is none


 91%|█████████ | 352/388 [02:03<00:43,  1.22s/it]

analyze_all_test_ckpt_path is none
analyze_all_test_ckpt_path is none


 91%|█████████▏| 355/388 [02:04<00:27,  1.19it/s]

analyze_all_test_ckpt_path is none


 94%|█████████▍| 364/388 [02:15<00:29,  1.22s/it]

analyze_all_test_ckpt_path is none


 94%|█████████▍| 366/388 [02:16<00:21,  1.02it/s]

analyze_all_test_ckpt_path is none


 96%|█████████▌| 371/388 [02:22<00:19,  1.15s/it]

analyze_all_test_ckpt_path is none
analyze_all_test_ckpt_path is none
analyze_all_test_ckpt_path is none
analyze_all_test_ckpt_path is none


100%|██████████| 388/388 [02:40<00:00,  2.42it/s]


In [17]:
def get_compute_cluster(host: str):
    if "cedar" in host:
        return "cc_cedar"
    elif "narval" in host:
        return "cc_narval"
    elif host.startswith("cn-"):
        return "mila"
    else:
        return host

def get_grouped_df(gdf):
    seed_runs = gdf[gdf.job_type == "attn_analysis2"]
    completed = len(seed_runs[seed_runs.is_complete == True])
    is_running = "running" in gdf.state.unique().tolist()
    group_url = gdf.gr_url.tolist()[0]
    host = gdf.host.tolist()[0]
    launcher_ids = gdf[["launcher_id", "created_at"]]
    launcher_ids = [tuple(x) for x in launcher_ids.values]
    launcher_ids.sort(key=lambda x: datetime.datetime.fromisoformat(x[1]), reverse=True)
    launcher_id = launcher_ids[0][0]

    return pd.DataFrame.from_records([{
        "num_completed": completed,
        "num_seed_runs": len(seed_runs),
        "is_running": is_running,
        "is_done": completed >= 3,
        "launcher_id": launcher_id,
        "group_url": group_url,
        "cluster": get_compute_cluster(host),
    }])


In [18]:
xdf = df.groupby(["ds", "model", "scratchpad_config"]).apply(get_grouped_df)
xdf

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,num_completed,num_seed_runs,is_running,is_done,launcher_id,group_url,cluster
ds,model,scratchpad_config,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
pcfg--md_productivity,pe_abs_sin,no_scratchpad,0,3,3,False,True,tmow5sw8,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval
pcfg--md_productivity,pe_alibi,no_scratchpad,0,3,3,False,True,hgdqbeme,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval
pcfg--md_productivity,pe_none,no_scratchpad,0,3,3,False,True,10fmychn,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval
pcfg--md_productivity,pe_rotary,no_scratchpad,0,3,3,False,True,20va5o6z,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval
pcfg--md_productivity,pe_t5,no_scratchpad,0,3,3,False,True,3q091sft,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval
...,...,...,...,...,...,...,...,...,...,...
scan--len_tr25_ts48,pe_abs_sin,no_scratchpad,0,3,3,False,True,2h78fggj,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_cedar
scan--len_tr25_ts48,pe_alibi,no_scratchpad,0,3,3,False,True,2f1ugo2p,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_cedar
scan--len_tr25_ts48,pe_none,no_scratchpad,0,3,3,False,True,1c6wqxht,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_cedar
scan--len_tr25_ts48,pe_rotary,no_scratchpad,0,3,3,False,True,3ubk5nqu,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_cedar


In [19]:
xdf[(xdf["is_done"] == False) & (xdf["is_running"] == False)]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,num_completed,num_seed_runs,is_running,is_done,launcher_id,group_url,cluster
ds,model,scratchpad_config,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
s2s_addition--len_tr8_ts16,pe_abs_sin,i0_c1_o1_v1_r1,0,0,3,False,False,dww6ak56,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval
s2s_addition--len_tr8_ts16,pe_abs_sin,i1_c0_o1_v1_r1,0,0,3,False,False,2zaci7fu,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval
s2s_addition--len_tr8_ts16,pe_abs_sin,i1_c1_o1_v0_r1,0,0,3,False,False,18hplwyn,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval
s2s_addition--len_tr8_ts16,pe_abs_sin,i1_c1_o1_v1_r0,0,0,3,False,False,x5ui27ay,https://wandb.ai/kzmnjd/len_gen/groups/SW-t5_d...,cc_narval


In [36]:
check_runs_generated_by_launcher_id("2nxodkyt")

sw_7ab136dc4d07a4313a1f8bf646f4808c_2_256788 crashed attn_analysis2 SW-t5_dec_base_pe_t5_w_scratchpad_f_ufs__i1_c1_o1_v0_r1_s2s_addition_sweep___data-s2s_addition-len_tr8_ts16 ['attention_analysis', 'launched_by_2nxodkyt', 'manual_sweep', 'scratchpad', 'scratchpad_f', 'scratchpad_s2s_addition', 'sweep']
Not enough metadata files: 119
False



In [None]:
xxdf = xdf.reset_index()
xxdf[(xxdf["is_done"] == False) & (xxdf["ds"] == "s2s_addition--len_tr8_ts16")]

In [None]:
xdf[(xdf["is_done"] == False) & (xdf["is_running"] == False) & (xdf["ds"] == "s2s_addition--len_tr8_ts16")]

In [None]:
xdf.columns

## Fix Runs without summary

In [10]:
# completed_runs = set(df[(df["is_complete"] == True) & (df["job_type"] == "best_run_seed_exp")]["id"].tolist())

completed_runs = set(df[(df["job_type"] == "best_run_seed_exp")]["id"].tolist())


In [11]:
c = 0
for run in tqdm(runs):
    # if run.id not in completed_runs:
    #     continue
    if run.job_type == "agent":
        continue

    # keys_not_present = []
    # for k in target_summary_keys:
    #     if k not in run.summary:
    #         keys_not_present.append(k)
    if "pred/test_acc_overall" in run.summary:
        continue

    print(run.group)
    print(run.id)
    print(run.url)
    print("------------------")

    c += 1
    # continue

    h_all = list(run.scan_history())

    s = {}
    for h in h_all:
        for k,v in h.items():
            if k.startswith("pred/"):
                s[k] = v

    run.summary.update(s)
    run.save()

c

  0%|          | 0/684 [00:00<?, ?it/s]

SW-t5_dec_base_pe_t5_w_scratchpad_f_ufs__i1_c1_o1_v1_r1_s2s_poly_sweep___data-s2s_poly-n_terms_tr8_ts16
sw_a62383120826ac1dfd973fd78242ae47__234054
https://wandb.ai/kzmnjd/len_gen/runs/sw_a62383120826ac1dfd973fd78242ae47__234054
------------------


  2%|▏         | 11/684 [00:01<01:13,  9.18it/s]

SW-t5_dec_base_pe_rotary_w_scratchpad_f_ufs__i1_c1_o0_v0_r1_s2s_sort_sweep___data-s2s_sort-len_mltd_tr8_ts16
sw_a79eb30b88160539e41e9c6fcc3edd63__146317
https://wandb.ai/kzmnjd/len_gen/runs/sw_a79eb30b88160539e41e9c6fcc3edd63__146317
------------------


  5%|▌         | 36/684 [00:03<01:01, 10.48it/s]

SW-t5_dec_base_pe_rotary_w_scratchpad_f_ufs__i0_c1_o1_v0_r1_s2s_sort_sweep___data-s2s_sort-len_mltd_tr8_ts16
sw_a36b16200f2194a71894e7f64f30cf65__146317
https://wandb.ai/kzmnjd/len_gen/runs/sw_a36b16200f2194a71894e7f64f30cf65__146317
------------------


  6%|▌         | 41/684 [00:05<01:51,  5.74it/s]

SW-t5_dec_base_pe_rotary_w_scratchpad_f_ufs__i1_c1_o1_v0_r0_s2s_sort_sweep___data-s2s_sort-len_mltd_tr8_ts16
sw_eeeb4be8a6b915880903e86e11cc8846__146317
https://wandb.ai/kzmnjd/len_gen/runs/sw_eeeb4be8a6b915880903e86e11cc8846__146317
------------------


100%|██████████| 684/684 [00:06<00:00, 99.22it/s]


4