In [None]:
import torch
import gym
import numpy as np

import torch.optim as op
import torch.nn.functional as torch_func

import matplotlib.pyplot as plt
import pandas as pd
import minihack
from minihack import RewardManager
from nle import nethack
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)

In [None]:
# hyperparameters
learning_rate = 5e-3

# Constants
GAMMA = 0.99

PICTURE_HEIGHT = 21
PICTURE_WIDTH = 79
PICTURE_CHANNELS = 1


ACTION_SPACE = [
      nethack.CompassDirection.N,
      nethack.CompassDirection.E,
      nethack.CompassDirection.S,
      nethack.CompassDirection.W,
      nethack.CompassDirection.NE,
      nethack.CompassDirection.SE,
      nethack.CompassDirection.SW,
      nethack.CompassDirection.NW,
      nethack.CompassDirectionLonger.N,
      nethack.CompassDirectionLonger.E,
      nethack.CompassDirectionLonger.S,
      nethack.CompassDirectionLonger.W,
      nethack.CompassDirectionLonger.NE,
      nethack.CompassDirectionLonger.SE,
      nethack.CompassDirectionLonger.SW,
      nethack.CompassDirectionLonger.NW,
      nethack.Command.EAT,
      nethack.MiscDirection.DOWN,
      nethack.MiscDirection.WAIT,
      nethack.MiscAction.MORE,
      nethack.Command.ADJUST,
      nethack.Command.APPLY,
      nethack.Command.ATTRIBUTES,
      nethack.Command.CALL,
      nethack.Command.CAST,
      nethack.Command.CHAT,
      nethack.Command.CLOSE,
      nethack.Command.DIP,
      nethack.Command.DROP,
      nethack.Command.DROPTYPE,
      nethack.Command.ENGRAVE,
      nethack.Command.ENHANCE,
      nethack.Command.ESC,
      nethack.Command.FIGHT,
      nethack.Command.FIRE,
      nethack.Command.FORCE,
      nethack.Command.INVENTORY,
      nethack.Command.INVENTTYPE,
      nethack.Command.INVOKE,
      nethack.Command.JUMP,
      nethack.Command.KICK,
      nethack.Command.LOOK,
      nethack.Command.LOOT,
      nethack.Command.MONSTER,
      nethack.Command.MOVE,
      nethack.Command.MOVEFAR,
      nethack.Command.OFFER,
      nethack.Command.OPEN,
      nethack.Command.PAY,
      nethack.Command.PICKUP,
      nethack.Command.PRAY,
      nethack.Command.PUTON,
      nethack.Command.QUAFF,
      nethack.Command.QUIVER,
      nethack.Command.READ,
      nethack.Command.REMOVE,
      nethack.Command.RIDE,
      nethack.Command.RUB,
      nethack.Command.RUSH,
      nethack.Command.RUSH2,
      nethack.Command.SEARCH,
      nethack.Command.SEEARMOR,
      nethack.Command.SEERINGS,
      nethack.Command.SEETOOLS,
      nethack.Command.SEETRAP,
      nethack.Command.SEEWEAPON,
      nethack.Command.SHELL,
      nethack.Command.SIT,
      nethack.Command.SWAP,
      nethack.Command.TAKEOFF,
      nethack.Command.TAKEOFFALL,
      nethack.Command.THROW,
      nethack.Command.TIP,
      nethack.Command.TURN,
      nethack.Command.TWOWEAPON,
      nethack.Command.UNTRAP,
      nethack.Command.VERSIONSHORT,
      nethack.Command.WEAR,
      nethack.Command.WIELD,
      nethack.Command.WIPE,
      nethack.Command.ZAP,
      nethack.TextCharacters.PLUS,
      nethack.TextCharacters.QUOTE,
      nethack.TextCharacters.DOLLAR,
      nethack.TextCharacters.SPACE,
]

