# Install NLE

In [None]:
! apt update -qq && apt install -qq -y flex bison libbz2-dev libglib2.0 libsm6 libxext6 cmake
! pip install gnuplotlib
! apt-get install -y gnuplot
! pip install nle
! pip install "nle[agent]"
! pip install torch-geometric

55 packages can be upgraded. Run 'apt list --upgradable' to see them.
[1;33mW: [0mSkipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)[0m
[1;31mE: [0mUnable to locate package libglib2.0[0m
[1;31mE: [0mCouldn't find any package by glob 'libglib2.0'[0m
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
gnuplot is already the newest version (5.4.2+dfsg2-2).
0 upgraded, 0 newly installed, 0 to remove and 55 not upgraded.


In [None]:
# From: https://github.com/facebookresearch/nle/issues/359#issue-1782082844
# !sudo apt-get install -y build-essential autoconf libtool pkg-config \
#     python3-dev python3-pip python3-numpy git flex bison libbz2-dev
# ! pip install --no-use-pep517 nle
# ! pip install "nle[agent]"

# Faile
# !apt update -qq && apt install -qq -y flex bison libbz2-dev libglib2.0 libsm6 libxext6 cmake
# !pip install nle==0.9.0

In [None]:
import os
os.chdir("/content/drive/My Drive/Colab Notebooks")

# Base Model

In [None]:
import argparse
import logging
import pprint
import threading
import time
import timeit
import traceback
import shutil
import numpy as np

# Necessary for multithreading.
os.environ["OMP_NUM_THREADS"] = "1"
load_model = True
f_dir = 'torchbeast/plots/Relational_v21.3_3.tar'
log_time = 100

act_list = [
    "N", "E", "S", "W", "NE", "SE", "SW", "NW",  # CompassDirection
    "N_", "E_", "S_", "W_", "NE_", "SE_", "SW_", "NW_",  # CompassDirectionLonger
    "UP", "DOWN", "WAIT", "MORE",           # MiscDirection and MiscAction
    "ADJUST", "APPLY", "ATTRIBUTES", "CALL", "CAST", "CHAT", "CLOSE", "DIP", "DROP", "DROPTYPE",
    "EAT", "ENGRAVE", "ENHANCE", "ESC", "FIGHT", "FIRE", "FORCE", "INVENTORY", "INVENTTYPE", "INVOKE",
    "JUMP", "KICK", "LOOK", "LOOT", "MONSTER", "MOVE", "MOVEFAR", "OFFER", "OPEN", "PAY", "PICKUP",
    "PRAY", "PUTON", "QUAFF", "QUIVER", "READ", "REMOVE", "RIDE", "RUB", "RUSH", "RUSH2", "SEARCH",
    "SEEARMOR", "SEERINGS", "SEETOOLS", "SEETRAP", "SEEWEAPON", "SHELL", "SIT", "SWAP", "TAKEOFF",
    "TAKEOFFALL", "THROW", "TIP", "TURN", "TWOWEAPON", "UNTRAP", "VERSIONSHORT", "WEAR", "WIELD",
    "WIPE", "ZAP",
    "PLUS", "QUOTE", "DOLLAR", "SPACE"         # TextCharacters
]

act_useful = [
    "MORE",
    # C-m
    "N", "E", "S", "W", "NE", "SE", "SW", "NW",
    # k  l  j   h  u   n   b   y
    "N_", "E_", "S_", "W_", "NE_", "SE_", "SW_", "NW_",
    "UP", "DOWN", "WAIT", "KICK", "EAT", "SEARCH",
    # <   >    .   C-d    e    s
     "PICKUP", "ESC", "DROP", "LOOK",
    #  ,         d    :
     "WIELD" ,"PUTON", "REMOVE", "WEAR", "TAKEOFF",
    #  w    P     R     W     T
     "APPLY", "CLOSE", "FIRE", "RUSH", "INVENTORY", "MOVE",
    #  a    c     f    g     i     m
     "OPEN", "PAY", "QUAFF",# "READ", "TAKEOFFALL", "UNTRAP", "ZAP", "CAST",
    #  o    p    q
]

# act_index = [act_list.index(i) for i in act_useful]
act_mask = [1 if action in act_useful else 0 for action in act_list]


try:
    import torch
    from torch import multiprocessing as mp
    from torch import nn
    from torch.nn import functional as F
except ImportError:
    logging.exception(
        "PyTorch not found. Please install the agent dependencies with "
        '`pip install "nle[agent]"`'
    )

import gymnasium as gym  # noqa: E402

import nle  # noqa: F401, E402
from nle import nethack  # noqa: E402
from nle.agent import vtrace  # noqa: E402

# yapf: disable
parser = argparse.ArgumentParser(description="PyTorch Scalable Agent")

parser.add_argument("--env", type=str, default="NetHackScore-v0",
                    help="Gym environment.")
parser.add_argument("--mode", default="train",
                    choices=["train", "test", "test_render"],
                    help="Training or test mode.")

# Training settings.
parser.add_argument("--disable_checkpoint", action="store_true",
                    help="Disable saving checkpoint.")
parser.add_argument("--savedir", default="~/torchbeast/",
                    help="Root dir where experiment data will be saved.")
parser.add_argument("--num_actors", default=4, type=int, metavar="N",
                    help="Number of actors (default: 4).")
parser.add_argument("--total_steps", default=100000, type=int, metavar="T",
                    help="Total environment steps to train for.")
parser.add_argument("--total_steps_", default=100000, type=int, metavar="T",
                    help="Total environment steps each time to train for.")
parser.add_argument("--batch_size", default=8, type=int, metavar="B",
                    help="Learner batch size.")
parser.add_argument("--unroll_length", default=80, type=int, metavar="T",
                    help="The unroll length (time dimension).")
parser.add_argument("--num_buffers", default=None, type=int,
                    metavar="N", help="Number of shared-memory buffers.")
parser.add_argument("--num_learner_threads", "--num_threads", default=2, type=int,
                    metavar="N", help="Number learner threads.")
parser.add_argument("--disable_cuda", action="store_true",
                    help="Disable CUDA.")
parser.add_argument("--use_lstm", action="store_true",
                    help="Use LSTM in agent model.")
parser.add_argument("--save_ttyrec_every", default=1000, type=int,
                    metavar="N", help="Save ttyrec every N episodes.")


# Loss settings.
parser.add_argument("--entropy_cost", default=0.0006,
                    type=float, help="Entropy cost/multiplier.")
parser.add_argument("--baseline_cost", default=0.5,
                    type=float, help="Baseline cost/multiplier.")
parser.add_argument("--discounting", default=0.99,
                    type=float, help="Discounting factor.")
parser.add_argument("--reward_clipping", default="abs_one",
                    choices=["abs_one", "none"],
                    help="Reward clipping.")

# Optimizer settings.
parser.add_argument("--learning_rate", default=0.00048,
                    type=float, metavar="LR", help="Learning rate.")
parser.add_argument("--alpha", default=0.99, type=float,
                    help="RMSProp smoothing constant.")
parser.add_argument("--momentum", default=0, type=float,
                    help="RMSProp momentum.")
parser.add_argument("--epsilon", default=0.01, type=float,
                    help="RMSProp epsilon.")
parser.add_argument("--grad_norm_clipping", default=40.0, type=float,
                    help="Global gradient norm clip.")
# yapf: enable


logging.basicConfig(
    format=(
        "[%(levelname)s:%(process)d %(module)s:%(lineno)d %(asctime)s] " "%(message)s"
    ),
    level=logging.INFO,
)


def nested_map(f, n):
    if isinstance(n, tuple) or isinstance(n, list):
        return n.__class__(nested_map(f, sn) for sn in n)
    if isinstance(n, dict):
        return {k: nested_map(f, v) for k, v in n.items()}
    return f(n)


def compute_baseline_loss(advantages):
    return 0.5 * torch.sum(advantages**2)


def compute_entropy_loss(logits):
    """Return the entropy loss, i.e., the negative entropy of the policy."""
    policy = F.softmax(logits, dim=-1)
    log_policy = F.log_softmax(logits, dim=-1)
    return torch.sum(policy * log_policy)


def compute_policy_gradient_loss(logits, actions, advantages):
    cross_entropy = F.nll_loss(
        F.log_softmax(torch.flatten(logits, 0, 1), dim=-1),
        target=torch.flatten(actions, 0, 1),
        reduction="none",
    )
    cross_entropy = cross_entropy.view_as(advantages)
    return torch.sum(cross_entropy * advantages.detach())


def create_env(name, *args, **kwargs):
    return gym.make(name, observation_keys=("glyphs", "blstats", "message", "inv_glyphs", "inv_letters", "inv_oclasses", "inv_strs"), *args, **kwargs)  # noqa: B026
    # "inv_glyphs","inv_letters", "inv_oclasses", "inv_strs"

def act(
    flags,
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    buffers,
    initial_agent_state_buffers,
):
    try:
        logging.info("Actor %i started.", actor_index)

        gym_env = create_env(
            flags.env, savedir=flags.rundir, save_ttyrec_every=flags.save_ttyrec_every
        )
        env = ResettingEnvironment(gym_env)
        env_output = env.initial()
        agent_state = model.initial_state(batch_size=1)
        agent_output, unused_state = model(env_output, agent_state)
        while True:
            index = free_queue.get()
            if index is None:
                break

            # Write old rollout end.
            for key in env_output:
                buffers[key][index][0, ...] = env_output[key]
            for key in agent_output:
                buffers[key][index][0, ...] = agent_output[key]
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor

            # Do new rollout.
            for t in range(flags.unroll_length):
                with torch.no_grad():
                    agent_output, agent_state = model(env_output, agent_state)

                env_output = env.step(agent_output["action"])

                for key in env_output:
                    buffers[key][index][t + 1, ...] = env_output[key]
                for key in agent_output:
                    buffers[key][index][t + 1, ...] = agent_output[key]

            full_queue.put(index)

    except KeyboardInterrupt:
        pass  # Return silently.
    except Exception:
        logging.error("Exception in worker process %i", actor_index)
        traceback.print_exc()
        print()
        raise


def get_batch(
    flags,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    buffers,
    initial_agent_state_buffers,
    lock=threading.Lock(),
):
    with lock:
        indices = [full_queue.get() for _ in range(flags.batch_size)]
    batch = {
        key: torch.stack([buffers[key][m] for m in indices], dim=1) for key in buffers
    }
    initial_agent_state = (
        torch.cat(ts, dim=1)
        for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
    )
    for m in indices:
        free_queue.put(m)
    batch = {k: t.to(device=flags.device, non_blocking=True) for k, t in batch.items()}
    initial_agent_state = tuple(
        t.to(device=flags.device, non_blocking=True) for t in initial_agent_state
    )
    return batch, initial_agent_state


def learn(
    flags,
    actor_model,
    model,
    batch,
    initial_agent_state,
    optimizer,
    scheduler,
    lock=threading.Lock(),  # noqa: B008
):
    """Performs a learning (optimization) step."""
    with lock:
        learner_outputs, unused_state = model(batch, initial_agent_state)

        # Take final value function slice for bootstrapping.
        bootstrap_value = learner_outputs["baseline"][-1]

        # Move from obs[t] -> action[t] to action[t] -> obs[t].
        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {key: tensor[:-1] for key, tensor in learner_outputs.items()}

        rewards = batch["reward"]


        if flags.reward_clipping == "abs_one":
            clipped_rewards = torch.clamp(rewards, -1, 1)
        elif flags.reward_clipping == "none":
            clipped_rewards = rewards

        discounts = (~batch["done"]).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch["policy_logits"],
            target_policy_logits=learner_outputs["policy_logits"],
            actions=batch["action"],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs["baseline"],
            bootstrap_value=bootstrap_value,
        )

        pg_loss = compute_policy_gradient_loss(
            learner_outputs["policy_logits"],
            batch["action"],
            vtrace_returns.pg_advantages,
        )
        baseline_loss = flags.baseline_cost * compute_baseline_loss(
            vtrace_returns.vs - learner_outputs["baseline"]
        )
        entropy_loss = flags.entropy_cost * compute_entropy_loss(
            learner_outputs["policy_logits"]
        )

        total_loss = pg_loss + baseline_loss + entropy_loss

        # TODO
        acts = learner_outputs["action"]
        elements, counts = torch.unique(acts, return_counts=True)
        combined = list(zip([act_list[i] for i in elements.tolist()], counts.tolist()))
        combined_ = sorted(combined, key=lambda x: x[1], reverse=True)
        episode_returns = batch["episode_return"][batch["done"]]

        stats = {
            "episode_returns": tuple(episode_returns.cpu().numpy()[:10]),
            "mean_episode_return": torch.mean(episode_returns).item(),
            "total_loss": total_loss.item(),
            "pg_loss": pg_loss.item(),
            "baseline_loss": baseline_loss.item(),
            "entropy_loss": entropy_loss.item(),
            #"actions": elements,
            "action_counts": combined_[:30],
            "num_episodes": episode_returns.shape[0]
        }

        optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping)
        optimizer.step()
        scheduler.step()

        actor_model.load_state_dict(model.state_dict())
        return stats


def create_buffers(flags, observation_space, num_actions, num_overlapping_steps=1):
    size = (flags.unroll_length + num_overlapping_steps,)

    # Get specimens to infer shapes and dtypes.
    samples = {k: torch.from_numpy(v) for k, v in observation_space.sample().items()}

    specs = {
        key: dict(size=size + sample.shape, dtype=sample.dtype)
        for key, sample in samples.items()
    }
    specs.update(
        reward=dict(size=size, dtype=torch.float32),
        done=dict(size=size, dtype=torch.bool),
        episode_return=dict(size=size, dtype=torch.float32),
        episode_step=dict(size=size, dtype=torch.int32),
        policy_logits=dict(size=size + (num_actions,), dtype=torch.float32),
        baseline=dict(size=size, dtype=torch.float32),
        last_action=dict(size=size, dtype=torch.int64),
        action=dict(size=size, dtype=torch.int64),
    )
    buffers = {key: [] for key in specs}
    for _ in range(flags.num_buffers):
        for key in buffers:
            buffers[key].append(torch.empty(**specs[key]).share_memory_())
    return buffers


def _format_observations(observation, keys=("glyphs", "blstats", "message", "inv_glyphs", "inv_letters", "inv_oclasses", "inv_strs")):
    observations = {}
    for key in keys:
        entry = observation[key]
        entry = torch.from_numpy(entry)
        entry = entry.view((1, 1) + entry.shape)  # (...) -> (T,B,...).
        observations[key] = entry
    return observations


class ResettingEnvironment:
    """Turns a Gym environment into something that can be step()ed indefinitely."""

    def __init__(self, gym_env):
        self.gym_env = gym_env
        self.episode_return = None
        self.episode_step = None
        self.depth = 1
        self.hungry = 1
        self.glyphs = None
        self.glyphs_sum = None
        self.attr_index = [3, 4, 5, 6, 7, 8, 11, 15, 18]
        self.attributes = None
        self.AC = None
        self.T = 1
        self.frozen_init = -3
        self.loop_init = -20
        self.frozen_step = self.frozen_init
        self.loop_step = self.loop_init



    def initial(self):
        initial_reward = torch.zeros(1, 1)
        # This supports only single-tensor actions ATM.
        initial_last_action = torch.zeros(1, 1, dtype=torch.int64)
        self.episode_return = torch.zeros(1, 1)
        self.episode_step = torch.zeros(1, 1, dtype=torch.int32)
        initial_done = torch.ones(1, 1, dtype=torch.uint8)
        obs, reset_info = self.gym_env.reset()
        self.depth = obs['blstats'][12]
        self.hungry = obs['blstats'][21]
        self.glyphs = np.count_nonzero(obs['glyphs'] == 2359)
        self.glyphs_sum = np.sum(obs['glyphs'])
        self.attributes = sum(obs['blstats'][pos] for pos in self.attr_index)
        self.AC = obs['blstats'][16]

        self.T = obs['blstats'][20]
        self.frozen_step = self.frozen_init
        self.loop_step = self.loop_init
        result = _format_observations(obs)
        result.update(
            reward=initial_reward,
            done=initial_done,
            episode_return=self.episode_return,
            episode_step=self.episode_step,
            last_action=initial_last_action,
        )
        return result

    def step(self, action):
        observation, reward, done, truncated, unused_info = self.gym_env.step(
            action.item()
        )
        self.episode_step += 1

        self.episode_return += reward
        episode_step = self.episode_step
        episode_return = self.episode_return

        ### reward shaping
        # + gain item
        # + gain non-crop food
        # - CC diff
        # + lv diff
        # - in trap
        #reward_ = torch.tensor(reward).view(1, 1)

        # origin 4*exp 10
        reward = np.tanh(reward)*2

        # attributes 3 4 5 6 7 8 / 11 13 15 16 18 HP Gold Mana AC(negative) lvl -----------------------------------------
        attributes = sum(observation['blstats'][pos] for pos in self.attr_index)
        reward += max(0,attributes-self.attributes)*2
        self.attributes = max(attributes,self.attributes)
        # AC
        AC = observation['blstats'][16]
        reward += max(0,self.AC-AC)*2
        self.AC = min(AC,self.AC)

        # hungry 21
        hungry = observation['blstats'][21]
        reward += max(0,self.hungry-hungry)*3
        self.hungry = hungry

        # discover 2359
        glyphs = np.count_nonzero(observation['glyphs'] == 2359)
        glyphs_diff = np.tanh(self.glyphs-glyphs)
        reward += max(0,glyphs_diff)
        self.glyphs = glyphs

        # depth 12
        depth = observation['blstats'][12]
        reward += max(0,depth-self.depth-glyphs_diff)*5
        self.depth = max(depth,self.depth)

        # time 20
        T = observation['blstats'][20]
        if T == self.T:
            self.frozen_step += 1
        else:
            self.frozen_step = self.frozen_init
        reward -= max(0,self.frozen_step)*0.001
        self.T = T

        # loop
        glyphs_sum = np.sum(observation['glyphs'])
        if glyphs_sum == self.glyphs_sum:
            self.loop_step += 1
        else:
            self.loop_step = self.loop_init
        reward -= max(0,self.frozen_step)*0.001
        self.glyphs_sum = glyphs_sum

        reward = np.tanh(reward)
        # print(f"reward is {reward}"
        ###

        if done:
            observation, reset_info = self.gym_env.reset()
            self.episode_return = torch.zeros(1, 1)
            self.episode_step = torch.zeros(1, 1, dtype=torch.int32)
            self.depth = 1
            self.hungry = 1
            self.glyphs = np.count_nonzero(observation['glyphs'] == 2359)
            self.glyphs_sum = np.sum(observation['glyphs'])
            self.attributes = sum(observation['blstats'][pos] for pos in self.attr_index)
            self.AC = observation['blstats'][16]
            self.T = observation['blstats'][20]
            self.frozen_step = self.frozen_init
            self.loop_step = self.loop_init

        result = _format_observations(observation)

        reward = torch.tensor(reward).view(1, 1)
        done = torch.tensor(done).view(1, 1)

        result.update(
            reward=reward,
            #reward_=reward_,
            done=done,
            episode_return=episode_return,
            episode_step=episode_step,
            last_action=action,
        )
        return result

    def close(self):
        self.gym_env.close()


def train(flags):  # pylint: disable=too-many-branches, too-many-statements
    flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir))

    rundir = os.path.join(
        flags.savedir, "torchbeast-%s" % time.strftime("%Y%m%d-%H%M%S")
    )

    if not os.path.exists(rundir):
        os.makedirs(rundir)
    logging.info("Logging results to %s", rundir)

    symlink = os.path.join(flags.savedir, "latest")
    try:
        if os.path.islink(symlink):
            os.remove(symlink)
        if not os.path.exists(symlink):
            os.symlink(rundir, symlink)
        logging.info("Symlinked log directory: %s", symlink)
    except OSError:
        raise

    logfile = open(os.path.join(rundir, "logs.tsv"), "a", buffering=1)

    checkpointpath = os.path.join(rundir, "model.tar")

    flags.rundir = rundir

    if flags.num_buffers is None:  # Set sensible default for num_buffers.
        flags.num_buffers = max(2 * flags.num_actors, flags.batch_size)
    if flags.num_actors >= flags.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")
    if flags.num_buffers < flags.batch_size:
        raise ValueError("num_buffers should be larger than batch_size")

    T = flags.unroll_length
    B = flags.batch_size

    flags.device = None
    if not flags.disable_cuda and torch.cuda.is_available():
        logging.info("Using CUDA.")
        flags.device = torch.device("cuda")
    else:
        logging.info("Not using CUDA.")
        flags.device = torch.device("cpu")

    env = create_env(flags.env)
    observation_space = env.observation_space
    action_space = env.action_space
    del env  # End this before forking.

    model = Net(observation_space, action_space.n, flags.use_lstm)

    # load model #
    if load_model:
        ckpt = torch.load(f_dir, map_location="cpu")
        state_dict = ckpt["model_state_dict"]

        # if 'act_mask' in state_dict:
        #     del state_dict['act_mask']

        model.load_state_dict(state_dict,)
        model.reset_act_mask(act_mask)


    buffers = create_buffers(flags, observation_space, model.num_actions)

    model.share_memory()

    # Add initial RNN state.
    initial_agent_state_buffers = []
    for _ in range(flags.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    for i in range(flags.num_actors):
        actor = ctx.Process(
            target=act,
            args=(
                flags,
                i,
                free_queue,
                full_queue,
                model,
                buffers,
                initial_agent_state_buffers,
            ),
            name="Actor-%i" % i,
        )
        actor.start()
        actor_processes.append(actor)

    learner_model = Net(observation_space, action_space.n, flags.use_lstm).to(
        device=flags.device
    )
    learner_model.load_state_dict(model.state_dict())

    optimizer = torch.optim.RMSprop(
        learner_model.parameters(),
        lr=flags.learning_rate,
        momentum=flags.momentum,
        eps=flags.epsilon,
        alpha=flags.alpha,
    )

    # optimizer = torch.optim.Adam(
    #     learner_model.parameters(),
    #     lr=flags.learning_rate,
    #     eps=flags.epsilon,
    # )

    def lr_lambda(epoch):
        return 1 - min(epoch * T * B, flags.total_steps) / flags.total_steps

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    if load_model:
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        scheduler.load_state_dict(ckpt["scheduler_state_dict"])
        last_epoch = ckpt["scheduler_state_dict"]["last_epoch"]
        step = last_epoch*T*B
    else:
        step = 0

    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
        "num_episodes",
    ]
    logfile.write("# Step\t%s\n" % "\t".join(stat_keys))

    step_, stats = 0, {}

    def batch_and_learn(i, lock=threading.Lock()):
        """Thread target for the learning process."""
        nonlocal step, step_, stats
        while step_ < flags.total_steps_:
            batch, agent_state = get_batch(
                flags, free_queue, full_queue, buffers, initial_agent_state_buffers
            )

            stats = learn(
                flags, model, learner_model, batch, agent_state, optimizer, scheduler
            )
            with lock:
                logfile.write("%i\t" % step)
                logfile.write("\t".join(str(stats[k]) for k in stat_keys))
                logfile.write("\n")
                step += T * B
                step_ += T * B

    for m in range(flags.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(flags.num_learner_threads):
        thread = threading.Thread(
            target=batch_and_learn,
            name="batch-and-learn-%d" % i,
            args=(i,),
            daemon=True,  # To support KeyboardInterrupt below.
        )
        thread.start()
        threads.append(thread)

    def checkpoint():
        if flags.disable_checkpoint:
            return
        logging.info("Saving checkpoint to %s", checkpointpath)
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "flags": vars(flags),
            },
            checkpointpath,
        )

    timer = timeit.default_timer
    try:
        last_checkpoint_time = timer()
        start_step = 0 ####
        while step_ < flags.total_steps_:

            if start_step == step_: ####
                continue
            else:
                start_step = step_
            start_time = timer()
            time.sleep(log_time)

            if timer() - last_checkpoint_time > 30 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()
                source_path = os.path.join(rundir, "logs.tsv")
                destination_path = os.path.join(rundir, "logs_.tsv")
                shutil.copy(source_path, destination_path)

            #sps = (step - start_step) / (timer() - start_time)
            sps = timer()
            if stats.get("episode_returns", None):
                mean_return = (
                    "Return per episode: %.1f. " % stats["mean_episode_return"]
                )
            else:
                mean_return = ""
            total_loss = stats.get("total_loss", float("inf"))
            logging.info(
                "Steps %i ( %i of %i ) @ %.1f SPS. Loss %f. %sStats:\n%s",
                step_,
                step,
                flags.total_steps,
                sps,
                total_loss,
                mean_return,
                pprint.pformat(stats),
            )
    except KeyboardInterrupt:
        logging.warning("Quitting.")
        return  # Try joining actors then quit.
    else:
        for thread in threads:
            thread.join()
        logging.info("Learning finished after %d steps.", step)
    finally:
        for _ in range(flags.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)

    checkpoint()
    logfile.close()


def test(flags, num_episodes=10):
    flags.savedir = os.path.expandvars(os.path.expanduser(flags.savedir))
    checkpointpath = os.path.join(flags.savedir, "GAT_TRAN_test_4", "model.tar")

    gym_env = create_env(flags.env, save_ttyrec_every=flags.save_ttyrec_every)
    env = ResettingEnvironment(gym_env)
    model = Net(gym_env.observation_space, gym_env.action_space.n, flags.use_lstm)
    model.eval()
    checkpoint = torch.load(checkpointpath, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])

    observation = env.initial()
    returns = []

    agent_state = model.initial_state(batch_size=1)

    while len(returns) < num_episodes:
        if flags.mode == "test_render":
            env.gym_env.render()
        policy_outputs, agent_state = model(observation, agent_state)
        observation = env.step(policy_outputs["action"])
        if observation["done"].item():
            returns.append(observation["episode_return"].item())
            logging.info(
                "Episode ended after %d steps. Return: %.1f",
                observation["episode_step"].item(),
                observation["episode_return"].item(),
            )
    env.close()
    logging.info(
        "Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns)
    )


class RandomNet(nn.Module):
    def __init__(self, observation_shape, num_actions, use_lstm):
        super(RandomNet, self).__init__()
        del observation_shape, use_lstm
        self.num_actions = num_actions
        self.theta = torch.nn.Parameter(torch.zeros(self.num_actions))

    def forward(self, inputs, core_state):
        # print(inputs)
        T, B, *_ = inputs["observation"].shape
        zeros = self.theta * 0
        # set logits to 0
        policy_logits = zeros[None, :].expand(T * B, -1)
        # set baseline to 0
        baseline = policy_logits.sum(dim=1).view(-1, B)

        # sample random action
        action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1).view(
            T, B
        )
        policy_logits = policy_logits.view(T, B, self.num_actions)
        return (
            dict(policy_logits=policy_logits, baseline=baseline, action=action),
            core_state,
        )

    def initial_state(self, batch_size):
        return ()


