In [14]:
import pandas as pd 
from tqdm.auto import tqdm
from pathlib import Path
import shutil

pd.set_option('display.max_colwidth', 120)

In [2]:
path = Path("./raw_data")
out_path = Path("./data")
out_path.mkdir(exist_ok=True, parents=True)

In [8]:
meta = []
for p in tqdm(list(path.iterdir())):
    if "_" not in p.name:
        continue
    meta.append(
        {
            "id": p.name.split("_")[1].removesuffix(".tsv"),
            "min_steps": pd.read_csv(p, nrows=5, sep="\t")["step"].min(),
            "filename": str(p),
        }
    )

runs_df = pd.read_csv(path / "runs.tsv", sep="\t")
runs_df = pd.merge(runs_df, pd.DataFrame(meta), on="id", how="left")

order = ["filename", "group", "name", "id", "model_size", "seed", "min_steps", "max_steps"]
runs_df = (
    runs_df[order]
    .sort_values(["model_size", "seed", "min_steps", "max_steps"])
    .assign(total_steps=lambda _df: _df["max_steps"] - _df["min_steps"] + 1)
    .assign(selected=lambda _df: _df["total_steps"] == 143_000)
)
# runs_df.to_csv("./data/runs.tsv", index=False, sep="\t")

  0%|          | 0/143 [00:00<?, ?it/s]

## Good runs

Complete runs are directly moved to the `./data` folder

In [19]:
runs_df

Unnamed: 0,filename,group,name,id,model_size,seed,min_steps,max_steps,total_steps,selected
0,raw_data/14m-seed1_ggjp6fm7.tsv,pythia-14m_3kv6crxx,ip-26-0-149-246-0,ggjp6fm7,14m,1,1,63,63,False
1,raw_data/14m-seed1_jtsljaoj.tsv,pythia-14m_wi22ql21,ip-26-0-144-150-0,jtsljaoj,14m,1,1,143000,143000,True
2,raw_data/14m-seed2_q72jl2x5.tsv,pythia-14m_nwndy40r,ip-26-0-144-150-0,q72jl2x5,14m,2,1,143000,143000,True
3,raw_data/14m-seed3_khu6r64g.tsv,pythia-14m_rgh72080,ip-26-0-156-120-0,khu6r64g,14m,3,1,143000,143000,True
4,raw_data/14m-seed4_nb29awnu.tsv,pythia-14m_adskb1rh,ip-26-0-156-120-0,nb29awnu,14m,4,1,143000,143000,True
...,...,...,...,...,...,...,...,...,...,...
136,raw_data/70m-seed8_3qmbwmti.tsv,pythia-70m_f5agu3sy,ip-10-0-228-209-0,3qmbwmti,70m,8,43001,43300,300,False
132,raw_data/70m-seed8_9ugck98l.tsv,pythia-70m_nst0l0ku,ip-10-0-201-106-0,9ugck98l,70m,8,43001,143000,100000,False
141,raw_data/70m-seed9_qtv9ujq1.tsv,pythia-70m_qeyx8no0,ip-10-0-228-32-0,qtv9ujq1,70m,9,1,137755,137755,False
139,raw_data/70m-seed9_l2ui5xws.tsv,pythia-70m_cat910n6,ip-10-0-231-1-0,l2ui5xws,70m,9,1,143000,143000,True


In [28]:
names = []
for p in runs_df.query("selected == True")["filename"].tolist():
    p = Path(p)
    new_name = f"{p.name.split('_')[0]}.tsv"
    names.append(new_name)
    
    
    shutil.copy(p, out_path / new_name)

runs_df.query("selected == True").assign(name=names).to_csv(out_path / "metadata.tsv", index=False, sep="\t")

## Manual checks

In [29]:
to_check_df = (
    runs_df
    .groupby(["model_size", "seed"])["selected"].unique().map(max)
    .reset_index()
    .query("selected != True")
)
to_check_df

Unnamed: 0,model_size,seed,selected
9,160m,4,False
14,160m,9,False
15,410m,1,False
16,410m,2,False
17,410m,3,False
18,410m,4,False
19,410m,5,False
20,410m,6,False
21,410m,7,False
22,410m,8,False


