In [1]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("..")
import logging

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap, LogNorm, Normalize

from xlstm_scaling_laws.analysis.isoflop.data import create_filtered_isoflop_data_table
from xlstm_scaling_laws.analysis.parametric_sclaw_fit.data import (
    get_all_parametric_sclaw_fit_data_dataframe,
)
from xlstm_scaling_laws.analysis.parametric_sclaw_fit.plot.plot_model_training_data import (
    create_run_data_scatter_plot,
    get_combined_run_data_scatter_plot,
)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.ERROR,
    force=True,
)

In [2]:
mlstm_df = create_filtered_isoflop_data_table(data_specifier="isoflop_mlstm_ctx8192")
mlstm_df.sort_values(by="run_id")[["name", "run_id", "run_tag"]].style

Unnamed: 0,name,run_id,run_tag
147,dclm_mLSTMv1_700M_ctx8192_lr0.002_steps22000_nb27_ed1408_nh11_pf2.667_gbs128,02y5ghni,"nb27_ed1408_nh11_pf2.667,sclaw_iso"
15,dclm_mLSTMv1_160M_ctx8192_lr0.003_steps126500_nb12_ed768_nh6_pf2.667_gbs128,0cf4r4r9,"nb12_ed768_nh6_pf2.667,sclaw_iso"
115,dclm_mLSTMv1_200M_ctx8192_lr0.003_steps6000_nb12_ed896_nh7_pf2.667_gbs128,0rb1tbu7,"nb12_ed896_nh7_pf2.667,sclaw_iso"
43,dclm_mLSTMv1_600M_ctx8192_lr0.002_steps24000_nb30_ed1280_nh5_pf2.667_gbs128,1gcpfr5u,"nb30_ed1280_nh5_pf2.667,sclaw_iso"
39,dclm_mLSTMv1_600M_ctx8192_lr0.002_steps8000_nb27_ed1280_nh5_pf2.667_gbs128,1m8vuw6n,"nb27_ed1280_nh5_pf2.667,sclaw_iso"
14,dclm_mLSTMv1_160M_ctx8192_lr0.003_steps38000_nb12_ed768_nh6_pf2.667_gbs128,1md27kz0,"nb12_ed768_nh6_pf2.667,sclaw_iso"
38,dclm_mLSTMv1_600M_ctx8192_lr0.002_steps2500_nb27_ed1280_nh5_pf2.667_gbs128,1nlesen4,"nb27_ed1280_nh5_pf2.667,sclaw_iso"
137,dclm_mLSTMv1_500M_ctx8192_lr0.002_steps35500_nb24_ed1152_nh9_pf2.667_gbs128,1w6f9yi7,"nb24_ed1152_nh9_pf2.667,sclaw_iso"
159,dclm_mLSTMv1_400M_ctx8192_lr0.003_steps5500_nb18_ed1024_nh4_pf2.667_gbs128,1zh7d80c,"nb18_ed1024_nh4_pf2.667,sclaw_iso"
133,dclm_mLSTMv1_200M_ctx8192_lr0.003_steps17000_nb24_ed896_nh7_pf2.667_gbs128,1zvxf628,"nb24_ed896_nh7_pf2.667,sclaw_iso"


In [3]:
dedup_mlstm_df = mlstm_df.sort_values(by=["run_tag"])
dedup_mlstm_df = dedup_mlstm_df[
    ~dedup_mlstm_df.drop(columns=["run_tag"]).duplicated(keep="first")
]
dedup_mlstm_df.sort_values(by="run_id")[["name", "run_id", "run_tag"]].style

Unnamed: 0,name,run_id,run_tag
147,dclm_mLSTMv1_700M_ctx8192_lr0.002_steps22000_nb27_ed1408_nh11_pf2.667_gbs128,02y5ghni,"nb27_ed1408_nh11_pf2.667,sclaw_iso"
15,dclm_mLSTMv1_160M_ctx8192_lr0.003_steps126500_nb12_ed768_nh6_pf2.667_gbs128,0cf4r4r9,"nb12_ed768_nh6_pf2.667,sclaw_iso"
115,dclm_mLSTMv1_200M_ctx8192_lr0.003_steps6000_nb12_ed896_nh7_pf2.667_gbs128,0rb1tbu7,"nb12_ed896_nh7_pf2.667,sclaw_iso"
43,dclm_mLSTMv1_600M_ctx8192_lr0.002_steps24000_nb30_ed1280_nh5_pf2.667_gbs128,1gcpfr5u,"nb30_ed1280_nh5_pf2.667,sclaw_iso"
39,dclm_mLSTMv1_600M_ctx8192_lr0.002_steps8000_nb27_ed1280_nh5_pf2.667_gbs128,1m8vuw6n,"nb27_ed1280_nh5_pf2.667,sclaw_iso"
14,dclm_mLSTMv1_160M_ctx8192_lr0.003_steps38000_nb12_ed768_nh6_pf2.667_gbs128,1md27kz0,"nb12_ed768_nh6_pf2.667,sclaw_iso"
38,dclm_mLSTMv1_600M_ctx8192_lr0.002_steps2500_nb27_ed1280_nh5_pf2.667_gbs128,1nlesen4,"nb27_ed1280_nh5_pf2.667,sclaw_iso"
137,dclm_mLSTMv1_500M_ctx8192_lr0.002_steps35500_nb24_ed1152_nh9_pf2.667_gbs128,1w6f9yi7,"nb24_ed1152_nh9_pf2.667,sclaw_iso"
159,dclm_mLSTMv1_400M_ctx8192_lr0.003_steps5500_nb18_ed1024_nh4_pf2.667_gbs128,1zh7d80c,"nb18_ed1024_nh4_pf2.667,sclaw_iso"
133,dclm_mLSTMv1_200M_ctx8192_lr0.003_steps17000_nb24_ed896_nh7_pf2.667_gbs128,1zvxf628,"nb24_ed896_nh7_pf2.667,sclaw_iso"


In [4]:
len(mlstm_df), len(dedup_mlstm_df)

(145, 145)

In [5]:
for model_type in ["all", "mlstm", "llama"]:
    df = get_all_parametric_sclaw_fit_data_dataframe(model_type=model_type)
    print(f"Model type: {model_type}, rows: {len(df)}")
    total = 0
    for exptag in df["experiment_set_ctx_length"].unique():
        count = len(df[df["experiment_set_ctx_length"] == exptag])
        total += count
        print(
            f"{exptag}: {len(df[df['experiment_set_ctx_length'] == exptag]['run_id'].unique())} unique names, total: {count}"
        )
    print(f"Total rows in df: {total}")

Model type: all, rows: 640
isoflop_ctx2048: 171 unique names, total: 171
isoflop_ctx8192: 252 unique names, total: 252
isoflop_ctx16384: 151 unique names, total: 151
tokenparam_ctx8192: 66 unique names, total: 66
Total rows in df: 640
Model type: mlstm, rows: 348
isoflop_ctx2048: 87 unique names, total: 87
isoflop_ctx8192: 145 unique names, total: 145
isoflop_ctx16384: 81 unique names, total: 81
tokenparam_ctx8192: 35 unique names, total: 35
Total rows in df: 348
Model type: llama, rows: 292
isoflop_ctx2048: 84 unique names, total: 84
isoflop_ctx8192: 107 unique names, total: 107
isoflop_ctx16384: 70 unique names, total: 70
tokenparam_ctx8192: 31 unique names, total: 31
Total rows in df: 292