def _step_to_range(delta, num_steps):
    """Range of `num_steps` integers with distance `delta` centered around zero."""
    return delta * torch.arange(-num_steps // 2, num_steps // 2)


class Crop(nn.Module):
    """Helper class for NetHackNet below."""

    def __init__(self, height, width, height_target, width_target):
        super(Crop, self).__init__()
        self.width = width
        self.height = height
        self.width_target = width_target
        self.height_target = height_target
        width_grid = _step_to_range(2 / (self.width - 1), self.width_target)[
            None, :
        ].expand(self.height_target, -1)
        height_grid = _step_to_range(2 / (self.height - 1), height_target)[
            :, None
        ].expand(-1, self.width_target)

        # "clone" necessary, https://github.com/pytorch/pytorch/issues/34880
        self.register_buffer("width_grid", width_grid.clone())
        self.register_buffer("height_grid", height_grid.clone())

    def forward(self, inputs, coordinates):
        """Calculates centered crop around given x,y coordinates.
        Args:
           inputs [B x H x W]
           coordinates [B x 2] x,y coordinates
        Returns:
           [B x H' x W'] inputs cropped and centered around x,y coordinates.
        """
        assert inputs.shape[1] == self.height
        assert inputs.shape[2] == self.width

        inputs = inputs[:, None, :, :].float()

        x = coordinates[:, 0]+1
        y = coordinates[:, 1]+1

        x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2)
        y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2)

        grid = torch.stack(
            [
                self.width_grid[None, :, :] + x_shift[:, None, None],
                self.height_grid[None, :, :] + y_shift[:, None, None],
            ],
            dim=3,
        )

        # TODO: only cast to int if original tensor was int
        return (
            torch.round(F.grid_sample(inputs, grid, align_corners=True))
            .squeeze(1)
            .long()
        )


