In [10]:
import pandas as pd

from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.utils.attr_dict import AttrDict
from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components
from sf_examples.nethack.models.simba import SimBaEncoder

In [5]:
register_nethack_components()
cfg = parse_nethack_args(argv=["--env=nethack_score", "--add_image_observation=True"])

env = make_env_func_batched(cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0))
env_info = extract_env_info(env, cfg)

[33m[2025-01-26 08:25:41,847][2216020] Environment nethack_progress already registered, overwriting...[0m
[33m[2025-01-26 08:25:41,848][2216020] Environment nethack_staircase already registered, overwriting...[0m
[33m[2025-01-26 08:25:41,848][2216020] Environment nethack_score already registered, overwriting...[0m
[33m[2025-01-26 08:25:41,849][2216020] Environment nethack_pet already registered, overwriting...[0m
[33m[2025-01-26 08:25:41,850][2216020] Environment nethack_oracle already registered, overwriting...[0m
[33m[2025-01-26 08:25:41,851][2216020] Environment nethack_gold already registered, overwriting...[0m
[33m[2025-01-26 08:25:41,851][2216020] Environment nethack_eat already registered, overwriting...[0m
[33m[2025-01-26 08:25:41,852][2216020] Environment nethack_scout already registered, overwriting...[0m
[33m[2025-01-26 08:25:41,852][2216020] Environment nethack_challenge already registered, overwriting...[0m
[36m[2025-01-26 08:25:41,853][2216020] register

In [16]:
results = []
for hidden_dim in [16, 32, 64, 128, 256, 512]:
    for depth in [1, 2, 3, 4]:
        model = SimBaEncoder(
            obs_space=env_info.obs_space,
            hidden_dim=hidden_dim,
            depth=depth,
            use_prev_action=cfg.use_prev_action,
        )
        total_params = sum(p.numel() for p in model.parameters())

        results.append({
            "Hidden Dim": hidden_dim,
            "Depth": depth,
            "Model Size": f"{total_params / 10**6:.2f}M"
        })

df = pd.DataFrame(results)
pivot_table = df.pivot(index="Depth", columns="Hidden Dim", values="Model Size")
print(pivot_table)

Hidden Dim    16     32     64     128     256     512
Depth                                                 
1           0.38M  0.77M  1.58M  3.35M   7.47M  18.08M
2           0.39M  0.79M  1.65M  3.64M   8.65M  22.80M
3           0.39M  0.81M  1.73M  3.94M   9.83M  27.52M
4           0.40M  0.83M  1.80M  4.23M  11.02M  32.24M