In [30]:
check_df = (
    pd.merge(runs_df, to_check_df.drop(columns=["selected"]), on=["model_size", "seed"], how="inner")
    .assign(interval=lambda _df: _df["min_steps"].astype(str) + "-" + _df["max_steps"].astype(str))
)

In [31]:
check_df.groupby(["model_size", "seed"]).agg(
    total_steps=("total_steps", "sum"),
    intervals=("interval", "unique"),
)

Unnamed: 0_level_0,Unnamed: 1_level_0,total_steps,intervals
model_size,seed,Unnamed: 2_level_1,Unnamed: 3_level_1
160m,4,143335,"[1-61185, 61001-61150, 61001-143000]"
160m,9,144278,"[1-22623, 22001-22008, 22001-31647, 31001-143000]"
410m,1,143000,"[1-10000, 10001-29000, 29001-30000, 30001-32000, 32001-41000, 41001-105000, 105001-116000, 116001-143000]"
410m,2,143000,"[1-3000, 3001-4000, 4001-29000, 29001-50000, 50001-88000, 88001-143000]"
410m,3,143000,"[1-7000, 7001-15000, 15001-16000, 16001-17000, 17001-34000, 34001-43000, 43001-60000, 60001-73000, 73001-75000, 7500..."
410m,4,143890,"[1-7000, 7001-55841, 55001-104000, 104001-121000, 121001-122000, 122001-122049, 122001-143000]"
410m,5,147075,"[1-81715, 81001-83463, 83001-95416, 95001-95427, 95001-124741, 124001-127586, 127001-137727, 137001-143000]"
410m,6,145204,"[1-61003, 61001-63474, 63001-75835, 75001-104892, 104001-143000]"
410m,7,162662,"[1-399, 1-18000, 257-381, 257-44907, 44001-53341, 53001-53038, 53001-58108, 58001-90000, 90001-143000]"
410m,8,145667,"[1-666, 513-1666, 1001-2252, 2001-23026, 23001-29392, 29001-29376, 29001-39344, 39001-74999, 75001-75078, 75001-9038..."


### 160m

For both seed 4 and 9 we simply concatenate and deduplicate all the data

In [91]:
meta = [pd.read_csv(out_path / "metadata.tsv", sep="\t")]

for seed in (4, 9):
    meta_df = check_df.query(f"(seed == {seed}) & (model_size == '160m')")
    meta.append(meta_df)

    data = pd.concat([pd.read_csv(p, sep="\t") for p in meta_df["filename"]], axis=0, ignore_index=False).assign(filename=p)
    # data = data.drop_duplicates()
    data.to_csv(out_path / f"160m-seed{seed}.tsv", index=False, sep="\t")
    break

pd.concat(meta, axis=0, ignore_index=False).drop_duplicates().to_csv(out_path / "metadata.tsv", index=False, sep="\t")


In [98]:
d = data.groupby("step")["train/lm_loss"].unique()

In [101]:
d[d.map(len) > 1].head(10)

step
61002    [2.453532695770264, 2.4644408226013184]
61003     [2.520005702972412, 2.544607639312744]
61004    [2.4462103843688965, 2.470116138458252]
61005     [2.474903345108032, 2.504105567932129]
61006    [2.506353616714477, 2.5264477729797363]
61007    [2.488455295562744, 2.5084657669067383]
61008    [2.4787654876708984, 2.495170831680298]
61009     [2.501901149749756, 2.515127897262573]
61010    [2.4980216026306152, 2.508857727050781]
61011    [2.4755306243896484, 2.484673023223877]
Name: train/lm_loss, dtype: object

In [77]:
d = data.query("(step >= 61002) & (step <= 61186)").reset_index(drop=True)

In [69]:
d.groupby("train/lm_loss")["filename"].nunique().value_counts()

filename
1    370
Name: count, dtype: int64

In [49]:
data

Unnamed: 0,train/lm_loss,step
0,10.973518,1
1,10.974413,2
2,10.973560,3
3,10.952214,4
4,10.899228,5
...,...,...
111995,2.475273,142996
111996,2.492165,142997
111997,2.454960,142998
111998,2.468572,142999
