In [1]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("..")
import logging

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap, LogNorm, Normalize

from xlstm_scaling_laws.analysis.parametric_sclaw_fit.data import (
    get_all_parametric_sclaw_fit_data_dataframe,
)
from xlstm_scaling_laws.analysis.parametric_sclaw_fit.plot.plot_model_training_data import (
    create_run_data_scatter_plot,
    get_combined_run_data_scatter_plot,
)
from xlstm_scaling_laws.load_data.token_param_ratio import (
    create_token_param_ratio_data_table,
)

logging.basicConfig(
    level=logging.ERROR,
    format="%(levelname)s: %(message)s",
    force=True,
)


def add_row_colors(latex_str):
    lines = latex_str.split("\n")
    new_lines = []
    in_tabular = False
    row_count = 0

    for line in lines:
        if "\\begin{tabular}" in line:
            in_tabular = True
            new_lines.append(line)
        elif "\\end{tabular}" in line:
            in_tabular = False
            new_lines.append(line)
        elif in_tabular and "\\\\" in line and not line.strip().startswith("\\"):
            if row_count % 2 == 1:
                new_lines.append("\\rowcolor{gray!10}" + line)
            else:
                new_lines.append(line)
            row_count += 1
        else:
            new_lines.append(line)

    return "\n".join(new_lines)


def add_adjustbox_scaling(latex_str, height_scale=0.9):
    """Add adjustbox scaling to a LaTeX table"""
    lines = latex_str.split("\n")
    new_lines = []

    for i, line in enumerate(lines):
        if "\\begin{tabular}" in line:
            new_lines.append(
                f"\\begin{{adjustbox}}{{max height={height_scale}\\textheight,center}}"
            )
            new_lines.append(line)
        elif "\\end{tabular}" in line:
            new_lines.append(line)
            new_lines.append("\\end{adjustbox}")
        else:
            new_lines.append(line)

    return "\n".join(new_lines)

# Create Run Dataset Model configuration Tables

We want to have the following columns in the table:

- Parameters (million)
- Architecture hyperparams
    - embedding dim
    - v_head dim
    - qk_head dim (only for xLSTM)
    - n heads
    - ffw dim
    - num blocks
- Optim parameters
    - ctx length
    - global batch size
    - learning rate

In [2]:
df = get_all_parametric_sclaw_fit_data_dataframe(model_type="all")
df[(df["experiment_set"] == "tokenparam") & (df["model_type"] == "llama")]

Unnamed: 0,experiment_set_ctx_length,name,run_tag,model_type,num_params,num_tokens_training,num_flops_training,val/.dclm_loss,token_param_ratio,width_depth_ratio,...,num_heads,proj_factor_ffn,ffn_multiple_of,ffn_dim,head_dim_qk,head_dim_v,IsoFLOP,train/.loss_mean,run_id,model_checkpoint_paths
609,tokenparam_ctx8192,dclm_llama_160M_ctx8192_lr0.003_steps3500_gbs128,scl_llama_160M,llama,162220800.0,3670016000.0,4.416455e+18,3.298485,22.623585,64.0,...,12,2.667,64,2048,,64,,3.30066,o2y7xnfn,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
610,tokenparam_ctx8192,dclm_llama_160M_ctx8192_lr0.003_steps5000_gbs128,scl_llama_160M,llama,162220800.0,5242880000.0,6.309221e+18,3.220001,32.319407,64.0,...,12,2.667,64,2048,,64,,3.217189,8b27x4t0,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
611,tokenparam_ctx8192,dclm_llama_160M_ctx8192_lr0.003_steps7000_gbs128,scl_llama_160M,llama,162220800.0,7340032000.0,8.832909e+18,3.162786,45.247169,64.0,...,12,2.667,64,2048,,64,,3.161473,4ce0r9qc,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
612,tokenparam_ctx8192,dclm_llama_160M_ctx8192_lr0.003_steps8000_gbs128,scl_llama_160M,llama,162220800.0,8388608000.0,1.009475e+19,3.143749,51.711051,64.0,...,12,2.667,64,2048,,64,,3.142752,x987at37,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
613,tokenparam_ctx8192,dclm_llama_160M_ctx8192_lr0.003_steps18000_gbs128,scl_llama_160M,llama,162220800.0,18874370000.0,2.27132e+19,3.050733,116.349864,64.0,...,12,2.667,64,2048,,64,,3.051504,1mmzna50,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
614,tokenparam_ctx8192,dclm_llama_160M_ctx8192_lr0.003_steps36000_gbs128,scl_llama_160M,llama,162220800.0,37748740000.0,4.542639e+19,2.995063,232.699728,64.0,...,12,2.667,64,2048,,64,,2.995211,p2nsobw9,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
615,tokenparam_ctx8192,dclm_llama_160M_ctx8192_lr0.003_steps87000_gbs128,scl_llama_160M,llama,162220800.0,91226110000.0,1.097804e+20,2.946427,562.357675,64.0,...,12,2.667,64,2048,,64,,2.946148,prhdeg55,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
616,tokenparam_ctx8192,dclm_llama_160M_ctx8192_lr0.001_steps173000_gb...,scl_llama_160M,llama,162220800.0,181403600000.0,2.18299e+20,2.933174,1118.25147,64.0,...,12,2.667,64,2048,,64,,2.934087,6i73refq,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
617,tokenparam_ctx8192,dclm_llama_400M_ctx8192_lr0.003_steps10000_gbs128,scl_llama_400M,llama,406635500.0,10485760000.0,3.525725e+19,2.96188,25.786631,42.666667,...,16,2.667,64,2752,,64,,2.966652,22m4wtpw,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."
618,tokenparam_ctx8192,dclm_llama_400M_ctx8192_lr0.003_steps18000_gbs128,scl_llama_400M,llama,406635500.0,18874370000.0,6.346304e+19,2.852509,46.415935,42.666667,...,16,2.667,64,2752,,64,,2.85314,299iz8at,"[""/nfs-gpu/xlstm/outputs_beck/sclaw/dclm_llama..."


