In [1]:
# Notebook support or argpase
import sys; sys.argv=['']; del sys

In [7]:
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F

import os
import apex
import copy
import random
import numpy as np
import matplotlib as mpl
import compress_pickle as cpkl

import rsatoolbox
from torchinfo import summary

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
# plt.style.use("seaborn-darkgrid")

import models
from models import GW_Actor, GRU_Actor

from mpl_toolkits.axes_grid1 import make_axes_locatable
%matplotlib inline

from ss_baselines.common.utils import plot_top_down_map


mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["axes.facecolor"] = "white"
mpl.rcParams["savefig.facecolor"] = "white" 

In [8]:
# General config related
from configurator import get_arg_dict, generate_args

# Env config related
from ss_baselines.av_nav.config import get_config
from ss_baselines.savi.config.default import get_config as get_savi_config
from ss_baselines.common.env_utils import construct_envs
from ss_baselines.common.environments import get_env_class

# region: Generating additional hyparams
CUSTOM_ARGS = [
	# General hyper parameters
	get_arg_dict("seed", int, 111),
	get_arg_dict("total-steps", int, 1_000_000),

	# Behavior cloning gexperiment config
	get_arg_dict("dataset-path", str, "SAVI_Oracle_Dataset_v0"),

	# SS env config
	get_arg_dict("config-path", str, "env_configs/savi/savi_ss1_rgb_spectro.yaml"),

	# PPO Hyper parameters
	get_arg_dict("num-envs", int, 10), # Number of parallel envs. 10 by default
	get_arg_dict("num-steps", int, 150), # For each env, how many steps are collected to form PPO Agent rollout.
	get_arg_dict("num-minibatches", int, 1), # Number of mini-batches the rollout data is split into to make the updates
	get_arg_dict("update-epochs", int, 4), # Number of gradient step for the policy and value networks
	get_arg_dict("gamma", float, 0.99),
	get_arg_dict("gae-lambda", float, 0.95),
	get_arg_dict("norm-adv", bool, True, metatype="bool"),
	get_arg_dict("clip-coef", float, 0.1), # Surrogate loss clipping coefficient
	get_arg_dict("clip-vloss", bool, True, metatype="bool"),
	get_arg_dict("ent-coef", float, 0.2), # Entropy loss coef; 0.2 in SS baselines
	get_arg_dict("vf-coef", float, 0.5), # Value loss coefficient
	get_arg_dict("max-grad-norm", float, 0.5),
	get_arg_dict("target-kl", float, None),
	get_arg_dict("lr", float, 2.5e-4), # Learning rate
	get_arg_dict("optim-wd", float, 0), # weight decay for adam optim
	## Agent network params
	get_arg_dict("agent-type", str, "gw", metatype="choice",
		choices=["gw", "gru"]),
	get_arg_dict("gru-type", str, "layernorm", metatype="choice",
					choices=["default", "layernorm"]),
	get_arg_dict("hidden-size", int, 512), # Size of the visual / audio features

	## BC related hyper parameters
	get_arg_dict("batch-chunk-length", int, 0), # For gradient accumulation
	get_arg_dict("dataset-ce-weights", bool, False, metatype="bool"), # If True, will read CEL weights based on action dist. from the 'dataset_statistics.bz2' file.
	get_arg_dict("ce-weights", float, None, metatype="list"), # Weights for the Cross Entropy loss

	## GW Agent with custom attention, recurrent encoder and null inputs
	get_arg_dict("gw-size", int, 512), # Dim of the GW vector
	get_arg_dict("recenc-use-gw", bool, True, metatype="bool"), # Use GW at Recur. Enc. level
	get_arg_dict("recenc-gw-detach", bool, True, metatype="bool"), # When using GW at Recurrent Encoder level, whether to detach the grads or not
	get_arg_dict("gw-use-null", bool, True, metatype="bool"), # Use Null at CrossAtt level
	get_arg_dict("gw-cross-heads", int, 1), # num_heads of the CrossAttn

	# Eval protocol
	get_arg_dict("eval", bool, True, metatype="bool"),
	get_arg_dict("eval-every", int, int(1.5e4)), # Every X frames || steps sampled
	get_arg_dict("eval-n-episodes", int, 5),

	# Logging params
	# NOTE: Video logging expensive
	get_arg_dict("save-videos", bool, False, metatype="bool"),
	get_arg_dict("save-model", bool, True, metatype="bool"),
	get_arg_dict("save-model-every", int, int(5e5)), # Every X frames || steps sampled
	get_arg_dict("log-sampling-stats-every", int, int(1.5e3)), # Every X frames || steps sampled
	get_arg_dict("log-training-stats-every", int, int(10)), # Every X model update
	get_arg_dict("logdir-prefix", str, "./logs/") # Overrides the default one
]
args = generate_args(CUSTOM_ARGS)
# endregion: Generating additional hyparams

