In [None]:
#  -------------------------------   Imports  -------------------------------  #
import gym
import minihack
from nle import nethack
from skimage.io import imshow
import numpy as np
from CustomRewardManager import RewardManager, InventoryEvent, MessageEvent
from actor_critic_v2_hierarchical_options import *

In [None]:
# -------------------------------------------------------------------------------#
#                    Custom Sub-Policies - Imported/Defined                      #
# -------------------------------------------------------------------------------#
class PotionPolicy:
    """
    A Policy loader that reads a stored pytorch file with model and optimizer parameters.

    INPUT: The policy should be stored in a file as a dictionary with the format:
        {"model_policy": policy, "optimizer": optimizer}

    This policy loader is made specifically for detecting potions in the agent's surrounds.

    """

    # Setup for added contextual information
    my_dict = {
        "": 0,
        "floor of a room": 1,
        "human rogue called Agent": 1,
        "staircase up": 3,
        "staircase down": 4,
    }
    directions = ["107", "108", "106", "104", "117", "110", "98", "121", None]
    obj_to_find = "potion"

    def __init__(self, policy_file, actions):
        modelA = ActorCritic(h_size=512, a_size=len(actions))
        optimizerA = torch.optim.Adam(modelA.parameters(), lr=0.02)

        checkpoint = torch.load(policy_file)
        modelA.load_state_dict(checkpoint["model_policy"])
        optimizerA.load_state_dict(checkpoint["optimizer"])

        modelA.eval()
        self.policy = modelA
        self.actions = actions

    def select_action(self, env, next_state):
        # Adding in contextual information
        neighbor_descriptions = env.get_neighbor_descriptions()
        mapped_descriptions = np.array(
            map_descriptions(self.my_dict, neighbor_descriptions)
        )
        mapped_descriptions = mapped_descriptions.reshape((1, len(mapped_descriptions)))

        selected_directions = env.get_object_direction(self.obj_to_find)
        selected_directions_encoded = np.zeros(len(self.directions), dtype=int)
        index = self.directions.index(
            str(selected_directions)
            if selected_directions is not None
            else selected_directions
        )
        selected_directions_encoded[index] = 1
        selected_directions_encoded = np.array(
            selected_directions_encoded.reshape((1, len(selected_directions_encoded)))
        )
        # Using the stored policy to generate the next set of action probabilities
        action_probs, state_value = self.policy.forward(
            next_state, mapped_descriptions, selected_directions_encoded
        )
        distribution = torch.distributions.Categorical(action_probs)
        action = distribution.sample()
        return action.item()


class DrinkPolicy:
    """Defined policy for confirming an action. Two steps are required to complete this Option."""

    def __init__(self, consumable):
        self.policy_step = 1
        self.consumable = consumable

    def select_action(self, env, observation=None):
        if self.policy_step == 1:
            self.policy_step += 1
            return nethack.Command.QUAFF
        else:
            inv_key = env.key_in_inventory(self.consumable)
            for action in env.actions:
                if ord(inv_key) == action.value:
                    return action
            print(f"Confirm action '{inv_key}' not found in those available")
            return None

In [None]:
# -------------------------------------------------------------------------------#
#                               Termination Events                               #
# -------------------------------------------------------------------------------#


class InventoryTerminationEvent:
    """An event which checks whether a specified object is in the inventory."""

    def __init__(self, inv_item: str):
        """Initialise the Event."""
        self.inv_item = inv_item

    def check_complete(self, env, observation) -> float:
        # del previous_observation, action, observation
        inventory_items = observation["inv_strs"]
        for inv_item in inventory_items:
            if self.inv_item in inv_item[
                : np.where(inv_item == 0)[0][0]
            ].tobytes().decode("utf-8"):
                return True
        return False


class MessageTerminationEvent:
    """An event which checks whether a specified message is received."""

    def __init__(self, messages: str):
        """Initialise the Event."""
        self.messages = messages

    def check_complete(self, env, observation) -> float:
        try:
            msg = observation["message"]
            curr_msg = msg[: np.where(msg == 0)[0][0]].tobytes().decode("utf-8")
            for msg in self.messages:
                if msg in curr_msg:
                    return True
        except:
            print("Failed to decode message:")
            print(observation[env._original_observation_keys.index("message")])
        return False

