In [1]:
import torch as th
from active_critic.learner.active_critic_learner import ActiveCriticLearner, ACLScores
from active_critic.learner.active_critic_args import ActiveCriticLearnerArgs
from active_critic.policy.active_critic_policy import ActiveCriticPolicy
from active_critic.utils.gym_utils import make_dummy_vec_env, make_vec_env, parse_sampled_transitions, sample_expert_transitions, DummyExtractor, new_epoch_reach, sample_new_episode
from active_critic.utils.pytorch_utils import make_part_obs_data, count_parameters, build_tf_horizon_mask
from active_critic.utils.dataset import DatasetAC
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from active_critic.utils.dataset import DatasetAC
from active_critic.model_src.whole_sequence_model import (
    WholeSequenceModel)
from active_critic.model_src.transformer import (
    ModelSetup)
from active_critic.policy.active_critic_policy import ActiveCriticPolicySetup, ActiveCriticPolicy
from active_critic.model_src.state_model import *


from gym import Env
th.manual_seed(0)

def make_acps(seq_len, extractor, new_epoch, batch_size = 2, device='cpu', horizon = 0):
    acps = ActiveCriticPolicySetup()
    acps.device=device
    acps.epoch_len=seq_len
    acps.extractor=extractor
    acps.new_epoch=new_epoch
    acps.opt_steps=100
    acps.inference_opt_lr = 1e-1
    acps.optimizer_class = th.optim.Adam
    acps.optimize = True
    acps.batch_size = batch_size
    acps.pred_mask = build_tf_horizon_mask(seq_len=seq_len, horizon=horizon, device=device)
    acps.opt_mask = th.zeros([seq_len, 1], device=device, dtype=bool)
    acps.opt_mask[:,-1] = 1
    return acps

def setup_opt_state(batch_size, seq_len, device='cpu'):
    num_cpu = 1
    env, expert = make_vec_env('reach', num_cpu, seq_len=seq_len)
    d_output = env.action_space.shape[0]
    embed_dim = 10
    lr = 1e-3

    actor_args = StateModelArgs()
    actor_args.arch = [200, 200, env.action_space.shape[0]]
    actor_args.device = device
    actor_args.lr = lr
    actor = StateModel(args=actor_args)

    critic_args = StateModelArgs()
    critic_args.arch = [200, 200, 1]
    critic_args.device = device
    critic_args.lr = lr
    critic = StateModel(args=critic_args)

    emitter_args = StateModelArgs()
    emitter_args.arch = [200, 200, embed_dim]
    emitter_args.device = device
    emitter_args.lr = lr
    emitter = StateModel(args=emitter_args)

    predictor_args = StateModelArgs()
    predictor_args.arch = [200, 200, embed_dim]
    predictor_args.device = device
    predictor_args.lr = lr
    predictor = StateModel(args=emitter_args)


    acps = make_acps(
        seq_len=seq_len, extractor=DummyExtractor(), new_epoch=new_epoch_reach, device=device, batch_size=batch_size)
    acps.opt_steps = 10
    acps.clip = True
    ac = ActiveCriticPolicy(observation_space=env.observation_space, 
                            action_space=env.action_space,
                            actor=actor,
                            critic=critic,
                            predictor=predictor,
                            emitter=emitter,
                            acps=acps)
    return ac, acps, batch_size, seq_len, env, expert


def make_acl():
    device = 'cpu'
    acla = ActiveCriticLearnerArgs()
    acla.data_path = '/home/hendrik/Documents/master_project/LokalData/TransformerImitationLearning/'
    acla.device = device
    acla.extractor = DummyExtractor()
    acla.imitation_phase = False
    acla.logname = 'reach_plot_embedding'
    acla.tboard = True
    acla.batch_size = 32
    acla.val_every = 1000
    acla.add_data_every = 100
    acla.validation_episodes = 1
    acla.training_epsiodes = 1
    acla.actor_threshold = 1e-2
    acla.critic_threshold = 1e-2
    acla.predictor_threshold = 1e-2
    acla.num_cpu = 1

    batch_size = 2
    seq_len = 5
    ac, acps, batch_size, seq_len, env, expert= setup_opt_state(device=device, batch_size=batch_size, seq_len=seq_len)
    eval_env, expert = make_vec_env('reach', num_cpu=acla.num_cpu, seq_len=seq_len)
    acl = ActiveCriticLearner(ac_policy=ac, env=env, eval_env=eval_env, network_args_obj=acla)
    return acl, env, expert, seq_len, device
