## Soft-Module performance check

In [2]:
import sys
sys.path.append("..")
import torchrl.policies as policies
import torchrl.networks as networks
from torchrl.utils import get_params
from torchrl.algo import MTSAC
from torchrl.replay_buffers.shared import AsyncSharedReplayBuffer
from torchrl.collector.para.async_mt import AsyncMultiTaskParallelCollectorUniform
import gym
from metaworld_utils.meta_env import get_meta_env, generate_single_mt_env
import numpy as np
import torch
import pprint

In [3]:
SEED = 11
params = get_params("../meta_config/mt10/modular_2_2_2_256_reweight_rand.json")

# make mtenv
env, cls_dicts, cls_args = get_meta_env( params['env_name'], params['env'], params['meta_env'])
tasks = list(cls_dicts.keys())
example_embedding = env.active_task_one_hot
print("Obs_space:" , env.observation_space)
print("Act_space:", env.action_space)

# create policy
device = torch.device("cuda:0")
snapshot_path = "../log/starter/mt10/11/model/model_pf_best.pth"
params['net']['base_type']=networks.MLPBase
pf = policies.ModularGuassianGatedCascadeCondContPolicy(
        input_shape=env.observation_space.shape[0],
        em_input_shape=(np.prod(example_embedding.shape)),
        output_shape=2 * env.action_space.shape[0],
        **params['net'])
pf.load_state_dict(torch.load(snapshot_path, map_location='cpu'))
pf.to(device)
pf.eval()
env.eval()

print("Module:", pf.named_modules)

Obs_space: Box(9,)
Act_space: Box(4,)
Module: <bound method Module.named_modules of ModularGuassianGatedCascadeCondContPolicy(
  (base): MLPBase(
    (fc0): Linear(in_features=9, out_features=400, bias=True)
    (fc1): Linear(in_features=400, out_features=400, bias=True)
  )
  (em_base): MLPBase(
    (fc0): Linear(in_features=10, out_features=400, bias=True)
  )
  (module_0_0): Linear(in_features=400, out_features=256, bias=True)
  (module_0_1): Linear(in_features=400, out_features=256, bias=True)
  (module_1_0): Linear(in_features=256, out_features=256, bias=True)
  (module_1_1): Linear(in_features=256, out_features=256, bias=True)
  (last): Linear(in_features=256, out_features=8, bias=True)
  (gating_fc_0): Linear(in_features=400, out_features=256, bias=True)
  (gating_fc_1): Linear(in_features=256, out_features=256, bias=True)
  (gating_weight_fc_0): Linear(in_features=256, out_features=4, bias=True)
  (gating_weight_cond_last): Linear(in_features=4, out_features=256, bias=True)
  (

In [4]:
#choose a single task to evaluate on.
pprint.pprint(tasks)
env_id = tasks[5]
env_args = {
            "task_cls": cls_dicts[env_id],
            "task_args": cls_args[env_id],
            "env_rank": 5,
            "num_tasks": env.num_tasks,
            "max_obs_dim": np.prod(env.observation_space.shape),
            "env_params": params["env"],
            "meta_env_params": params["meta_env"]
        }
pprint.pprint(env_args)
env = generate_single_mt_env(**env_args)
env.seed(SEED)

# test env
example_ob = env.reset()
print(env.action_space)
print(env.observation_space)
print(example_ob)

['reach-v1',
 'push-v1',
 'pick-place-v1',
 'door-v1',
 'drawer-open-v1',
 'drawer-close-v1',
 'button-press-topdown-v1',
 'ped-insert-side-v1',
 'window-open-v1',
 'window-close-v1']
{'env_params': {'obs_norm': False, 'reward_scale': 1},
 'env_rank': 5,
 'max_obs_dim': 9,
 'meta_env_params': {'obs_type': 'with_goal', 'random_init': True},
 'num_tasks': 10,
 'task_args': {'args': [],
               'kwargs': {'obs_type': 'plain', 'random_init': True}},
 'task_cls': <class 'metaworld.envs.mujoco.sawyer_xyz.sawyer_drawer_close.SawyerDrawerCloseEnv'>}
Box(4,)
Box(6,)
[-0.03265199  0.51487863  0.23688568  0.07903633  0.49999998  0.09
  0.07903633  0.69999998  0.04      ]


In [5]:
#eval for timesteps
eval_ob = env.reset()
rew = 0
done = False
imgs = []
try:
    for i in range(1000):
        embedding_input = torch.zeros(env.num_tasks)
        embedding_input[env_args["env_rank"]] = 1
        embedding_input = embedding_input.unsqueeze(0).to(device)
        act = pf.eval_act( torch.Tensor( eval_ob ).to(device).unsqueeze(0), embedding_input)
        eval_ob, r, done, info = env.step( act )
        rew += r
        imgs.append(env.render('rgb_array'))
        done = info["success"]
        if i % 20 == 0:
            print("moving..", act)
finally:
    env.close()


Creating window glfw
moving.. [ 0.71749854 -0.3827432  -0.64066875 -0.93301755]
moving.. [ 0.24028979 -0.62976956 -0.11336626 -0.8342447 ]
moving.. [ 0.00202824 -0.7172431   0.06591776 -0.72614825]
moving.. [ 0.00358261 -0.71576416  0.04401302 -0.7354229 ]
moving.. [ 0.00078825 -0.71581423  0.03266397 -0.7407523 ]
moving.. [ 3.6201809e-04 -7.1585149e-01  2.3268715e-02 -7.4440801e-01]
moving.. [ 3.5193004e-04 -7.1572053e-01  1.6463466e-02 -7.4701065e-01]
moving.. [ 2.4689920e-04 -7.1565342e-01  1.1746833e-02 -7.4881101e-01]
moving.. [ 1.6042031e-04 -7.1560335e-01  8.3354926e-03 -7.5008130e-01]
moving.. [ 1.1291541e-04 -7.1556449e-01  5.9105488e-03 -7.5097907e-01]
moving.. [ 7.9924241e-05 -7.1553707e-01  4.1921837e-03 -7.5161326e-01]
moving.. [ 5.6521967e-05 -7.1551776e-01  2.9741414e-03 -7.5206184e-01]
moving.. [ 4.1591004e-05 -7.1550214e-01  2.1184701e-03 -7.5237530e-01]
moving.. [ 2.8962269e-05 -7.1549177e-01  1.5100057e-03 -7.5259823e-01]
moving.. [ 2.0148233e-05 -7.1548456e-01  1.07

In [6]:
import imageio
imageio.mimsave("close-soft-module1.gif", [np.array(img) for i, img in enumerate(imgs) if i%2 == 0], duration=200)