In [2]:
import sys
sys.path.append("../../")
sys.path.append("../../models/episodic_transformer_memory_ppo")

from environments.Passive_T_Maze_Flag.env.env_passive_t_maze_flag import TMazeClassicPassive
from models.episodic_transformer_memory_ppo.model import ActorCriticModel
import os 

import numpy as np
import gym
import matplotlib.pyplot as plt
import random
import torch
import yaml
import time
from moviepy.editor import ImageSequenceClip, VideoFileClip


In [3]:
def init_transformer_memory(trxl_conf, max_episode_steps, device):
    """Returns initial tensors for the episodic memory of the transformer.

    Arguments:
        trxl_conf {dict} -- Transformer configuration dictionary
        max_episode_steps {int} -- Maximum number of steps per episode
        device {torch.device} -- Target device for the tensors

    Returns:
        memory {torch.Tensor}, memory_mask {torch.Tensor}, memory_indices {torch.Tensor} -- Initial episodic memory, episodic memory mask, and sliding memory window indices
    """
    # Episodic memory mask used in attention
    memory_mask = torch.tril(torch.ones((trxl_conf["memory_length"], trxl_conf["memory_length"])), diagonal=-1)
    # Episdic memory tensor
    memory = torch.zeros((1, max_episode_steps, trxl_conf["num_blocks"], trxl_conf["embed_dim"])).to(device)
    # Setup sliding memory window indices
    repetitions = torch.repeat_interleave(torch.arange(0, trxl_conf["memory_length"]).unsqueeze(0), trxl_conf["memory_length"] - 1, dim = 0).long()
    memory_indices = torch.stack([torch.arange(i, i + trxl_conf["memory_length"]) for i in range(max_episode_steps - trxl_conf["memory_length"] + 1)]).long()
    memory_indices = torch.cat((repetitions, memory_indices))
    return memory, memory_mask, memory_indices

# Train on L = 15, test L = 20

In [4]:
config_path = '/opt/Memory-RL-Codebase/configs/GTRXL_configs/Passive_T_Maze_Flag/Dense/Passive_T_Maze_Flag_SHORT_TERM.yaml'

episode_timeout = 15
corridor_length = episode_timeout - 2
penalty = -1/(episode_timeout - 1)

device = torch.device('cuda:0')


with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

env = TMazeClassicPassive(episode_length=episode_timeout, 
                            corridor_length=corridor_length, 
                            goal_reward=1.0,
                            penalty=penalty)


# Checkpoint 1

In [5]:
# ckp 1

checkpoint_path = '/opt/Memory-RL-Codebase/autorun/checkpoints_2024_09_29_19_00/Passive_T_Maze_Flag/GTXL/GTXL_Passive_T_Maze_Flag_SHORT_TERM_dense/2024_09_29-12_25_14.pt'
checkpoint = torch.load(checkpoint_path)

config['transformer']['memory_length'] = corridor_length

config['transformer']['num_blocks'] = 3
config['transformer']['embed_dim'] = 64
config['transformer']['num_heads'] = 4
config['hidden_layer_size'] = 64


# Checkpoint 2

In [46]:
# ckp 2

checkpoint_path = '/opt/Memory-RL-Codebase/autorun/checkpoints_2024_09_29_22_00/Passive_T_Maze_Flag/GTXL/GTXL_Passive_T_Maze_Flag_SHORT_TERM_dense/2024_09_28-23_36_02.pt'
checkpoint = torch.load(checkpoint_path)

config['transformer']['memory_length'] = corridor_length


config['transformer']['num_blocks'] = 6
config['transformer']['embed_dim'] = 128
config['transformer']['num_heads'] = 8
config['hidden_layer_size'] = 128


# Checkpoint 3

In [33]:
# ckp 3

checkpoint_path = '/opt/Memory-RL-Codebase/autorun/checkpoints/Passive_T_Maze_Flag/GTXL/GTXL_Passive_T_Maze_Flag_SHORT_TERM_dense/2024_09_29-12_23_46.pt'
checkpoint = torch.load(checkpoint_path)

config['transformer']['memory_length'] = corridor_length

