In [None]:
import gym
import torch.optim as optim

from planet.models.determinstic_state import DeterministicStateModel
from planet.models.stochastic_state import StochasticStateModel
from planet.models.observation import ObservationModel
from planet.models.reward import RewardModel
from planet.models.encoder import EncoderModel

from planet.utils.wrappers import RepeatActionWrapper
from planet.utils.seed import set_seed
from planet.utils.envs import make_env
from planet.trainer import PlanetTrainer


%load_ext autoreload
%autoreload 2

In [None]:
# set seed for reproducibility
set_seed(0)

In [None]:
# initialize the environment
# env = gym.make("BipedalWalker-v3", hardcore=False)
# env = gym.make("Pendulum-v1")
# env = gym.make("LunarLander-v2", continuous=True)
# env = gym.make("HalfCheetah-v4")
# env = gym.make("Ant-v4")
# env = gym.make('Walker2d-v4')

# env_config = {"env_type": "gym", "id": "Pendulum-v1", "skip": 2}
action_repeat = 4
env_config = {"env_type": "gym", "id": "HalfCheetah-v4", "skip": action_repeat}
env = make_env(env_config)

In [None]:
# observation info
observation, info = env.reset()
observation_size = observation.shape[0]

# action info
action = env.action_space.sample()
action_size = action.shape[0]

# model sizes
state_size = 30
hidden_state_size = 200
hidden_layer_size = 200

In [None]:
action_size

In [None]:
det_state_model = DeterministicStateModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    action_size=action_size,
    hidden_layer_size=hidden_layer_size
).cuda()

In [None]:
stoch_state_model = StochasticStateModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

In [None]:
obs_model = ObservationModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    observation_size=observation_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

In [None]:
reward_obs_model = RewardModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

In [None]:
enc_model = EncoderModel(
    hidden_state_size=hidden_state_size,
    observation_size=observation_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

In [None]:
models = {
    "det_state_model": det_state_model,
    "stoch_state_model": stoch_state_model,
    "obs_model": obs_model,
    "reward_obs_model": reward_obs_model,
    "enc_model": enc_model,
}


lr = 6e-4
all_params = list(det_state_model.parameters()) + list(stoch_state_model.parameters()) + list(obs_model.parameters()) + list(reward_obs_model.parameters()) + list(enc_model.parameters())

optimizers = {
    "all_params": optim.Adam(
        all_params,
        lr=lr, 
    ),
}



In [None]:
trainer = PlanetTrainer(
    models=models,
    optimizers=optimizers,
    config={
        "env_config": env_config,
        "train_config": {
            "S": 5,
            "train_steps": 1_000,
            "C": 100,
            "B": 50,
            "L": 50,
            "H": 15,
            "I": 10,
            "J": 1000,
            "K": 100,
            "log_interval": 1,
            "action_noise": 0.1,
            "free_nats": 2.0,
            "checkpoint_dir": "checkpoints-state",
            "max_episode_length": 1000,
            "action_repeat": action_repeat,
            "all_params": all_params
        },
        "state_config": {
            "hidden_state_size": hidden_state_size,
            "state_size": state_size,
            "action_size": action_size,
        },
        "eval_config": {
            "eval_interval": 25,
            "num_eval_episodes": 5,
        }
    }
)

In [None]:
trainer.fit()

In [None]:
observation_size