In [None]:
import torch as th
from active_critic.learner.active_critic_learner import ActiveCriticLearner, ACLScores
from active_critic.learner.active_critic_args import ActiveCriticLearnerArgs
from active_critic.policy.active_critic_policy import ActiveCriticPolicy
from active_critic.utils.gym_utils import make_vec_env, DummyExtractor, new_epoch_reach
from active_critic.utils.pytorch_utils import build_tf_horizon_mask
from active_critic.utils.dataset import DatasetAC
from active_critic.policy.active_critic_policy import ActiveCriticPolicySetup, ActiveCriticPolicy
from active_critic.model_src.state_model import *
from active_critic.model_src.whole_sequence_model import WholeSequenceModel, WholeSequenceModelArgs
from active_critic.model_src.transformer import ModelSetup


from gym import Env
th.manual_seed(0)

def make_wsm_setup(seq_len, d_output, device='cpu'):
    wsm = WholeSequenceModelArgs()
    wsm.model_setup = ModelSetup()
    seq_len = seq_len
    d_output = d_output
    wsm.model_setup.d_output = d_output
    wsm.model_setup.nhead = 1
    wsm.model_setup.d_hid = 200
    wsm.model_setup.d_model = 200
    wsm.model_setup.nlayers = 3
    wsm.model_setup.seq_len = seq_len
    wsm.model_setup.dropout = 0
    wsm.lr = 5e-4
    wsm.model_setup.device = device
    wsm.optimizer_class = th.optim.Adam
    wsm.optimizer_kwargs = {}
    return wsm

def make_acps(seq_len, extractor, new_epoch, batch_size = 2, device='cpu', horizon = 0):
    acps = ActiveCriticPolicySetup()
    acps.device=device
    acps.epoch_len=seq_len
    acps.extractor=extractor
    acps.new_epoch=new_epoch
    acps.opt_steps=100
    acps.inference_opt_lr = 1e-2
    acps.optimizer_class = th.optim.SGD
    acps.optimize = True
    acps.batch_size = batch_size
    acps.pred_mask = build_tf_horizon_mask(seq_len=seq_len, horizon=seq_len, device=device)
    acps.opt_mask = th.ones([seq_len, 1], device=device, dtype=bool)
    acps.opt_mask[:,-1] = 1
    acps.opt_goal = True
    acps.optimize_goal_emb_acts = False
    acps.goal_label_multiplier = 0
    return acps

def setup_opt_state(batch_size, seq_len, device='cpu'):
    num_cpu = 1
    env, expert = make_vec_env('reach', num_cpu, seq_len=seq_len)
    d_output = env.action_space.shape[0]
    embed_dim = 20
    lr = 5e-4

    actor_args = StateModelArgs()
    actor_args.arch = [200, env.action_space.shape[0]]
    actor_args.device = device
    actor_args.lr = lr
    actor = StateModel(args=actor_args)

    critic_args = StateModelArgs()
    critic_args.arch = [200, 1]
    critic_args.device = device
    critic_args.lr = lr
    critic = StateModel(args=critic_args)

    inv_critic_args = StateModelArgs()
    inv_critic_args.arch = [200, embed_dim + env.action_space.shape[0]]
    inv_critic_args.device = device
    inv_critic_args.lr = lr
    inv_critic = StateModel(args=inv_critic_args)

    emitter_args = StateModelArgs()
    emitter_args.arch = [200, embed_dim]
    emitter_args.device = device
    emitter_args.lr = lr
    emitter = StateModel(args=emitter_args)

    predictor_args = make_wsm_setup(
    seq_len=seq_len, d_output=embed_dim, device=device)
    predictor_args.model_setup.d_hid = 200
    predictor_args.model_setup.d_model = 200
    predictor = WholeSequenceModel(args=predictor_args)


    acps = make_acps(
        seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_reach, device=device, batch_size=batch_size)
    acps.clip = True
    ac = ActiveCriticPolicy(observation_space=env.observation_space, 
                            action_space=env.action_space,
                            actor=actor,
                            critic=critic,
                            predictor=predictor,
                            emitter=emitter,
                            inverse_critic=inv_critic,
                            acps=acps)
    return ac, acps, batch_size, seq_len, env, expert


