In [None]:
import gymnasium as gym
import torch

from huggingface_hub import login
from transformers import DecisionTransformerConfig

from ardt_naive import SingleAgentRobustDT
from ardt_full import TwoAgentRobustDT
from render_frame import RenderFrame

#
import warnings
warnings.filterwarnings('ignore')

from access_tokens import WRITE_TOKEN

## Config

In [None]:
envs_in_gym = {
    0: "Walker2d-v4",
    1: "HalfCheetah-v4",
}

default_tr_per_1000 = {
    "Walker2d-v4": 7200,
    "HalfCheetah-v4": 12000
}

chosen_env = envs_in_gym[1]
env_target_per_1000 = default_tr_per_1000[chosen_env]

In [None]:
hf_model_to_use = "dt-halfcheetah-adrt-v1"

agent = {
    0: SingleAgentRobustDT,
    1: TwoAgentRobustDT
}

# FIXME hf_model_to_use.startswith("dt") for backwards compatibility
chosen_agent = agent[0] if hf_model_to_use.startswith("ardt-naive") or hf_model_to_use.startswith("dt") else agent[1]

## Loading model

In [None]:
# # from HF
# login(token=WRITE_TOKEN)
# config = DecisionTransformerConfig.from_pretrained(f"afonsosamarques/{hf_model_to_use}", use_auth_token=True)
# model = chosen_agent(config)
# model = model.from_pretrained(f"afonsosamarques/{hf_model_to_use}", use_auth_token=True)
# model_name = model.config._name_or_path.split("-")[-1]

# from local
config = DecisionTransformerConfig.from_pretrained(f"./{hf_model_to_use}", use_auth_token=True)
model = chosen_agent(config)
model = model.from_pretrained(f"./{hf_model_to_use}", use_auth_token=True)
model_name = model.config._name_or_path

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = "cpu"
model.to(device)

## Testing model

In [None]:
def get_action(model, states, pr_actions, adv_actions, rewards, returns_to_go, timesteps, device):
    # NOTE this implementation does not condition on past rewards
    # reshape to model input format
    states = states.reshape(1, -1, model.config.state_dim)
    pr_actions = pr_actions.reshape(1, -1, model.config.pr_act_dim)
    adv_actions = adv_actions.reshape(1, -1, model.config.adv_act_dim)
    returns_to_go = returns_to_go.reshape(1, -1, 1)
    timesteps = timesteps.reshape(1, -1)

    # normalisation constants
    state_mean = torch.from_numpy(np.array(model.config.state_mean).astype(np.float32)).to(device=device)
    state_std = torch.from_numpy(np.array(model.config.state_std).astype(np.float32)).to(device=device)

    # retrieve window of observations based on context length
    states = states[:, -model.config.context_size :]
    pr_actions = pr_actions[:, -model.config.context_size :]
    adv_actions = adv_actions[:, -model.config.context_size :]
    returns_to_go = returns_to_go[:, -model.config.context_size :]
    timesteps = timesteps[:, -model.config.context_size :]

    # normalising states
    states = (states - state_mean) / state_std

    # pad all tokens to sequence length
    padlen = model.config.context_size - states.shape[1]
    attention_mask = torch.cat([torch.zeros(padlen, device=device), torch.ones(states.shape[1], device=device)]).to(dtype=torch.long).reshape(1, -1)
    states = torch.cat([torch.zeros((1, padlen, model.config.state_dim), device=device), states], dim=1).float()
    pr_actions = torch.cat([torch.zeros((1, padlen, model.config.pr_act_dim), device=device), pr_actions], dim=1).float()
    adv_actions = torch.cat([torch.zeros((1, padlen, model.config.adv_act_dim), device=device), adv_actions], dim=1).float()
    returns_to_go = torch.cat([torch.zeros((1, padlen, 1), device=device), returns_to_go], dim=1).float()
    timesteps = torch.cat([torch.zeros((1, padlen), dtype=torch.long, device=device), timesteps], dim=1)

    # forward pass
    pr_action_preds, adv_action_preds = model.forward(
        is_train=False,
        states=states,
        pr_actions=pr_actions,
        adv_actions=adv_actions,
        rewards=rewards,
        returns_to_go=returns_to_go,
        timesteps=timesteps,
        attention_mask=attention_mask,
        return_dict=False,
    )

    return pr_action_preds[0, -1], adv_action_preds[0, -1]

In [None]:
with torch.no_grad():
    env = gym.make(chosen_env, render_mode="rgb_array")
    env = RenderFrame(env, "./" + model_name + "-output-frames")
    state, _ = env.reset()

    returns_scale = model.config.returns_scale if "returns_scale" in model.config.to_dict().keys() else 1000.0  # FIXME compatibility
    episode_return, episode_length = 0, 0
    target_return = torch.tensor(env_target_per_1000/returns_scale, device=device, dtype=torch.float32).reshape(1, 1)
    states = torch.from_numpy(state).reshape(1, model.config.state_dim).to(device=device, dtype=torch.float32)
    pr_actions = torch.zeros((0, model.config.pr_act_dim), device=device, dtype=torch.float32)
    adv_actions = torch.zeros((0, model.config.adv_act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    for t in range(model.config.max_ep_len):
        pr_actions = torch.cat([pr_actions, torch.zeros((1, model.config.pr_act_dim), device=device)], dim=0)
        adv_actions = torch.cat([adv_actions, torch.zeros((1, model.config.adv_act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)])

        pr_action, adv_action = get_action(
            model,
            states,
            pr_actions,
            adv_actions,
            rewards,
            target_return,
            timesteps,
            device,
        )
        pr_actions[-1] = pr_action
        adv_actions[-1] = adv_action

        action = pr_action.detach().cpu().numpy()
        state, reward, done, _, _ = env.step(action)

        cur_state = torch.from_numpy(state.astype(np.float32)).to(device=device).reshape(1, model.config.state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward

        pred_return = target_return[0, -1] - (reward / returns_scale)
        target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
        timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

        episode_return += reward
        episode_length += 1

        if done: 
            break

In [None]:
env.play()