In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import pickle
import sys 
import copy
sys.path.append('..')

from xlstm_scaling_laws.flops.count_flops import count_model_flops_fwbw, FlopCountConfig
from xlstm_scaling_laws.params.count_params import count_model_params, ParamCountConfig

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
def compute_fwbw_flops_and_num_params_for_models(
    model_config_dict: dict[str, dict],
    model_type: str = "mlstm_v1",
    context_length: int = 8192,
) -> dict[str, dict]:
    updated_model_dict = copy.deepcopy(model_config_dict)
    for model_key, model_dict in model_config_dict.items():
        updated_model_dict[model_key].update(
            {
                "num_flops_fwbw": count_model_flops_fwbw(
                    model_type=model_type, context_length=context_length,
                    model_kwargs=model_dict,
                    config=FlopCountConfig()
                )[0],
                "num_params": count_model_params(model_type=model_type, model_kwargs=model_dict, config=ParamCountConfig()),
            }
        )
    return updated_model_dict

In [20]:
context_length = 8192
model_size_dict = {
    "mlstm_7B_nh4": {
        "num_blocks": 32,
        "embedding_dim": 4096,
        "proj_factor_ffn": 2.667,
        "num_heads": 4,
        "proj_factor_qk": 0.5,
        "chunk_size": 64,
        "vocab_size": 50304,
        "ffn_multiple_of": 64,
        "global_batch_size": 512,
        "learning_rate": 0.0005,
    },
    "mlstm_7B_nh8": {
        "num_blocks": 32,
        "embedding_dim": 4096,
        "proj_factor_ffn": 2.667,
        "num_heads": 8,
        "proj_factor_qk": 0.5,
        "chunk_size": 64,
        "vocab_size": 50304,
        "ffn_multiple_of": 64,
        "global_batch_size": 512,
        "learning_rate": 0.0005,
    },
    "mlstm_7B_nh16": {
        "num_blocks": 32,
        "embedding_dim": 4096,
        "proj_factor_ffn": 2.667,
        "num_heads": 16,
        "proj_factor_qk": 0.5,
        "chunk_size": 64,
        "vocab_size": 50304,
        "ffn_multiple_of": 64,
        "global_batch_size": 512,
        "learning_rate": 0.0005,
    },
    "mlstm_7B_nh32": {
        "num_blocks": 32,
        "embedding_dim": 4096,
        "proj_factor_ffn": 2.667,
        "num_heads": 32,
        "proj_factor_qk": 0.5,
        "chunk_size": 64,
        "vocab_size": 50304,
        "ffn_multiple_of": 64,
        "global_batch_size": 512,
        "learning_rate": 0.0005,
    },
}

In [21]:
model_size_dict_w_flop_params = compute_fwbw_flops_and_num_params_for_models(
    model_config_dict=model_size_dict, model_type="mlstm_v1", context_length=8192
)

In [22]:
res_df = pd.DataFrame(model_size_dict_w_flop_params).T
res_df 

Unnamed: 0,num_blocks,embedding_dim,proj_factor_ffn,num_heads,proj_factor_qk,chunk_size,vocab_size,ffn_multiple_of,global_batch_size,learning_rate,num_flops_fwbw,num_params
mlstm_7B_nh4,32.0,4096.0,2.667,4.0,0.5,64.0,50304.0,64.0,512.0,0.0005,334172300000000.0,6864376000.0
mlstm_7B_nh8,32.0,4096.0,2.667,8.0,0.5,64.0,50304.0,64.0,512.0,0.0005,330901800000000.0,6865425000.0
mlstm_7B_nh16,32.0,4096.0,2.667,16.0,0.5,64.0,50304.0,64.0,512.0,0.0005,329347300000000.0,6867523000.0
mlstm_7B_nh32,32.0,4096.0,2.667,32.0,0.5,64.0,50304.0,64.0,512.0,0.0005,328731500000000.0,6871718000.0
