In [1]:
%load_ext autoreload
%autoreload 2
import copy
import pickle
import sys

import numpy as np
import pandas as pd

sys.path.append("..")
from pathlib import Path

from xlstm_scaling_laws.analysis.isoflop.experiment_setup.common import (
    filter_num_params_between,
    generate_iso_configs_mlstm_ctx,
    save_iso_configs,
)
from xlstm_scaling_laws.analysis.isoflop.experiment_setup.experiment_setup_llama import (
    create_train_len_df_llama_all,
)
from xlstm_scaling_laws.analysis.isoflop.experiment_setup.experiment_setup_mlstm_ctx import (
    create_train_len_df_mlstm_all,
)

# Generate IsoFLOP configs for xLSTM models

In this notebook we show an example for how we generated configs for our IsoFLOP experiments.

In [2]:
cols = [
    "index",
    "num_params in M",
    "global_batch_size",
    "learning_rate",
    "model_tag",
    "6.0e+18",
    "1.0e+19",
    "3.0e+19",
    "chinchilla_optimal",
    "num_blocks",
    "embedding_dim",
    "num_heads",
    "proj_factor_ffn",
    "num_nodes",
    "batch_size_per_device",
    "context_length",
]

In [3]:
train_len_df_all_round1 = create_train_len_df_mlstm_all(
    context_length=8192
).sort_values(by=["num_params in M"], ascending=True)
train_len_df_all_round1[cols]

Unnamed: 0,index,num_params in M,global_batch_size,learning_rate,model_tag,6.0e+18,1.0e+19,3.0e+19,chinchilla_optimal,num_blocks,embedding_dim,num_heads,proj_factor_ffn,num_nodes,batch_size_per_device,context_length
0,mlstm_80M_1,83.680848,128,0.003,nb10_ed512_nh4_pf2.667,16315.158775,27191.931292,81575.793875,1755.694061,10,512,4,2.667,1,16,8192
1,mlstm_80M_2,90.114656,128,0.003,nb12_ed512_nh4_pf2.667,14673.540273,24455.900455,73367.701366,1890.680725,12,512,4,2.667,1,16,8192
2,mlstm_80M_3,96.548464,128,0.003,nb14_ed512_nh4_pf2.667,13332.077112,22220.128521,66660.385562,2025.667389,14,512,4,2.667,1,16,8192
3,mlstm_80M_4,102.982272,128,0.003,nb16_ed512_nh4_pf2.667,12215.343646,20358.906076,61076.718229,2160.654053,16,512,4,2.667,1,16,8192
4,mlstm_100M_1,114.03466,128,0.003,nb10_ed640_nh5_pf2.667,11502.441138,19170.73523,57512.205691,2392.542381,10,640,5,2.667,1,16,8192
5,mlstm_100M_2,123.96364,128,0.003,nb12_ed640_nh5_pf2.667,10248.63265,17081.054416,51243.163248,2600.860672,12,640,5,2.667,1,16,8192
6,mlstm_100M_3,133.89262,128,0.003,nb14_ed640_nh5_pf2.667,9241.297188,15402.161979,46206.485938,2809.178963,14,640,5,2.667,1,16,8192
7,mlstm_100M_4,143.8216,128,0.003,nb16_ed640_nh5_pf2.667,8414.261585,14023.769308,42071.307923,3017.497253,16,640,5,2.667,1,16,8192
8,mlstm_160M_1,164.110224,128,0.003,nb12_ed768_nh6_pf2.667,7578.469654,12630.782757,37892.34827,3443.169525,12,768,6,2.667,2,8,8192
9,mlstm_160M_2,185.820852,128,0.003,nb15_ed768_nh6_pf2.667,6459.388374,10765.647289,32296.941868,3898.676628,15,768,6,2.667,2,8,8192


In [4]:
train_len_df_all_round1 = create_train_len_df_mlstm_all(
    context_length=2048
).sort_values(by=["num_params in M"], ascending=True)
train_len_df_all_round1[cols]

Unnamed: 0,index,num_params in M,global_batch_size,learning_rate,model_tag,6.0e+18,1.0e+19,3.0e+19,chinchilla_optimal,num_blocks,embedding_dim,num_heads,proj_factor_ffn,num_nodes,batch_size_per_device,context_length
0,mlstm_80M_1,83.680848,512,0.003,nb10_ed512_nh4_pf2.667,16315.158775,27191.931292,81575.793875,1755.694061,10,512,4,2.667,1,64,2048
1,mlstm_80M_2,90.114656,512,0.003,nb12_ed512_nh4_pf2.667,14673.540273,24455.900455,73367.701366,1890.680725,12,512,4,2.667,1,64,2048
2,mlstm_80M_3,96.548464,512,0.003,nb14_ed512_nh4_pf2.667,13332.077112,22220.128521,66660.385562,2025.667389,14,512,4,2.667,1,64,2048
3,mlstm_80M_4,102.982272,512,0.003,nb16_ed512_nh4_pf2.667,12215.343646,20358.906076,61076.718229,2160.654053,16,512,4,2.667,1,64,2048
4,mlstm_100M_1,114.03466,512,0.003,nb10_ed640_nh5_pf2.667,11502.441138,19170.73523,57512.205691,2392.542381,10,640,5,2.667,1,64,2048
5,mlstm_100M_2,123.96364,512,0.003,nb12_ed640_nh5_pf2.667,10248.63265,17081.054416,51243.163248,2600.860672,12,640,5,2.667,1,64,2048
6,mlstm_100M_3,133.89262,512,0.003,nb14_ed640_nh5_pf2.667,9241.297188,15402.161979,46206.485938,2809.178963,14,640,5,2.667,1,64,2048
7,mlstm_100M_4,143.8216,512,0.003,nb16_ed640_nh5_pf2.667,8414.261585,14023.769308,42071.307923,3017.497253,16,640,5,2.667,1,64,2048
8,mlstm_160M_1,164.110224,512,0.003,nb12_ed768_nh6_pf2.667,7578.469654,12630.782757,37892.34827,3443.169525,12,768,6,2.667,2,32,2048
9,mlstm_160M_2,185.820852,512,0.003,nb15_ed768_nh6_pf2.667,6459.388374,10765.647289,32296.941868,3898.676628,15,768,6,2.667,2,32,2048


