In [2]:
import unittest

import torch as th
import numpy as np

from active_critic.model_src.whole_sequence_model import WholeSequenceModel
from active_critic.policy.active_critic_policy import *
from active_critic.utils.test_utils import (make_acps, make_obs_act_space,
                                            make_wsm_setup)
from active_critic.utils.gym_utils import (DummyExtractor, make_dummy_vec_env,
                                           new_epoch_pap,
                                           new_epoch_reach)

from active_critic.utils.gym_utils import make_policy_dict, new_epoch_reach, make_dummy_vec_env, sample_expert_transitions, parse_sampled_transitions
from active_critic.model_src.state_model import StateModel, StateModelArgs
from active_critic.utils.pytorch_utils import build_tf_horizon_mask

  from .autonotebook import tqdm as notebook_tqdm
  logger.warn(


In [3]:
from active_critic.model_src.whole_sequence_model import *


def setup_opt_state(device='cuda'):
    seq_len = 6
    action_dim = 2
    obs_dim = 3
    batch_size = 2
    embed_dim = 4
    lr = 1e-3

    actor_args = StateModelArgs()
    actor_args.arch = [20, action_dim]
    actor_args.device = device
    actor_args.lr = lr
    actor = StateModel(args=actor_args)

    critic_args = StateModelArgs()
    critic_args.arch = [10, 1]
    critic_args.device = device
    critic_args.lr = lr
    critic = StateModel(args=critic_args)

    emitter_args = StateModelArgs()
    emitter_args.arch = [20, embed_dim]
    emitter_args.device = device
    emitter_args.lr = lr
    emitter = StateModel(args=emitter_args)

    predictor_args = make_wsm_setup(
        seq_len=seq_len, d_output=embed_dim, device=device)
    predictor_args.model_setup.d_hid = 200
    predictor_args.model_setup.d_model = 200
    predictor_args.model_setup.nlayers = 1
    predictor = WholeSequenceModel(args=predictor_args)


    acps = make_acps(
        seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_pap, device=device)
    acps.opt_steps = 2
    obs_space, acts_space = make_obs_act_space(
        obs_dim=obs_dim, action_dim=action_dim)
    ac = ActiveCriticPolicy(observation_space=obs_space, 
                            action_space=acts_space,
                            actor=actor,
                            critic=critic,
                            predictor=predictor,
                            emitter=emitter,
                            acps=acps)
    return ac, acps, action_dim, obs_dim, batch_size, embed_dim, seq_len

In [4]:
th.manual_seed(0)
device = 'cuda'
ac, acps, action_dim, obs_dim, batch_size, embed_dim, seq_len = setup_opt_state(device=device)
horizon = 0

  logger.warn(


In [None]:
embeddings = th.ones([batch_size, 1, embed_dim], device=device)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True)
goal_embeddings = th.ones_like(embeddings)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=True)
opt_actions = actions.detach()
opt_actions.requires_grad = True
next_embedding = ac.predict_step(embeddings=seq_embeddings.detach(), actions=opt_actions, mask=mask)
seq_embedding = th.cat((embeddings[:,:1], next_embedding[:,:-1]), dim=1)
loss = ((seq_embedding - th.ones_like(seq_embedding))**2).mean()
loss.backward()
assert (opt_actions.grad != 0).sum() == opt_actions[:,:-1].numel()


In [None]:
embeddings = th.ones([batch_size, 1, embed_dim], device=device)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True)
goal_embeddings = th.ones_like(embeddings)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=True)
opt_actions = actions.detach()
opt_actions.requires_grad = True
next_embedding = ac.predict_step(embeddings=seq_embeddings.detach(), actions=opt_actions, mask=mask)
seq_embedding = th.cat((embeddings[:,:1], next_embedding[:,:-1]), dim=1)
loss = ((seq_embedding[:,-1] - th.ones_like(seq_embedding[:,-1] ))**2).mean()
loss.backward()
assert (opt_actions.grad != 0).sum() == (min(seq_len-1, 1+horizon) * action_dim*batch_size)