QUEST_ACTION_SPACE = [
    nethack.CompassDirection.N,
    nethack.CompassDirection.E,
    nethack.CompassDirection.S,
    nethack.CompassDirection.W,
    nethack.Command.EAT,
    nethack.Command.PICKUP,
    nethack.Command.APPLY,
    nethack.Command.FIRE,
    nethack.Command.RUSH,
    nethack.Command.ZAP,
    nethack.Command.PUTON,
    nethack.Command.READ,
    nethack.Command.WEAR,
    nethack.Command.QUAFF
]

MOVE_ACTION_SPACE = [
      nethack.CompassDirection.N,
      nethack.CompassDirection.E,
      nethack.CompassDirection.S,
      nethack.CompassDirection.W,
      nethack.CompassDirection.NE,
      nethack.CompassDirection.SE,
      nethack.CompassDirection.SW,
      nethack.CompassDirection.NW,
      nethack.CompassDirectionLonger.N,
      nethack.CompassDirectionLonger.E,
      nethack.CompassDirectionLonger.S,
      nethack.CompassDirectionLonger.W,
      nethack.CompassDirectionLonger.NE,
      nethack.CompassDirectionLonger.SE,
      nethack.CompassDirectionLonger.SW,
]


EAT_ACTION_SPACE = [
      nethack.CompassDirection.N,
      nethack.CompassDirection.E,
      nethack.CompassDirection.S,
      nethack.CompassDirection.W,
      nethack.CompassDirection.NE,
      nethack.CompassDirection.SE,
      nethack.CompassDirection.SW,
      nethack.CompassDirection.NW,
      nethack.CompassDirectionLonger.N,
      nethack.CompassDirectionLonger.E,
      nethack.CompassDirectionLonger.S,
      nethack.CompassDirectionLonger.W,
      nethack.CompassDirectionLonger.NE,
      nethack.CompassDirectionLonger.SE,
      nethack.CompassDirectionLonger.SW,
      nethack.Command.EAT,
]

device = torch.device("cpu")

In [None]:
reward_manager = RewardManager()

EAT_ITEMS = [
    "apple",
]

MONSTERS = [
    "jackal",
    "rat",
    "lichen"
]

WIELDED_ITEMS = [
    "dagger",
]

for eat in EAT_ITEMS:
  reward_manager.add_eat_event(name=eat, reward=1)

for monster in MONSTERS:
  reward_manager.add_kill_event(name=monster, reward=1)

for wield in WIELDED_ITEMS:
  reward_manager.add_wield_event(name=wield, reward=1)

reward_manager.add_location_event("sink", reward=-1, terminal_required=False)

reward_manager.add_location_event("lava", reward=-1, terminal_required=False)
reward_manager.reward_lose = -1


In [None]:

envs = [
    {"name": "MiniHack-Room-Random-5x5-v0", "episodes": 50, "action_space": MOVE_ACTION_SPACE, "must_print": False},
    {"name": "MiniHack-Eat-v0", "episodes": 300, "action_space": EAT_ACTION_SPACE, "must_print": False},
    {"name": "MiniHack-LavaCross-Full-v0", "episodes": 250, "action_space": QUEST_ACTION_SPACE, "must_print": False},
]

quest_envs = [
    {"name": "MiniHack-Quest-Easy-v0", "episodes": 500, "action_space": QUEST_ACTION_SPACE, "must_print": False},
    {"name": "MiniHack-Quest-Medium-v0", "episodes": 500, "action_space": QUEST_ACTION_SPACE, "must_print": False},
    {"name": "MiniHack-Quest-Hard-v0", "episodes": 500, "action_space": QUEST_ACTION_SPACE, "must_print": False},
]

In [None]:
class WarpFrame(gym.ObservationWrapper):
    def __init__(self, env):
        """Warp frames to 84x84 as done in the Nature paper and later work.
        Expects inputs to be of shape height x width x num_channels
        """
        gym.ObservationWrapper.__init__(self, env)
        self.width = PICTURE_WIDTH
        self.height = PICTURE_HEIGHT
        self.observation_space = spaces.Box(
            low=0, high=255, shape=(self.height, self.width, 1), dtype=np.uint8
        )

    def observation(self, frame):
        frame = frame["pixel"]
        # frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(
            frame, (self.width, self.height), interpolation=cv2.INTER_AREA
        )
        return frame[:, :, None]


class GlyphWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.LastFullImg = None

    def observation(self, observation):
        self.LastFullImg = observation["pixel"]
        glyphs = observation["glyphs"]
        glyphs_tensor = torch.tensor(glyphs, dtype=torch.float32)
        flattened = glyphs_tensor.view(glyphs_tensor.size(0)*glyphs_tensor.size(1))
        flattened = torch.div(flattened, other=torch.max(flattened))

        return flattened.unsqueeze(0)

class Glyp2DhWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.LastFullImg = None

    def observation(self, observation):
        self.LastFullImg = observation["pixel"]
        glyphs = observation["glyphs"]
        glyphs_tensor = torch.tensor(glyphs, dtype=torch.float32)
        glyphs_tensor = torch.div(glyphs_tensor, other=torch.max(glyphs_tensor))

        glyphs_cropped = observation["glyphs_crop"]
        glyphs_cropped_tensor = torch.tensor(glyphs_cropped, dtype=torch.float32)
        glyphs_cropped_tensor = torch.div(glyphs_cropped_tensor, other=torch.max(glyphs_cropped_tensor))

        message = observation["message"]
        message_tensor = torch.tensor(message, dtype=torch.float32)
        message_tensor = torch.div(message_tensor, other=torch.max(message_tensor)) if torch.max(message_tensor) != 0 else message_tensor

        state ={
            "glyphs": glyphs_tensor.unsqueeze(0).unsqueeze(0),
            "glyphs_cropped": glyphs_cropped_tensor.unsqueeze(0).unsqueeze(0),
            "message": message_tensor.unsqueeze(0)
        }

        return state

    def get_last_full_img(self):
        return self.LastFullImg

