In [1]:
import pandas as pd 
import wandb
from multiprocessing import Pool
from os import cpu_count
from tqdm.auto import tqdm
import time
from pathlib import Path


In [2]:
path = Path("./raw_data/")  

In [68]:
meta = []
for p in tqdm(list(path.iterdir())):
    if "_" not in p.name:
        continue
    meta.append(
        {
            "id": p.name.split("_")[1].removesuffix(".tsv"),
            "min_steps": pd.read_csv(p, nrows=5, sep="\t")["step"].min(),
            "filename": str(p),
        }
    )



  0%|          | 0/140 [00:00<?, ?it/s]

In [69]:
runs_df = pd.read_csv(path / "runs.tsv", sep="\t")

In [70]:
runs_df = pd.merge(runs_df, pd.DataFrame(meta), on="id", how="left")

In [71]:
order = ["group", "name", "id", "model_size", "seed", "min_steps", "max_steps"]
runs_df = (
    runs_df[order]
    .sort_values(["model_size", "seed", "min_steps", "max_steps"])
    .assign(total_steps=lambda _df: _df["max_steps"] - _df["min_steps"] + 1)
    .assign(selected=lambda _df: _df["total_steps"] == 143_000)
)

In [72]:
runs_df.to_csv("./data/runs.tsv", index=False, sep="\t")

In [73]:
check_df = (
    runs_df
    .groupby(["model_size", "seed"])["selected"].unique().map(max)
    .reset_index()
    .query("selected != True")
)
check_df

Unnamed: 0,model_size,seed,selected
9,160m,4,False
14,160m,9,False
15,410m,1,False
16,410m,2,False
17,410m,3,False
18,410m,4,False
19,410m,5,False
20,410m,6,False
21,410m,7,False
22,410m,8,False


In [77]:
r = pd.merge(runs_df, check_df.drop(columns=["selected"]), on=["model_size", "seed"], how="inner")

In [89]:
r["interval"] = r["min_steps"].astype(str) + "-" + r["max_steps"].astype(str)

In [97]:
pd.set_option('display.max_colwidth', 120)

In [98]:
r.groupby(["model_size", "seed"]).agg(
    total_steps=("total_steps", "sum"),
    intervals=("interval", "unique"),
)

Unnamed: 0_level_0,Unnamed: 1_level_0,total_steps,intervals
model_size,seed,Unnamed: 2_level_1,Unnamed: 3_level_1
160m,4,143335,"[1-61185, 61001-61150, 61001-143000]"
160m,9,144278,"[1-22623, 22001-22008, 22001-31647, 31001-143000]"
410m,1,143000,"[1-10000, 10001-29000, 29001-30000, 30001-32000, 32001-41000, 41001-105000, 105001-116000, 116001-143000]"
410m,2,143000,"[1-3000, 3001-4000, 4001-29000, 29001-50000, 50001-88000, 88001-143000]"
410m,3,143000,"[1-7000, 7001-15000, 15001-16000, 16001-17000, 17001-34000, 34001-43000, 43001-60000, 60001-73000, 73001-75000, 7500..."
410m,4,143890,"[1-7000, 7001-55841, 55001-104000, 104001-121000, 121001-122000, 122001-122049, 122001-143000]"
410m,5,147075,"[1-81715, 81001-83463, 83001-95416, 95001-95427, 95001-124741, 124001-127586, 127001-137727, 137001-143000]"
410m,6,145204,"[1-61003, 61001-63474, 63001-75835, 75001-104892, 104001-143000]"
410m,7,77662,"[1-399, 1-18000, 257-381, 257-44907, 44001-53341, 53001-53038, 53001-58108]"
410m,8,145667,"[1-666, 513-1666, 1001-2252, 2001-23026, 23001-29392, 29001-29376, 29001-39344, 39001-74999, 75001-75078, 75001-9038..."
