In [6]:
import sys
sys.path.append("../../")
sys.path.append("../../models/DTQN")

from environments.Passive_T_Maze_Flag.env.env_passive_t_maze_flag import TMazeClassicPassive
from models.DTQN.dtqn.agents.dtqn import DtqnAgent
from models.DTQN.utils.agent_utils import get_agent

import numpy as np
import gym
import matplotlib.pyplot as plt
import random
import torch
import yaml

In [7]:
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict, List
from collections import OrderedDict, namedtuple


class _IncompatibleKeys(namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys'])):
    def __repr__(self):
        if not self.missing_keys and not self.unexpected_keys:
            return '<All keys matched successfully>'
        return super().__repr__()

    __str__ = __repr__

def load_state_dict(module, state_dict: Mapping[str, Any],
                    strict: bool = True, assign: bool = False):
    r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.

    If :attr:`strict` is ``True``, then
    the keys of :attr:`state_dict` must exactly match the keys returned
    by this module's :meth:`~torch.nn.Module.state_dict` function.

    .. warning::
        If :attr:`assign` is ``True`` the optimizer must be created after
        the call to :attr:`load_state_dict` unless
        :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.

    Args:
        state_dict (dict): a dict containing parameters and
            persistent buffers.
        strict (bool, optional): whether to strictly enforce that the keys
            in :attr:`state_dict` match the keys returned by this module's
            :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
        assign (bool, optional): When ``False``, the properties of the tensors
            in the current module are preserved while when ``True``, the
            properties of the Tensors in the state dict are preserved. The only
            exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
            for which the value from the module is preserved.
            Default: ``False``

    Returns:
        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
            * **missing_keys** is a list of str containing the missing keys
            * **unexpected_keys** is a list of str containing the unexpected keys

    Note:
        If a parameter or buffer is registered as ``None`` and its corresponding key
        exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
        ``RuntimeError``.
    """
    if not isinstance(state_dict, Mapping):
        raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")

    missing_keys: List[str] = []
    unexpected_keys: List[str] = []
    error_msgs: List[str] = []

    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, '_metadata', None)
    state_dict = OrderedDict(state_dict)
    if metadata is not None:
        # mypy isn't aware that "_metadata" exists in state_dict
        state_dict._metadata = metadata  # type: ignore[attr-defined]


    def load(module, local_state_dict, prefix=''):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        if assign:
            local_metadata['assign_to_params_buffers'] = assign
        module._load_from_state_dict(
            local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
        for name, child in module._modules.items():
            if child is not None:
                child_prefix = prefix + name + '.'
                child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
                load(child, child_state_dict, child_prefix)  # noqa: F821

        # Note that the hook can modify missing_keys and unexpected_keys.
        incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
        for hook in module._load_state_dict_post_hooks.values():
            out = hook(module, incompatible_keys)
            assert out is None, (
                "Hooks registered with ``register_load_state_dict_post_hook`` are not"
                "expected to return new values, if incompatible_keys need to be modified,"
                "it should be done inplace."
            )

    load(module, state_dict)
    del load

    if strict:
        if len(unexpected_keys) > 0:
            error_msgs.insert(
                0, 'Unexpected key(s) in state_dict: {}. '.format(
                    ', '.join(f'"{k}"' for k in unexpected_keys)))
        if len(missing_keys) > 0:
            error_msgs.insert(
                0, 'Missing key(s) in state_dict: {}. '.format(
                    ', '.join(f'"{k}"' for k in missing_keys)))

    if len(error_msgs) > 0:
        print('Error(s) in loading state_dict for {}:\n\t{}'.format(
                        module.__class__.__name__, "\n\t".join(error_msgs)))
        #raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
        #                module.__class__.__name__, "\n\t".join(error_msgs)))
    return _IncompatibleKeys(missing_keys, unexpected_keys)

from itertools import permutations

def generate_permutations(nums):

    perms = permutations(nums)
    result = [int(''.join(map(str, perm))) for perm in perms]
    
    return result


In [9]:
config_path = '/opt/Memory-RL-Codebase/configs/DTQN_configs/Passive_T_Maze_Flag/Dense/ICLR_exp_1/Passive_T_Maze_Flag_SHORT_TERM.yaml'



episode_timeout = 20
corridor_length = episode_timeout - 2
penalty = -1/(episode_timeout - 1)

device = torch.device('cuda:0')


with open(config_path, 'r') as file:
    args = yaml.safe_load(file)

env = TMazeClassicPassive(episode_length=episode_timeout, 
                            corridor_length=corridor_length, 
                            goal_reward=1.0,
                            penalty=penalty)


# Checkpoint 1

In [10]:
checkpoint_path = '/opt/Memory-RL-Codebase/autorun/checkpoints/Passive_T_Maze_Flag/DTQN/DTQN_Passive_T_Maze_Flag_SHORT_TERM_dense/2024_09_28-23_43_40.pt'


args['inembed'] = 64
args['context'] = episode_timeout

# Checkpoint 2

In [28]:
# args['pos'] = 'sin'

In [11]:
agent = get_agent(
        args['model'],
        env,
        env,
        args['obsembed'],
        args['inembed'],
        args['buf_size'],
        device,
        args['lr'],
        args['batch'],
        args['context'],
        args['history'],
        args['num_steps'],
        # DTQN specific
        args['heads'],
        args['layers'],
        args['dropout'],
        args['identity'],
        args['gate'],
        args['pos'],
    )

MultiDiscrete([3 3 2 3], start=[-1 -1  0 -1]) Discrete(4)
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
LayerNorm weigths init!
MultiDiscrete([3 3 2 3], start=[-1 -1  0 -1]) Discrete(4)


In [12]:
# att_mask = agent.policy_network.transformer_layers[1].attn_mask

In [36]:
agent.policy_network.state_dict().keys()

odict_keys(['position_embedding', 'obs_embedding.embedding.weight', 'obs_embedding.embedding.bias', 'transformer_layers.0.attn_mask', 'transformer_layers.0.layernorm1.weight', 'transformer_layers.0.layernorm1.bias', 'transformer_layers.0.layernorm2.weight', 'transformer_layers.0.layernorm2.bias', 'transformer_layers.0.attention.in_proj_weight', 'transformer_layers.0.attention.in_proj_bias', 'transformer_layers.0.attention.out_proj.weight', 'transformer_layers.0.attention.out_proj.bias', 'transformer_layers.0.ffn.0.weight', 'transformer_layers.0.ffn.0.bias', 'transformer_layers.0.ffn.2.weight', 'transformer_layers.0.ffn.2.bias', 'transformer_layers.1.attn_mask', 'transformer_layers.1.layernorm1.weight', 'transformer_layers.1.layernorm1.bias', 'transformer_layers.1.layernorm2.weight', 'transformer_layers.1.layernorm2.bias', 'transformer_layers.1.attention.in_proj_weight', 'transformer_layers.1.attention.in_proj_bias', 'transformer_layers.1.attention.out_proj.weight', 'transformer_layers.

In [49]:
ckp  = torch.load(checkpoint_path)

In [51]:
ckp['policy_net_state_dict'].keys()

odict_keys(['position_embedding', 'obs_embedding.embedding.weight', 'obs_embedding.embedding.bias', 'transformer_layers.0.attn_mask', 'transformer_layers.0.layernorm1.weight', 'transformer_layers.0.layernorm1.bias', 'transformer_layers.0.layernorm2.weight', 'transformer_layers.0.layernorm2.bias', 'transformer_layers.0.attention.in_proj_weight', 'transformer_layers.0.attention.in_proj_bias', 'transformer_layers.0.attention.out_proj.weight', 'transformer_layers.0.attention.out_proj.bias', 'transformer_layers.0.ffn.0.weight', 'transformer_layers.0.ffn.0.bias', 'transformer_layers.0.ffn.2.weight', 'transformer_layers.0.ffn.2.bias', 'transformer_layers.1.attn_mask', 'transformer_layers.1.layernorm1.weight', 'transformer_layers.1.layernorm1.bias', 'transformer_layers.1.layernorm2.weight', 'transformer_layers.1.layernorm2.bias', 'transformer_layers.1.attention.in_proj_weight', 'transformer_layers.1.attention.in_proj_bias', 'transformer_layers.1.attention.out_proj.weight', 'transformer_layers.

In [28]:
ckp['policy_net_state_dict']['position_embedding'].shape

torch.Size([1, 15, 64])

In [25]:
ckp['policy_net_state_dict']['position_embedding'][:, :12, :].shape

torch.Size([1, 12, 64])

In [85]:
(ckp['policy_net_state_dict']['obs_embedding.embedding.weight'] != agent.policy_network.obs_embedding.embedding.weight).sum()

tensor(0, device='cuda:0')

In [31]:
agent.policy_network.position_embedding.shape

torch.Size([1, 12, 64])

In [64]:
agent.policy_network.position_embedding = torch.nn.Parameter(ckp['policy_net_state_dict']['position_embedding'][:, :13, :])

In [12]:
agent.load_checkpoint(checkpoint_path)

'tensorboard'

In [13]:
load_state_dict(agent.policy_network, ckp['policy_net_state_dict'])
load_state_dict(agent.target_network, ckp['target_net_state_dict'])


Error(s) in loading state_dict for DTQN:
	size mismatch for position_embedding: copying a param with shape torch.Size([1, 15, 64]) from checkpoint, the shape in current model is torch.Size([1, 20, 64]).
	size mismatch for transformer_layers.0.attn_mask: copying a param with shape torch.Size([15, 15]) from checkpoint, the shape in current model is torch.Size([20, 20]).
	size mismatch for transformer_layers.1.attn_mask: copying a param with shape torch.Size([15, 15]) from checkpoint, the shape in current model is torch.Size([20, 20]).
	size mismatch for transformer_layers.2.attn_mask: copying a param with shape torch.Size([15, 15]) from checkpoint, the shape in current model is torch.Size([20, 20]).
	size mismatch for transformer_layers.3.attn_mask: copying a param with shape torch.Size([15, 15]) from checkpoint, the shape in current model is torch.Size([20, 20]).
	size mismatch for transformer_layers.4.attn_mask: copying a param with shape torch.Size([15, 15]) from checkpoint, the shape

<All keys matched successfully>

In [86]:
agent.policy_network.transformer_layers[1].attn_mask.size()

torch.Size([12, 12])

In [14]:

nums = [1, 2, 3, 4, 5]
eval_seeds = generate_permutations(nums)

videos_limit = len(eval_seeds) + 1
n_episode = len(eval_seeds)

videos_dir = '/opt/Memory-RL-Codebase/eval/Minigrid_Memory/DTQN'

run_name = checkpoint_path.split('/')[-1].strip('.pt')
run_type = checkpoint_path.split('/')[-2]

In [15]:
agent.evaluate(n_episode = n_episode, eval_seeds = eval_seeds, render = False, videos_limit = videos_limit, videos_dir = videos_dir, run_name = run_name, run_type =run_type )

Evaluate!
Episode: 0, seed: 12345 Reward: -0.05263157894736842, Steps: 20 Mean reward: -0.05263157894736842, Mean steps: 20.0
Episode: 1, seed: 12354 Reward: -0.05263157894736842, Steps: 20 Mean reward: -0.05263157894736842, Mean steps: 20.0
Episode: 2, seed: 12435 Reward: -0.05263157894736842, Steps: 20 Mean reward: -0.05263157894736842, Mean steps: 20.0
Episode: 3, seed: 12453 Reward: -0.05263157894736842, Steps: 20 Mean reward: -0.05263157894736842, Mean steps: 20.0
Episode: 4, seed: 12534 Reward: -0.05263157894736842, Steps: 20 Mean reward: -0.05263157894736842, Mean steps: 20.0
Episode: 5, seed: 12543 Reward: -0.05263157894736842, Steps: 20 Mean reward: -0.05263157894736842, Mean steps: 20.0
Episode: 6, seed: 13245 Reward: -0.05263157894736842, Steps: 20 Mean reward: -0.05263157894736842, Mean steps: 20.0
Episode: 7, seed: 13254 Reward: -0.05263157894736842, Steps: 20 Mean reward: -0.05263157894736842, Mean steps: 20.0
Episode: 8, seed: 13425 Reward: 1.0, Steps: 19 Mean reward: 0.

(0.2, 0.14473684210526327, 19.825)

In [None]:
agent.evaluate(n_episode = n_episode, eval_seeds = eval_seeds, render = False, videos_limit = videos_limit, videos_dir = videos_dir, run_name = run_name, run_type =run_type )

In [7]:
agent.evaluate(n_episode = n_episode, eval_seeds = eval_seeds, render = False, videos_limit = videos_limit, videos_dir = videos_dir, run_name = run_name, run_type =run_type )

Evaluate!
Episode: 0, seed: 12345 Reward: -0.9090909090909093, Steps: 12 Mean reward: -0.9090909090909093, Mean steps: 12.0
Episode: 1, seed: 12354 Reward: -1.0000000000000002, Steps: 12 Mean reward: -0.9545454545454548, Mean steps: 12.0
Episode: 2, seed: 12435 Reward: -1.0000000000000002, Steps: 12 Mean reward: -0.9696969696969701, Mean steps: 12.0
Episode: 3, seed: 12453 Reward: -1.0000000000000002, Steps: 12 Mean reward: -0.9772727272727275, Mean steps: 12.0
Episode: 4, seed: 12534 Reward: -1.0000000000000002, Steps: 12 Mean reward: -0.981818181818182, Mean steps: 12.0
Episode: 5, seed: 12543 Reward: -0.9090909090909093, Steps: 12 Mean reward: -0.9696969696969698, Mean steps: 12.0
Episode: 6, seed: 13245 Reward: -0.9090909090909093, Steps: 12 Mean reward: -0.9610389610389612, Mean steps: 12.0
Episode: 7, seed: 13254 Reward: -1.0000000000000002, Steps: 12 Mean reward: -0.965909090909091, Mean steps: 12.0
Episode: 8, seed: 13425 Reward: -0.9090909090909093, Steps: 12 Mean reward: -0.9

(0.0, -0.9462121212121203, 12.0)

In [5]:
config_path = '/opt/Memory-RL-Codebase/configs/DTQN_configs/Passive_T_Maze_Flag/Dense/ICLR_exp_1/Passive_T_Maze_Flag_SHORT_TERM.yaml'
testing_res = []

for episode_timeout in [15, 20, 30]:

    episode_timeout = 40
    corridor_length = episode_timeout - 2
    penalty = -1/(episode_timeout - 1)

    device = torch.device('cuda:0')


    with open(config_path, 'r') as file:
        args = yaml.safe_load(file)

    env = TMazeClassicPassive(episode_length=episode_timeout, 
                                corridor_length=corridor_length, 
                                goal_reward=1.0,
                                penalty=penalty)

    checkpoint_path = '/opt/Memory-RL-Codebase/autorun/checkpoints/Passive_T_Maze_Flag/DTQN/DTQN_Passive_T_Maze_Flag_SHORT_TERM_dense/2024_09_28-23_43_40.pt'


    args['inembed'] = 64
    args['context'] = episode_timeout


    agent = get_agent(
        args['model'],
        env,
        env,
        args['obsembed'],
        args['inembed'],
        args['buf_size'],
        device,
        args['lr'],
        args['batch'],
        args['context'],
        args['history'],
        args['num_steps'],
        # DTQN specific
        args['heads'],
        args['layers'],
        args['dropout'],
        args['identity'],
        args['gate'],
        args['pos'],
    )
    

    ckp  = torch.load(checkpoint_path)
    load_state_dict(agent.policy_network, ckp['policy_net_state_dict'])
    load_state_dict(agent.target_network, ckp['target_net_state_dict'])

    nums = [1, 2, 3, 4, 5]

    eval_seeds = generate_permutations(nums)

    videos_limit = len(eval_seeds) + 1
    n_episode = len(eval_seeds)

    videos_dir = '/opt/Memory-RL-Codebase/eval/Minigrid_Memory/DTQN'

    run_name = checkpoint_path.split('/')[-1].strip('.pt')
    run_type = checkpoint_path.split('/')[-2]

    sr, rw, sr = agent.evaluate(n_episode = n_episode, eval_seeds = eval_seeds, render = False, videos_limit = videos_limit, videos_dir = videos_dir, run_name = run_name, run_type =run_type )

    testing_res.append((episode_timeout, sr, rw, sr))

    print(f'res: {(episode_timeout, sr, rw, sr)}')

Episode: 8, seed: 13425 Reward: 1.0, Steps: 39 Mean reward: 0.31339031339031337, Mean steps: 39.77777777777778
Episode: 9, seed: 13452 Reward: -0.02564102564102564, Steps: 40 Mean reward: 0.2794871794871795, Mean steps: 39.8
Episode: 10, seed: 13524 Reward: 1.0, Steps: 39 Mean reward: 0.34498834498834496, Mean steps: 39.72727272727273


KeyboardInterrupt: 