In [1]:
import pandas as pd

from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.utils.attr_dict import AttrDict
from sf_examples.nethack.train_nethack import parse_nethack_args, register_nethack_components
from sf_examples.nethack.models.simba import SimBaEncoder
from sf_examples.nethack.models.vit import ViTEncoder
from sf_examples.nethack.models.chaotic_dwarf import ChaoticDwarvenGPT5

In [2]:
register_nethack_components()
cfg = parse_nethack_args(argv=["--env=nethack_score", "--add_image_observation=True"])

env = make_env_func_batched(cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0))
env_info = extract_env_info(env, cfg)

[36m[2025-01-27 15:36:52,392][76384] register_encoder_factory: <function make_nethack_encoder at 0x70fa55becdc0>[0m
[36m[2025-01-27 15:36:52,392][76384] register_actor_critic_factory: <function make_nethack_actor_critic at 0x70fa55becca0>[0m


### SimBa

In [3]:
results = []
for hidden_dim in [64, 128, 256, 512, 1024]:
    for depth in [1, 2, 3, 4]:
        model = SimBaEncoder(
            obs_space=env_info.obs_space,
            hidden_dim=hidden_dim,
            depth=depth,
            use_prev_action=cfg.use_prev_action,
            use_max_pool=True,
            expansion=2,
        )
        total_params = sum(p.numel() for p in model.parameters())

        results.append({
            "Hidden Dim": hidden_dim,
            "Depth": depth,
            "Model Size": f"{total_params / 10**6:.2f}M"
        })

df = pd.DataFrame(results)
pivot_table = df.pivot(index="Depth", columns="Hidden Dim", values="Model Size")
print(pivot_table)

Hidden Dim   64     128    256     512     1024
Depth                                          
1           1.50M  3.03M  6.21M  13.06M  28.73M
2           1.51M  3.04M  6.22M  13.07M  28.75M
3           1.54M  3.07M  6.25M  13.11M  28.79M
4           1.66M  3.19M  6.37M  13.24M  28.94M


### ChaoticDwarven

In [4]:
model = ChaoticDwarvenGPT5(cfg, env_info.obs_space)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model Size: {total_params / 10**6:.2f}M")

Model Size: 3.28M


### ViT

In [6]:
results = []
for hidden_dim in [64, 128, 256, 512]:
    for depth in [1, 2, 3, 4]:
        model = ViTEncoder(
            obs_space=env_info.obs_space,
            hidden_dim=hidden_dim,
            depth=depth,
            heads=8,
            mlp_dim=hidden_dim * 2,
            use_prev_action=cfg.use_prev_action,
        )
        total_params = sum(p.numel() for p in model.parameters())

        results.append({
            "Hidden Dim": hidden_dim,
            "Depth": depth,
            "Model Size": f"{total_params / 10**6:.2f}M"
        })

df = pd.DataFrame(results)
pivot_table = df.pivot(index="Depth", columns="Hidden Dim", values="Model Size")
print(pivot_table)

Hidden Dim    64     128    256     512
Depth                                  
1           1.66M  3.38M  7.07M  15.43M
2           1.80M  3.71M  7.86M  17.53M
3           1.95M  4.03M  8.64M  19.63M
4           2.10M  4.36M  9.43M  21.73M
