In [None]:
import gym
import numpy as np
import torch as th
from ActiveCritic.metaworld.metaworld.envs import \
    ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
from ActiveCritic.model_src.transformer import (ModelSetup,
                                                TransformerModel)
from ActiveCritic.model_src.whole_sequence_model import (WholeSequenceModelSetup)
from ActiveCritic.policy.active_critic_policy import (ACPOptResult,
                                                      ActiveCriticPolicy,
                                                      ActiveCriticPolicySetup)
from ActiveCritic.tests.test_utils.utils import make_wsm_setup
from ActiveCritic.utils.gym_utils import (DummyExtractor, make_policy_dict,
                                          new_epoch_reach)
from ActiveCritic.utils.pytorch_utils import make_partially_observed_seq
from ActiveCritic.utils.gym_utils import make_dummy_vec_env

from gym.wrappers import TimeLimit
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv

th.manual_seed(0)

In [None]:
def make_obs_act_space(obs_dim, action_dim):
    obs_array_low = [0]*obs_dim
    obs_array_high = [1]*obs_dim
    action_low = [0]*action_dim
    action_high = [1]*action_dim
    observation_space = gym.spaces.box.Box(
        np.array(obs_array_low), np.array(obs_array_high), (obs_dim,), float)
    action_space = gym.spaces.box.Box(
        np.array(action_low), np.array(action_high), (action_dim,), float)
    return observation_space, action_space

In [None]:
from ActiveCritic.model_src.whole_sequence_model import WholeSequenceModel


seq_len = 20
d_output = 2
d_result = 1

wsm_actor_setup = make_wsm_setup(seq_len=seq_len, d_output=d_output)
wsm_critic_setup = make_wsm_setup(seq_len=seq_len, d_output=1)
actor = WholeSequenceModel(wsm_actor_setup)
critic = WholeSequenceModel(wsm_critic_setup)

In [None]:
obs_dim = 4
act_dim = 2

In [None]:
obs, acts = make_obs_act_space(obs_dim, act_dim)

In [None]:
def make_acps(seq_len, extractor, new_epoch):
    acps = ActiveCriticPolicySetup()
    acps.device='cuda'
    acps.epoch_len=seq_len
    acps.extractor=extractor
    acps.new_epoch=new_epoch
    acps.opt_steps=100
    acps.optimisation_threshold=0.5
    acps.inference_opt_lr = 1e-1
    acps.optimize = True
    return acps

acps = make_acps(seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_reach)

In [None]:
ac = ActiveCriticPolicy(observation_space=obs, action_space=acts, actor=actor, critic=critic, acps=acps)

In [None]:
current_step = 2
batch_size = 2
org_actions = th.ones([batch_size,acps.epoch_len,act_dim], device=acps.device, dtype=th.float, requires_grad=True)
opt_actions = th.zeros([batch_size,acps.epoch_len,act_dim], device=acps.device, dtype=th.float, requires_grad=True)

pro_opt_actions = ac.proj_actions(org_actions, opt_actions, current_step=current_step)

assert th.equal(org_actions[:,:current_step], pro_opt_actions[:, :current_step])
assert th.equal(opt_actions[:,current_step:], pro_opt_actions[:, current_step:])