# relational

In [None]:
# 'chars': Box(0, 255, (21, 79), uint8), 'colors': Box(0, 15, (21, 79), uint8), 'specials': Box(0, 255, (21, 79), uint8),
# 'inv_glyphs': Box(0, 5976, (55,), int16), 'inv_letters': Box(0, 127, (55,), uint8), 'inv_oclasses': Box(0, 18, (55,), uint8), 'inv_strs': Box(0, 255, (55, 80), uint8),
# 'chars','colors','specials'
from torch.nn import Linear, LayerNorm
import numpy as np
from torch.nn.functional import one_hot

class NetHackNet_GAT(nn.Module):
    def __init__(
        self,
        observation_shape,
        num_actions,
        use_lstm,
        embedding_dim=32,
        crop_dim=13,
        num_layers=5,
    ):
        super(NetHackNet_GAT, self).__init__()

        self.act_mask = nn.Parameter((torch.tensor(act_mask, dtype=torch.float32)-1)*1e9, requires_grad=False)


        BLSTAT_NORMALIZATION_STATS = [[
            1.0 / 79.0, # hero col
            1.0 / 21, # hero row
            0.0, # strength pct
            1.0 / 10, # strength
            1.0 / 10, # dexterity
            1.0 / 10, # constitution
            1.0 / 10, # intelligence
            1.0 / 10, # wisdom
            1.0 / 10, # charisma
            0.0,      # score
            1.0 / 10, # hitpoints
            1.0 / 10, # max hitpoints
            1.0, # depth
            1.0 / 1000, # gold
            1.0 / 10, # energy
            1.0 / 10, # max energy
            1.0 / 10, # armor class
            0.0, # monster level
            1.0 / 10, # experience level
            1.0 / 100, # experience points
            1.0 / 1000, # time
            1.0, # hunger_state
            1.0 / 10, # carrying capacity
            0.0, # dungeon number
            0.0, # level number
            0.0, # condition bits
            0.0, # character alignment
            ]]
        self.blstats_scale = nn.Parameter(torch.tensor(BLSTAT_NORMALIZATION_STATS, dtype=torch.float32))
        self.BLSTAT_CLIP_RANGE = (0, 5)

        self.glyph_shape = observation_shape["glyphs"].shape
        self.blstats_size = observation_shape["blstats"].shape[0]
        self.message_size = observation_shape["message"].shape[0]
        self.inv_shape = observation_shape["inv_strs"].shape
        self.timer = timeit.default_timer
        self.t = self.timer()

        self.num_actions = num_actions
        self.base_model = False # not(use_lstm)
        self.use_message = True
        self.use_inv = True

        self.H = self.glyph_shape[0]
        self.W = self.glyph_shape[1]

        self.k_dim = embedding_dim  # glyph_dim
        self.s_dim = 128 # 128    blstats_dim
        self.att_fc_h_dim = 256
        self.att_fc_out_dim = 512   # 1024
        self.h_dim = 512   # 1024
        self.cnn_out_dim = self.s_dim   # 128

        self.crop_dim = crop_dim


        #--------------inv emb---------------#

        if self.use_inv:
            self.inv_embed = nn.Embedding(nethack.MAX_GLYPH+1, self.k_dim)

            self.inv_emb_norm = nn.LayerNorm(self.k_dim)

            self.embed_inv = nn.Sequential(
                nn.Linear(self.k_dim+self.inv_shape[-1]+1+19, 128),
                nn.ReLU(),
                nn.Linear(128, self.s_dim),
                nn.ReLU(),
                )
            self.inv_norm = nn.LayerNorm(self.s_dim)

        #--------------------------------------#



        #--------------CNN crop---------------#
        K = embedding_dim  # number of input filters
        F = 3  # filter dimensions
        S = 1  # stride
        P = 1  # padding
        M = 128  # number of intermediate filters   16
        Y = self.cnn_out_dim  # number of output filters  8
        L = num_layers  # number of convnet layers

        self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim)
        in_channels = [K] + [M] * (L - 1)
        out_channels = [M] * (L - 1) + [Y]
        def interleave(xs, ys):
            return [val for pair in zip(xs, ys) for val in pair]
        self.embed = nn.Embedding(nethack.MAX_GLYPH, self.k_dim)
        conv_extract_crop = [
            nn.Conv2d(
                in_channels=in_channels[i],
                out_channels=out_channels[i],
                kernel_size=(F, F),
                stride=S,
                padding=P,
            )
            for i in range(L)
        ]

        self.extract_crop_representation = nn.Sequential(
            *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract_crop))
        )
        self.CNN_residual_norm = nn.LayerNorm(Y)

        # self.conv_pool = nn.MaxPool2d(self.crop_dim,5)
        # self.CNN_residual_norm2 = nn.LayerNorm(Y)
        #--------------------------------------#

        #---------------baseCNN----------------#
        if self.base_model:
            in_channels_ = [self.cnn_out_dim] + [32] * 4
            out_channels_ = [32] * 4 + [self.cnn_out_dim]
            base_conv = [
                nn.Conv2d(
                    in_channels=in_channels_[i],
                    out_channels=out_channels_[i],
                    kernel_size=(3, 3),
                    stride=1,
                    padding=1,
                )
                for i in range(len(in_channels_))
            ]

            self.base_conv_representation = nn.Sequential(
                *interleave(base_conv, [nn.ELU()] * len(base_conv))
            )
            out_dim = self.att_fc_out_dim + self.s_dim
        #--------------------------------------#

        #----------------ATT-------------------#
        else:
            # encoder_layer = nn.TransformerEncoderLayer(d_model=self.cnn_out_dim, nhead=2, batch_first=True)
            # self.attention = nn.TransformerEncoder(encoder_layer, num_layers=2)
            self.attention = nn.MultiheadAttention(embed_dim=self.cnn_out_dim, num_heads=2, dropout=0.1, batch_first=True)
            self.attention_mlp = nn.Sequential(
                nn.Linear(self.cnn_out_dim, self.att_fc_h_dim),
                nn.ReLU(),
                nn.Linear(self.att_fc_h_dim, self.cnn_out_dim),
                nn.ReLU(),
            )
            self.dropout = nn.Dropout(0.1)

            self.att_norm1 = nn.LayerNorm(self.cnn_out_dim)
            self.att_norm2 = nn.LayerNorm(self.cnn_out_dim)
            self.att_norm3 = nn.LayerNorm(self.cnn_out_dim)
            self.att_norm4 = nn.LayerNorm(self.cnn_out_dim)
            self.dropout1 = nn.Dropout(0.1)
            self.dropout2 = nn.Dropout(0.1)

            out_dim = self.att_fc_out_dim
            self.att_fc = nn.Sequential(
                nn.Linear(self.cnn_out_dim, self.att_fc_h_dim),
                nn.ReLU(),
                nn.Linear(self.att_fc_h_dim, self.att_fc_out_dim),
                nn.ReLU(),
            )

        #--------------------------------------#

        if self.use_message:
            self.embed_message = nn.Sequential(
                nn.Linear(self.message_size, 128),
                nn.ReLU(),
                nn.Linear(128, self.s_dim),
                nn.ReLU(),
            )
            self.message_norm = nn.LayerNorm(self.s_dim)

        self.action_embed = nn.Embedding(self.num_actions, self.s_dim)
        self.embed_action = nn.Sequential(
            nn.Linear(self.s_dim, 128),
            nn.ReLU(),
            nn.Linear(128, self.s_dim),
            nn.ReLU(),
        )
        self.embed_blstats = nn.Sequential(
            nn.Linear(self.blstats_size, 128),
            nn.ReLU(),
            nn.Linear(128, self.s_dim),
            nn.ReLU(),
        )

        self.att_norm = nn.LayerNorm(self.att_fc_out_dim)
        self.blstats_norm = nn.LayerNorm(self.s_dim)
        self.action_norm = nn.LayerNorm(self.s_dim)


        self.fc = nn.Sequential(
            nn.Linear(out_dim, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
        )

        self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1)

        self.policy = nn.Linear(self.h_dim, self.num_actions)
        self.baseline = nn.Linear(self.h_dim, 1)

    def initial_state(self, batch_size=1):
        return tuple(
            torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
            for _ in range(2)
        )
    def reset_act_mask(self,a_mask=None):
        if a_mask is None:
            self.act_mask = None
        else:
            self.act_mask = nn.Parameter((torch.tensor(a_mask, dtype=torch.float32)-1)*1e9, requires_grad=False)

    def _select(self, embed, x):
        # Work around slow backward pass of nn.Embedding, see
        # https://github.com/pytorch/pytorch/issues/24912
        out = embed.weight.index_select(0, x.reshape(-1))
        return out.reshape(x.shape + (-1,))

    def forward(self, env_outputs, core_state):
        # print(env_outputs)
        # time.sleep(100)

        # -- [T x B x H x W]
        glyphs = env_outputs["glyphs"]
        T, B, *_ = glyphs.shape

        # -- [T x B x F]
        blstats = env_outputs["blstats"]
        # -- [B' x F]
        blstats = blstats.view(T * B, -1).float()
        coordinates = blstats[:, :2]

        # -- [T x B x 1]
        last_actions = env_outputs["last_action"]

        #----------blstats+action---------#
        # -- [B' x 1]
        last_actions = last_actions.view(T * B, -1)
        # -- [B' x emb]
        action_emb_ = self.action_embed(last_actions)
        action_emb = self.embed_action(action_emb_)

        ## scale
        blstats = blstats * self.blstats_scale
        blstats = torch.clamp(blstats, min=self.BLSTAT_CLIP_RANGE[0], max=self.BLSTAT_CLIP_RANGE[1])

        ##

        # -- [B' x 27]
        #blstats = torch.log1p(torch.nn.functional.relu(blstats))


        # -- [B' x 1+27]
        #blstats_action = torch.cat([action_emb.squeeze(1),blstats],dim=1)

        # -- [B' x K''']
        blstats_emb = self.embed_blstats(blstats)

        assert blstats_emb.shape[0] == T * B
        #----------------------------------#

        #----------glyphs CNN-------------#
        # -- [B' x H x W]
        glyphs = torch.flatten(glyphs, 0, 1)  # Merge time and batch.
        # -- [B' x H x W]
        glyphs = glyphs.long()
        # -- [B' x H' x W']
        crop = self.crop(glyphs, coordinates)
        # -- [B' x H' x W' x K]
        crop_emb = self._select(self.embed, crop)
        # CNN crop model.
        # -- [B' x K x W' x H']
        crop_emb = crop_emb.transpose(1, 3)  # -- TODO: slow?
        # -- [B' x K' x W' x H']
        crop_rep = self.extract_crop_representation(crop_emb)
        #--------------------------------------#

        ## full
        # glyphs_emb = self._select(self.embed, glyphs)
        # glyphs_emb = glyphs_emb.transpose(1, 3)
        # glyphs_rep_  = self.extract_crop_representation(glyphs_emb)
        # glyphs_rep = self.conv_pool(glyphs_rep_)
        ##

        # crop_rep_ += crop_emb
        # crop_rep_ = crop_rep_.transpose(1, 3)
        # crop_rep = self.CNN_residual_norm(crop_rep_)
        # # -- [B' x K' x W' x H']
        # crop_rep = crop_rep.transpose(1, 3)


        #-------------base CNN----------------#
        if self.base_model:
            # # -- [B' x K' x W' x H']
            observation_att = self.base_conv_representation(crop_rep)
            observation_att = observation_att.view(T * B, -1, self.cnn_out_dim)
            observation_rep_, max_ind = torch.max(observation_att,dim=1)
            # -- [B' x K']
            observation_rep = self.att_fc(observation_rep_)

            reps =[self.att_norm(observation_rep)]
            reps.append(self.blstats_norm(blstats_emb))
            reps.append(self.message_norm(message_emb))

            st = torch.cat(reps, dim=1)
            # -- [B x K]
            st = self.fc(st)

        #--------------------------------------#





        #----------------ATT-------------------#
        else:
            # -- [B' x W'H' x K']
            crop_rep = crop_rep.reshape(T * B, -1, self.cnn_out_dim)
            assert crop_rep.shape[0] == T * B
            # glyphs_rep = glyphs_rep.reshape(T * B, -1, self.cnn_out_dim)
            # assert glyphs_rep.shape[0] == T * B

            reps = [self.CNN_residual_norm(crop_rep)]
            # reps.append(self.CNN_residual_norm2(glyphs_rep))
            reps.append(self.blstats_norm(blstats_emb).unsqueeze(1))
            reps.append(self.action_norm(action_emb))

            if self.use_inv:
                inv_glyphs = env_outputs["inv_glyphs"]
                inv_glyphs_emb = self.inv_embed(inv_glyphs.long())
                inv_glyphs_emb = self.inv_emb_norm(inv_glyphs_emb)
                inv_letters = env_outputs["inv_letters"]/127
                inv_letters = inv_letters.unsqueeze(-1)
                inv_oclasses = torch.nn.functional.one_hot(env_outputs["inv_oclasses"].long(),19)
                inv_strs = env_outputs["inv_strs"]/255

                inv_emb_ = torch.cat((inv_glyphs_emb, inv_letters, inv_oclasses, inv_strs), dim=-1)
                inv_emb = self.embed_inv(inv_emb_)
                inv_emb = self.inv_norm(inv_emb)
                inv_emb = torch.flatten(inv_emb, 0, 1)
                reps.append(inv_emb)

            if self.use_message:
                # -- [T x B x F]
                message = env_outputs["message"]
                # -- [B' x F]
                message = message.view(T * B, -1).float()
                message_emb = self.embed_message(message/255)
                reps.append(self.message_norm(message_emb).unsqueeze(1))

            # -- [B' x W'H'+3 x K']
            crop_rep_ = torch.cat(reps, dim=1)

            # -- [B' x W'H' x K']
            #observation_att = self.attention(crop_rep_)

            observation_att_, att_w = self.attention(crop_rep_, crop_rep_, crop_rep_)
            observation_att = self.att_norm1(self.dropout1(observation_att_) + crop_rep_)
            observation_att_mlp = self.att_norm2(self.attention_mlp(observation_att)+observation_att)
            #observation_att_2, _ = self.attention(crop_rep_, crop_rep_, crop_rep_)
            observation_att = self.att_norm3(self.dropout2(observation_att_mlp) + observation_att_)
            observation_att_out = self.att_norm4(self.attention_mlp(observation_att)+observation_att)
            ## back
            # observation_att_, _ = self.attention(crop_rep_, crop_rep_, crop_rep_)
            # observation_att_ = self.att_norm1(self.dropout1(observation_att_) + crop_rep_)
            # #observation_att_2, _ = self.attention(crop_rep_, crop_rep_, crop_rep_)
            # observation_att = self.att_norm2(self.dropout2(observation_att_) + crop_rep_)



            # -- [B' x K']
            observation_rep_, max_ind = torch.max(observation_att_out,dim=1)
            # -- [B' x K']
            observation_rep = self.att_fc(observation_rep_)
            st = self.att_norm(observation_rep)
        #--------------------------------------#


        #----------------LSTM-------------------#
        core_input = st.view(T, B, -1)
        core_output_list = []
        notdone = (~env_outputs["done"]).float()
        for input, nd in zip(core_input.unbind(), notdone.unbind()):
            # Reset core state to zero whenever an episode ended.
            # Make `done` broadcastable with (num_layers, B, hidden_size)
            # states:
            nd = nd.view(1, -1, 1)
            core_state = tuple(nd * s for s in core_state)
            output, core_state = self.core(input.unsqueeze(0), core_state)
            core_output_list.append(output)
        # -- [T x B x K']
        core_output_ = torch.cat(core_output_list)

        # -- [B' x K']
        core_output = torch.flatten(core_output_, 0, 1)

        #---------------------------------------#

        # -- [B' x A]
        policy_logits = self.policy(core_output)
        # -- [B' x A]
        baseline = self.baseline(core_output)

        if self.act_mask is not None:
            policy_logits += self.act_mask

        if self.training:
            ##
            #policy_logits_ = policy_logits.clone()
            #policy_logits_[:, 1:17] *= (1 / 16)
            # print(f"logits/16 is {policy_logits_}")
            ##
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        # if last_actions.shape[0] == 1:
        #     if last_actions[0] == 21 or last_actions[0] == 18:
        #         if torch.rand(1) < 0.6:
        #             action[0] = 8


        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)
        return (
            dict(policy_logits=policy_logits, baseline=baseline, action=action),
            core_state,
        )