config['transformer']['num_blocks'] = 6
config['transformer']['embed_dim'] = 128
config['transformer']['num_heads'] = 8
config['hidden_layer_size'] = 128

In [6]:
agent = ActorCriticModel(config, env.observation_space, (env.action_space.n,), env.max_episode_steps).to(device)
agent.load_state_dict(checkpoint["model_state_dict"])
agent.eval()
agent = agent.to(device)
torch.set_default_device(device)

In [7]:
from itertools import permutations

def generate_permutations(nums):

    perms = permutations(nums)
    result = [int(''.join(map(str, perm))) for perm in perms]
    
    return result



In [8]:

### evaluate !

videos_dir = '/opt/Memory-RL-Codebase/eval/Minigrid_Memory/GTRXL'

nums = [1, 2, 3, 4, 5]
eval_seeds = generate_permutations(nums)

videos_limit = len(eval_seeds) + 1
n_episode = len(eval_seeds)


render = False

total_reward = 0
num_successes = 0
total_steps = 0


for i in range(n_episode):

    if render:
        frames = []

    done = False
    memory, memory_mask, memory_indices = init_transformer_memory(config["transformer"], env.max_episode_steps, device)

    memory = memory.to(device)
    memory_mask = memory_mask.to(device)
    memory_indices = memory_indices.to(device)


    memory_length = config["transformer"]["memory_length"]
    # eval_seeds = config.get("eval_seeds", None)
    t = 0
    ep_reward = 0

    if eval_seeds is not None:
        obs = env.reset(eval_seeds[i])    
    else:
        obs = env.reset()

    if render and i < videos_limit:
        rofl = env.render()
        time.sleep(0.5)
        frames.append(rofl)



    while not done:
        # Prepare observation and memory
        obs = torch.tensor(np.expand_dims(obs, 0), dtype=torch.float32, device=device)
        in_memory = memory[0, memory_indices[t].unsqueeze(0)]
        t_ = max(0, min(t, memory_length - 1))
        mask = memory_mask[t_].unsqueeze(0)
        indices = memory_indices[t].unsqueeze(0)
        # Forward model
        policy, value, new_memory = agent(obs.to(device), in_memory.to(device), mask.to(device), indices.to(device))
        memory[:, t] = new_memory
        # Sample action
        action = []
        for action_branch in policy:
            action.append(action_branch.sample().item())
        # Step environemnt
        # print(f'action: {action}')
        obs, reward, done, info = env.step(action)
        # print(f'Action :{action}, obs: {obs.shape}, reward: {reward}, terminated: {done}, info: {info}')
        if render and i < videos_limit:
            rofl = env.render()
            if done:
                print(f"Episode terminated. Episode reward: {ep_reward}")
            time.sleep(0.5)
            frames.append(rofl)



        ep_reward += reward
        t += 1


    if info.get("is_success"):
        num_successes += 1
    total_reward += ep_reward
    total_steps += t

    if render and i < videos_limit:
        desired_resolution = (945, 540)
        original_aspect_ratio = 112 / 64
        width = int(desired_resolution[0] * original_aspect_ratio)
        height = desired_resolution[1]

        observations = [np.squeeze(o) for o in frames]

        clip = ImageSequenceClip(observations, fps=2)
        clip = clip.resize(width=width, height=height)


        run_name = checkpoint_path.split('/')[-1].strip('.pt')
        run_type = checkpoint_path.split('/')[-2]
        curr_seed = eval_seeds[i]
        curr_reward = float(info['reward'])

        if not os.path.exists(videos_dir + f"/{run_type}/{run_name}"):
            os.makedirs(videos_dir + f"/{run_type}/{run_name}")

        clip.write_videofile(videos_dir + f"/{run_type}/{run_name}/{run_name}_seed={curr_seed}_reward={curr_reward:0.2}.mp4", fps=2)

    curr_seed = eval_seeds[i]
    print(f'Episode: {i}, seed: {curr_seed} Reward: {ep_reward}, Steps: {t} Mean reward: {total_reward / (i + 1)}, Mean steps: {total_steps / (i + 1)}')


print(f'Total num episodes: {n_episode} Success rate: {num_successes / n_episode}, Mean reward: {total_reward / n_episode}, Mean steps: {total_steps / n_episode}')