In [None]:
class ActorCriticModel2D(torch.nn.Module):
    def __init__(self, action_space):
        super(ActorCriticModel2D, self).__init__()

        self.action_space = action_space

        # Activations
        self.relu_layer = torch.nn.ReLU()
        self.max_pooling = torch.nn.MaxPool2d(kernel_size=2)
        self.tanh_layer = torch.nn.Tanh()

        # Convolutional layers BIG
        self.cnn_layer_one = torch.nn.Conv2d(in_channels=1, stride=1, kernel_size=(3, 3),
                                             out_channels=8)
        self.cnn_layer_two = torch.nn.Conv2d(in_channels=8, stride=1, kernel_size=(3, 3),
                                             out_channels=16)

        # Convolutional layers CROPPED
        self.cnn_layer_one_crop = torch.nn.Conv2d(in_channels=1, stride=1, kernel_size=(3, 3),
                                             out_channels=4)
        self.cnn_layer_two_crop = torch.nn.Conv2d(in_channels=4, stride=1, kernel_size=(3, 3),
                                             out_channels=8)

        # Message layers
        self.message_input = torch.nn.Linear(in_features=256, out_features=400)
        self.message_hidden = torch.nn.Linear(in_features=400, out_features=100)

        # Dummy input for flattened input
        dummy_input = torch.zeros(1, PICTURE_CHANNELS, PICTURE_WIDTH, PICTURE_HEIGHT)
        conv_out = self.forward_conv(dummy_input)
        conv_out = conv_out.view(conv_out.size(0), -1)
        dummy_message = torch.zeros((1, 256))
        dummy_message = self.message_forward(dummy_message)
        dummy_glyph_cropped = torch.zeros(1, PICTURE_CHANNELS, 9, 9)
        cropped_conv_out = self.forward_conv_crop(dummy_glyph_cropped)
        cropped_conv_out = cropped_conv_out.view(cropped_conv_out.size(0), -1)

        dummy_output = torch.cat(tensors=(conv_out, cropped_conv_out, dummy_message), dim=1)
        dummy_output_shape = int(np.prod(dummy_output.shape))

        # Critic layers
        self.critic_final = torch.nn.Linear(in_features=dummy_output_shape, out_features=1)

        # Actor layers
        self.actor_final = torch.nn.Linear(in_features=dummy_output_shape, out_features=self.action_space)

    def forward_conv(self, state):
        state = self.relu_layer(self.cnn_layer_one(state))
        state = self.max_pooling(state)
        state = self.relu_layer(self.cnn_layer_two(state))
        state = self.max_pooling(state)
        return state

    def forward_conv_crop(self, state):
        state = self.relu_layer(self.cnn_layer_one_crop(state))
        state = self.max_pooling(state)
        state = self.relu_layer(self.cnn_layer_two_crop(state))
        return state

    def actor_forward(self, state):
        policy_dist = self.actor_final(state)
        policy_softmax_dist = torch_func.softmax(policy_dist, dim=1)

        return policy_softmax_dist

    def critic_forward(self, state):
        value = self.critic_final(state)

        return value

    def message_forward(self, state):
        message_input = self.message_input(state)
        message_input = self.tanh_layer(message_input)
        message_hidden = self.message_hidden(message_input)
        message_output = self.tanh_layer(message_hidden)

        return message_output

    def forward(self, state):
        # Glyph
        glyphs = state["glyphs"]
        conv_state = self.forward_conv(state=glyphs)
        feature_extraction = conv_state.view(conv_state.size(0), -1)

        # Cropped glyph
        glyphs_cropped = state["glyphs_cropped"]
        cropped_conv_state = self.forward_conv_crop(state=glyphs_cropped)
        cropped_feature_extraction = cropped_conv_state.view(cropped_conv_state.size(0), -1)

        # Message state
        message_state = self.message_forward(state["message"])

        output_state = torch.cat(tensors=(feature_extraction, message_state, cropped_feature_extraction), dim=1)

        # Actor
        policy_softmax_dist = self.actor_forward(output_state)

        # Critic
        value = self.critic_forward(output_state)

        return policy_softmax_dist, value


In [None]:
class ActorCriticRun():
    def __init__(self, action_space, env_name, model, env, max_episodes, must_print=False):
        self.env = env
        self.A2C_model = model

        self.all_lengths = []
        self.average_lengths = []
        self.all_rewards = []
        self.entropy_term = 0
        self.optimizer = op.Adam(self.A2C_model.parameters(), lr=learning_rate)
        self.action_space = action_space
        self.must_print = must_print
        self.max_episodes = max_episodes
        self.env_name = env_name

    def run(self):

        for ep in range(self.max_episodes):
            log_probs = []
            values = []
            rewards = []

            state = self.env.reset()
            for step in range(num_steps):
                policy_dist, value = self.A2C_model(state)
                value = value.cpu().detach().numpy()[0][0]
                dist = policy_dist.cpu().detach().numpy() #CHECK

                action = np.random.choice(len(self.action_space), p=np.squeeze(dist)) #CHECK

                log_prob = torch.log(policy_dist.squeeze(0)[action]) #CHECK
                entropy = -np.sum(np.mean(dist)*np.log(dist)) #CHECK
                next_state, reward, terminated, info = self.env.step(action)

                rewards.append(reward)
                values.append(value)
                log_probs.append(log_prob)
                self.entropy_term += entropy
                state = next_state

                if terminated or step == num_steps-1:
                    _, Qval = self.A2C_model(next_state)
                    Qval = torch.max(Qval.squeeze()) #CHECK
                    self.all_rewards.append(np.sum(rewards))
                    self.all_lengths.append(step+1)
                    self.average_lengths.append(np.mean(self.all_lengths[-10:]))

                    print(f"episode: {ep+1}, reward: {np.sum(rewards)}, average_reward: {np.mean(self.all_rewards)}, total length: {step+1}, average length: {self.average_lengths[-1]}")

                    break

            Qvals = np.zeros_like(values) #CHECK
            for t in reversed(range(len(rewards))):
                Qval = rewards[t] + GAMMA * Qval
                Qvals[t] = Qval

            values = torch.FloatTensor(values)
            Qvals = torch.FloatTensor(Qvals)
            log_probs = torch.stack(log_probs).cpu()

            advantage = Qvals - values
            actor_loss = (-log_probs*advantage).mean()  #CHECK
            critic_loss = 0.5 * advantage.pow(2).mean()  #CHECK
            ac_loss = actor_loss + critic_loss + 0.001 * self.entropy_term  #CHECK

            self.optimizer.zero_grad()
            ac_loss.backward()
            self.optimizer.step()


    def plot_results(self):
      plt.plot(self.all_rewards, label="All Rewards")
      plt.plot()
      plt.xlabel("Episode")
      plt.ylabel("Reward")
      plt.title(self.env_name)
      plt.legend()
      plt.show()

      plt.plot(self.all_lengths, label="All lengths")
      plt.plot(self.average_lengths, label="Avg lengths")
      plt.xlabel("Episode")
      plt.ylabel("Episode Length")
      plt.title(self.env_name)
      plt.legend()
      plt.show()

