In [None]:
import gymnasium as gym
import numpy as np
import torch

from huggingface_hub import login
from transformers import DecisionTransformerConfig

from model.ardt_vanilla import SingleAgentRobustDT
from model.ardt_full import TwoAgentRobustDT
from model.trainable_dt import TrainableDT
from utils.render_frame import RenderFrame

from utils.helpers import set_seed_everywhere

#
import warnings
warnings.filterwarnings('ignore')

from access_tokens import HF_WRITE_TOKEN

## Config

In [None]:
envs_in_gym = {
    0: "Walker2d-v4",
    1: "HalfCheetah-v4",
}

default_tr_per_1000 = {
    "Walker2d-v4": 5000,
    "HalfCheetah-v4": 12000  # 6000 in DT!!!!
}

chosen_env = envs_in_gym[1]
env_target_per_1000 = default_tr_per_1000[chosen_env]

In [None]:
hf_model_to_use = "ardt_full_all_plus_adv_d4rl"
# hf_model_to_use = "ardt_vanilla_all_plus_adv_d4rl"
# hf_model_to_use = "dt-halfcheetah-v2"

agent = {
    0: SingleAgentRobustDT,
    1: TwoAgentRobustDT,
    2: TrainableDT
}

chosen_agent = agent[0] if hf_model_to_use.startswith("ardt_vanilla") else (agent[1] if hf_model_to_use.startswith("ardt_full") else (agent[2 if hf_model_to_use.startswith("dt") else None]))
if chosen_agent is None:
    raise Exception("Model not available.")
elif hf_model_to_use.startswith("dt"):
    is_adv = False
else:
    is_adv = True

print(chosen_agent)
print(is_adv)

## Loading model

In [None]:
# # from HF
# login(token=HF_WRITE_TOKEN)
# config = DecisionTransformerConfig.from_pretrained(f"afonsosamarques/{hf_model_to_use}", use_auth_token=True)
# model = chosen_agent(config)
# model = model.from_pretrained(f"afonsosamarques/{hf_model_to_use}", use_auth_token=True)
# model_name = model.config._name_or_path.split("-")[-1]

# from local
config = DecisionTransformerConfig.from_pretrained(f"./agents-pipeline/{hf_model_to_use}", use_auth_token=True)
model = chosen_agent(config)
model = model.from_pretrained(f"./agents-pipeline/{hf_model_to_use}", use_auth_token=True)
model_name = model.config._name_or_path.split("/")[-1]

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

## Testing model

In [None]:
def sample_env_params(env):
    mb = env.model.body_mass
    mb = torch.tensor(mb)
    gauss = torch.distributions.Normal(mb, torch.ones_like(mb)*1.0)
    mb = gauss.sample()
    env.model.body_mass = np.array(mb)
    
    mb = env.model.opt.gravity
    mb = torch.tensor(mb)
    gauss = torch.distributions.Normal(mb, torch.ones_like(mb)*1.0)
    mb = gauss.sample()
    env.model.opt.gravity = np.array(mb)

    mb = env.model.geom_friction
    mb = torch.tensor(mb)
    gauss = torch.distributions.Normal(mb, torch.ones_like(mb)*0.1)
    mb = gauss.sample()
    env.model.geom_friction = np.array(mb)

    return env

In [None]:
rets = []
for i in range(10):
    with torch.no_grad():
        seed = (i+1)*2
        env = gym.make(chosen_env, render_mode="rgb_array")
        set_seed_everywhere(seed, env)
        
        env = sample_env_params(env)
        print("Checking that sampling worked: ")
        print(env.model.opt.gravity, "\n")
        
        # env = RenderFrame(env, "./env-sims/" + model_name)
        state, _ = env.reset()

        returns_scale = model.config.returns_scale if "returns_scale" in model.config.to_dict().keys() else 1000.0  # NOTE compatibility
        episode_return, episode_length = 0, 0
        target_return = torch.tensor(env_target_per_1000/returns_scale, device=device, dtype=torch.float32).reshape(1, 1)
        states = torch.from_numpy(state).reshape(1, model.config.state_dim).to(device=device, dtype=torch.float32)
        if is_adv:
            pr_actions = torch.zeros((0, model.config.pr_act_dim), device=device, dtype=torch.float32)
            adv_actions = torch.zeros((0, model.config.adv_act_dim), device=device, dtype=torch.float32)
        else:
            actions = torch.zeros((0, model.config.act_dim), device=device, dtype=torch.float32)
        rewards = torch.zeros(0, device=device, dtype=torch.float32)
        timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

        for t in range(model.config.max_ep_len):
            if is_adv:
                pr_actions = torch.cat([pr_actions, torch.zeros((1, model.config.pr_act_dim), device=device)], dim=0)
                adv_actions = torch.cat([adv_actions, torch.zeros((1, model.config.adv_act_dim), device=device)], dim=0)
            else:
                actions = torch.cat([actions, torch.zeros((1, model.config.act_dim), device=device)], dim=0)
        
            rewards = torch.cat([rewards, torch.zeros(1, device=device)])

            if is_adv:
                pr_action, adv_action = model.get_action(
                    states,
                    pr_actions,
                    adv_actions,
                    rewards,
                    target_return,
                    timesteps,
                    device,
                )
                pr_actions[-1] = pr_action
                adv_actions[-1] = adv_action
                action = pr_action.detach().cpu().numpy()
            else:
                action = model.get_action(
                    states,
                    actions,
                    rewards,
                    target_return,
                    timesteps,
                    device,
                )
                actions[-1] = action
                action = action.detach().cpu().numpy()

            state, reward, done, _, _ = env.step(action)

            cur_state = torch.from_numpy(state.astype(np.float32)).to(device=device).reshape(1, model.config.state_dim)
            states = torch.cat([states, cur_state], dim=0)
            rewards[-1] = reward

            pred_return = target_return[0, -1] - (reward / returns_scale)
            target_return = torch.cat([target_return, pred_return.reshape(1, 1)], dim=1)
            timesteps = torch.cat([timesteps, torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)

            episode_return += reward
            episode_length += 1

            if done or t == model.config.max_ep_len - 1:
                rets.append(episode_return)
                break

    print("Mean episode return: ", np.mean(rets))
    print("Std episode return: ", np.std(rets))
    print("Median episode return: ", np.median(rets))
    print("Max episode return: ", np.max(rets))
    print("Min episode return: ", np.min(rets))
    

In [None]:
# env.play()