def main(flags):
    if flags.mode == "train":
        train(flags)
    else:
        test(flags)

# train

TODO:
action mask
exp replay actor


In [None]:
#Net = NetHackNet
Net = NetHackNet_GAT
logging.getLogger().setLevel(logging.INFO)
#logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
simulated_args = [
    '--env', "NetHack-v0",     # NetHackChallenge-v0 NetHackScore-v0
    '--savedir', "torchbeast/",
    '--num_actors', '80',
    '--batch_size', '32',
    '--unroll_length', '80',
    '--learning_rate', '0.0005',
    '--entropy_cost', '0.001',
    '--use_lstm',
    '--total_steps', '10000000000',
    '--total_steps_', '20000000'
]
flags = parser.parse_args(simulated_args)
main(flags)

## info

0 MORE 13 C-m read the next message

1 North 107 k   75 K

2 East 108 l    76 L

3 South 106 j   74 J

4 West 104 h    72 H

5 North East 117 u  85 U

6 South East 110 n  78 N

7 South West 98 b   66 B

8 North West 121 y  89 Y

17 UP 60 < go up (e.g., a staircase)

18 DOWN 62 > go down (e.g., a staircase)

19 WAIT / SELF 46 . rest one move while doing nothing / apply to self

20 KICK 4 C-d kick something

