In [1]:
from pathlib import Path
import pickle
import gym
import random
import nle
from pprint import pprint
import numpy as np
import time
from datetime import datetime
import cProfile
from IPython import display
import os

In [2]:
env = gym.make('NetHackScore-v0')
obs = env.reset()

In [3]:
list(enumerate(env._actions))

[(0, <MiscAction.MORE: 13>),
 (1, <CompassDirection.N: 107>),
 (2, <CompassDirection.E: 108>),
 (3, <CompassDirection.S: 106>),
 (4, <CompassDirection.W: 104>),
 (5, <CompassDirection.NE: 117>),
 (6, <CompassDirection.SE: 110>),
 (7, <CompassDirection.SW: 98>),
 (8, <CompassDirection.NW: 121>),
 (9, <CompassDirectionLonger.N: 75>),
 (10, <CompassDirectionLonger.E: 76>),
 (11, <CompassDirectionLonger.S: 74>),
 (12, <CompassDirectionLonger.W: 72>),
 (13, <CompassDirectionLonger.NE: 85>),
 (14, <CompassDirectionLonger.SE: 78>),
 (15, <CompassDirectionLonger.SW: 66>),
 (16, <CompassDirectionLonger.NW: 89>),
 (17, <MiscDirection.UP: 60>),
 (18, <MiscDirection.DOWN: 62>),
 (19, <MiscDirection.WAIT: 46>),
 (20, <Command.KICK: 4>),
 (21, <Command.EAT: 101>),
 (22, <Command.SEARCH: 115>)]

In [17]:
class RandomAgent():
    def __init__(self):
        self.actions = [0, 1, 2, 3, 4, 5, 6, 7, 8]
    
    def act(self):
        return random.choice(self.actions)

In [18]:
class TerminalStream():
    def __init__(self, env, save_dir='nh-runs'):
        # Indices for getting terminal screen data
        self.tty_chars_idx = env._observation_keys.index('tty_chars')
        self.tty_colors_idx = env._observation_keys.index('tty_colors')
        self.tty_cursor_idx = env._observation_keys.index('tty_cursor')
        
        # Total frames collected
        self.frame_counter = 0
        
        # List of np arrays holding collected data, empty rn
        self.tty_chars_stack = []
        self.tty_colors_stack = []
        self.tty_cursors_stack = []
        
        # If no savedir, then leave
        if save_dir is None:
            self.save_path = None
            return 
        
        # Otherwise, make sure path exists, and reset run (name run)
        self.save_path = Path(save_dir)
        if not self.save_path.is_dir():
            self.save_path.mkdir(parents=True)
        self.reset_run()
        
        # Finally, just set the correct function for ansi playback / recording
        try:
            from nle.nethack import tty_render as nle_tty_render
            self.tty_render_version = 'tty_render'
        except ImportError:
            print(f'Warning: tty_render function not found in nle.nethack. Searching in nle.env.base.')
            from nle.env.base import NLE
            self.tty_render_version = 'NLE'
    
    # reset character name, save path
    def reset_run(self):
        self.run_name = f'{env.character}_{datetime.now().strftime("%X_%x").replace("/","-").replace(":","-")}'
        self.run_path = self.save_path / Path(f'{self.run_name}.pickle')
    
    # drop data as long as save path set
    def save_data(self):
        # Exit if save path is n/a
        if self.save_path is None:
            return
        
        # Make a block with the data to save, and dump it as a pickle to the run folder
        save_block = {
            'frame_counter': self.frame_counter, 
            'chars_stack': self.tty_chars_stack, 
            'colors_stack': self.tty_colors_stack,
            'cursors_stack': self.tty_cursors_stack,
        }
        self.run_path.write_bytes(pickle.dumps(save_block))

        # Reset arrays to hold only last data
        self.tty_chars_stack = []
        self.tty_colors_stack = []
        self.tty_cursors_stack = []
    
    def load_data(self, path):
        file_path = Path(path)
        save_block = pickle.loads(file_path.read_bytes())
        self.frame_counter = save_block['frame_counter']
        self.tty_chars_stack = save_block['chars_stack']
        self.tty_colors_stack = save_block['colors_stack']
        self.tty_cursors_stack = save_block['cursors_stack']
    
    def record(self, env):
        self.tty_chars_stack += [env.last_observation[self.tty_chars_idx]]
        self.tty_colors_stack += [env.last_observation[self.tty_colors_idx]]
        self.tty_cursors_stack += [env.last_observation[self.tty_cursor_idx]]
        self.frame_counter += 1
    
    def finish(self):
        if self.save_path is not None:
            self.save_data()
    
    def get_frame_text(self, tty_chars, tty_colors, tty_cursor):
        if self.tty_render_version == 'tty_render':
            return nle_tty_render(tty_chars, tty_colors, tty_cursor)
        else:
            return NLE().get_tty_rendering(tty_chars, tty_colors)
    
    def print_frame(self, tty_chars, tty_colors):
        print(self.get_frame_text(tty_chars, tty_colors))
    
    def render(self, export_path=None, export=True):
        if not export:
            return
        export_path = self.run_path.with_suffix('.html')
        pass