In [3]:
# Get min FLOPs, max FLOPs, min train tokens, max train tokens, min params, max params
df = get_all_parametric_sclaw_fit_data_dataframe(model_type="all")
print(df["num_flops_training"].sum())
df["num_flops_training"].describe()

3.1347372475364666e+23


count    6.400000e+02
mean     4.898027e+20
std      3.935773e+21
min      2.809833e+18
25%      9.981188e+18
50%      2.992133e+19
75%      3.064443e+19
max      8.480968e+22
Name: num_flops_training, dtype: float64

In [4]:
df["num_tokens_training"].describe()

count    6.400000e+02
mean     3.852149e+10
std      1.374194e+11
min      1.887437e+09
25%      4.404019e+09
50%      8.808038e+09
75%      2.013266e+10
max      2.097152e+12
Name: num_tokens_training, dtype: float64

In [5]:
df["num_params"].describe()

count    6.400000e+02
mean     7.374013e+08
std      1.050585e+09
min      8.363469e+07
25%      2.046973e+08
50%      4.066355e+08
75%      8.340864e+08
max      6.867523e+09
Name: num_params, dtype: float64

## Model Configuration Tables

In [6]:
def get_experiment_set_df(exp_set: str | list[str], model_type: str) -> pd.DataFrame:
    mlstm_df = get_all_parametric_sclaw_fit_data_dataframe(model_type=model_type)
    if model_type == "mlstm":
        sel_cols = [
            "num_params",
            "embedding_dim",
            "ffn_dim",
            "head_dim_qk",
            "head_dim_v",
            "num_heads",
            "num_blocks",
            # "context_length",
            # "global_batch_size",
            # "learning_rate",
        ]
    elif model_type == "llama":
        sel_cols = [
            "num_params",
            "embedding_dim",
            "ffn_dim",
            "head_dim_v",
            "num_heads",
            "num_blocks",
            # "context_length",
            # "global_batch_size",
            # "learning_rate",
        ]
    if "tokenparam" in exp_set:
        sel_cols += ["global_batch_size", "learning_rate"]
    if isinstance(exp_set, str):
        exp_set = [exp_set]
    exp_set_df = (
        mlstm_df[mlstm_df["experiment_set_ctx_length"].isin(exp_set)][sel_cols]
        .drop_duplicates()
        .sort_values(by=["num_params"])
    )
    if "head_dim_qk" in sel_cols:
        exp_set_df["head_dim_qk"] = exp_set_df["head_dim_qk"].astype(int)

    if "global_batch_size" in sel_cols:
        exp_set_df["global_batch_size"] = exp_set_df["global_batch_size"].astype(int)

    # convert num_params in millions
    exp_set_df["num_params"] = (exp_set_df["num_params"] / 1e6).astype(int)
    exp_set_df = exp_set_df.rename(columns={"num_params": "num_params (M)"})
    exp_set_df = exp_set_df.reset_index(drop=True)

    # add a \ before each _ in column names for latex
    # exp_set_df.columns = [col.replace("_", "\\_") for col in exp_set_df.columns]

    # prettify column names
    if model_type == "mlstm":
        col_name_map = {
            "num_params (M)": "\#Params (M)",
            "embedding_dim": r"$d_{\text{model}}$",
            "ffn_dim": r"$d_{\text{ff}}$",
            "head_dim_qk": r"$d_{\text{qk}}$",
            "head_dim_v": r"$d_{\text{hv}}$",
            "num_heads": r"$n_{\text{heads}}$",
            "num_blocks": r"$n_{\text{layer}}$",
            "context_length": r"$T$ (ctx)",
            "global_batch_size": r"$B$ (batch)",
            "learning_rate": "LR",
        }
    elif model_type == "llama":
        col_name_map = {
            "num_params (M)": "\#Params (M)",
            "embedding_dim": r"$d_{\text{model}}$",
            "ffn_dim": r"$d_{\text{ff}}$",
            "head_dim_v": r"$d_{\text{v}}$",
            "num_heads": r"$n_{\text{heads}}$",
            "num_blocks": r"$n_{\text{layer}}$",
            "context_length": r"$T$ (ctx)",
            "global_batch_size": r"$B$ (batch)",
            "learning_rate": "LR",
        }
    exp_set_df = exp_set_df.rename(columns=col_name_map)

    return exp_set_df