21 EAT 101 e eat something

22 SEARCH 115 s search for traps and secret doors

0 MiscAction.MORE
1 CompassDirection.N
2 CompassDirection.E
3 CompassDirection.S
4 CompassDirection.W
5 CompassDirection.NE
6 CompassDirection.SE
7 CompassDirection.SW
8 CompassDirection.NW
9 CompassDirectionLonger.N
10 CompassDirectionLonger.E
11 CompassDirectionLonger.S
12 CompassDirectionLonger.W
13 CompassDirectionLonger.NE
14 CompassDirectionLonger.SE
15 CompassDirectionLonger.SW
16 CompassDirectionLonger.NW
17 MiscDirection.UP
18 MiscDirection.DOWN
19 MiscDirection.WAIT
20 Command.KICK
21 Command.EAT
22 Command.SEARCH

<details>
<summary>Explanation of each key difference in the action space of nethack:</summary>

Discrete(86)
0 CompassDirection.N
1 CompassDirection.E
2 CompassDirection.S
3 CompassDirection.W
4 CompassDirection.NE
5 CompassDirection.SE
6 CompassDirection.SW
7 CompassDirection.NW
8 CompassDirectionLonger.N
9 CompassDirectionLonger.E
10 CompassDirectionLonger.S
11 CompassDirectionLonger.W
12 CompassDirectionLonger.NE
13 CompassDirectionLonger.SE
14 CompassDirectionLonger.SW
15 CompassDirectionLonger.NW
16 MiscDirection.UP
17 MiscDirection.DOWN
18 MiscDirection.WAIT
19 MiscAction.MORE
20 Command.ADJUST
21 Command.APPLY
22 Command.ATTRIBUTES
23 Command.CALL
24 Command.CAST
25 Command.CHAT
26 Command.CLOSE
27 Command.DIP
28 Command.DROP
29 Command.DROPTYPE
30 Command.EAT
31 Command.ENGRAVE
32 Command.ENHANCE
33 Command.ESC
34 Command.FIGHT
35 Command.FIRE
36 Command.FORCE
37 Command.INVENTORY
38 Command.INVENTTYPE
39 Command.INVOKE
40 Command.JUMP
41 Command.KICK
42 Command.LOOK
43 Command.LOOT
44 Command.MONSTER
45 Command.MOVE
46 Command.MOVEFAR
47 Command.OFFER
48 Command.OPEN
49 Command.PAY
50 Command.PICKUP
51 Command.PRAY
52 Command.PUTON
53 Command.QUAFF
54 Command.QUIVER
55 Command.READ
56 Command.REMOVE
57 Command.RIDE
58 Command.RUB
59 Command.RUSH
60 Command.RUSH2
61 Command.SEARCH
62 Command.SEEARMOR
63 Command.SEERINGS
64 Command.SEETOOLS
65 Command.SEETRAP
66 Command.SEEWEAPON
67 Command.SHELL
68 Command.SIT
69 Command.SWAP
70 Command.TAKEOFF
71 Command.TAKEOFFALL
72 Command.THROW
73 Command.TIP
74 Command.TURN
75 Command.TWOWEAPON
76 Command.UNTRAP
77 Command.VERSIONSHORT
78 Command.WEAR
79 Command.WIELD
80 Command.WIPE
81 Command.ZAP
82 TextCharacters.PLUS
83 TextCharacters.QUOTE
84 TextCharacters.DOLLAR
85 TextCharacters.SPACE