In [None]:
horizon = 1
embeddings = th.ones([batch_size, 1, embed_dim], device=device)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True)
goal_embeddings = th.ones_like(embeddings)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=True)
opt_actions = actions.detach()
opt_actions.requires_grad = True
next_embedding = ac.predict_step(embeddings=seq_embeddings.detach(), actions=opt_actions, mask=mask)
seq_embedding = th.cat((embeddings[:,:1], next_embedding[:,:-1]), dim=1)
loss = ((seq_embedding[:,-1] - th.ones_like(seq_embedding[:,-1] ))**2).mean()
loss.backward()
assert (opt_actions.grad != 0).sum() == (min(seq_len-1, 1+horizon) * action_dim*batch_size)

In [None]:
horizon = seq_len
embeddings = th.ones([batch_size, 1, embed_dim], device=device)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True)
goal_embeddings = th.ones_like(embeddings)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=True)
opt_actions = actions.detach()
opt_actions.requires_grad = True
next_embedding = ac.predict_step(embeddings=seq_embeddings.detach(), actions=opt_actions, mask=mask)
seq_embedding = th.cat((embeddings[:,:1], next_embedding[:,:-1]), dim=1)
loss = ((seq_embedding[:,-1] - th.ones_like(seq_embedding[:,-1] ))**2).mean()
loss.backward()
assert (opt_actions.grad != 0).sum() == (min(seq_len-1, 1+horizon) * action_dim*batch_size)

In [None]:
horizon = seq_len
embeddings = th.ones([batch_size, 1, embed_dim], device=device)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True)
goal_embeddings = th.ones_like(embeddings)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=True)
opt_actions = actions.detach()
opt_actions.requires_grad = True
next_embedding = ac.predict_step(embeddings=seq_embeddings.detach(), actions=opt_actions, mask=mask)
seq_embedding = th.cat((embeddings[:,:1], next_embedding[:,:-1]), dim=1)
loss = ((seq_embedding[:,-1] - th.ones_like(seq_embedding[:,-1] ))**2).mean()
loss.backward()
assert (opt_actions.grad != 0).sum() == (min(seq_len-1, 1+horizon) * action_dim*batch_size)

In [None]:
horizon = seq_len
embeddings = th.ones([batch_size, 1, embed_dim], device=device)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True)
goal_embeddings = th.ones_like(embeddings)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=True)
opt_actions = actions.detach()
opt_actions.requires_grad = True
next_embedding = ac.predict_step(embeddings=seq_embeddings.detach(), actions=opt_actions, mask=mask)
seq_embedding = th.cat((embeddings[:,:1], next_embedding[:,:-1]), dim=1)
loss = ((seq_embedding[:,-1] - th.ones_like(seq_embedding[:,-1] ))**2).mean()
loss.backward()
assert (opt_actions.grad != 0).sum() == (min(seq_len-1, 1+horizon) * action_dim*batch_size)

In [None]:
horizon = 0
embeddings = th.ones([batch_size, 1, embed_dim], device=device, requires_grad=True)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=False)
loss = ((seq_embeddings[:,-1] - th.ones_like(seq_embeddings[:,-1] ))**2).mean()
loss.backward()
assert (actions.grad!=0).sum() == actions.numel() - action_dim*batch_size

In [None]:
horizon = 0
embeddings = th.ones([batch_size, 1, embed_dim], device=device, requires_grad=True)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=True)
opt_actions = actions.detach()
opt_actions.requires_grad = True
next_embedding = ac.predict_step(embeddings=seq_embeddings.detach(), actions=opt_actions, mask=mask)
seq_embedding = th.cat((embeddings[:,:1], next_embedding[:,:-1]), dim=1)
loss = ((seq_embedding[:,-1] - th.ones_like(seq_embedding[:,-1] ))**2).mean()
loss.backward()
assert (opt_actions.grad != 0).sum() == batch_size*action_dim