In [7]:
# mlstm token param table
df = get_experiment_set_df("tokenparam_ctx8192", "mlstm")
latex_table = df.to_latex(
    index=False,
    formatters={"LR": lambda x: f"{x:.0e}".replace("e-0", "e-").replace("e+0", "e+")},
    caption="List of hyperparameters for xLSTM models trained with the Token/Param configuration.",
    label="tab:tokenparam_hyperparams",
    longtable=False,
    column_format="r|" + "r" * (len(df.columns) - 1),
)
colored_latex_table = add_row_colors(latex_table)
print(colored_latex_table)

\begin{table}
\caption{List of hyperparameters for xLSTM models trained with the Token/Param configuration.}
\label{tab:tokenparam_hyperparams}
\begin{tabular}{r|rrrrrrrr}
\toprule
\#Params (M) & $d_{\text{model}}$ & $d_{\text{ff}}$ & $d_{\text{qk}}$ & $d_{\text{hv}}$ & $n_{\text{heads}}$ & $n_{\text{layer}}$ & $B$ (batch) & LR \\
\midrule
164 & 768 & 2112 & 64 & 128 & 6 & 12 & 128 & 3e-3 \\
\rowcolor{gray!10}406 & 1024 & 2752 & 128 & 256 & 4 & 24 & 128 & 3e-3 \\
406 & 1024 & 2752 & 128 & 256 & 4 & 24 & 128 & 1e-3 \\
\rowcolor{gray!10}841 & 1536 & 4160 & 192 & 384 & 4 & 24 & 256 & 1e-3 \\
841 & 1536 & 4160 & 192 & 384 & 4 & 24 & 256 & 8e-4 \\
\rowcolor{gray!10}1420 & 2048 & 5504 & 256 & 512 & 4 & 24 & 256 & 8e-4 \\
1420 & 2048 & 5504 & 256 & 512 & 4 & 24 & 256 & 7e-4 \\
\rowcolor{gray!10}2780 & 2560 & 6848 & 256 & 512 & 5 & 32 & 512 & 7e-4 \\
6865 & 4096 & 10944 & 256 & 512 & 8 & 32 & 512 & 5e-4 \\
\rowcolor{gray!10}6865 & 4096 & 10944 & 256 & 512 & 8 & 32 & 256 & 5e-4 \\
6865 & 4096 &

In [8]:
# llama token param table
df = get_experiment_set_df("tokenparam_ctx8192", "llama")
latex_table = df.to_latex(
    index=False,
    formatters={"LR": lambda x: f"{x:.0e}".replace("e-0", "e-").replace("e+0", "e+")},
    caption="List of hyperparameters for Transformer models trained with the Token/Param configuration.",
    label="tab:tokenparam_hyperparams",
    longtable=False,
    column_format="r|" + "r" * (len(df.columns) - 1),
)
colored_latex_table = add_row_colors(latex_table)
print(colored_latex_table)