In [5]:
sel_train_len_df_round1 = filter_num_params_between(
    train_len_df_all_round1,
    min_num_params=50,
    max_num_params=550,
)
sel_train_len_df_round1[cols]

Unnamed: 0,index,num_params in M,global_batch_size,learning_rate,model_tag,6.0e+18,1.0e+19,3.0e+19,chinchilla_optimal,num_blocks,embedding_dim,num_heads,proj_factor_ffn,num_nodes,batch_size_per_device,context_length
0,mlstm_80M_1,83.680848,512,0.003,nb10_ed512_nh4_pf2.667,16315.158775,27191.931292,81575.793875,1755.694061,10,512,4,2.667,1,64,2048
1,mlstm_80M_2,90.114656,512,0.003,nb12_ed512_nh4_pf2.667,14673.540273,24455.900455,73367.701366,1890.680725,12,512,4,2.667,1,64,2048
2,mlstm_80M_3,96.548464,512,0.003,nb14_ed512_nh4_pf2.667,13332.077112,22220.128521,66660.385562,2025.667389,14,512,4,2.667,1,64,2048
3,mlstm_80M_4,102.982272,512,0.003,nb16_ed512_nh4_pf2.667,12215.343646,20358.906076,61076.718229,2160.654053,16,512,4,2.667,1,64,2048
4,mlstm_100M_1,114.03466,512,0.003,nb10_ed640_nh5_pf2.667,11502.441138,19170.73523,57512.205691,2392.542381,10,640,5,2.667,1,64,2048
5,mlstm_100M_2,123.96364,512,0.003,nb12_ed640_nh5_pf2.667,10248.63265,17081.054416,51243.163248,2600.860672,12,640,5,2.667,1,64,2048
6,mlstm_100M_3,133.89262,512,0.003,nb14_ed640_nh5_pf2.667,9241.297188,15402.161979,46206.485938,2809.178963,14,640,5,2.667,1,64,2048
7,mlstm_100M_4,143.8216,512,0.003,nb16_ed640_nh5_pf2.667,8414.261585,14023.769308,42071.307923,3017.497253,16,640,5,2.667,1,64,2048
8,mlstm_160M_1,164.110224,512,0.003,nb12_ed768_nh6_pf2.667,7578.469654,12630.782757,37892.34827,3443.169525,12,768,6,2.667,2,32,2048
9,mlstm_160M_2,185.820852,512,0.003,nb15_ed768_nh6_pf2.667,6459.388374,10765.647289,32296.941868,3898.676628,15,768,6,2.667,2,32,2048


In [6]:
from xlstm_scaling_laws.analysis.isoflop.experiment_setup.config_templates import (
    config_template_str_mlstm_ctx,
)

cfgs_round1 = generate_iso_configs_mlstm_ctx(
    sel_train_len_df_round1[cols],
    config_template_str_mlstm_ctx,
    valevery_steps=200,
    run_valevery_steps=1000,
)

In [7]:
print(cfgs_round1["mLSTMv1_200M"][1])


# @package _global_
defaults:
  - /data@data_train.ds1: dclm_arrayrecord_train
  - /data@data_eval.ds1: dclm_arrayrecord_eval_preprocessed
  # - /data@data_eval.ds2: slimpajama_627B_arrayrecord_eval_preprocessed
  - override /parallel: mLSTMv1_7B #mLSTMv1_7B # use fsdp #mLSTMv1_1.3B # use this for no FSDP (pure dp)
  - override /model: mLSTMv1_default
  - override /optimizer: adamw
  - override /scheduler: cosine_decay
  - override /hydra/launcher: slurm_launcher
  - _self_

# specify the deltas from the defaults:
task_name: sclaw_mlstm_ctx_iso8 #! adapt here
batch_size_per_device: 32
context_length: 2048
num_epochs: 1000
num_train_steps: 5000, 8200, 24800 #18_000
lr: 0.003

scheduler:
  warmup_steps: 750
  cooldown_steps: 1000 #2000

trainer:
  gradient_accumulate_steps: 1
  check_val_every_n_steps: 1000
  log_logit_stats: true
  log_intermediates: false

checkpointing:
  monitor: dclm_perplexity

data_train:
  ds1:
    tokenizer_path: "EleutherAI/gpt-neox-20b"

data_eval:
  ds1:
   

In [None]:
# save_iso_configs(config_dict=cfgs_round1, save_dir=Path("iso_configs"))