In [1]:
import sys
from timeit import timeit
from os.path import join

import numpy as np

from nle import nethack

from sample_factory.algo.utils.context import global_model_factory, sf_global_context
from sample_factory.cfg.arguments import load_from_path, parse_full_cfg, parse_sf_args
from sample_factory.envs.env_utils import register_env
from sample_factory.utils.typing import ActionSpace, Config, ObsSpace
from sample_factory.utils.utils import log
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sf_examples.nethack.nethack_params import (
    add_extra_params_general,
    add_extra_params_learner,
    add_extra_params_model,
    add_extra_params_nethack_env,
    nethack_override_defaults,
)
from sf_examples.nethack.nethack_utils import NETHACK_ENVS, make_nethack_env

In [2]:
for env_name in NETHACK_ENVS.keys():
    register_env(env_name, make_nethack_env)

In [3]:
def parse_nethack_args(argv=None, evaluation=False):
    parser, partial_cfg = parse_sf_args(argv=argv, evaluation=evaluation)
    add_extra_params_nethack_env(parser)
    add_extra_params_model(parser)
    add_extra_params_learner(parser)
    add_extra_params_general(parser)
    nethack_override_defaults(partial_cfg.env, parser)
    final_cfg = parse_full_cfg(parser, argv)
    return final_cfg

In [4]:
cfg = parse_nethack_args(argv=["--env=challenge"], evaluation=True)

In [5]:
render_mode = "rgb_array"

In [6]:
env = make_env_func_batched(
    cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode
)
env_info = extract_env_info(env, cfg)



In [7]:
obs = env.reset()

In [8]:
obs = env.last_observation
tty_chars, tty_colors, tty_cursor = obs[2], obs[3], obs[4]
ascii_observation = nethack.tty_render(tty_chars, tty_colors, tty_cursor)

  logger.warn(


In [9]:
print(ascii_observation)

Konnichi wa Agent, welcome to NetHack!  You are a lawful male human Samurai.    
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                            

### Old tty_render

In [None]:
def tty_render(chars, colors, cursor=None):
    """Returns chars as string with ANSI escape sequences.

    Args:
      chars: A row x columns numpy array of chars.
      colors: A numpy array of colors (0-15), same shape as chars.
      cursor: An optional (row, column) index for the cursor,
        displayed as underlined.

    Returns:
      A string with chars decorated by ANSI escape sequences.
    """
    rows, cols = chars.shape
    if cursor is None:
        cursor = (-1, -1)
    cursor = tuple(cursor)
    result = ""
    for i in range(rows):
        result += "\n"
        for j in range(cols):
            entry = "\033[%d;3%dm%s" % (
                # & 8 checks for brightness.
                bool(colors[i, j] & 8),
                colors[i, j] & ~8,
                chr(chars[i, j]),
            )
            if cursor != (i, j):
                result += entry
            else:
                result += "\033[4m%s\033[0m" % entry
    return result + "\033[0m"


In [38]:
%%timeit
tty_render(tty_chars, tty_colors, tty_cursor)

7.78 ms ± 93.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Faster tty_render

In [157]:
def tty_render_vectorized(chars, colors, cursor=None):
    """Returns chars as string with ANSI escape sequences.

    Args:
      chars: A row x columns numpy array of chars.
      colors: A numpy array of colors (0-15), same shape as chars.
      cursor: An optional (row, column) index for the cursor,
        displayed as underlined.

    Returns:
      A string with chars decorated by ANSI escape sequences.
    """
    # Generate ANSI escape sequences for each element in chars and colors
    entries = np.char.add(
        np.char.add(
            np.char.add(
                np.char.add(
                    "\033[",
                    np.where(colors & 8, "1;", "0;"),
                ),
                np.char.add(
                    "3",
                    np.char.mod("%d", colors & ~8),
                ),
            ),
            "m",
        ),
        np.char.mod("%c", chars),
    )

    # Add cursor 
    if cursor is not None:
        entries = entries.astype("<U16")
        cursor = tuple(cursor)
        entries[cursor] = np.char.add("\033[4m", np.char.add(entries[cursor], "\033[0m"))

    # Combine entries into a string and add reset escape sequence at the end
    result = "\n".join(np.apply_along_axis("".join, 1, entries)) + "\033[0m"

    return result

In [158]:
print(tty_render_vectorized(tty_chars, tty_colors, tty_cursor))

Hello Agent, welcome to NetHack!  You are a chaotic female human Rogue.         
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                                                                                
                            

In [159]:
%%timeit
tty_render_vectorized(tty_chars, tty_colors, tty_cursor)

4.51 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
7.78 / 4.51

1.7250554323725056

### new implementation is 72% faster