In [None]:
import numpy as np
import torch
import gymnasium as gym
import renderlab as rl

from huggingface_hub import interpreter_login
from transformers import DecisionTransformerModel

In [None]:
interpreter_login()

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = DecisionTransformerModel.from_pretrained("afonsosamarques/dt-halfcheetah-v1", use_auth_token=True).to(device)

In [None]:
def get_action(model, states, actions, rewards, returns_to_go, timesteps, device):
    # NOTE this implementation does not condition on past rewards

    # reshape to model input format
    states = states.reshape(1, -1, model.config.state_dim)
    actions = actions.reshape(1, -1, model.config.act_dim)
    returns_to_go = returns_to_go.reshape(1, -1, 1)
    timesteps = timesteps.reshape(1, -1)

    # normalisation constants
    state_mean = torch.from_numpy(np.array(model.config.state_mean).astype(np.float32)).to(device=device)
    state_std = torch.from_numpy(np.array(model.config.state_std).astype(np.float32)).to(device=device)

    # retrieve window of observations based on context length
    states = states[:, -model.config.context_size :]
    actions = actions[:, -model.config.context_size :]
    returns_to_go = returns_to_go[:, -model.config.context_size :]
    timesteps = timesteps[:, -model.config.context_size :]

    # normlisation
    states = (states - state_mean) / state_std

    # pad all tokens to sequence length
    padlen = model.config.context_size - states.shape[1]
    attention_mask = torch.cat([torch.zeros(padlen, device=device), torch.ones(states.shape[1], device=device)]).to(dtype=torch.long).reshape(1, -1)
    states = torch.cat([torch.zeros((1, padlen, model.config.state_dim), device=device), states], dim=1).float()
    actions = torch.cat([torch.zeros((1, padlen, model.config.act_dim), device=device), actions], dim=1).float()
    returns_to_go = torch.cat([torch.zeros((1, padlen, 1), device=device), returns_to_go], dim=1).float()
    timesteps = torch.cat([torch.zeros((1, padlen), dtype=torch.long, device=device), timesteps], dim=1)

    # forward pass
    _, action_preds, _ = model.forward(
        states=states,
        actions=actions,
        rewards=rewards,
        returns_to_go=returns_to_go,
        timesteps=timesteps,
        attention_mask=attention_mask,
        return_dict=False,
    )

    return action_preds[0, -1]

In [None]:
SCALE = 1000.0   # normalisation for rewards/returns
TARGET_RETURN = 12000 / SCALE

env = gym.make("HalfCheetah-v4", render_mode = "rgb_array")
env = rl.RenderFrame(env, "./" + model.config._name_or_path.split("/")[-1] + "-output-frames")
state, _ = env.reset()

episode_return, episode_length = 0, 0
target_return = torch.tensor(TARGET_RETURN, device=device, dtype=torch.float32).reshape(1, 1)
states = torch.from_numpy(state).reshape(1, model.config.state_dim).to(device=device, dtype=torch.float32)
actions = torch.zeros((0, model.config.act_dim), device=device, dtype=torch.float32)
rewards = torch.zeros(0, device=device, dtype=torch.float32)
timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

for t in range(model.config.max_ep_len):
  actions = torch.cat([actions, torch.zeros((1, model.config.act_dim), device=device)], dim=0)
  rewards = torch.cat([rewards, torch.zeros(1, device=device)])

  action = get_action(
      model,
      states,
      actions,
      rewards,
      target_return,
      timesteps,
      device,
  )
  actions[-1] = action
  action = action.detach().cpu().numpy()

  state, reward, done, _, _ = env.step(action)

  cur_state = torch.from_numpy(state.astype(np.float32)).to(device=device).reshape(1, model.config.state_dim)
  states = torch.cat([states, cur_state], dim=0)
  rewards[-1] = reward

  pred_return = target_return[0, -1] - (reward / SCALE)
  target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
  timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

  episode_return += reward
  episode_length += 1

  if done: 
    break

In [None]:
env.play()