In [None]:
agent_2d = ActorCriticModel2D(len(ACTION_SPACE))
# agent_2d = agent_2d.cuda()

for env_obj in envs:
  print(f"{'-'*5} Starting {env_obj['name']} {'-'*5} ")

  env = gym.make(env_obj["name"], actions=env_obj["action_space"], observation_keys=("glyphs", "chars", "colors", "pixel", "message", 'glyphs_crop'))
  env = Glyp2DhWrapper(env)

  a2c = ActorCriticRun(action_space=env_obj["action_space"], env_name=env_obj["name"], model=agent_2d, env=env, max_episodes=env_obj["episodes"], must_print=env_obj["must_print"])
  a2c.run()
  a2c.plot_results()

  torch.save(agent_2d.state_dict(), "/content/drive/MyDrive/University/Honours/RLWeights/a2c_2d_weights.pth")

In [None]:


for env_obj in quest_envs:
  print(f"{'-'*5} Starting {env_obj['name']} {'-'*5} ")
  agent_quest = ActorCriticModel2D(len(QUEST_ACTION_SPACE))

  env = gym.make(env_obj["name"], actions=QUEST_ACTION_SPACE, observation_keys=("glyphs", "chars", "colors", "pixel", "message", 'glyphs_crop'))
  env = Glyp2DhWrapper(env)

  a2c = ActorCriticRun(action_space=QUEST_ACTION_SPACE, env_name=env_obj["name"], model=agent_quest, env=env, max_episodes=env_obj["episodes"], must_print=env_obj["must_print"])
  a2c.run()
  a2c.plot_results()

In [None]:
videoFrames = []
for env_obj in quest_envs:
  print(f"{'-'*5} Starting {env_obj['name']} {'-'*5} ")

  env = gym.make(env_obj["name"], actions=QUEST_ACTION_SPACE, observation_keys=("glyphs", "chars", "colors", "pixel", "message", 'glyphs_crop'), reward_manager=reward_manager)
  # self.env = WarpFrame(self.env)
  env = Glyp2DhWrapper(env)

  # agent_2d.set_action_space(env_obj["action_space"])
  a2c = ActorCriticRun(action_space=env_obj["action_space"], env_name=env_obj["name"], model=agent_quest, env=env, max_episodes=1, must_print=env_obj["must_print"])
  a2c.run(record=True)
  videoFrames = a2c.get_frames()

In [None]:
output_video_path = f'A2C_Quest_Hard.mp4'

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = 5  # Frames per second
frame_width, frame_height = videoFrames[0].shape[1], videoFrames[0].shape[0]

out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))

for image in videoFrames:

    frame = cv2.cvtColor(np.uint8(image), cv2.COLOR_RGB2BGR)
    out.write(frame)
out.release()

print(f'Video saved at {output_video_path}')