In [None]:
import numpy as np
import torch
import gymnasium as gym
import renderlab as rl

from huggingface_hub import login
from transformers import DecisionTransformerConfig

from trainable_dt import TrainableDT

#
import warnings
warnings.filterwarnings('ignore')

from access_tokens import WRITE_TOKEN

In [None]:
login(WRITE_TOKEN)

## Config

In [None]:
hf_model_to_use = "dt-halfcheetah-rarl"

envs_in_gym = {
    0: "Walker2d-v4",
    1: "HalfCheetah-v4",
}

default_tr_per_1000 = {
    "Walker2d-v4": 7200,
    "HalfCheetah-v4": 12000
}

chosen_env = envs_in_gym[1]
env_target_per_1000 = default_tr_per_1000[chosen_env]

## Loading model

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
config = DecisionTransformerConfig.from_pretrained(f"afonsosamarques/{hf_model_to_use}", use_auth_token=True)
model = TrainableDT(config)
model = model.from_pretrained(f"afonsosamarques/{hf_model_to_use}", use_auth_token=True)

## Testing model

In [None]:
with torch.no_grad():
    env = gym.make(chosen_env, render_mode="rgb_array")
    env = rl.RenderFrame(env, "./" + model.config._name_or_path.split("/")[-1] + "-output-frames")
    state, _ = env.reset()

    returns_scale = model.config.returns_scale if "returns_scale" in model.config.to_dict().keys() else 1000.0  # FIXME compatibility
    episode_return, episode_length = 0, 0
    target_return = torch.tensor(env_target_per_1000/returns_scale, device=device, dtype=torch.float32).reshape(1, 1)
    states = torch.from_numpy(state).reshape(1, model.config.state_dim).to(device=device, dtype=torch.float32)
    actions = torch.zeros((0, model.config.act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    for t in range(model.config.max_ep_len):
        actions = torch.cat([actions, torch.zeros((1, model.config.act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)])

        action = model.get_action(
            states,
            actions,
            rewards,
            target_return,
            timesteps,
            device,
        )
        actions[-1] = action
        action = action.detach().cpu().numpy()

        state, reward, done, _, _ = env.step(action)

        cur_state = torch.from_numpy(state.astype(np.float32)).to(device=device).reshape(1, model.config.state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward

        pred_return = target_return[0, -1] - (reward / returns_scale)
        target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
        timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

        episode_return += reward
        episode_length += 1

        if done: 
            break

print(episode_return)

In [None]:
env.play()