\begin{table}
\caption{List of hyperparameters for Transformer models trained with the Token/Param configuration.}
\label{tab:tokenparam_hyperparams}
\begin{tabular}{r|rrrrrrr}
\toprule
\#Params (M) & $d_{\text{model}}$ & $d_{\text{ff}}$ & $d_{\text{v}}$ & $n_{\text{heads}}$ & $n_{\text{layer}}$ & $B$ (batch) & LR \\
\midrule
162 & 768 & 2048 & 64 & 12 & 12 & 128 & 3e-3 \\
\rowcolor{gray!10}162 & 768 & 2048 & 64 & 12 & 12 & 128 & 1e-3 \\
406 & 1024 & 2752 & 64 & 16 & 24 & 128 & 3e-3 \\
\rowcolor{gray!10}406 & 1024 & 2752 & 64 & 16 & 24 & 128 & 1e-3 \\
834 & 1536 & 4096 & 96 & 16 & 24 & 256 & 1e-3 \\
\rowcolor{gray!10}1420 & 2048 & 5504 & 128 & 16 & 24 & 256 & 8e-4 \\
2779 & 2560 & 6848 & 80 & 32 & 32 & 512 & 7e-4 \\
\rowcolor{gray!10}6863 & 4096 & 10944 & 128 & 32 & 32 & 256 & 5e-4 \\
6863 & 4096 & 10944 & 128 & 32 & 32 & 512 & 5e-4 \\
\bottomrule
\end{tabular}
\end{table}



In [9]:
# mlstm isoflop table
df = get_experiment_set_df(
    ["isoflop_ctx2048", "isoflop_ctx8192", "isoflop_ctx16384"], "mlstm"
)
latex_table = df.to_latex(
    index=False,
    formatters={"LR": lambda x: f"{x:.0e}".replace("e-0", "e-").replace("e+0", "e+")},
    caption="List of hyperparameters for xLSTM models trained with the IsoFLOP configuration.",
    label="tab:xlstm_isoflop_hyperparams",
    longtable=False,
    column_format="r|" + "r" * (len(df.columns) - 1),
)
colored_latex_table = add_row_colors(latex_table)
colored_latex_table = add_adjustbox_scaling(colored_latex_table, height_scale=0.5)
print(colored_latex_table)

\begin{table}
\caption{List of hyperparameters for xLSTM models trained with the IsoFLOP configuration.}
\label{tab:xlstm_isoflop_hyperparams}
\begin{adjustbox}{max height=0.5\textheight,center}
\begin{tabular}{r|rrrrrr}
\toprule
\#Params (M) & $d_{\text{model}}$ & $d_{\text{ff}}$ & $d_{\text{qk}}$ & $d_{\text{hv}}$ & $n_{\text{heads}}$ & $n_{\text{layer}}$ \\
\midrule
83 & 512 & 1408 & 64 & 128 & 4 & 10 \\
\rowcolor{gray!10}90 & 512 & 1408 & 64 & 128 & 4 & 12 \\
96 & 512 & 1408 & 64 & 128 & 4 & 14 \\
\rowcolor{gray!10}102 & 512 & 1408 & 64 & 128 & 4 & 16 \\
114 & 640 & 1728 & 64 & 128 & 5 & 10 \\
\rowcolor{gray!10}123 & 640 & 1728 & 64 & 128 & 5 & 12 \\
128 & 640 & 1728 & 64 & 128 & 5 & 13 \\
\rowcolor{gray!10}133 & 640 & 1728 & 64 & 128 & 5 & 14 \\
143 & 640 & 1728 & 64 & 128 & 5 & 16 \\
\rowcolor{gray!10}164 & 768 & 2112 & 64 & 128 & 6 & 12 \\
185 & 768 & 2112 & 64 & 128 & 6 & 15 \\
\rowcolor{gray!10}207 & 896 & 2432 & 64 & 128 & 7 & 12 \\
207 & 768 & 2112 & 64 & 128 & 6 & 18 \\
\ro

In [10]:
# llama isoflop table
df = get_experiment_set_df(
    ["isoflop_ctx2048", "isoflop_ctx8192", "isoflop_ctx16384"], "llama"
)
latex_table = df.to_latex(
    index=False,
    formatters={"LR": lambda x: f"{x:.0e}".replace("e-0", "e-").replace("e+0", "e+")},
    caption="List of hyperparameters for Transformer models trained with the IsoFLOP configuration.",
    label="tab:transformer_isoflop_hyperparams",
    longtable=False,
    column_format="r|" + "r" * (len(df.columns) - 1),
)
colored_latex_table = add_row_colors(latex_table)
colored_latex_table = add_adjustbox_scaling(colored_latex_table, height_scale=0.5)
print(colored_latex_table)