def make_acl(device):
    device = device
    acla = ActiveCriticLearnerArgs()
    acla.data_path = '/home/hendrik/Documents/master_project/LokalData/WSM/'
    #acla.data_path = '/data/bing/hendrik/'
    acla.device = device
    acla.extractor = DummyExtractor()
    acla.imitation_phase = False
    acla.logname = 'straight_dense_50'
    acla.tboard = True
    acla.batch_size = 32
    acla.validation_episodes = 1
    acla.training_epsiodes = 1
    acla.actor_threshold = 1e-2
    acla.critic_threshold = 1e-2
    acla.predictor_threshold = 1e-2
    acla.gen_scores_threshold = 1e-1
    acla.num_cpu = acla.validation_episodes

    seq_len = 10
    ac, acps, batch_size, seq_len, env, expert= setup_opt_state(device=device, batch_size=acla.batch_size, seq_len=seq_len)
    
    acps.opt_steps = 20
    acla.val_every = 1
    acla.add_data_every = 1

    

    eval_env, expert = make_vec_env('reach', num_cpu=acla.num_cpu, seq_len=seq_len)
    acl = ActiveCriticLearner(ac_policy=ac, env=env, eval_env=eval_env, network_args_obj=acla)
    return acl, env, expert, seq_len, device


In [None]:
acl, env, expert, seq_len, device = make_acl(device='cpu')

In [None]:
acl.train(10)

In [21]:
obs = env.reset()

In [22]:
acl.policy.reset()

In [23]:
acl.policy.args_obj.goal_label_multiplier = 1

In [24]:
acl.policy.args_obj.opt_steps = 10

In [25]:
actions = acl.policy.predict(obs)

 |----------------------------------------------------------------------------------------------------| 0.0% Predicting Epsiode

In [26]:
acl.policy.history.scores[0][0,].shape

torch.Size([10, 10, 1])

In [27]:
acl.policy.goal_label

tensor([[[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]]])

In [28]:
acl.policy.history.scores[0][0,]

tensor([[[0.1472],
         [0.0846],
         [0.0512],
         [0.0334],
         [0.0223],
         [0.0151],
         [0.0101],
         [0.0058],
         [0.0028],
         [0.0016]],

        [[0.1513],
         [0.0857],
         [0.0737],
         [0.0478],
         [0.0308],
         [0.0212],
         [0.0135],
         [0.0081],
         [0.0040],
         [0.0018]],

        [[0.1513],
         [0.0860],
         [0.0742],
         [0.0494],
         [0.0328],
         [0.0233],
         [0.0155],
         [0.0100],
         [0.0061],
         [0.0039]],

        [[0.1513],
         [0.0862],
         [0.0748],
         [0.0510],
         [0.0350],
         [0.0254],
         [0.0175],
         [0.0121],
         [0.0082],
         [0.0060]],

        [[0.1513],
         [0.0864],
         [0.0752],
         [0.0527],
         [0.0372],
         [0.0276],
         [0.0197],
         [0.0142],
         [0.0104],
         [0.0081]],

        [[0.1513],
         [0.0867],
  

In [32]:
acl.policy.history.scores[0][0,].mean(dim=1)[0]

tensor([0.0374])

In [33]:
acl.policy.history.scores[0][0,].mean(dim=1)[1]

tensor([0.0438])

In [34]:
acl.policy.history.scores[0][0,].mean(dim=1)[2]

tensor([0.0452])

In [35]:
acl.policy.history.scores[0][0,].mean(dim=1)[9]

tensor([0.0564])

In [None]:
th.optim.Adam()