In [23]:
import sys
sys.path.append("../../")
sys.path.append("../../models/Memory_RL")

# from environments.Passive_T_Maze_Flag.env.env_passive_t_maze_flag import TMazeClassicPassive
from models.Memory_RL.envs.tmaze import TMazeClassicPassive
from models.Memory_RL.policies.models.policy_rnn_dqn import ModelFreeOffPolicy_DQN_RNN
import os 

import numpy as np
import gym
import matplotlib.pyplot as plt
import random
import torch
import yaml
import time
from moviepy.editor import ImageSequenceClip, VideoFileClip


In [24]:
from configs.rl.name_fns import name_fn as name_fn1
from ml_collections import ConfigDict
from typing import Tuple
from torchkit import pytorch_utils as ptu

def dqn_name_fn(
    config: ConfigDict, max_episode_steps: int, max_training_steps: int
) -> Tuple[ConfigDict, str]:
    config, name = name_fn1(config)
    # set eps = 1/T, so that the asymptotic prob to
    # sample fully exploited trajectory during exploration is
    # (1-1/T)^T = 1/e
    config.init_eps = 1.0
    config.end_eps = 1.0 / max_episode_steps
    config.schedule_steps = config.schedule_end * max_training_steps

    return config, name


def get_rl_config():
    config = ConfigDict()
    config.name_fn = dqn_name_fn

    config.algo = "dqn"

    config.critic_lr = 3e-4

    config.config_critic = ConfigDict()
    config.config_critic.hidden_dims = (256, 256)

    config.discount = 0.99
    config.tau = 0.005
    config.schedule_end = 0.1  # at least good for TMaze-like envs

    config.replay_buffer_size = 1e6
    config.replay_buffer_num_episodes = 1e3

    return config

In [25]:
from ml_collections import ConfigDict
from typing import Tuple
from configs.seq_models.name_fns import name_fn


def attn_name_fn(config: ConfigDict, max_episode_steps: int) -> Tuple[ConfigDict, str]:
    config, name = name_fn(config, max_episode_steps)

    config.model.seq_model_config.hidden_size = 0
    if config.model.observ_embedder is not None:
        config.model.seq_model_config.hidden_size += (
            config.model.observ_embedder.hidden_size
        )
    if config.model.action_embedder is not None:
        config.model.seq_model_config.hidden_size += (
            config.model.action_embedder.hidden_size
        )
    if config.model.reward_embedder is not None:
        config.model.seq_model_config.hidden_size += (
            config.model.reward_embedder.hidden_size
        )

    config.model.seq_model_config.max_seq_length = (
        config.sampled_seq_len + 1
    )  # NOTE: zero-prepend

    return config, name


def get_seq_config():
    config = ConfigDict()
    config.name_fn = attn_name_fn

    config.is_markov = False
    config.is_attn = True
    config.use_dropout = True

    config.sampled_seq_len = -1

    config.clip = False
    config.max_norm = 1.0
    config.use_l2_norm = False

    # fed into Module
    config.model = ConfigDict()

    # seq_model_config specific
    config.model.seq_model_config = ConfigDict()
    config.model.seq_model_config.name = "gpt"

    config.model.seq_model_config.hidden_size = (
        128  # NOTE: will be overwritten by name_fn
    )
    config.model.seq_model_config.n_layer = 1
    config.model.seq_model_config.n_head = 1
    config.model.seq_model_config.pdrop = 0.1
    config.model.seq_model_config.position_encoding = "sine"

    # embedders
    config.model.observ_embedder = ConfigDict()
    config.model.observ_embedder.name = "mlp"
    config.model.observ_embedder.hidden_size = 64

    config.model.action_embedder = ConfigDict()
    config.model.action_embedder.name = "mlp"
    config.model.action_embedder.hidden_size = 64

    config.model.reward_embedder = ConfigDict()
    config.model.reward_embedder.name = "mlp"
    config.model.reward_embedder.hidden_size = 0

    return config

In [26]:
from itertools import permutations

def generate_permutations(nums):

    perms = permutations(nums)
    result = [int(''.join(map(str, perm))) for perm in perms]
    
    return result

In [27]:

# AGENT_CLASSES = {
#     "Policy_MLP": Policy_MLP,
#     "Policy_RNN_MLP": Policy_RNN_MLP,
#     "Policy_Separate_RNN": Policy_Separate_RNN,
#     "Policy_Shared_RNN": Policy_Shared_RNN,
#     "Policy_DQN_RNN": Policy_DQN_RNN,
# }
from torchkit.pytorch_utils import set_gpu_mode
set_gpu_mode('cuda', 0)

agent_class = ModelFreeOffPolicy_DQN_RNN
agent_arch = agent_class.ARCH

device = torch.device('cuda:0')
torch.set_default_device(device)

set device: cuda:0


