In [3]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import pickle
import sys
import copy

sys.path.append("..")
from pathlib import Path
from xlstm_scaling_laws.analysis.isoflop.experiment_setup.common import (
    generate_iso_configs_llama,
    save_iso_configs,
    filter_num_params_between
)
from xlstm_scaling_laws.analysis.isoflop.experiment_setup.experiment_setup_llama import (
    create_train_len_df_llama_all,
)
from xlstm_scaling_laws.analysis.isoflop.experiment_setup.experiment_setup_mlstm_ctx import (
    create_train_len_df_mlstm_all,
)
from xlstm_scaling_laws.analysis.isoflop.experiment_setup.config_templates import (
    config_template_str_llama,
)

# Generate IsoFLOP configs for Transformer / Llama models

In this notebook we show an example for how we generated configs for our IsoFLOP experiments.

In [4]:
cols = [
    "index",
    "num_params in M",
    "global_batch_size",
    "learning_rate",
    "model_tag",
    "6.0e+18",
    "1.0e+19",
    "3.0e+19",
    "chinchilla_optimal",
    "num_blocks",
    "embedding_dim",
    "head_dim",
    "proj_factor_ffn",
    "num_nodes",
    "batch_size_per_device",
    "context_length",
]

In [5]:
train_len_df_all_round0 = create_train_len_df_llama_all(context_length=8192).sort_values(by=["num_params in M"], ascending=True)
train_len_df_all_round0[cols]

Unnamed: 0,index,num_params in M,global_batch_size,learning_rate,model_tag,6.0e+18,1.0e+19,3.0e+19,chinchilla_optimal,num_blocks,embedding_dim,head_dim,proj_factor_ffn,num_nodes,batch_size_per_device,context_length
0,llama_80M_1,83.634688,128,0.003,nb10_ed512_hd64_pf2.667,9537.559152,15895.93192,47687.795759,1754.725586,10,512,64,2.667,1,16,8192
1,llama_80M_2,90.059264,128,0.003,nb12_ed512_hd64_pf2.667,8304.47502,13840.7917,41522.375101,1889.518555,12,512,64,2.667,1,16,8192
2,llama_80M_3,96.48384,128,0.003,nb14_ed512_hd64_pf2.667,7353.731773,12256.219622,36768.658865,2024.311523,14,512,64,2.667,1,16,8192
3,llama_80M_4,102.908416,128,0.003,nb16_ed512_hd64_pf2.667,6598.318921,10997.198202,32991.594606,2159.104492,16,512,64,2.667,1,16,8192
4,llama_100M_1,113.96416,128,0.003,nb10_ed640_hd64_pf2.667,7073.658242,11789.430403,35368.291209,2391.063232,10,640,64,2.667,1,16,8192
5,llama_100M_2,128.83648,128,0.003,nb13_ed640_hd64_pf2.667,5758.613755,9597.689592,28793.068777,2703.096924,13,640,64,2.667,1,16,8192
6,llama_100M_3,133.79392,128,0.003,nb14_ed640_hd64_pf2.667,5422.581327,9037.635545,27112.906634,2807.108154,14,640,64,2.667,1,16,8192
7,llama_100M_4,143.7088,128,0.003,nb16_ed640_hd64_pf2.667,4855.871876,8093.119793,24279.35938,3015.130615,16,640,64,2.667,1,16,8192
8,llama_160M_1,162.2208,128,0.003,nb12_ed768_hd64_pf2.667,4754.777271,7924.628784,23773.886353,3403.527832,12,768,64,2.667,1,16,8192
9,llama_160M_2,183.459072,128,0.003,nb15_ed768_hd64_pf2.667,3956.228931,6593.714885,19781.144656,3849.124512,15,768,64,2.667,1,16,8192


In [6]:
cfgs_round0 = generate_iso_configs_llama(
    train_len_df_all_round0[cols],
    config_template_str_llama,
    valevery_steps=200,
    run_valevery_steps=1000,
)

In [7]:
print(cfgs_round0["llama_200M"][3])

 
# @package _global_
defaults:
  - /data@data_train.ds1: dclm_arrayrecord_train
  - /data@data_eval.ds1: dclm_arrayrecord_eval_preprocessed # for different ctx len use: dclm_arrayrecord_eval
  # - /data@data_eval.ds1: slimpajama_627B_arrayrecord_eval_preprocessed
  - override /parallel: llama1.3B # no fsdp
  - override /model: llama_default
  - override /optimizer: adamw
  - override /scheduler: cosine_decay
  - override /hydra/launcher: slurm_launcher
  - _self_

# specify the deltas from the defaults:
task_name: sclaw_llama_iso13 #! adapt here
batch_size_per_device: 16
context_length: 8192
num_epochs: 1000
num_train_steps: 2400, 4000, 11800 #95_000
lr: 0.003

scheduler:
  warmup_steps: 750
  cooldown_steps: 1000 #2000

trainer:
  gradient_accumulate_steps: 1
  check_val_every_n_steps: 1000
  log_logit_stats: false
  log_intermediates: false

checkpointing:
  monitor: dclm_perplexity

data_train:
  ds1:
    tokenizer_path: "EleutherAI/gpt-neox-20b"

data_eval:
  ds1:
    tokenizer_pa