In [70]:
ac = ac.double()
seq_len = 8
horizon = 0
embeddings = th.ones([batch_size, 1, embed_dim], device=device, requires_grad=True, dtype=th.double)
org_embeddings = th.ones([batch_size, 1, embed_dim], device=device, requires_grad=True, dtype=th.double)
actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True, dtype=th.double)
org_actions = th.ones([batch_size, seq_len, action_dim], device=device, requires_grad=True, dtype=th.double)
action_optim = th.optim.Adam([actions], lr=1e-1)
opt_paras = action_optim.state_dict()

mask = build_tf_horizon_mask(seq_len=seq_len, horizon=0, device=device).double()
seq_embeddings = ac.build_sequence(embeddings=embeddings, actions=actions, seq_len=seq_len, mask=mask, detach=True)

In [200]:
goal_embeddings = th.ones_like(seq_embeddings)

def optimize_sequence(actions:th.Tensor, seq_embeddings:th.Tensor, mask:th.Tensor, goal_embeddings:th.Tensor, steps:int, detach:bool, current_step:int, reward_weight:float = 0):
    th.manual_seed(0)
    actions = actions.detach().clone()
    actions.requires_grad = True
    seq_embeddings = seq_embeddings.detach().clone()
    seq_embeddings.requires_grad = True
    optimizer = th.optim.Adam([actions, seq_embeddings], lr=1e-2)
    opt_paras = optimizer.state_dict()
    for i in range(steps):
        actions = actions.detach().clone()
        actions.requires_grad = True
        seq_embeddings = seq_embeddings.detach().clone()
        seq_embeddings.requires_grad = True

        last_embeddings = seq_embeddings.detach().clone()

        optimizer = th.optim.Adam([actions, seq_embeddings], lr=1e-2)


        if detach:
            next_seq_embeddings = ac.predict_step(embeddings=seq_embeddings, actions=actions, mask=mask)
            loss_embedding = calcMSE(next_seq_embeddings[:,:-1], goal_embeddings[:,1:])

        else:
            seq_embeddings = ac.build_sequence(embeddings=seq_embedding.detach()[:,:current_step], actions=actions, seq_len=actions.shape[1], mask=mask, detach=False)
            loss_embedding = 0


        optimizer.load_state_dict(opt_paras)

        loss_reward = calcMSE(seq_embeddings[:,-1], goal_embeddings[:,1])
        loss = loss_embedding + loss_reward * reward_weight
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        opt_paras = optimizer.state_dict()

    return loss_reward, loss_embedding, actions, seq_embeddings


In [162]:
loss_reward, loss_embedding, na, ns = optimize_sequence(actions=actions, seq_embeddings=seq_embeddings, mask=mask, goal_embeddings=goal_embeddings, steps=1, detach=True, current_step=1, reward_weight=1)
assert loss_embedding == 0
assert th.equal(org_actions, actions)
assert th.equal(embeddings, org_embeddings)

In [207]:
loss_reward, loss_embedding, na, ns = optimize_sequence(actions=actions, seq_embeddings=seq_embeddings, mask=mask, goal_embeddings=goal_embeddings, steps=5000, detach=True, current_step=1, reward_weight=1)

In [202]:
#steps: 1
print(loss_reward)
print(loss_embedding)
print(na)