In [77]:
#config_path = '/opt/Memory-RL-Codebase/configs/GTRXL_configs/MinigridMemory/Static/MinigridMemory_SHORT_TERM.yaml'
config_path = '/opt/Memory-RL-Codebase/configs/GTRXL_configs/Passive_T_Maze_Flag/Dense/Passive_T_Maze_Flag_SHORT_TERM.yaml'



episode_timeout = 18
corridor_length = episode_timeout - 2
penalty = -1/(episode_timeout - 1)



with open(config_path, 'r') as file:
    config = yaml.safe_load(file)

env = TMazeClassicPassive(episode_length=episode_timeout, 
                            corridor_length=corridor_length, 
                            goal_reward=1.0,
                            penalty=penalty)

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]
 [0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]


In [90]:
max_episode_steps = 18
max_training_steps = 999

config_seq, _ = attn_name_fn(get_seq_config(), max_episode_steps = max_episode_steps)
config_rl, _ = dqn_name_fn(config = get_rl_config(), max_episode_steps =max_episode_steps , max_training_steps =max_training_steps)



In [91]:
image_encoder_fn = lambda: None

obs_dim = env.observation_space.shape[0]
act_dim = 4

freeze_critic = False

agent = agent_class(
    obs_dim=obs_dim,
    action_dim=act_dim,
    config_seq=config_seq,
    config_rl=config_rl,
    image_encoder_fn=image_encoder_fn,
    freeze_critic=freeze_critic,
).to(device)

{'h.0.ln_1.weight': torch.Size([128]), 'h.0.ln_1.bias': torch.Size([128]), 'h.0.attn.c_attn.weight': torch.Size([128, 384]), 'h.0.attn.c_attn.bias': torch.Size([384]), 'h.0.attn.c_proj.weight': torch.Size([128, 128]), 'h.0.attn.c_proj.bias': torch.Size([128]), 'h.0.ln_2.weight': torch.Size([128]), 'h.0.ln_2.bias': torch.Size([128]), 'h.0.mlp.c_fc.weight': torch.Size([128, 512]), 'h.0.mlp.c_fc.bias': torch.Size([512]), 'h.0.mlp.c_proj.weight': torch.Size([512, 128]), 'h.0.mlp.c_proj.bias': torch.Size([128]), 'ln_f.weight': torch.Size([128]), 'ln_f.bias': torch.Size([128])}


In [102]:
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
from collections import OrderedDict, namedtuple


class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
    def __repr__(self):
        if not self.missing_keys and not self.unexpected_keys:
            return '<All keys matched successfully>'
        return super().__repr__()

    __str__ = __repr__