# Load environment config
is_SAVi = str.__contains__(args.config_path, "savi")
if is_SAVi:
	env_config = get_savi_config(config_paths=args.config_path)
else:
	env_config = get_config(config_paths=args.config_path)

## Instantiate obs / act space based on args and env_config

In [9]:
# Overriding some envs parametes from the .yaml env config
env_config.defrost()
env_config.NUM_PROCESSES = 1 # Corresponds to number of envs, makes script startup faster for debugs
env_config.USE_SYNC_VECENV = True
# env_config.USE_VECENV = False
# env_config.CONTINUOUS = args.env_continuous
## In caes video saving is enabled, make sure there is also the rgb videos
env_config.freeze()
# print(env_config)

# Environment instantiation
# envs = construct_envs(env_config, get_env_class(env_config.ENV_NAME))
# Dummy environment spaces

# TODO: add dyanmicallly set single_observation_space so that RGB and RGBD based variants
# can be evaluated at thet same time
from gym import spaces
single_action_space = spaces.Discrete(4)
single_observation_space = spaces.Dict({
    "rgb": spaces.Box(shape=[128,128,3], low=0, high=255, dtype=np.uint8),
    # "depth": spaces.Box(shape=[128,128,1], low=0, high=255, dtype=np.uint8),
    "audiogoal": spaces.Box(shape=[2,16000], low=-3.4028235e+38, high=3.4028235e+38, dtype=np.float32),
    "spectrogram": spaces.Box(shape=[65,26,2], low=-3.4028235e+38, high=3.4028235e+38, dtype=np.float32)
})
# single_observation_space = envs.observation_spaces[0]
# single_action_space = envs.action_spaces[0]

single_observation_space, single_action_space

(Dict(audiogoal:Box(-3.4028235e+38, 3.4028235e+38, (2, 16000), float32), rgb:Box(0, 255, (128, 128, 3), uint8), spectrogram:Box(-3.4028235e+38, 3.4028235e+38, (65, 26, 2), float32)),
 Discrete(4))

# Loading the Category-Scene-Trajs, Scene-Category-Trajs, and Dataset's metadata

### Loads data for analysis, as well as dataset's metadata

In [5]:
# Specify file name
analysis_trajs_filename = "cats_scenes_trajs_C_6_M_5_N_5__2023_06_01_10_41.bz2"

# Read the filtred trajectories data
## Default format is {cat -> { scenes -> traj: []}}
with open(analysis_trajs_filename, "rb") as f:
    cats_scenes_trajs_dict = cpkl.load(f)

## Compute the equivalent scenes cat trajs format
## {scenes -> { cat -> trajs: []}}
scenes_cats_trajs_dict = {}
for cat, cat_scenes_trajs in cats_scenes_trajs_dict.items():
    for scene, scenes_trajs in cat_scenes_trajs.items():
        if scene not in scenes_cats_trajs_dict.keys():
            scenes_cats_trajs_dict[scene] = {}
        
        scenes_cats_trajs_dict[scene][cat] = scenes_trajs

# Generic: load the dataset statistics
## Compute action coefficient for CEL of BC
dataset_stats_filepath = f"{args.dataset_path}/dataset_statistics.bz2"
# Override dataset statistics if the file already exists
if os.path.exists(dataset_stats_filepath):
    with open(dataset_stats_filepath, "rb") as f:
        dataset_statistics = cpkl.load(f)

# Extract some global metadata
# TARGET_SCENE_LIST = list(cats_scenes_trajs_dict[list(cats_scenes_trajs_dict.keys())[0]].keys())
TARGET_SCENE_LIST = list(dataset_statistics["scene_counts"].keys())
TARGET_SCENE_DICT = {scene: i for i, scene in enumerate(TARGET_SCENE_LIST)}
TARGET_CATEGORY_LIST = list(cats_scenes_trajs_dict.keys())
TARGET_CATEGORY_DICT = {cat: i for i, cat in enumerate(TARGET_CATEGORY_LIST)}

