In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import torch.optim as optim
from planet.models.determinstic_state import DeterministicStateModel
from planet.models.stochastic_state import StochasticStateModel
from planet.models.reward import RewardModel
from planet.models.encoder import ImageEncoderModel
from planet.models.observation import ImageObservationModel
from planet.trainer import PlanetTrainer

from planet.utils.wrappers import RepeatActionWrapper, GymPixelWrapper, ImagePreprocessorWrapper
from planet.utils.seed import set_seed
from planet.utils.envs import make_env


%load_ext autoreload
%autoreload 2

In [2]:
# set seed for reproducibility
set_seed(0)

In [5]:
env_config = {
    "env_type": "dm_control",
    "domain_name": "cheetah", 
    "task_name":"run", 
    "render_kwargs": {'width': 64, 'height': 64}
}time_step = env.reset()
print(time_step.observation)

env = make_env(env_config)

In [None]:
# action info
action = env.action_space.sample()
action_size = action.shape[0]

# model sizes
state_size = 30
hidden_state_size = 200
observation_size = 1024
hidden_layer_size = 300

In [None]:
det_state_model = DeterministicStateModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    action_size=action_size,
    hidden_layer_size=hidden_layer_size
).cuda()

In [None]:
stoch_state_model = StochasticStateModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

In [None]:
obs_model = ImageObservationModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    observation_size=observation_size,
).cuda()

In [None]:
reward_obs_model = RewardModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

In [None]:
enc_model = ImageEncoderModel(
    hidden_state_size=hidden_state_size,
    observation_size=observation_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

In [None]:
models = {
    "det_state_model": det_state_model,
    "stoch_state_model": stoch_state_model,
    "obs_model": obs_model,
    "reward_obs_model": reward_obs_model,
    "enc_model": enc_model,
}


lr = 1e-3
eps = 1e-4
optimizers = {
    "all_params": optim.Adam(
        list(det_state_model.parameters()) + 
        list(stoch_state_model.parameters()) + 
        list(obs_model.parameters()) + 
        list(reward_obs_model.parameters()) + 
        list(enc_model.parameters()), 
        lr=lr, 
        eps=eps
    ),
}



In [None]:
trainer = PlanetTrainer(
    models=models,
    optimizers=optimizers,
    config={
        "env_config": env_config,
        "train_config": {
            "S": 5,
            "train_steps": 100000,
            "C": 100,
            "B": 50,
            "L": 50,
            "H": 12,
            "I": 10,
            "J": 1000,
            "K": 100,
            "log_interval": 1,
            "action_noise": 0.3,
            "free_nats": 3.0,
            "checkpoint_dir": "checkpoints",
        },
        "state_config": {
            "hidden_state_size": hidden_state_size,
            "state_size": state_size,
            "action_size": action_size,
        },
        "eval_config": {
            "eval_interval": 10,
            "num_eval_episodes": 10,
        }
    }
)

In [12]:
trainer.fit()