In [None]:
# %%
import torch.optim as optim
from planet.models.determinstic_state import DeterministicStateModel
from planet.models.stochastic_state import StochasticStateModel
from planet.models.reward import RewardModel
from planet.models.encoder import EncoderModel
from planet.models.observation import ObservationModel
from planet.utils.seed import set_seed
from planet.utils.envs import make_env
from planet.trainer import load_models

In [None]:
# set seed for reproducibility
set_seed(13)
free_nats = 3.0
action_repeat = 4
action_noise = 0.3

env_config = {
    "env_type": "gym", 
    "skip": action_repeat,
    "id": "HalfCheetah-v4", 
    "kwargs": {"render_mode": "rgb_array"}
}
env = make_env(env_config)

In [None]:
# observation info
observation, info = env.reset()
observation_size = observation.shape[0]

# action info
action = env.action_space.sample()
action_size = action.shape[0]

# model sizes
state_size = 30
hidden_state_size = 200
hidden_layer_size = 300

det_state_model = DeterministicStateModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    action_size=action_size,
    hidden_layer_size=hidden_layer_size
).cuda()


stoch_state_model = StochasticStateModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()


obs_model = ObservationModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    observation_size=observation_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

reward_obs_model = RewardModel(
    hidden_state_size=hidden_state_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

enc_model = EncoderModel(
    hidden_state_size=hidden_state_size,
    observation_size=observation_size,
    state_size=state_size,
    hidden_layer_size=hidden_layer_size,
).cuda()

models = {
    "det_state_model": det_state_model,
    "stoch_state_model": stoch_state_model,
    "obs_model": obs_model,
    "reward_obs_model": reward_obs_model,
    "enc_model": enc_model,
}


lr = 6e-4
all_params = list(det_state_model.parameters()) + list(stoch_state_model.parameters()) + list(obs_model.parameters()) + list(reward_obs_model.parameters()) + list(enc_model.parameters())
optimizers = {
    "all_params": optim.Adam(
        all_params,
        lr=lr, 
    ),
}

load_models(models, optimizers, "checkpoints/latest_model.pth")

In [None]:
config={
    "env_config": env_config,
    "train_config": {
        "S": 5,
        "train_steps": 2_000,
        "C": 100,
        "B": 50,
        "L": 50,
        "H": 15,
        "I": 10,
        "J": 1000,
        "K": 100,
        "log_interval": 1,
        "action_noise": 0.1,
        "free_nats": free_nats,
        "checkpoint_dir": "checkpoints",
        "max_episode_length": 1000,
        "action_repeat": action_repeat,
        "all_params": all_params
    },
    "state_config": {
        "hidden_state_size": hidden_state_size,
        "state_size": state_size,
        "action_size": action_size,
    },
    "eval_config": {
        "eval_interval": 25,
        "num_eval_episodes": 5,
    }
}

In [None]:
import torch
from tqdm import tqdm
from typing import Any, Optional
from planet.trainer import _set_models_eval
from planet.planning.planner import latent_planning


@torch.no_grad()
def collect_episode(
    env: Any, action_noise: Optional[float] = None
):
    _set_models_eval(models)

    # reset environment
    sequence, episode_reward = [], 0
    obs, _ = env.reset()

    # initialize hidden state and state belief
    hidden_state = torch.zeros(
        1, config["state_config"]["hidden_state_size"]
    ).cuda()

    max_episode_length = config["train_config"]["max_episode_length"]
    action_repeat = config["train_config"]["action_repeat"]
    T = max_episode_length // action_repeat

    for _ in tqdm(range(T)):
        sequence.append(env.render())
        observation = torch.from_numpy(obs).float().unsqueeze(0).cuda()
        posterior_dist = models["enc_model"](
            hidden_state=hidden_state,
            observation=observation,
        )
        action = latent_planning(
            H=config["train_config"]["H"],
            I=config["train_config"]["I"],
            J=config["train_config"]["J"],
            K=config["train_config"]["K"],
            hidden_state=hidden_state,
            current_state_belief=posterior_dist,
            deterministic_state_model=models["det_state_model"],
            stochastic_state_model=models["stoch_state_model"],
            reward_model=models["reward_obs_model"],
            action_size=config["state_config"]["action_size"],
        )

        # add exploration noise
        if action_noise is not None:
            action += torch.randn_like(action) * action_noise

        # take action in the environment
        action_cpu = action.cpu()
        next_obs, reward, terminated, truncated, _ = env.step(
            action_cpu.numpy()
        )

        # update episode reward and add
        # step to the sequence
        episode_reward += reward
        done = 1 if terminated or truncated else 0
        if done == 1:
            break

        # update observation
        obs = next_obs

        # update hidden state
        hidden_state = models["det_state_model"](
            hidden_state=hidden_state,
            state=posterior_dist.sample(),
            action=action.unsqueeze(0),
        )

    return sequence, episode_reward

In [None]:
import cv2
import imageio
import numpy as np
from typing import List, Tuple

def tensors_to_gif(
    tensor_list: List[np.ndarray], filename: str, duration=2.0, value_range: Tuple[int, int] = (-1, 1)
) -> None:
    """Save a list of tensors as a GIF file.

    :param tensor_list: List of tensors
    :param filename: Name of the GIF file
    :param duration: Duration of each frame in seconds
    """
    images = []

    for image in tensor_list:
        # image = image.transpose(1, 2, 0) + 0.5
        # resize image
        image = cv2.resize(image, (256, 256))
        # Ensure pixel values are in the range [0, 255] and of integer type
        # image = (image * 255).astype("uint8")
        images.append(image)

    # Write images to GIF using imageio
    imageio.mimsave(filename, images, format="GIF", duration=duration, loop=0)  # type: ignore[call-overload]

In [None]:
rewards = []
sequences = []

for i in range(1):
    sequence, reward = collect_episode(env, action_noise=None)
    sequences.append(sequence)
    rewards.append(reward)

In [None]:
np.mean(rewards), np.std(rewards)

In [None]:
rewards

In [None]:
np.median(rewards)

In [None]:
tensors_to_gif(sequences[0], "out.gif")