In [None]:
current_step = 3
org_actions = th.ones([batch_size,acps.epoch_len,act_dim], device=acps.device, dtype=th.float, requires_grad=False)
opt_actions = th.zeros([batch_size,acps.epoch_len,act_dim], device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2*th.ones([batch_size,current_step+1,obs_dim], device=acps.device, dtype=th.float, requires_grad=False)
org_obs_seq = 2*th.ones([batch_size,current_step+1,obs_dim], device=acps.device, dtype=th.float, requires_grad=False)
optimizer = th.optim.Adam([opt_actions], lr=1e-1)
goal_label = th.ones([batch_size, seq_len, 1], device=acps.device, dtype=th.float)

actions, critic_result = ac.inference_opt_step(org_actions=org_actions, opt_actions=opt_actions, obs_seq=obs_seq, optimizer=optimizer, goal_label=goal_label, current_step=current_step)

last_critic_result = th.clone(critic_result.detach())

In [None]:
for i in range(3):
    actions, critic_result = ac.inference_opt_step(org_actions=org_actions, opt_actions=opt_actions, obs_seq=obs_seq, optimizer=optimizer, goal_label=goal_label, current_step=current_step)
    print(critic_result)
    assert th.equal(org_actions[:,:current_step], actions[:, :current_step]), 'org_actions were overwritten'
    assert not th.equal(opt_actions[:,current_step:], org_actions[:, current_step:])
    assert th.all(critic_result > last_critic_result), 'optimisation does not work.'
    last_critic_result = th.clone(critic_result.detach())

In [None]:
last_critic_result.shape

In [None]:
actions

In [None]:
actions, expected_success = ac.optimize_act_sequence(actions=actions, observations=obs_seq, current_step=current_step)
assert th.equal(org_actions[:,:current_step], actions[:, :current_step]), 'org_actions were overwritten'
assert not th.equal(opt_actions[:,current_step:], org_actions[:, current_step:])
assert th.all(expected_success >= ac.args_obj.optimisation_threshold), 'optimisation does not work.'
assert th.equal(obs_seq, org_obs_seq)


In [None]:
from ActiveCritic.tests.test_utils.utils import make_wsm_setup, make_obs_act_space, make_acps
from ActiveCritic.utils.gym_utils import new_epoch_pap, DummyExtractor

import gym
import numpy as np
import torch as th
from ActiveCritic.metaworld.metaworld.envs import \
    ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
from ActiveCritic.model_src.transformer import (ModelSetup,
                                                TransformerModel)
from ActiveCritic.model_src.whole_sequence_model import (WholeSequenceModelSetup, WholeSequenceModel)
from ActiveCritic.policy.active_critic_policy import (ACPOptResult,
                                                      ActiveCriticPolicy,
                                                      ActiveCriticPolicySetup)
from ActiveCritic.tests.test_utils.utils import make_wsm_setup
from ActiveCritic.utils.gym_utils import (DummyExtractor, make_policy_dict,
                                          new_epoch_reach)
from ActiveCritic.utils.pytorch_utils import make_partially_observed_seq
from ActiveCritic.utils.gym_utils import make_dummy_vec_env

from gym.wrappers import TimeLimit
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv

th.manual_seed(0)
def setup_ac():
    seq_len = 4
    d_output = 2
    d_result = 1
    obs_dim = 3
    batch_size = 2
    wsm_actor_setup = make_wsm_setup(
        seq_len=seq_len, d_output=d_output)
    wsm_critic_setup = make_wsm_setup(
        seq_len=seq_len, d_output=1)
    acps = make_acps(
        seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_pap)
    obs_space, acts_space = make_obs_act_space(
        obs_dim=obs_dim, action_dim=d_output)
    actor = WholeSequenceModel(wsm_actor_setup)
    critic = WholeSequenceModel(wsm_critic_setup)
    ac = ActiveCriticPolicy(observation_space=obs_space, action_space=acts_space,
                            actor=actor, critic=critic, acps=acps)
    return ac, acps, d_output, obs_dim, batch_size

In [None]:
th.manual_seed(0)

current_step = 1
ac, acps, act_dim, obs_dim, batch_size = setup_ac()


org_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2*th.ones([batch_size, acps.epoch_len, obs_dim],
                    device=acps.device, dtype=th.float, requires_grad=False)
org_obs_seq = 2*th.ones([batch_size, acps.epoch_len, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)

In [None]:
def setup_ac_reach():
    seq_len = 5
    env, gt_policy = make_dummy_vec_env('reach', seq_len=seq_len)
    d_result = 1
    d_output = env.action_space.shape[0]
    wsm_actor_setup = make_wsm_setup(
        seq_len=seq_len, d_output=d_output)
    wsm_critic_setup = make_wsm_setup(
        seq_len=seq_len, d_output=1)
    acps = make_acps(
        seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_reach)
    actor = WholeSequenceModel(wsm_actor_setup)
    critic = WholeSequenceModel(wsm_critic_setup)
    ac = ActiveCriticPolicy(observation_space=env.observation_space, action_space=env.action_space,
                            actor=actor, critic=critic, acps=acps)
    return ac, acps, env

In [None]:
th.manual_seed(0)
ac, acps, env = setup_ac_reach()

In [None]:
obsv = env.reset()
last_obsv = th.tensor(obsv)
all_taken_actions = []
all_observations = [obsv]
for i in range(5):
    action = ac.predict(obsv)
    all_taken_actions.append(action)
    obsv, rew, dones, info = env.step(action)
    all_observations.append(obsv)
    assert len(th.nonzero(ac.obs_seq[:,ac.current_step+1:])) == 0
    if new_epoch_reach(last_obsv, th.tensor(obsv)):
        assert ac.current_step == ac.args_obj.epoch_len - 1
        ata = th.tensor(all_taken_actions).transpose(0,1)
        print(th.equal(ata.to('cuda'), ac.current_result.gen_trj))
        aob = th.tensor(all_observations).transpose(0,1)[:,:5]
        print(th.equal(aob.to('cuda'), ac.obs_seq))
        assert ac.current_result.expected_succes_before < ac.current_result.expected_succes_after


In [None]:
aob.shape