def load_state_dict(module, state_dict: Mapping[str, Any],
                    strict: bool = True, assign: bool = False):
    r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

    If :attr:`strict` is ``True``, then
    the keys of :attr:`state_dict` must exactly match the keys returned
    by this module's :meth:`~torch.nn.Module.state_dict` function.

    .. warning::
        If :attr:`assign` is ``True`` the optimizer must be created after
        the call to :attr:`load_state_dict` unless
        :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.

    Args:
        state_dict (dict): a dict containing parameters and
            persistent buffers.
        strict (bool, optional): whether to strictly enforce that the keys
            in :attr:`state_dict` match the keys returned by this module's
            :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
        assign (bool, optional): When ``False``, the properties of the tensors
            in the current module are preserved while when ``True``, the
            properties of the Tensors in the state dict are preserved. The only
            exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
            for which the value from the module is preserved.
            Default: ``False``

    Returns:
        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
            * **missing_keys** is a list of str containing the missing keys
            * **unexpected_keys** is a list of str containing the unexpected keys

    Note:
        If a parameter or buffer is registered as ``None`` and its corresponding key
        exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
        ``RuntimeError``.
    """
    if not isinstance(state_dict, Mapping):
        raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")

    missing_keys: List[str] = []
    unexpected_keys: List[str] = []
    error_msgs: List[str] = []

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = OrderedDict(state_dict)
    if metadata is not None:
        # mypy isn't aware that "_metadata" exists in state_dict
        state_dict._metadata = metadata  # type: ignore[attr-defined]


    def load(module, local_state_dict, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        if assign:
            local_metadata['assign_to_params_buffers'] = assign
        module._load_from_state_dict(
            local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                child_prefix = prefix + name + '.'
                child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
                load(child, child_state_dict, child_prefix)  # noqa: F821

        # Note that the hook can modify missing_keys and unexpected_keys.
        incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
        for hook in module._load_state_dict_post_hooks.values():
            out = hook(module, incompatible_keys)
            assert out is None, (
                "Hooks registered with ``register_load_state_dict_post_hook`` are not"
                "expected to return new values, if incompatible_keys need to be modified,"
                "it should be done inplace."
            )

    load(module, state_dict)
    del load

    if strict:
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0, 'Unexpected key(s) in state_dict: {}. '.format(
                    ', '.join(f'"{k}"' for k in unexpected_keys)))
        if len(missing_keys) > 0:
            error_msgs.insert(
                0, 'Missing key(s) in state_dict: {}. '.format(
                    ', '.join(f'"{k}"' for k in missing_keys)))

    if len(error_msgs) > 0:
        print('Error(s) in loading state_dict for {}:\n\t{}'.format(
                        module.__class__.__name__, "\n\t".join(error_msgs)))
        #raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
        #                module.__class__.__name__, "\n\t".join(error_msgs)))
    return _IncompatibleKeys(missing_keys, unexpected_keys)

In [92]:
#ckpt_path = '/opt/Memory-RL-Codebase/models/Memory_RL/logs/Passive_T_Maze_Flag/GPT_2_DQN/LONG_TERM/2024_09_30-19_20_51/best_agent.pt'

ckpt_path = '/opt/Memory-RL-Codebase/models/Memory_RL/logs_2024_09_30_12_00/GPT_2_DQN/2024_09_30-02_17_36/curr_agent.pt'

In [103]:
ckp = torch.load(ckpt_path, map_location=device)
# agent.load_state_dict(ckp, strict = False)

load_state_dict(agent, ckp)


Error(s) in loading state_dict for ModelFreeOffPolicy_DQN_RNN:
	size mismatch for critic.seq_model.transformer.h.0.attn.bias: copying a param with shape torch.Size([1, 1, 18, 18]) from checkpoint, the shape in current model is torch.Size([1, 1, 21, 21]).
	size mismatch for critic_target.seq_model.transformer.h.0.attn.bias: copying a param with shape torch.Size([1, 1, 18, 18]) from checkpoint, the shape in current model is torch.Size([1, 1, 21, 21]).


<All keys matched successfully>

In [116]:
agent.critic.seq_model.transformer.h[0].attn.bias #.shape

tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1,

In [117]:
ckp['critic.seq_model.transformer.h.0.attn.bias'] #.shape

tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
          

In [82]:
from utils import helpers as utl

deterministic = True
eval_episodes = 2

In [83]:
agent = agent.to(device)

In [84]:
nums = [1, 2, 3, 4, 5]
eval_seeds = generate_permutations(nums)

videos_limit = len(eval_seeds) + 1
n_episode = len(eval_seeds)


render = False

total_reward = 0
num_successes = 0
total_steps = 0

In [85]:

agent.eval()  # set to eval mode for deterministic dropout

returns_per_episode = np.zeros(n_episode)
success_rate = np.zeros(n_episode)
# total_steps = np.zeros(n_episode)

for task_idx in range(n_episode):
    step = 0
    running_reward = 0.0
    done_rollout = False

    if eval_seeds is not None and False:
        obs = ptu.from_numpy(env.reset(seed = eval_seeds[task_idx])).to(device)  # reset
    else:
        obs = ptu.from_numpy(env.reset()).to(device)  # reset

    obs = obs.reshape(1, obs.shape[-1])

    # assume initial reward = 0.0
    action, reward, internal_state = agent.get_initial_info(
        config_seq.sampled_seq_len
    )

    while not done_rollout:
        action, internal_state = agent.act(
            prev_internal_state=internal_state,
            prev_action=action.to(device),
            reward=reward.to(device),
            obs=obs.to(device),
            deterministic=deterministic,
        )


        # observe reward and next obs
        next_obs, reward, done, info = utl.env_step(
            env, action.squeeze(dim=0)
        )

        # add raw reward
        running_reward += reward.item()
        step += 1
        done_rollout = False if ptu.get_numpy(done[0][0]) == 0.0 else True

        # set: obs <- next_obs
        obs = next_obs.clone()

    #returns_per_episode[task_idx] = running_reward
    #total_steps[task_idx] = step
    if "success" in info and info["success"] == True:  # keytodoor
        success_rate[task_idx] = 1.0
        num_successes += 1
    
    total_reward += running_reward
    total_steps += step

    curr_seed = eval_seeds[task_idx]
    print(f'Episode: {task_idx}, seed: {curr_seed} Reward: {running_reward}, Steps: {step} Mean reward: {total_reward / (task_idx + 1)}, Mean steps: {total_steps / (task_idx + 1)}')


print(f'Total num episodes: {n_episode} Success rate: {num_successes / n_episode}, Mean reward: {total_reward / n_episode}, Mean steps: {total_steps / n_episode}')


Episode: 0, seed: 12345 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0
Episode: 1, seed: 12354 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0
Episode: 2, seed: 12435 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0
Episode: 3, seed: 12453 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0
Episode: 4, seed: 12534 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0
Episode: 5, seed: 12543 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0
Episode: 6, seed: 13245 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0
Episode: 7, seed: 13254 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.1764705888926983, Mean steps: 18.0
Episode: 8, seed: 13425 Reward: -0.1764705888926983, Steps: 18 Mean reward: -0.176470588