# Eval

In [None]:
def stack_ob(obs_list, max_len=80):
    # Trim the list to the last max_len elements if it's too long
    if len(obs_list) > max_len:
        obs_list = obs_list[-max_len:]

    # Stack each key's values
    stacked_obs = {}
    for key in obs_list[0].keys():
        try:
            stacked_obs[key] = torch.cat([obs[key] for obs in obs_list], dim=0)
        except:
            pass

    return stacked_obs

In [None]:
from IPython.display import clear_output
from random import sample
#checkpointpath = 'torchbeast/plots/Relational_full_v1.tar'
checkpointpath = 'torchbeast/plots/model4.tar'
gym_env = gym.make("NetHackScore-v0")
# NetHack-v0 NetHackScore-v0 NetHackChallenge-v0
# NetHackStaircase-v0 NetHackStaircasePet-v0 NetHackOracle-v0 NetHackGold-v0 NetHackEat-v0 NetHackScout-v0
env = ResettingEnvironment(gym_env)
model = NetHackNet_GAT(gym_env.observation_space, gym_env.action_space.n, True)
model.eval()
checkpoint = torch.load(checkpointpath, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])

agent_state = model.initial_state(batch_size=1)

observation = env.initial()
returns = []

agent_state = model.initial_state(batch_size=1)

a_list = [9,10,9,10,9,10]

num_episodes = 200
reward = 0
action_list = []
while len(returns) < num_episodes:
    policy_outputs, agent_state = model(observation, agent_state)
    #raction = torch.tensor([[gym_env.action_space.sample()]])
    #raction_ = torch.tensor([[sample([9,10,11,12],1)]])
    action = policy_outputs["action"]
    #action_list.append(action)
    observation = env.step(action)
    reward += observation["reward"].item()

    clear_output(wait=True)
    env.gym_env.render()
    print(observation["blstats"])
    print(policy_outputs)
    print(reward)
    if observation["done"].item():
        actions = torch.tensor(action_list)
        elements, counts = torch.unique(actions, return_counts=True)
        print('elements',elements)
        print('counts',counts)
        print('reward',reward)

        # action_list = []
        reward = 0
        returns.append(observation["episode_return"].item())
        logging.info(
            "Episode ended after %d steps. Return: %.1f",
            observation["episode_step"].item(),
            observation["episode_return"].item(),
            observation["last_action"].item()
        )
        time.sleep(3)

    time.sleep(0.3)
env.close()
logging.info(
    "Average returns over %i steps: %.1f", num_episodes, sum(returns) / len(returns)
)

0 MiscAction.MORE
1 CompassDirection.N
2 CompassDirection.E
3 CompassDirection.S
4 CompassDirection.W
5 CompassDirection.NE
6 CompassDirection.SE
7 CompassDirection.SW
8 CompassDirection.NW
9 CompassDirectionLonger.N
10 CompassDirectionLonger.E
11 CompassDirectionLonger.S
12 CompassDirectionLonger.W
13 CompassDirectionLonger.NE
14 CompassDirectionLonger.SE
15 CompassDirectionLonger.SW
16 CompassDirectionLonger.NW
17 MiscDirection.UP
18 MiscDirection.DOWN
19 MiscDirection.WAIT
20 Command.KICK
21 Command.EAT
22 Command.SEARCH

# Plot

In [None]:
! python -m nle.scripts.plot --file "torchbeast/Relational_v1/logs.tsv"
! python -m nle.scripts.plot --file "torchbeast/Relational_v2/logs.tsv"