\begin{table}
\caption{List of hyperparameters for Transformer models trained with the IsoFLOP configuration.}
\label{tab:transformer_isoflop_hyperparams}
\begin{adjustbox}{max height=0.5\textheight,center}
\begin{tabular}{r|rrrrr}
\toprule
\#Params (M) & $d_{\text{model}}$ & $d_{\text{ff}}$ & $d_{\text{v}}$ & $n_{\text{heads}}$ & $n_{\text{layer}}$ \\
\midrule
83 & 512 & 1408 & 64 & 8 & 10 \\
\rowcolor{gray!10}90 & 512 & 1408 & 64 & 8 & 12 \\
96 & 512 & 1408 & 64 & 8 & 14 \\
\rowcolor{gray!10}102 & 512 & 1408 & 64 & 8 & 16 \\
113 & 640 & 1728 & 64 & 10 & 10 \\
\rowcolor{gray!10}128 & 640 & 1728 & 64 & 10 & 13 \\
133 & 640 & 1728 & 64 & 10 & 14 \\
\rowcolor{gray!10}143 & 640 & 1728 & 64 & 10 & 16 \\
162 & 768 & 2048 & 64 & 12 & 12 \\
\rowcolor{gray!10}183 & 768 & 2048 & 64 & 12 & 15 \\
204 & 768 & 2048 & 64 & 12 & 18 \\
\rowcolor{gray!10}207 & 896 & 2432 & 64 & 14 & 12 \\
236 & 896 & 2432 & 64 & 14 & 15 \\
\rowcolor{gray!10}265 & 896 & 2432 & 64 & 14 & 18 \\
294 & 896 & 2432 & 64 & 14 

In [11]:
mlstm_df = get_all_parametric_sclaw_fit_data_dataframe(model_type="mlstm")
mlstm_df[mlstm_df["embedding_dim"] == 4096][
    [
        "experiment_set_ctx_length",
        "name",
        "run_tag",
        "learning_rate",
        "global_batch_size",
        "num_params",
        "ffn_dim",
        "head_dim_qk",
        "head_dim_v",
        "num_heads",
        "num_blocks",
        "context_length",
        "train/.loss_mean",
    ]
]

Unnamed: 0,experiment_set_ctx_length,name,run_tag,learning_rate,global_batch_size,num_params,ffn_dim,head_dim_qk,head_dim_v,num_heads,num_blocks,context_length,train/.loss_mean
228,isoflop_ctx8192,dclm_mLSTMv1_7B_ctx8192_lr0.0009_steps7600_nb3...,sclaw_iso_round8,0.0009,256.0,6464058000.0,10944,128.0,256,16,30,8192,2.553666
229,isoflop_ctx8192,dclm_mLSTMv1_7B_ctx8192_lr0.0009_steps7200_nb3...,sclaw_iso_round8,0.0009,256.0,6867523000.0,10944,128.0,256,16,32,8192,2.56239
605,tokenparam_ctx8192,dclm_mLSTMv1_7B_ctx8192_lr0.0005_steps73000_gb...,scl_mlstm_7B,0.0005,512.0,6865425000.0,10944,256.0,512,8,32,8192,2.206036
606,tokenparam_ctx8192,dclm_mLSTMv1_7B_ctx8192_lr0.0005_steps76000_gb...,scl_mlstm_7B,0.0005,256.0,6865425000.0,10944,256.0,512,8,32,8192,2.251832
607,tokenparam_ctx8192,dclm_mLSTMv1_7B_ctx8192_lr0.0005_steps181000_g...,scl_mlstm_7B,0.0005,512.0,6865425000.0,10944,256.0,512,8,32,8192,2.150207
608,tokenparam_ctx8192,dclm_mLSTMv1_7B_ctx8192_gbs512,dclm_mLSTMv1_7B_longrun_pretraining_final,0.0004,512.0,6865425000.0,10944,256.0,512,8,32,8192,2.100448


