In [1]:
import unittest

import torch as th
import numpy as np

from active_critic.model_src.whole_sequence_model import WholeSequenceModel, OptimizeMaximumCritic, OptimizeEndCritic
from active_critic.policy.active_critic_policy import ActiveCriticPolicy, ACPOptEnd
from active_critic.utils.test_utils import (make_acps, make_obs_act_space,
                                            make_wsm_setup)
from active_critic.utils.gym_utils import (DummyExtractor, make_dummy_vec_env,
                                           new_epoch_pap,
                                           new_epoch_reach)

from active_critic.utils.gym_utils import make_policy_dict, new_epoch_reach, make_dummy_vec_env, sample_expert_transitions, parse_sampled_transitions

  from .autonotebook import tqdm as notebook_tqdm
  logger.warn(


In [2]:
def setup_ac():
        seq_len = 50
        d_output = 2
        obs_dim = 3
        batch_size = 2
        wsm_actor_setup = make_wsm_setup(
            seq_len=seq_len, d_output=d_output)
        wsm_critic_setup = make_wsm_setup(
            seq_len=seq_len, d_output=1)
        acps = make_acps(
            seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_pap)
        obs_space, acts_space = make_obs_act_space(
            obs_dim=obs_dim, action_dim=d_output)
        actor = WholeSequenceModel(wsm_actor_setup)
        critic = WholeSequenceModel(wsm_critic_setup)
        ac = ActiveCriticPolicy(observation_space=obs_space, action_space=acts_space,
                                actor=actor, critic=critic, acps=acps)
        return ac, acps, d_output, obs_dim, batch_size

def setup_opt_max():
        seq_len = 50
        d_output = 2
        obs_dim = 3
        batch_size = 2
        wsm_actor_setup = make_wsm_setup(
            seq_len=seq_len, d_output=d_output)
        wsm_critic_setup = make_wsm_setup(
            seq_len=seq_len, d_output=1)
        acps = make_acps(
            seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_pap)
        acps.opt_steps = 2
        obs_space, acts_space = make_obs_act_space(
            obs_dim=obs_dim, action_dim=d_output)
        actor = WholeSequenceModel(wsm_actor_setup)
        critic = OptimizeMaximumCritic(wsms=wsm_critic_setup)
        ac = ActiveCriticPolicy(observation_space=obs_space, action_space=acts_space,
                                actor=actor, critic=critic, acps=acps)
        return ac, acps, d_output, obs_dim, batch_size

def setup_opt_end():
        seq_len = 50
        d_output = 2
        obs_dim = 3
        batch_size = 2
        wsm_actor_setup = make_wsm_setup(
            seq_len=seq_len, d_output=d_output)
        wsm_critic_setup = make_wsm_setup(
            seq_len=seq_len, d_output=1)
        acps = make_acps(
            seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_pap)
        acps.opt_steps = 2
        obs_space, acts_space = make_obs_act_space(
            obs_dim=obs_dim, action_dim=d_output)
        actor = WholeSequenceModel(wsm_actor_setup)
        critic = OptimizeEndCritic(wsms=wsm_critic_setup)
        ac = ACPOptEnd(observation_space=obs_space, action_space=acts_space,
                                actor=actor, critic=critic, acps=acps)
        return ac, acps, d_output, obs_dim, batch_size


def setup_ac_reach():
        seq_len = 50
        env, gt_policy = make_dummy_vec_env('reach', seq_len=seq_len)
        d_output = env.action_space.shape[0]
        wsm_actor_setup = make_wsm_setup(
            seq_len=seq_len, d_output=d_output)
        wsm_critic_setup = make_wsm_setup(
            seq_len=seq_len, d_output=1)
        acps = make_acps(
            seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_reach)
        actor = WholeSequenceModel(wsm_actor_setup)
        critic = WholeSequenceModel(wsm_critic_setup)
        ac = ActiveCriticPolicy(observation_space=env.observation_space, action_space=env.action_space,
                                actor=actor, critic=critic, acps=acps)
        return ac, acps, env, gt_policy

In [3]:
th.manual_seed(3)
current_step = 30

ac, acps, act_dim, obs_dim, batch_size =  setup_ac()
ac.args_obj.opt_steps = 2

opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
obs_seq[0] *= 2
critic_input = ac.get_critic_input(acts=opt_actions, obs_seq=obs_seq)
critic_scores = ac.critic.forward(critic_input)
assert th.all(critic_scores < 1)
actions, expected_success = ac.optimize_act_sequence(
    actions=opt_actions, observations=obs_seq, current_step=current_step, stop_opt=True)

  logger.warn(


In [4]:
end_wo = expected_success[:,30:].mean()
all_wo = expected_success.mean()

In [5]:
end_wo

tensor(0.1856, device='cuda:0', grad_fn=<MeanBackward0>)

In [6]:
all_wo

tensor(0.0966, device='cuda:0', grad_fn=<MeanBackward0>)

In [7]:
th.manual_seed(3)
current_step = 30

ac, acps, act_dim, obs_dim, batch_size =  setup_opt_end()
ac.args_obj.opt_steps = 2

opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
obs_seq[0] *= 2
critic_input = ac.get_critic_input(acts=opt_actions, obs_seq=obs_seq)
critic_scores = ac.critic.forward(critic_input)
assert th.all(critic_scores < 1)
actions, expected_success = ac.optimize_act_sequence(
    actions=opt_actions, observations=obs_seq, current_step=current_step, stop_opt=True)

  logger.warn(


In [8]:
end_w = expected_success[:,30:].mean()
all_w = expected_success.mean()

In [9]:
end_w

tensor(0.1861, device='cuda:0', grad_fn=<MeanBackward0>)

In [10]:
all_w

tensor(0.0964, device='cuda:0', grad_fn=<MeanBackward0>)

In [14]:
end_wo

tensor(0.1856, device='cuda:0', grad_fn=<MeanBackward0>)

In [28]:
expected_success.max(dim=1)

torch.return_types.max(
values=tensor([[0.2679],
        [0.2087]], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([[20],
        [20]], device='cuda:0'))

In [29]:
th.manual_seed(3)
current_step = 1

ac, acps, act_dim, obs_dim, batch_size =  setup_opt_max()
ac.args_obj.opt_steps = 2

opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
obs_seq[0] *= 2
critic_input = ac.get_critic_input(acts=opt_actions, obs_seq=obs_seq)
assert th.all(critic_scores < 1)
critic_scores = ac.critic.forward(critic_input)
actions, expected_success = ac.optimize_act_sequence(
    actions=opt_actions, observations=obs_seq, current_step=current_step, stop_opt=True)

In [30]:
expected_success.max(dim=1)

torch.return_types.max(
values=tensor([[0.2953],
        [0.2442]], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([[20],
        [26]], device='cuda:0'))

In [18]:
th.manual_seed(0)
ac, acps, env, _ = setup_ac_reach()
obsv = env.reset()
all_taken_actions = []
all_observations = [obsv]
all_scores_after = []
all_scores_before = []
ac.reset()
epsiodes = 2

for i in range(epsiodes*ac.args_obj.epoch_len):
    action = ac.predict(obsv)
    all_taken_actions.append(action)
    obsv, rew, dones, info = env.step(action)
    all_observations.append(obsv)
    all_scores_after.append(ac.current_result.expected_succes_after)
    all_scores_before.append(ac.current_result.expected_succes_before)
    if (i+1) % 5 == 0:
        all_observations = [obsv]

all_scores_after_th = th.tensor(np.array([s.detach().cpu().numpy() for s in all_scores_after]).reshape(
    [epsiodes, ac.args_obj.epoch_len, ac.args_obj.epoch_len, 1]), device=ac.args_obj.device)
all_scores_before_th = th.tensor(np.array([s.detach().cpu().numpy() for s in all_scores_before]).reshape(
    [epsiodes, ac.args_obj.epoch_len, ac.args_obj.epoch_len, 1]), device=ac.args_obj.device)


  logger.warn(


In [12]:
th.manual_seed(0)

current_step = 1
ac, acps, act_dim, obs_dim, batch_size =  setup_ac()

org_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
org_obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                            device=acps.device, dtype=th.float, requires_grad=False)

actions, expected_success = ac.optimize_act_sequence(
    actions=opt_actions, observations=obs_seq, current_step=current_step, stop_opt=False, opt_end=False)
assert th.equal(
org_actions[:, :current_step], actions[:, :current_step]), 'org_actions were overwritten'

assert not th.equal(actions[:, current_step:], org_actions[:,
                                                                current_step:]),'seq optimisation did not change the actions'
assert th.all(
expected_success[:, -1] >= ac.args_obj.optimisation_threshold), 'optimisation does not work.'
assert th.equal(obs_seq, org_obs_seq), 'Observations were changed.'

In [16]:
ac.history.opt_scores

[]

In [4]:
th.manual_seed(1)
current_step = 1

ac, acps, act_dim, obs_dim, batch_size =  setup_ac()
opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
obs_seq[0] *= 2
critic_input = ac.get_critic_input(acts=opt_actions, obs_seq=obs_seq)
critic_scores = ac.critic.forward(critic_input)
actions, expected_success_nonstop = ac.optimize_act_sequence(
    actions=opt_actions, observations=obs_seq, current_step=current_step, stop_opt=False, opt_end=False)

th.manual_seed(1)
ac, acps, act_dim, obs_dim, batch_size =  setup_ac()
opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
obs_seq[0] *= 2
critic_input = ac.get_critic_input(acts=opt_actions, obs_seq=obs_seq)
critic_scores = ac.critic.forward(critic_input)
actions, expected_success = ac.optimize_act_sequence(
    actions=opt_actions, observations=obs_seq, current_step=current_step, stop_opt=True, opt_end=False)

assert (expected_success[:,-1] < expected_success_nonstop[:,-1]).sum() > 0
assert (expected_success[:,-1] < expected_success_nonstop[:,-1]).sum() < len(expected_success)

In [3]:
th.manual_seed(1)
current_step = 1

ac, acps, act_dim, obs_dim, batch_size =  setup_ac()
opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
obs_seq[0] *= 2
critic_input = ac.get_critic_input(acts=opt_actions, obs_seq=obs_seq)
critic_scores = ac.critic.forward(critic_input)
actions, expected_success_nonstop = ac.optimize_act_sequence(
    actions=opt_actions, observations=obs_seq, current_step=current_step, stop_opt=False, opt_end=False)

th.manual_seed(1)
ac, acps, act_dim, obs_dim, batch_size =  setup_ac()
opt_actions = th.zeros([batch_size, acps.epoch_len, act_dim],
                        device=acps.device, dtype=th.float, requires_grad=True)
obs_seq = 2 * th.ones([batch_size, current_step + 1, obs_dim],
                        device=acps.device, dtype=th.float, requires_grad=False)
obs_seq[0] *= 2
critic_input = ac.get_critic_input(acts=opt_actions, obs_seq=obs_seq)
critic_scores = ac.critic.forward(critic_input)
actions, expected_success = ac.optimize_act_sequence(
    actions=opt_actions, observations=obs_seq, current_step=current_step, stop_opt=False, opt_end=True)

maxim, _ = th.max(actions, dim=1)
minim, _ = th.min(actions, dim=1)
assert th.all(maxim <= th.tensor(ac.action_space.high, device=max.device))
assert th.all(minim >= th.tensor(ac.action_space.low, device=max.device))

  logger.warn(


In [12]:
env.action_space

NameError: name 'env' is not defined