acl, env, expert, seq_len, device = make_acl()


2022-11-03 01:30:39.572305: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
  logger.warn(
2022-11-03 01:30:52.791031: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
'''transitions = sample_expert_transitions(
    policy=expert.predict, env=env, episodes=100)

exp_actions, exp_observations, exp_rewards = parse_sampled_transitions(
    transitions=transitions, extractor=DummyExtractor(), device=device)
imitation_data = DatasetAC(device=device)
imitation_data.onyl_positiv = False
imitation_data.add_data(obsv=exp_observations, actions=exp_actions, reward=exp_rewards)
acl.setDatasets(train_data=imitation_data)'''

'transitions = sample_expert_transitions(\n    policy=expert.predict, env=env, episodes=100)\n\nexp_actions, exp_observations, exp_rewards = parse_sampled_transitions(\n    transitions=transitions, extractor=DummyExtractor(), device=device)\nimitation_data = DatasetAC(device=device)\nimitation_data.onyl_positiv = False\nimitation_data.add_data(obsv=exp_observations, actions=exp_actions, reward=exp_rewards)\nacl.setDatasets(train_data=imitation_data)'

In [3]:
acl.train(epochs=10000)

Sampling transitions. 1


  actions = th.tensor(actions, dtype=th.float, device=device)


Sampling transitions. 1
/home/hendrik/Documents/master_project/LokalData/TransformerImitationLearning/reach_plot_embedding/best_validation
Success Rate: 0.0
Reward: 0.11843143403530121
training samples: 1
Sampling transitions. 1
Sampling transitions. 1
/home/hendrik/Documents/master_project/LokalData/TransformerImitationLearning/reach_plot_embedding/best_validation
Success Rate: 0.0
Reward: 0.15702444314956665
training samples: 2
Sampling transitions. 1
Sampling transitions. 1
/home/hendrik/Documents/master_project/LokalData/TransformerImitationLearning/reach_plot_embedding/best_validation
Success Rate: 0.0
Reward: 0.17711839079856873
training samples: 3
Sampling transitions. 1
Sampling transitions. 1
Success Rate: 0.0
Reward: 0.13051263988018036
training samples: 4
Sampling transitions. 1
Sampling transitions. 1
Success Rate: 0.0
Reward: 0.14634986221790314
training samples: 5
Sampling transitions. 1
Sampling transitions. 1
Success Rate: 0.0
Reward: 0.15402697026729584
training sample

In [None]:
count_parameters(acl.policy)

+--------------------------------+------------+
|            Modules             | Parameters |
+--------------------------------+------------+
|  actor.model.layers.0.weight   |    2200    |
|   actor.model.layers.0.bias    |    200     |
|  actor.model.layers.2.weight   |   40000    |
|   actor.model.layers.2.bias    |    200     |
|  actor.model.layers.4.weight   |    800     |
|   actor.model.layers.4.bias    |     4      |
|  critic.model.layers.0.weight  |    2800    |
|   critic.model.layers.0.bias   |    200     |
|  critic.model.layers.2.weight  |   40000    |
|   critic.model.layers.2.bias   |    200     |
|  critic.model.layers.4.weight  |    200     |
|   critic.model.layers.4.bias   |     1      |
| predicor.model.layers.0.weight |    546     |
|  predicor.model.layers.0.bias  |     39     |
| predicor.model.layers.2.weight |    7800    |
|  predicor.model.layers.2.bias  |    200     |
| predicor.model.layers.4.weight |   40000    |
|  predicor.model.layers.4.bias  |    20

187810