from soundspaces.mp3d_utils import CATEGORY_INDEX_MAPPING
def get_category_name(idx):
    assert idx >= 0 and idx <=20, f"Invalid category index number: {idx}"

    for catname, catidx in CATEGORY_INDEX_MAPPING.items():
        if catidx == idx:
            return catname

def get_sceneid_by_idx(scene_idx):
    for k, v in TARGET_SCENE_DICT.items():
        if v == scene_idx:
            return k

C = len(TARGET_CATEGORY_LIST) # C: total number of categories
M = len(TARGET_SCENE_LIST) # M: total number of rooms, assuming all categories has N trajs for a same set of scenes.

print(f"# of categories C: {C} | # of scenes: {M}")
print(f"TARGET_CATEGORY_DICT: {TARGET_CATEGORY_DICT}")
print(f"TARGET_SCENE_DICT: {TARGET_SCENE_DICT}")
print("")

# for catname, cat_scenes_trajs in cats_scenes_trajs_dict.items():
#     print(f"Cat: {catname}; Scenes: {[k for k in cat_scenes_trajs.keys()]}")

# Basic check of the scene -> categories fileted trajectories
# for scene, scenes_cat_trajs in scenes_cats_trajs_dict.items():
#     print(f"Scene: {scene}; Cats: {[k for k in scenes_cat_trajs.keys()]}")

# More detailed breakdown of the trajectories per categories then scenes
for catname, cat_scenes_trajs in cats_scenes_trajs_dict.items():
    print(f"{catname}:")
    for scene, scene_trajs in cat_scenes_trajs.items():
        traj_lengths = [len(traj_data["edd"]["done_list"]) for traj_data in scene_trajs]
        print(f"\t{scene}: {traj_lengths}")
    print("")

# More detailed breakdown of the trajectories per categories then scenes
for scene, scene_cats_trajs in scenes_cats_trajs_dict.items():
    print(f"{scene}")
    for cat, cat_trajs in scene_cats_trajs.items():
        traj_lengths = [len(traj_data["edd"]["done_list"]) for traj_data in cat_trajs]
        print(f"\t{cat}: {traj_lengths}")
    print("")

