In [11]:
import pandas as pd

from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.utils.attr_dict import AttrDict
from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components
from sf_examples.nethack.models.simba import SimBaEncoder
from sf_examples.nethack.models.vit import ViTEncoder
from sf_examples.nethack.models.chaotic_dwarf import ChaoticDwarvenGPT5

In [12]:
register_nethack_components()
cfg = parse_nethack_args(argv=["--env=nethack_score", "--add_image_observation=True"])

env = make_env_func_batched(cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0))
env_info = extract_env_info(env, cfg)

[33m[2025-01-26 18:43:01,984][139285] Environment nethack_progress already registered, overwriting...[0m
[33m[2025-01-26 18:43:01,985][139285] Environment nethack_staircase already registered, overwriting...[0m
[33m[2025-01-26 18:43:01,986][139285] Environment nethack_score already registered, overwriting...[0m
[33m[2025-01-26 18:43:01,987][139285] Environment nethack_pet already registered, overwriting...[0m
[33m[2025-01-26 18:43:01,987][139285] Environment nethack_oracle already registered, overwriting...[0m
[33m[2025-01-26 18:43:01,988][139285] Environment nethack_gold already registered, overwriting...[0m
[33m[2025-01-26 18:43:01,988][139285] Environment nethack_eat already registered, overwriting...[0m
[33m[2025-01-26 18:43:01,989][139285] Environment nethack_scout already registered, overwriting...[0m
[33m[2025-01-26 18:43:01,990][139285] Environment nethack_challenge already registered, overwriting...[0m
[36m[2025-01-26 18:43:01,990][139285] register_encoder_f

### SimBa

In [13]:
results = []
for hidden_dim in [16, 32, 64, 128, 256, 512, 1024]:
    for depth in [1, 2, 3, 4, 6, 8]:
        model = SimBaEncoder(
            obs_space=env_info.obs_space,
            hidden_dim=hidden_dim,
            depth=depth,
            use_prev_action=cfg.use_prev_action,
        )
        total_params = sum(p.numel() for p in model.parameters())

        results.append({
            "Hidden Dim": hidden_dim,
            "Depth": depth,
            "Model Size": f"{total_params / 10**6:.2f}M"
        })

df = pd.DataFrame(results)
pivot_table = df.pivot(index="Depth", columns="Hidden Dim", values="Model Size")
print(pivot_table)

Hidden Dim   16     32     64     128    256     512      1024
Depth                                                         
1           0.02M  0.04M  0.11M  0.38M  1.44M   5.59M   22.06M
2           0.02M  0.06M  0.18M  0.68M  2.64M  10.38M   41.20M
3           0.03M  0.07M  0.26M  0.98M  3.83M  15.17M   60.34M
4           0.03M  0.09M  0.33M  1.28M  5.03M  19.95M   79.48M
6           0.04M  0.13M  0.48M  1.88M  7.43M  29.53M  117.77M
8           0.05M  0.17M  0.64M  2.48M  9.82M  39.10M  156.05M


### ChaoticDwarven

In [4]:
from sf_examples.nethack.models import ChaoticDwarvenGPT5

model = ChaoticDwarvenGPT5(cfg, env_info.obs_space)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model Size: {total_params / 10**6:.2f}M")

Model Size: 3.28M


### ViT

In [10]:
results = []
for hidden_dim in [16, 32, 64, 128, 256, 512]:
    for depth in [1, 2, 3, 4]:
        model = ViTEncoder(
            obs_space=env_info.obs_space,
            hidden_dim=hidden_dim,
            depth=depth,
            heads=8,
            mlp_dim=hidden_dim * 2,
            use_prev_action=cfg.use_prev_action,
        )
        total_params = sum(p.numel() for p in model.parameters())

        results.append({
            "Hidden Dim": hidden_dim,
            "Depth": depth,
            "Model Size": f"{total_params / 10**6:.2f}M"
        })

df = pd.DataFrame(results)
pivot_table = df.pivot(index="Depth", columns="Hidden Dim", values="Model Size")
print(pivot_table)

Hidden Dim    16     32     64     128    256     512
Depth                                                
1           0.41M  0.82M  1.65M  3.37M  7.07M  15.43M
2           0.45M  0.89M  1.80M  3.70M  7.85M  17.53M
3           0.48M  0.96M  1.95M  4.03M  8.64M  19.63M
4           0.51M  1.03M  2.09M  4.36M  9.43M  21.74M