0
tensor(3.1139, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor([[[0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [1.0000, 1.0000]],

        [[0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [0.9900, 0.9900],
         [1.0000, 1.0000]]], device='cuda:0', dtype=torch.float64,
       requires_grad=True)


In [204]:
#steps: 10
print(loss_reward)
print(loss_embedding)
print(na)

0
tensor(2.8329, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor([[[0.8998, 0.8998],
         [0.9000, 0.8999],
         [0.8995, 0.9000],
         [0.8998, 0.8997],
         [0.9000, 0.8996],
         [0.8996, 0.8998],
         [0.9002, 0.8994],
         [1.0000, 1.0000]],

        [[0.8998, 0.8998],
         [0.9000, 0.8999],
         [0.8995, 0.9000],
         [0.8998, 0.8997],
         [0.9000, 0.8996],
         [0.8996, 0.8998],
         [0.9002, 0.8994],
         [1.0000, 1.0000]]], device='cuda:0', dtype=torch.float64,
       requires_grad=True)


In [206]:
#steps: 1000
print(loss_reward)
print(loss_embedding)
print(na)

0
tensor(1.1696, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor([[[-2.9414e-01,  2.5817e-03],
         [-3.0123e-01, -4.9969e-04],
         [-2.9671e-01, -9.1164e-03],
         [-2.8929e-01, -2.4488e-02],
         [-2.6265e-01, -4.3764e-02],
         [-2.3138e-01,  4.8827e-03],
         [-2.3820e-01,  2.1441e-02],
         [ 1.0000e+00,  1.0000e+00]],

        [[-2.9414e-01,  2.5817e-03],
         [-3.0123e-01, -4.9969e-04],
         [-2.9671e-01, -9.1164e-03],
         [-2.8929e-01, -2.4488e-02],
         [-2.6265e-01, -4.3764e-02],
         [-2.3138e-01,  4.8827e-03],
         [-2.3820e-01,  2.1441e-02],
         [ 1.0000e+00,  1.0000e+00]]], device='cuda:0', dtype=torch.float64,
       requires_grad=True)


In [208]:
#steps: 5000
print(loss_reward)
print(loss_embedding)
print(na)

0
tensor(1.1696, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor([[[-2.9424e-01,  2.7234e-03],
         [-3.0127e-01,  1.3389e-04],
         [-2.9732e-01, -9.3588e-03],
         [-2.9111e-01, -2.4961e-02],
         [-2.6100e-01, -4.6579e-02],
         [-2.3155e-01,  4.7419e-03],
         [-2.3852e-01,  2.1526e-02],
         [ 1.0000e+00,  1.0000e+00]],

        [[-2.9424e-01,  2.7234e-03],
         [-3.0127e-01,  1.3389e-04],
         [-2.9732e-01, -9.3588e-03],
         [-2.9111e-01, -2.4961e-02],
         [-2.6100e-01, -4.6579e-02],
         [-2.3155e-01,  4.7419e-03],
         [-2.3852e-01,  2.1526e-02],
         [ 1.0000e+00,  1.0000e+00]]], device='cuda:0', dtype=torch.float64,
       requires_grad=True)


In [181]:
#steps: 100
print(loss_reward)
print(loss_embedding)
print(na)

tensor(0.0693, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(0.3587, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor([[[0.3825, 0.8234],
         [0.6761, 1.9818],
         [0.8022, 1.0552],
         [0.5206, 2.0053],
         [0.9305, 0.9750],
         [1.3733, 1.1619],
         [0.1467, 0.6618],
         [1.0000, 1.0000]],

        [[0.3825, 0.8234],
         [0.6761, 1.9818],
         [0.8022, 1.0552],
         [0.5206, 2.0053],
         [0.9305, 0.9750],
         [1.3733, 1.1619],
         [0.1467, 0.6618],
         [1.0000, 1.0000]]], device='cuda:0', dtype=torch.float64,
       requires_grad=True)


In [107]:
#steps: 30
print(loss_reward)
print(loss_embedding)
print(na)

tensor(0.8319, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor(0., device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
tensor([[[1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [0.9900, 0.9900],
         [1.0000, 1.0000]],

        [[1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [0.9900, 0.9900],
         [1.0000, 1.0000]]], device='cuda:0', dtype=torch.float64,
       requires_grad=True)


In [101]:
na

tensor([[[1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [0.9800, 0.9800],
         [1.0000, 1.0000]],

        [[1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [1.0000, 1.0000],
         [0.9800, 0.9800],
         [1.0000, 1.0000]]], device='cuda:0', dtype=torch.float64,
       requires_grad=True)

In [85]:
loss_reward

tensor(0.7650, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

In [83]:
loss_embedding

tensor(0., device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

In [73]:
loss, na, ns = optimize_sequence(actions=actions, seq_embeddings=seq_embeddings, mask=mask, goal_embeddings=goal_embeddings, steps=1, detach=True, current_step=1, score_weight=1)


TypeError: optimize_sequence() got an unexpected keyword argument 'score_weight'

In [65]:
actions

tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]], device='cuda:0', dtype=torch.float64, requires_grad=True)