In [2]:
import pandas as pd

from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.utils.attr_dict import AttrDict
from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components
from sf_examples.nethack.models.simba import SimBaEncoder

In [3]:
register_nethack_components()
cfg = parse_nethack_args(argv=["--env=nethack_score", "--add_image_observation=True"])

env = make_env_func_batched(cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0))
env_info = extract_env_info(env, cfg)

[36m[2025-01-26 09:32:39,142][825930] register_encoder_factory: <function make_nethack_encoder at 0x71519d7113f0>[0m
[36m[2025-01-26 09:32:39,144][825930] register_actor_critic_factory: <function make_nethack_actor_critic at 0x71519d7112d0>[0m


In [4]:
results = []
for hidden_dim in [16, 32, 64, 128, 256, 512]:
    for depth in [1, 2, 3, 4]:
        model = SimBaEncoder(
            obs_space=env_info.obs_space,
            hidden_dim=hidden_dim,
            depth=depth,
            use_prev_action=cfg.use_prev_action,
        )
        total_params = sum(p.numel() for p in model.parameters())

        results.append({
            "Hidden Dim": hidden_dim,
            "Depth": depth,
            "Model Size": f"{total_params / 10**6:.2f}M"
        })

df = pd.DataFrame(results)
pivot_table = df.pivot(index="Depth", columns="Hidden Dim", values="Model Size")
print(pivot_table)

Hidden Dim    16     32     64     128     256     512
Depth                                                 
1           0.39M  0.78M  1.59M  3.35M   7.48M  18.09M
2           0.39M  0.79M  1.66M  3.65M   8.66M  22.81M
3           0.40M  0.81M  1.73M  3.94M   9.84M  27.53M
4           0.40M  0.83M  1.81M  4.24M  11.02M  32.25M


In [None]:
from sf_examples.nethack.models import ChaoticDwarvenGPT5

model = ChaoticDwarvenGPT5(cfg, env_info.obs_space)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model Size: {total_params / 10**6:.2f}M")

Model Size: 3.28M