# of categories C: 6 | # of scenes: 56
TARGET_CATEGORY_DICT: {'chair': 0, 'picture': 1, 'table': 2, 'cushion': 3, 'cabinet': 4, 'plant': 5}
TARGET_SCENE_DICT: {'gTV8FGcVJC9': 0, '5LpN3gDmAk7': 1, 'vyrNrziPKCB': 2, 'b8cTxDM8gDG': 3, 'Vvot9Ly1tCj': 4, 'rPc6DW4iMge': 5, 'PuKPg4mmafe': 6, '759xd9YjKW5': 7, 'ZMojNkEp431': 8, 'VzqfbhrpDEA': 9, 'ac26ZMwG7aT': 10, 'D7N2EKCX4Sj': 11, 'E9uDoFAP3SH': 12, 'S9hNv5qa7GM': 13, '5q7pvUzZiYa': 14, 'kEZ7cmS4wCh': 15, 'VFuaQ6m2Qom': 16, '7y3sRwLe3Va': 17, 'p5wJjkQkbXX': 18, 'V2XKFyX4ASd': 19, 'VVfe2KiqLaN': 20, 'mJXqzFtmKg4': 21, 'SN83YJsR3w2': 22, 'EDJbREhghzL': 23, 'PX4nDJXEHrG': 24, 'JmbYfDe2QKZ': 25, 'r1Q1Z4BcV1o': 26, 'aayBHfsNo7d': 27, 'r47D5H71a5s': 28, 'pRbA3pwrgk9': 29, 'Pm6F8kyY3z2': 30, 'sKLMLpTHeUy': 31, 'GdvgFV5R1Z5': 32, 'e9zR4mvMWw7': 33, 'JeFG25nYj2p': 34, 'B6ByNegPMKs': 35, 'uNb9QFRL6hY': 36, 'cV4RVeZvu5T': 37, 'D7G3Y4RVNrH': 38, 'XcA2TqTSSAj': 39, 'ur6pFq6Qu1A': 40, '29hnd4uzFmX': 41, 's8pcmisQ38h': 42, 'qoiz87JEwZ2': 43, 'ULsKaCPVFJR':

## Helpers to extract traj. data based on "category", "scene", etc...

In [10]:
# region: Categories -> Scenes
## cats_scenes_trajs_dict: dictionary structured as: {category: {scene: [traj_data]}}
# TODO: add support for the device in case tensors are returned
def get_traj_data_by_category_scene_trajIdx(trajs_dicts, category, scene, trajIdx=0, tensorize=False, device="cpu"):
    # Get a single trajectory specified by idx, for a specificed category and scene
    # TODO: maybe fix the "depth" dimension here directly ?
    obs_list_dict = trajs_dicts[category][scene][trajIdx]["edd"]["obs_list"]
    done_list = trajs_dicts[category][scene][trajIdx]["edd"]["done_list"]

    obs_dict_list = []
    target_scene_idx_list, target_category_idx_list = [], []

    T = len(obs_list_dict["rgb"])
    for t in range(T):
        obs_dict_list.append({k: v[t] for k, v in obs_list_dict.items()})
        target_scene_idx_list.append(TARGET_SCENE_DICT[scene])
        target_category_idx_list.append(CATEGORY_INDEX_MAPPING[category])

    # Tensorize if required
    if tensorize:
        done_list__th = []
        obs_dict_list__th = []

        for t, (obs_dict, done) in enumerate(zip(obs_dict_list, done_list)):
            # done_list__th.append(th.Tensor(np.array([done])[None, :]))
            done_list__th.append(th.Tensor(np.array([done])).to(device)) # TODO: make sure that the deprecation warning stops showing up. Or always stay on current Torch version.
            tmp_dict = {}
            for k, v in obs_dict.items():
                if k == "depth":
                    v = np.array(v)[:, :, None] # From (H, W) -> (H, W, 1)
                tmp_dict[k] = th.Tensor(v)[None, :].to(device)
            
            obs_dict_list__th.append(tmp_dict)
        
        return obs_dict_list__th, done_list__th, target_scene_idx_list, target_category_idx_list

    return obs_dict_list, done_list, target_scene_idx_list, target_category_idx_list

def get_traj_data_by_category_scene(trajs_dicts, category, scene, max_scenes=0, tensorize=False, device="cpu"):
    # Get all trajectories for a specific category and scene
    obs_dict_list, done_list = [], []
    target_scene_idx_list, target_category_idx_list = [], []

    N_SCENES = len(trajs_dicts[category][scene])
    res_n_scenes = N_SCENES if max_scenes <= 0 else max_scenes

    for i in range(N_SCENES):
        traj_obs_dict_list, traj_done_list, target_scene_idxes, target_category_idxes = \
            get_traj_data_by_category_scene_trajIdx(trajs_dicts, category, scene, i, tensorize=tensorize, device=device)

        obs_dict_list.extend(traj_obs_dict_list)
        done_list.extend(traj_done_list)
        target_scene_idx_list.extend(target_scene_idxes)
        target_category_idx_list.extend(target_category_idxes)

        traj_length = len(traj_done_list)
        # print(f"Selected traj of length: {traj_length}")
        if i >= res_n_scenes - 1:
            break

    return obs_dict_list, done_list, target_scene_idx_list, target_category_idx_list

def get_traj_data_by_category(trajs_dicts, category, max_scenes=0, tensorize=False, device="cpu"):
    # Get all trajectories for a specific category, across all scenes and all trajectories
    obs_dict_list, done_list =[], []
    target_scene_idx_list, target_category_idx_list = [], []

    for scene in trajs_dicts[category].keys():
        scene_obs_dict_list, scene_done_list, target_scene_idxes, target_category_idxes = \
            get_traj_data_by_category_scene(trajs_dicts, category, scene, max_scenes=max_scenes, tensorize=tensorize, device=device)

        obs_dict_list.extend(scene_obs_dict_list)
        done_list.extend(scene_done_list)
        target_scene_idx_list.extend(target_scene_idxes)
        target_category_idx_list.extend(target_category_idxes)
    
    return obs_dict_list, done_list, target_scene_idx_list, target_category_idx_list

def get_all_traj_data_by_category(trajs_dicts, tensorize=False, device="cpu"):
    # Get all trajectories for a specific category, across all scenes and all trajectories
    obs_dict_list, done_list =[], []
    target_scene_idx_list, target_category_idx_list = [], []

    for cat in trajs_dicts.keys():
        cat_scene_obs_dict_list, cat_scene_done_list, cat_target_scene_idxes, cat_target_category_idxes = \
            get_traj_data_by_category(trajs_dicts, cat, tensorize=tensorize, device=device)

        obs_dict_list.extend(cat_scene_obs_dict_list)
        done_list.extend(cat_scene_done_list)
        target_scene_idx_list.extend(cat_target_scene_idxes)
        target_category_idx_list.extend(cat_target_category_idxes)
    
    return obs_dict_list, done_list, target_scene_idx_list, target_category_idx_list
# endregion: Categories -> Scenes


# region: Scenes -> Categories
# TODO: add "return" for target categories and scenes label
## scenes_cats_trajs_dict: dictionary structured as: {scene: {category: [traj-data]}}
def get_traj_data_by_scene_category_trajIdx(trajs_dicts, scene, category, trajIdx=0, tensorize=False, device="cpu"):
    # Get a single trajectory specified by idx, for a specificed category and scene
    # TODO: maybe fix the "depth" dimension here directly ?
    obs_list_dict = trajs_dicts[scene][category][trajIdx]["edd"]["obs_list"]
    done_list = trajs_dicts[scene][category][trajIdx]["edd"]["done_list"]
    target_scene_idx_list, target_category_idx_list = [], []

    obs_dict_list = []
    T = len(obs_list_dict["rgb"])
    for t in range(T):
        obs_dict_list.append({k: v[t] for k, v in obs_list_dict.items()})
        target_scene_idx_list.append(TARGET_SCENE_DICT[scene])
        target_category_idx_list.append(CATEGORY_INDEX_MAPPING[category])

    # Tensorize if required
    if tensorize:
        done_list__th = []
        obs_dict_list__th = []

        for t, (obs_dict, done) in enumerate(zip(obs_dict_list, done_list)):
            # done_list__th.append(th.Tensor(np.array([done])[None, :]))
            done_list__th.append(th.Tensor(np.array([done])).to(device)) # TODO: make sure that the deprecation warning stops showing up. Or always stay on current Torch version.
            tmp_dict = {}
            for k, v in obs_dict.items():
                if k == "depth":
                    v = np.array(v)[:, :, None] # From (H, W) -> (H, W, 1)
                tmp_dict[k] = th.Tensor(v)[None, :].to(device)
            
            obs_dict_list__th.append(tmp_dict)
        
        return obs_dict_list__th, done_list__th, target_scene_idx_list, target_category_idx_list
        
    return obs_dict_list, done_list, target_scene_idx_list, target_category_idx_list

def get_traj_data_by_scene_category(trajs_dicts, scene, category, tensorize=False, device="cpu"):
    # Get all trajectories for a specific category and scene
    obs_dict_list, done_list = [], []
    target_scene_idx_list, target_category_idx_list = [], []

    for i in range(len(trajs_dicts[scene][category])):
        traj_obs_dict_list, traj_done_list, target_scene_idxes, target_category_idxes = \
            get_traj_data_by_scene_category_trajIdx(trajs_dicts, scene, category, i, tensorize=tensorize, device=device)

        obs_dict_list.extend(traj_obs_dict_list)
        done_list.extend(traj_done_list)
        target_scene_idx_list.extend(target_scene_idxes)
        target_category_idx_list.extend(target_category_idxes)

        traj_length = len(traj_done_list)
        # print(f"Selected traj of length: {traj_length}")

    return obs_dict_list, done_list, target_scene_idx_list, target_category_idx_list

def get_traj_data_by_scene(trajs_dicts, scene, tensorize=False, device="cpu"):
    # Get all trajectories for a specific category, across all scenes and all trajectories
    obs_dict_list, done_list =[], []
    target_scene_idx_list, target_category_idx_list = [], []
    
    for cat in trajs_dicts[scene].keys():
        cat_obs_dict_list, cat_done_list, target_scene_idxes, target_category_idxes = \
            get_traj_data_by_scene_category(trajs_dicts, scene, cat, tensorize=tensorize, device=device)

        obs_dict_list.extend(cat_obs_dict_list)
        done_list.extend(cat_done_list)
        target_scene_idx_list.extend(target_scene_idxes)
        target_category_idx_list.extend(target_category_idxes)
    
    return obs_dict_list, done_list, target_scene_idx_list, target_category_idx_list

def get_all_traj_data_by_scene(trajs_dicts, tensorize=False, device="cpu"):
    # Get all trajectories for a specific category, across all scenes and all trajectories
    obs_dict_list, done_list =[], []
    target_scene_idx_list, target_category_idx_list = [], []

    for scene in trajs_dicts.keys():
        # Too rushing / lazy to change the names of the temporary list of obs
        cat_scene_obs_dict_list, cat_scene_done_list, cat_target_scene_idxes, cat_target_category_idxes = \
            get_traj_data_by_category(trajs_dicts, scene, tensorize=tensorize, device=device)

        obs_dict_list.extend(cat_scene_obs_dict_list)
        done_list.extend(cat_scene_done_list)
        target_scene_idx_list.extend(cat_target_scene_idxes)
        target_category_idx_list.extend(cat_target_category_idxes)
    
    return obs_dict_list, done_list, target_scene_idx_list, target_category_idx_list
# endregion: Scenes -> Categories

In [15]:
# Loading pretrained agent
MODEL_VARIANTS_TO_STATEDICT_PATH = {
    ## GRU
    # region: SAVI BC GRUv3 variants: rec enc gw3 detach
    "ppo_bc__sweep_gru_512": {
        "pretty_name": "GRU 1",
        "state_dict_path": "/home/rousslan/random/rl/exp-logs/ss-hab-bc-revised-sweep/"
            "ppo_bc_seed_42__2024_02_05_18_30_00_569723.musashi"
            # "/models/ppo_agent.19995001.ckpt.pth",
            "/models/ppo_agent.10000500.ckpt.pth",
        # TODO: prending probes
        "probe_path": "/home/rousslan/random/rl/exp-logs/ss-hab-bc-probing/"
            "ppo_bc__savi_ss1_rgb_spectro__gruv3__gw_detach__usenull__grulynrm__entcoef_0.2__no_cew__n_mb_50__prb_dpth_2_seed_111__2023_11_16_16_08_52_068321.conan"
    },
    # endregion: SAVI BC GRUv3 variants: rec enc gw3 detach

    ## GWTv3 H=512
    # region: SAVI BC GWTv3 variants: rec enc gw3 detach; CA uses null
    "ppo_bc__sweep_gw_64": {
        "pretty_name": "GW | 1",
        "state_dict_path": "/home/rousslan/random/rl/exp-logs/ss-hab-bc-revised-sweep/"
            "ppo_bc_seed_42__2024_01_23_15_44_57_777702.musashi"
            "/models/ppo_agent.10000500.ckpt.pth",
        # TODO: prending probes
        "probe_path": "/home/rousslan/random/rl/exp-logs/ss-hab-bc-probing/"
            "ppo_bc__savi_ss1_rgb_spectro__gwtv3__gw_detach__usenull__grulynrm__entcoef_0.2__no_cew__n_mb_50__prb_dpth_2_seed_111__2023_11_14_18_38_49_687853.musashi"
    },
    # endregion: SAVI BC GWTv3 variants: rec enc gw3 detach; CA uses null
}

# dev = th.device("cpu")
dev = th.device("cuda")

# 'variant named' indexed 'torch agent'
MODEL_VARIANTS_TO_AGENTMODEL = {}

for k, v in MODEL_VARIANTS_TO_STATEDICT_PATH.items():
    args_copy = copy.copy(args)
    # Override args depending on the model in use
    if k.__contains__("gru"):
        print(f"Loaded GRU Agent: '{k}'")
        tmp_args = copy.copy(args)
        # TODO: find a better way, swap hidden size and agent type checking script
        if k.__contains__("64"):
          tmp_args.gw_size = 64
        elif k.__contains__("512"):
          tmp_args.gw_size = 512

        print(f"  GW size: {tmp_args.hidden_size}")

        agent = GRU_Actor(single_observation_space, single_action_space, tmp_args,
            analysis_layers=models.GWTAGENT_DEFAULT_ANALYSIS_LAYER_NAMES)
        # print(agent)
    elif k.__contains__("gw"):
        print(f"Loaded GW Agent: '{k}'")
        tmp_args = copy.copy(args)
        # TODO: find a better way:
        if k.__contains__("64"):
          tmp_args.gw_size = 64
        elif k.__contains__("512"):
          tmp_args.gw_size = 512
        print(f"  GW size: {tmp_args.hidden_size}")

        agent = GW_Actor(single_observation_space, single_action_space, tmp_args,
            analysis_layers=models.GWTAGENT_DEFAULT_ANALYSIS_LAYER_NAMES + ["state_encoder.ca.mha"])
        # print(agent)

    agent.eval()
    # Load the model weights
    # TODO: add map location device to use CPU only ?
    if v["state_dict_path"] != "":
        agent_state_dict = th.load(v["state_dict_path"], map_location=dev)
        agent.load_state_dict(agent_state_dict)
    agent = agent.to(dev)

    MODEL_VARIANTS_TO_AGENTMODEL[k] = agent

Loaded GRU Agent: 'ppo_bc__sweep_gru_512'
  GW size: 512
Loaded GW Agent: 'ppo_bc__sweep_gw_64'
  GW size: 512