In [23]:
class GameEnv():
    def __init__(self, agent):
        self.agent = agent
        self.env = gym.make('NetHackScore-v0')
        self.tty_stream = TerminalStream(self.env)
    
    def reset(self):
        self.env.reset()
        self.tty_stream.record(self.env)
    
    def step(self):
        act = self.agent.act()
        print(act)
        obs, reward, done, info = self.env.step(act)
        return obs, reward, done
    
    def run(self, total_steps = 300):
        self.reset()
        for i in range(total_steps):
            obs, reward, done = self.step()
            self.tty_stream.record(self.env)
            if done:
                self.reset()
        self.tty_stream.finish()

In [24]:
agent = RandomAgent()
game = GameEnv(agent)
game.run()

7
5
7
3
8
1
5
3
0
4
5
1
4
8
1
2
2
6
6
5
4
8
1
8
5
0
0
6
7
7
8
4
0
1
7
2
2
8
4
0
3
6
6
2
6
6
7
2
1
6
5
7
4
5
3
6
4
6
6
3
2
6
0
1
2
2
4
4
6
5
5
6
6
8
3
6
2
5
5
7
5
2
4
1
5
2
6
7
2
1
8
6
8
4
8
2
4
2
5
8
3
3
6
0
8
5
1
4
5
2
2
5
8
1
4
6
6
2
3
8
3
0
7
7
1
1
0
7
7
5
6
7
1
4
5
4
7
0
6
8
1
3
3
2
5
0
3
0
5
3
4
7
3
1
8
7
3
0
7
6
5
5
0
1
1
2
7
5
5
6
4
6
4
4
5
8
5
3
0
3
6
1
7
2
7
3
8
2
8
8
3
5
6
3
6
8
0
4
3
6
4
4
2
8
5
3
6
1
4
7
4
2
2
7
4
1
0
6
5
1
7
4
6
5
3
0
6
4
0
4
3
2
8
8
0
8
8
2
3
4
0
2
4
0
7
1
6
5
5
6
0
2
3
0
3
4
7
0
5
3
1
1
1
2
1
3
4
8
3
3
8
4
8
4
2
1
3
1
8
2
0
2
7
7
7
7
6
4
0
3
2
2
6
0
2
2
1
5
1
6


In [13]:
print(''.join(obs))

In [11]:
pprint(obs)

{'blstats': array([53,  3, 17, 17, 12, 10, 10, 16, 10,  0, 14, 14,  1,  0,  5,  5,  4,
        0,  1,  0,  1,  1,  0,  0,  1,  0]),
 'chars': array([[32, 32, 32, ..., 32, 32, 32],
       [32, 32, 32, ..., 32, 32, 32],
       [32, 32, 32, ..., 32, 32, 32],
       ...,
       [32, 32, 32, ..., 32, 32, 32],
       [32, 32, 32, ..., 32, 32, 32],
       [32, 32, 32, ..., 32, 32, 32]], dtype=uint8),
 'colors': array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
 'glyphs': array([[2359, 2359, 2359, ..., 2359, 2359, 2359],
       [2359, 2359, 2359, ..., 2359, 2359, 2359],
       [2359, 2359, 2359, ..., 2359, 2359, 2359],
       ...,
       [2359, 2359, 2359, ..., 2359, 2359, 2359],
       [2359, 2359, 2359, ..., 2359, 2359, 2359],
       [2359, 2359, 2359, ..., 2359, 2359, 2359]], dtype=int16),
 'inv_glyphs': array([2043, 2028, 2