In [None]:
# -------------------------------------------------------------------------------#
#                               Sample Options Setup                             #
# -------------------------------------------------------------------------------#

MOVE_ACTIONS = tuple(nethack.CompassDirection)
NAVIGATE_ACTIONS = MOVE_ACTIONS + (
    nethack.Command.OPEN,
    nethack.Command.KICK,
    nethack.Command.SEARCH,
    nethack.Command.FIGHT,
)

# Lava ENV -- extra actions needed for confirming action
LAVA_ACTIONS = NAVIGATE_ACTIONS + (
    nethack.Command.PICKUP,
    nethack.Command.APPLY,
    nethack.Command.PUTON,
    nethack.Command.WEAR,
    nethack.Command.QUAFF,
    nethack.Command.INVOKE,
    nethack.Command.CAST,
    nethack.Command.INVENTORY,
    nethack.MiscAction.MORE,
    nethack.Command.ESC,
    nethack.Command.FIRE,  # most commonly needed confirmation action
    nethack.Command.DROP,
    nethack.Command.RUSH,
)

des_file = """
MAZE: "mylevel", ' '
FLAGS:hardfloor
INIT_MAP: solidfill,' '
GEOMETRY:center,center
MAP
--------
|...L..|
|...L..|
|...L..|
--------
ENDMAP
REGION:(0,0,6,3),lit,"ordinary"
$left_bank = selection:fillrect (1,1,3,3)
$right_bank = selection:fillrect (5,1,6,3)
OBJECT:('!',"levitation"),rndcoord($left_bank),blessed
BRANCH:(1,1,3,3),(0,0,0,0)
STAIR:rndcoord($right_bank),down
"""

# Specify RewardManager for reward shaping
reward_manager = RewardManager()
# Reward for adding a potion to the inventory
reward_manager.add_event(InventoryEvent(0.5, False, True, False, inv_item="potion"))
# Message based reward after drinking a levitation potion
reward_manager.add_event(
    MessageEvent(0.5, False, True, False, messages=["You start to float in the air!"])
)
# Add a penalty reward for entering the lava
reward_manager.add_location_event("molten lava", reward=-1, terminal_required=False)
# Add reward for reaching the final destination and set it to terminal
reward_manager.add_location_event(
    "staircase down",
    reward=1,
    repeatable=False,
    terminal_required=True,
    terminal_sufficient=True,
)

env = gym.make(
    "MiniHack-Skill-Custom-v0",
    des_file=des_file,
    observation_keys=[
        "glyphs",
        "pixel",
        "message",
        "pixel_crop",
        "glyphs_crop",
        "blstats",
        "inv_strs",
    ],
    reward_manager=reward_manager,
    # reward_lose=-1, # Does not work when reward manager is used
    actions=LAVA_ACTIONS,
    autopickup=True,
    allow_all_modes=True,  # Enables confirmation message for consuming potion
    max_episode_steps=500,
)

# Option 1: Loaded policy that's trained to find and pickup a potion
termination_clause = InventoryTerminationEvent("potion")
potion_policy = PotionPolicy(
    "./policy_potion_pickup_with_neighbours_2000.pt", MOVE_ACTIONS
)
# Option 2: Defined policy for consuming a potion in the agent's inventory
drink_policy = DrinkPolicy("potion")
levitation_message = MessageTerminationEvent(["You start to float in the air!"])

# Specify the set of Options, including the primitive actions
ENV_OPTIONS = [(action, None, 1) for action in env.actions] + [
    (potion_policy, termination_clause, 20),
    (drink_policy, levitation_message, 2),
]

In [None]:
# Train the model with options instead of actions
policy, results, optimizer = run_actor_critic(
    env,
    number_episodes=5000,
    max_episode_length=1000,
    iterations=3,
    env_options=ENV_OPTIONS,
)

# Save and process results
torch.save(
    {"model_policy": policy.state_dict(), "optimizer": optimizer.state_dict()},
    "./policy_potion_pickup_with_options.pt",
)
plot_results(
    env_name="Lava Cross with Potion Pickup Option",
    scores=results,
    ylim=(-1.5, 2),
    color="teal",
)