plotting torchbeast/Relational_v1/logs.tsv
[0;39m                                                                                
[0;39m                                                                                
[0;39m                           [0;39maveraged mean_episode_return                         
[0;39m [0;39m 120 +---------------------------------------------------------------------+   
[0;39m      [0;39m|             +             +             +             +             |   
[0;39m      [0;39m|             [0;39m:             :       [1;35m++-+ ++-++-+ ++-+-+ [0;39m:             [0;39m|   
[0;39m [0;39m 100 |-+[0;39m...........:.............:..[1;35m++-++++++++++++++++++++++-+[0;39m..........[0;39m+-|   
[0;39m      [0;39m|[1;35m-+           [0;39m:             [1;35m++++++++|+++++|++|++++|+||++-+            [0;39m|   
[0;39m      [0;39m|[1;35m-+           [0;39m:     [1;35m+-++-+++++++|++|||++|||||||||||||||||             [0;39m|   
[0

# Tensorboard

In [None]:
!pip install tensorboardX

In [None]:
from math import nan, isnan
import os
from tensorboardX import SummaryWriter
import csv

# Change to the desired working directory
os.chdir("/content/drive/My Drive/Colab Notebooks")

# List of TSV log file paths
tsv_files = [
    'torchbeast/plots/Baseline_v1.tsv',
    'torchbeast/plots/Relational_v1.tsv',
    # Add more log files here
]

# Directory to save the TensorBoard logs
log_dir = 'runs/log_example'

# Create the log directory if it doesn't exist
# if not os.path.exists(log_dir):
#     os.makedirs(log_dir)

# Function to process a TSV file and log it to TensorBoard
def process_tsv_file(tsv_file_path, run_name):
    # Create a SummaryWriter for each run (log file)
    writer = SummaryWriter(os.path.join(log_dir, run_name))

    with open(tsv_file_path, 'r') as tsv_file:
        reader = csv.DictReader(tsv_file, delimiter='\t')

        for row in reader:
            try:
                # Extract the step, loss, and return values from the TSV file
                step = int(row['# Step'])              # Step or iteration number
                loss = float(row['total_loss'])        # Loss value
                ret = float(row['mean_episode_return'])# Return value

                # Log the scalar values to TensorBoard
                writer.add_scalar('Loss', loss, step)
                if not isnan(ret):
                    writer.add_scalar('Return', ret, step)

            except KeyError as e:
                print(f"Missing expected column: {e}")
            except ValueError as e:
                print(f"Data conversion error: {e}")

    # Close the writer after processing the file
    writer.close()

# Process all the log files
for tsv_file_path in tsv_files:
    # Use the file name (without extension) as the run name
    run_name = os.path.splitext(os.path.basename(tsv_file_path))[0]
    process_tsv_file(tsv_file_path, run_name)

In [None]:
# Load the TensorBoard extension
%load_ext tensorboard

# Launch TensorBoard
%tensorboard --logdir runs/log_example

# Info

wall 2361:- 2360:| 2362-2365:↖↗↘↙corner 2378:floor 2359:black
< 2382 > 2383

<details>
<summary>Explanation of each key difference in the action space of challenge:</summary>

Discrete(121)
0 CompassDirection.N
1 CompassDirection.E
2 CompassDirection.S
3 CompassDirection.W
4 CompassDirection.NE
5 CompassDirection.SE
6 CompassDirection.SW
7 CompassDirection.NW
8 CompassDirectionLonger.N
9 CompassDirectionLonger.E
10 CompassDirectionLonger.S
11 CompassDirectionLonger.W
12 CompassDirectionLonger.NE
13 CompassDirectionLonger.SE
14 CompassDirectionLonger.SW
15 CompassDirectionLonger.NW
16 MiscDirection.UP
17 MiscDirection.DOWN
18 MiscDirection.WAIT
19 MiscAction.MORE
	20 Command.EXTCMD
	21 Command.EXTLIST
22 Command.ADJUST 20
	23 Command.ANNOTATE
24 Command.APPLY 21
25 Command.ATTRIBUTES 22
	26 Command.AUTOPICKUP
27 Command.CALL 23
28 Command.CAST
29 Command.CHAT
30 Command.CLOSE 26
	31 Command.CONDUCT
32 Command.DIP 27
33 Command.DROP
34 Command.DROPTYPE
35 Command.EAT
36 Command.ENGRAVE
37 Command.ENHANCE
38 Command.ESC
39 Command.FIGHT
40 Command.FIRE
41 Command.FORCE 36
	42 Command.GLANCE
	43 Command.HISTORY
44 Command.INVENTORY 37
45 Command.INVENTTYPE
46 Command.INVOKE
47 Command.JUMP
48 Command.KICK 41
	49 Command.KNOWN
	50 Command.KNOWNCLASS
51 Command.LOOK 42
52 Command.LOOT
53 Command.MONSTER
54 Command.MOVE
55 Command.MOVEFAR
56 Command.OFFER
57 Command.OPEN 48
	58 Command.OPTIONS
	59 Command.OVERVIEW
60 Command.PAY 49
61 Command.PICKUP
62 Command.PRAY
63 Command.PUTON
64 Command.QUAFF 53
	65 Command.QUIT
66 Command.QUIVER 54
67 Command.READ 55
	68 Command.REDRAW
69 Command.REMOVE 56
70 Command.RIDE
71 Command.RUB
72 Command.RUSH
73 Command.RUSH2 60
	74 Command.SAVE
75 Command.SEARCH 61
	76 Command.SEEALL
	77 Command.SEEAMULET
78 Command.SEEARMOR 62
	79 Command.SEEGOLD
80 Command.SEERINGS 63
	81 Command.SEESPELLS
82 Command.SEETOOLS 64
83 Command.SEETRAP
84 Command.SEEWEAPON
85 Command.SHELL
86 Command.SIT
87 Command.SWAP
88 Command.TAKEOFF
89 Command.TAKEOFFALL 71
	90 Command.TELEPORT
91 Command.THROW 72
92 Command.TIP 73
	93 Command.TRAVEL
94 Command.TURN 74
95 Command.TWOWEAPON
96 Command.UNTRAP 76
	97 Command.VERSION
98 Command.VERSIONSHORT 77
99 Command.WEAR 78
	100 Command.WHATDOES
	101 Command.WHATIS
102 Command.WIELD 79
103 Command.WIPE
104 Command.ZAP
105 TextCharacters.PLUS 82
	106 TextCharacters.MINUS
107 TextCharacters.SPACE 83
	108 TextCharacters.APOS
109 TextCharacters.QUOTE 84
	110 TextCharacters.NUM_0
	111 TextCharacters.NUM_1
	112 TextCharacters.NUM_2
	113 TextCharacters.NUM_3
	114 TextCharacters.NUM_4
	115 TextCharacters.NUM_5
	116 TextCharacters.NUM_6
	117 TextCharacters.NUM_7
	118 TextCharacters.NUM_8
	119 TextCharacters.NUM_9
120 TextCharacters.DOLLAR 85

<details>
<summary>Explanation of each key in the action space:</summary>

TASK_ACTIONS = tuple(
    [nethack.MiscAction.MORE]
    + list(nethack.CompassDirection)
    + list(nethack.CompassDirectionLonger)
    + list(nethack.MiscDirection)
    + [nethack.Command.KICK, nethack.Command.EAT, nethack.Command.SEARCH]
)
class MiscAction(enum.IntEnum):
    MORE = ord("\r")  # read the next message

CompassDirection = enum.IntEnum(
    "CompassDirection",
    {
        **CompassCardinalDirection.__members__,
        **CompassIntercardinalDirection.__members__,
    },
)
class CompassCardinalDirection(enum.IntEnum):
    N = ord("k")
    E = ord("l")
    S = ord("j")
    W = ord("h")


class CompassIntercardinalDirection(enum.IntEnum):
    NE = ord("u")
    SE = ord("n")
    SW = ord("b")
    NW = ord("y")


CompassDirectionLonger = enum.IntEnum(
    "CompassDirectionLonger",
    {
        **CompassCardinalDirectionLonger.__members__,
        **CompassIntercardinalDirectionLonger.__members__,
    },
)
class CompassCardinalDirectionLonger(enum.IntEnum):
    N = ord("K")
    E = ord("L")
    S = ord("J")
    W = ord("H")

class CompassIntercardinalDirectionLonger(enum.IntEnum):
    NE = ord("U")
    SE = ord("N")
    SW = ord("B")
    NW = ord("Y")

class MiscDirection(enum.IntEnum):
    UP = ord("<")  # go up a staircase
    DOWN = ord(">")  # go down a staircase
    WAIT = ord(".")  # rest one move while doing nothing / apply to self

KICK = C("d")  # kick something
EAT = ord("e")  # eat something
SEARCH = ord("s")  # search for traps and secret doors

<details>
<summary>Explanation of each key in the observation dictionary:</summary>

1. glyphs: A 2D array representing the visual state of the game.
   Each value corresponds to a specific visual element or "glyph" in NetHack.
   Example shape: (21, 79) - corresponding to a 21x79 grid.

2. chars: A 2D array containing the character representation of the glyphs.
   Each value is a character code that visually represents the game's state.
   Example shape: (21, 79).

3. colors: A 2D array providing color information for the glyphs.
   Each value corresponds to a color code, giving more context to the visual elements.
   Example shape: (21, 79).

4. specials: A 2D array that indicates special attributes or states for each cell.
   These might include things like "lit" or "dark" areas, traps, etc.
   Example shape: (21, 79).

5. blstats: A 1D array containing various statistics and information about the player's status.
   This includes health, experience, gold, etc.
   Example shape: (25,) - representing various player statistics.

6. message: A 1D array (or string) that contains the latest message displayed to the player.
   This is typically the last line of text describing what happened in the game.
   Example shape: (256,) - representing the characters in the message.

7. inv_glyphs: A 1D array showing the glyphs for items in the player's inventory.
   Each glyph corresponds to an item.
   Example shape: (55,) - representing the inventory slots.

8. inv_strs: A 1D array containing the string representations of the inventory items.
   Each string describes an item in the player's inventory.
   Example shape: (55,) - one string per inventory slot.

9. inv_letters: A 1D array giving the letters corresponding to each item in the inventory.
   In NetHack, each item in the inventory is usually assigned a letter for quick access.
   Example shape: (55,).

10. inv_oclasses: A 1D array showing the object classes for items in the inventory.
    This indicates the type of each item, such as weapon, armor, potion, etc.
    Example shape: (55,).

11. tty_chars: A 2D array representing the state of the game as displayed in a traditional terminal (TTY) view.
    This is a character-based view of the game, similar to the "chars" key.
    Example shape: (24, 80) - a standard terminal size.

12. tty_colors: A 2D array providing color information for the TTY view.
    Each value corresponds to a color code, similar to the "colors" key.
    Example shape: (24, 80).

13. tty_cursor: A 1D array indicating the position of the cursor in the TTY view.
    This shows where the cursor is currently located on the screen.
    Example shape: (2,) - representing the row and column of the cursor.>

<details>
<summary>Explanation of each key in the blstats:</summary>

        1.0 / 79.0, # hero col
        1.0 / 21, # hero row
        0.0, # strength pct
        1.0 / 10, # strength
        1.0 / 10, # dexterity
        1.0 / 10, # constitution
        1.0 / 10, # intelligence
        1.0 / 10, # wisdom
        1.0 / 10, # charisma
        0.0,      # score
        1.0 / 10, # hitpoints
        1.0 / 10, # max hitpoints
        0.0, # depth
        1.0 / 1000, # gold
        1.0 / 10, # energy
        1.0 / 10, # max energy
        1.0 / 10, # armor class
        0.0, # monster level
        1.0 / 10, # experience level
        1.0 / 100, # experience points
        1.0 / 1000, # time
        1.0, # hunger_state
        1.0 / 10, # carrying capacity
        0.0, # carrying capacity
        0.0, # level number
        0.0, # condition bits

        /* blstats indices, see also botl.c and statusfields in botl.h. */
        #define NLE_BL_X 0
        #define NLE_BL_Y 1
        #define NLE_BL_STR25 2  /* strength 3..25 */
        #define NLE_BL_STR125 3 /* strength 3..125   */
        #define NLE_BL_DEX 4
        #define NLE_BL_CON 5
        #define NLE_BL_INT 6
        #define NLE_BL_WIS 7
        #define NLE_BL_CHA 8
        #define NLE_BL_SCORE 9
        #define NLE_BL_HP 10
        #define NLE_BL_HPMAX 11
        #define NLE_BL_DEPTH 12
        #define NLE_BL_GOLD 13
        #define NLE_BL_ENE 14
        #define NLE_BL_ENEMAX 15
        #define NLE_BL_AC 16
        #define NLE_BL_HD 17  /* monster level, "hit-dice" */
        #define NLE_BL_XP 18  /* experience level */
        #define NLE_BL_EXP 19 /* experience points */
        #define NLE_BL_TIME 20
        #define NLE_BL_HUNGER 21 /* hunger state */
        #define NLE_BL_CAP 22    /* carrying capacity */
        #define NLE_BL_DNUM 23
        #define NLE_BL_DLEVEL 24
        #define NLE_BL_CONDITION 25 /* condition bit mask */
        #define NLE_BL_ALIGN 26

# Test

In [None]:
# def generate_mask(done, input_tensor):

#     cum_sum_mark = torch.cumsum(done.float(), dim=0)  # Shape [T*B, 1]
#     print(cum_sum_mark)
#     mask = (cum_sum_mark > 0).float()  # Valid elements where the cumulative sum is 0

#     return mask

# end_mark = torch.tensor([[False,False], [True,False], [False,False], [True,False]])

# # Define an input tensor of shape [T*B, seq_len]
# input_tensor = torch.randn(4, 2, 8)  # [T*B, seq_len], 5 sequences of length 8

# # Generate the mask
# mask = generate_mask(end_mark, input_tensor)
# print(mask)
# glyphs = torch.randint(0, 2500, (2, 9, 9))

# # Step 2: Instantiate the get_graph_batch class
# graph_generator = get_graph_batch()

# # Step 3: Use the forward method to process the glyph maps and generate a batch of graphs
# graph_batch = graph_generator(glyphs)

# # Step 4: Print out key information to verify the output
# print("Node features (x):", graph_batch.x)          # Node features (glyph IDs)
# print("Edge index (edges):", graph_batch.edge_index)  # Edge indices between nodes

# # If you want to verify further details:
# print(f"Number of nodes: {graph_batch.num_nodes}")
# print(f"Number of edges: {graph_batch.num_edges}")
# print(f"Batch size (graphs): {graph_batch.batch.max().item() + 1}")
# print(graph_batch.x.shape,graph_batch.edge_index.shape)

# gym_env22 = create_env('NetHackScore-v0')
# print(gym_env22.action_space)
# gym_env22.unwrapped.print_action_meanings()
# print(gym_env22.observation_space)

In [None]:
grid1 = np.array([
    [1, 2, 3],
     [4, 5, 6],
      [7, 8, 9]])

grid2 = np.array([
    [10, 11, 12],
     [13, 14, 15],
      [16, 17, 18]])
blocked_ids = {2, 5, 10, 11, 14}

# Blocked node IDs (walls)
batch = get_graph_batch(torch.tensor([grid1,grid2]),blocked_ids)
print(batch)

emb = nn.Embedding(25, 8)
embed = emb(batch.x).squeeze()
print(embed.shape)
md = GAT(8,4,4)
out = md(embed,batch.edge_index,batch.batch)
print(out.shape)

DataBatch(x=[13, 1], edge_index=[2, 30], batch=[13], ptr=[3])
torch.Size([13, 8])
torch.Size([2, 4])


In [None]:
from gensim.models import Word2Vec,KeyedVectors
w2v = Word2Vec.load("wiki.model")
v1 = w2v.wv['dog']
print(v1) # Return a 300 Dim vector

similar_words = w2v.wv.most_similar('door', topn=30)  # topn determines how many similar words you want

# Print the similar words
for similar_word, similarity_score in similar_words:
    print(f"{similar_word}: {similarity_score}")

[-4.5952139e+00  2.5131800e+00 -3.8313661e+00  1.9113704e+00
  8.7944448e-01 -2.3180022e+00  6.1128769e+00  1.2038110e+00
 -1.2344328e+00  1.5616350e+00  2.4153082e+00  2.5785027e+00
 -3.6446410e-01  3.0473328e-01 -6.0826463e-01 -1.1008358e+00
  5.7924190e-04 -2.1493638e+00 -4.9723701e+00 -4.4495182e+00]
doorway: 0.940357506275177
drawbridge: 0.925424337387085
doors: 0.8861480355262756
closet: 0.8858508467674255
closed: 0.8701543211936951
locked: 0.870032012462616
wall: 0.8460908532142639
secret: 0.8290387392044067
square: 0.8033416271209717
stairs: 0.7987379431724548
downwards: 0.7928858995437622
hole: 0.7917185425758362
corridor: 0.7909206748008728
ladder: 0.7736232876777649
floor: 0.7626956105232239
portcullis: 0.7608734965324402
vibrating: 0.7536466717720032
downstairs: 0.7511041760444641
chest: 0.7509174346923828
doorways: 0.750511646270752
magically: 0.7503724694252014
couldnt: 0.7420575022697449
grave: 0.741557776927948
untrapped: 0.7406994104385376
unseen: 0.7403778433799744
dr

In [None]:
import gymnasium as gym
import nle
import random
import time
from IPython.display import clear_output
import torch
from nle.env.tasks import NetHackChallenge
from nle.env.tasks import NetHackScore
import numpy as np
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

print(gym.envs.registry )
env = gym.make("NetHackScore-v0",character="mon-hum-neu-mal") # NetHack-v0 NetHackScore-v0 NetHackChallenge-v0
print(env.action_space)
env.unwrapped.print_action_meanings()
print(env.observation_space)

{'CartPole-v0': EnvSpec(id='CartPole-v0', entry_point='gymnasium.envs.classic_control.cartpole:CartPoleEnv', reward_threshold=195.0, nondeterministic=False, max_episode_steps=200, order_enforce=True, autoreset=False, disable_env_checker=False, apply_api_compatibility=False, kwargs={}, namespace=None, name='CartPole', version=0, additional_wrappers=(), vector_entry_point='gymnasium.envs.classic_control.cartpole:CartPoleVectorEnv'), 'CartPole-v1': EnvSpec(id='CartPole-v1', entry_point='gymnasium.envs.classic_control.cartpole:CartPoleEnv', reward_threshold=475.0, nondeterministic=False, max_episode_steps=500, order_enforce=True, autoreset=False, disable_env_checker=False, apply_api_compatibility=False, kwargs={}, namespace=None, name='CartPole', version=1, additional_wrappers=(), vector_entry_point='gymnasium.envs.classic_control.cartpole:CartPoleVectorEnv'), 'MountainCar-v0': EnvSpec(id='MountainCar-v0', entry_point='gymnasium.envs.classic_control.mountain_car:MountainCarEnv', reward_thr

In [None]:
env.reset()  # each reset generates a new dungeon
#ob = env.step(1)  # move agent '@' north
#env.render()

def to_char(l):
    try:
        char_message = [chr(n) for n in l if 0 <= n <= 127]
        char_message = ''.join(char_message)
        return char_message
    except TypeError:
        print(f"Error processing key: {i}, value type: {type(obs[i])}, value: {obs[i]}") # Print more info for debugging

a_list = [9,10,9,10,9,10]#[0,9,3,8,8,21]

for i in range(len(a_list)):
    #a = env.action_space.sample()
    a = a_list[i]
    print('act:',a)
    #clear_output(wait=True) ###
    obs, reward, done, truncation, info = env.step(a)
    cord = [obs['blstats'][0],obs['blstats'][1]]
    print(cord)
    for i in ['glyphs','chars','colors','specials']:
        print(i,crop(torch.tensor([obs[i]]),torch.tensor([cord])))
    print(get_edge_index(torch.tensor([obs['glyphs']])))
    # char_message = to_char(obs['message'])
    # print(char_message)
    # char_inv = to_char(obs['inv_letters'])
    # print(char_inv)
    # print(obs['inv_oclasses'])
    # print(obs['inv_strs'])

    #print(obs)
    env.render()
    time.sleep(1.0) ###

#env.close()

# CoPE

In [None]:
import torch
from torch import nn, einsum
import math

from einops import rearrange,repeat

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

class CoPE(nn.Module):
    def __init__ (self, C, T=None) :
        '''
            CoPE实现
        :param C: 每个头的通道维度
        :param T: 最大的上下文长度
        '''
        super().__init__()
        self.T = T if(T is not None) else C
        self.pos_emb = nn.parameter.Parameter(
            torch.zeros(1, C, self.T)
        )

    def forward (self, q, qk) :
        '''
            CoPE的思想是，融合上下文寻址和位置寻址。
            1. 通过qiT @ kj来产生上下文位置
            2. 利用传统的RPE方式，将通过上下文位置当做特征位置进行编码
        :param q: (B*H,N,C) or (B,H,N,C)
        :param qk: (B*H,N,N) or (B,H,N,N), q与k相乘的结果，不包含softmax部分
        :return: E: (B*H,N,N) or (B,H,N,N)
        '''
        '''
            q(B,H,N,C) @ kT(B,H,C,N) = qk(B,H,N,N)
            q中每一个行向量与kT中的列向量相乘，因此公式表达是qiT @ kj (qi,kj都表示的列向量)
            对上下文寻址，进行门限取值,取值范围为（0,1），取值越大，意味着权重越大
            这句代码，对应了公式(3), gij = sigmid(qiT @ ki)
            G(B,H,N,N)
        '''
        G = torch.sigmoid(qk)
        '''
            [B,N,N]沿着最后一个维度进行翻转特征，然后对最后一维度进行累计求和，最后再沿着最后一个维度将特征翻转刚回来
            [a,b,c] -> [c,b,a] -> [a+b+c,b+c,a] -> [a,b+c,a+b+c]
            这句代码，对应了公式(4), pij = sum{k=j ~ i}(gik)
            P(B,H,N,N)
        '''
        P = G.flip(-1).cumsum(dim=-1).flip(-1)
        P = P.clamp(max=self.T - 1)
        '''
            整型编码插值
            由于sigmod的原因，CoPE不能像传统的RPE一样，利用可学习的编码层学习位置信息
            因此，使用一种简单的整型向量插值方法，来融合可学习的编码特征
            以下代码对应公式(5)
            同时，公式(9)采用了一种更高效的实现
            E(B,H,N,N)
        '''
        P_ceil = P.ceil().long()
        P_floor = P.floor().long()
        # (B,H,N,C) @ (1,C,T) = (B,H,N,T)
        E = torch.matmul(q, self.pos_emb) # eij
        E_ceil = E.gather(-1, P_ceil)
        E_floor = E.gather(-1, P_floor)
        P_P_floor = P - P_floor
        #E = (P - P_floor) * E_cell + (1 - P + P_floor) * E_floor
        E = P_P_floor * E_ceil + (1 - P_P_floor) * E_floor
        return E


class Attention(nn.Module):

    def __init__(
            self,
            Q_dim,
            KV_dim = None,
            heads = 8,
            dim_head = 64
    ):
        super().__init__()
        inner_dim = dim_head * heads
        KV_dim = default(KV_dim, Q_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.to_q = nn.Linear(Q_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(KV_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, Q_dim)

        self.cope = CoPE(dim_head)

    def forward(self, x, kv = None, mask=None):

        h = self.heads
        q = self.to_q(x)
        kv = default(kv, x)
        k,v = self.to_kv(kv).chunk(2, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        pe = self.cope(q, sim)
        sim = sim + pe

        attn = sim.softmax(dim = -1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        out = self.to_out(out)
        return out

if __name__ == '__main__':

    q = torch.ones(size=(2,128,64))
    a = Attention(Q_dim=64)
    x = a(q)
    print(x.shape)

torch.Size([2, 128, 64])