In [12]:
df = get_all_parametric_sclaw_fit_data_dataframe(model_type="all")
df[df["experiment_set_ctx_length"] == "isoflop_ctx8192"][
    df["IsoFLOP"].isin(["6e+18", "6e+20", "1e+20"])
][
    [
        "name",
        "run_tag",
        "global_batch_size",
        "IsoFLOP",
        "learning_rate",
    ]
].sort_values(by=["IsoFLOP", "global_batch_size"])

  df[df["experiment_set_ctx_length"] == "isoflop_ctx8192"][


Unnamed: 0,name,run_tag,global_batch_size,IsoFLOP,learning_rate
90,dclm_mLSTMv1_100M_ctx8192_lr0.003_steps192000_...,"nb10_ed640_nh5_pf2.667,sclaw_iso",128.0,1e+20,0.0030
91,dclm_mLSTMv1_160M_ctx8192_lr0.003_steps126500_...,"nb12_ed768_nh6_pf2.667,sclaw_iso",128.0,1e+20,0.0030
96,dclm_mLSTMv1_200M_ctx8192_lr0.003_steps97500_n...,"nb12_ed896_nh7_pf2.667,sclaw_iso",128.0,1e+20,0.0030
101,dclm_mLSTMv1_100M_ctx8192_lr0.003_steps162000_...,"nb13_ed640_nh5_pf2.667,sclaw_iso",128.0,1e+20,0.0030
108,dclm_mLSTMv1_200M_ctx8192_lr0.003_steps82500_n...,"nb15_ed896_nh7_pf2.667,sclaw_iso",128.0,1e+20,0.0030
...,...,...,...,...,...
227,dclm_mLSTMv1_4.5B_ctx8192_lr0.0009_steps10600_...,sclaw_iso_round8,256.0,6e+20,0.0009
228,dclm_mLSTMv1_7B_ctx8192_lr0.0009_steps7600_nb3...,sclaw_iso_round8,256.0,6e+20,0.0009
229,dclm_mLSTMv1_7B_ctx8192_lr0.0009_steps7200_nb3...,sclaw_iso_round8,256.0,6e+20,0.0009
230,dclm_mLSTMv1_5.5B_ctx8192_lr0.0009_steps8200_n...,sclaw_iso_round8,256.0,6e+20,0.0009


In [13]:
def get_isoflop_batch_size_df(ctx_length=None) -> pd.DataFrame:
    df = get_all_parametric_sclaw_fit_data_dataframe(model_type="all")
    batch_size_df = (
        df[["IsoFLOP", "context_length", "global_batch_size"]]
        .drop_duplicates()
        .sort_values(by=["context_length", "IsoFLOP", "global_batch_size"])
        .dropna()
        .reset_index(drop=True)
    )
    batch_size_df.rename(columns={"global_batch_size": "bs_in_seqs"}, inplace=True)
    batch_size_df["bs_in_seqs"] = batch_size_df["bs_in_seqs"].astype(int)

    batch_size_df["bs_in_tokens"] = (
        batch_size_df["bs_in_seqs"] * batch_size_df["context_length"]
    )
    batch_size_df["bs_in_tokens"] = batch_size_df["bs_in_tokens"].astype(int)

    # rename columns for latex
    col_name_map = {
        "IsoFLOP": "IsoFLOP",
        "context_length": r"$T$ (ctx)",
        "bs_in_seqs": r"$B$ (seqs)",
        "bs_in_tokens": r"$B \times T$ (tokens)",
    }
    batch_size_df = batch_size_df.rename(columns=col_name_map)
    if ctx_length is not None:
        return batch_size_df[batch_size_df[r"$T$ (ctx)"] == ctx_length]
    return batch_size_df

In [14]:
df = get_isoflop_batch_size_df(8192)
latex_table = df.to_latex(
    index=False,
    formatters={r"$B \times T$ (tokens)": lambda x: f"{x:,}"},
    caption="Batch sizes used for models trained with the IsoFLOP configuration at context length 8192.",
    label="tab:isoflop_batch_sizes",
    longtable=False,
    column_format="r|" + "r" * (len(df.columns) - 1),
)
colored_latex_table = add_row_colors(latex_table)
# colored_latex_table = add_adjustbox_scaling(colored_latex_table, height_scale=0.5)
print(colored_latex_table)

\begin{table}
\caption{Batch sizes used for models trained with the IsoFLOP configuration at context length 8192.}
\label{tab:isoflop_batch_sizes}
\begin{tabular}{r|rrr}
\toprule
IsoFLOP & $T$ (ctx) & $B$ (seqs) & $B \times T$ (tokens) \\
\midrule
\rowcolor{gray!10}1e+19 & 8192 & 128 & 1,048,576 \\
1e+20 & 8192 & 128 & 1,048,576 \\
\rowcolor{gray!10}3e+19 & 8192 & 128 & 1,048,576 \\
6e+18 & 8192 & 128 & 1,048,576 \\
\rowcolor{gray!10}6e+20 & 8192 & 256 & 2,097,152 \\
\bottomrule
\end{tabular}
\end{table}

