In [1]:
import random
from dataclasses import dataclass

import numpy as np
import pickle
import torch

from datasets import load_dataset, load_from_disk
from huggingface_hub import login, list_models
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments

from ardt import SingleAgentRobustDT, DecisionTransformerGymDataCollator
from access_tokens import WRITE_TOKEN

import warnings
warnings.filterwarnings('ignore')

In [2]:
# login(token=WRITE_TOKEN)

## Configs

In [3]:
envs = {
    0: "walker2d-expert-v2",
    1: "halfcheetah-expert-v2",
}

chosen_env = envs[1]

## Loading and exploring the dataset: halfcheetah (expert)

Some notes:
* This is a multi-dimensional, continuous environment. States are represented by 17 continuous dimensions; actions are represented by 7 continuous dimensions.
* The state space includes the positions and velocities of multiple body parts of the robotic cheetah, which are continuous, unbounded, real-valued quantities.
* The action space consists of torques applied to the joints, which are real-valued and thus continuous. They are however limited to the interval [-1, 1]. 

For more details: https://www.gymlibrary.dev/environments/mujoco/half_cheetah/

In [4]:
# dataset = load_dataset("edbeeching/decision_transformer_gym_replay", chosen_env)['train']

In [5]:
# pr_actions = dataset['actions']
# new_pr_actions = []
# adv_actions =[]
# for tr in pr_actions:
#     pr_l = []
#     adv_l = []
#     for a in tr:
#         adv = np.array(a) * 0.1 * np.random.choice([-1, 1], size=len(a)) * np.random.rand(len(a))
#         pr = np.array(a) + adv
#         pr_l.append(list(pr))
#         adv_l.append(list(adv))
#     new_pr_actions.append(pr_l)
#     adv_actions.append(adv_l)

# dataset = dataset.add_column('pr_actions', new_pr_actions)
# dataset = dataset.add_column('adv_actions', adv_actions)
# dataset = dataset.remove_columns(['actions'])
# dataset.save_to_disk('./datasets/toy_dataset')

In [6]:
dataset = load_from_disk("./datasets/toy_dataset")

In [7]:
# dataset_entries = []
# with open("./datasets/halfcheetah-rarl-v2.dat", "rb") as f:
#     while True:
#         print(pickle.load(f))
#         try:
#             dataset_entries.append(pickle.load(f))
#         except EOFError:
#             break

In [8]:
print("Dataset elements: ", dataset[0].keys())
print("Number of steps: ", len(dataset[0]['observations']))
print("Size of state representation: ", len(dataset[0]['observations'][0]))
print("Size of action representation: ", len(dataset[0]['pr_actions'][0]))
print("Size of action representation: ", len(dataset[0]['adv_actions'][0]))
print("Reward type: ", type(dataset[0]['rewards'][0]))
print("Done flag: ", type(dataset[0]['dones'][0]))

Dataset elements:  dict_keys(['observations', 'rewards', 'dones', 'pr_actions', 'adv_actions'])
Number of steps:  1000
Size of state representation:  17
Size of action representation:  6
Size of action representation:  6
Reward type:  <class 'float'>
Done flag:  <class 'bool'>


## Setting up the model

In [9]:
RETURNS_SCALE = 1000.0
CONTEXT_SIZE = 20

In [10]:
collator = DecisionTransformerGymDataCollator(dataset, context_size=CONTEXT_SIZE, returns_scale=RETURNS_SCALE)
config = DecisionTransformerConfig(state_dim=collator.state_dim, 
                                   pr_act_dim=collator.pr_act_dim,
                                   adv_act_dim=collator.adv_act_dim,
                                   max_ep_len=collator.max_ep_len,
                                   context_size=collator.context_size,
                                   state_mean=list(collator.state_mean),
                                   state_std=list(collator.state_std),
                                   scale=collator.scale,
                                   lambda1=0.3,
                                   lambda2=10.0,)
model = SingleAgentRobustDT(config)

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


## Training the model

In [11]:
# my_env_name = "dt-" + chosen_env.split("-")[0]
# models = sorted([m.modelId.split("/")[-1] for m in list_models(author="afonsosamarques")])
# models = [m for m in models if my_env_name in m]
# if len(models) > 0:
#     latest_version = [m.split("-")[-1][1:] for m in models][-1]
#     new_version = "v" + str(int(latest_version) + 1)
# else:
#     new_version = "v0"
# model_name = my_env_name + "-" + new_version
# print(model_name)

model_name = 'dt-halfcheetah-v3'

In [12]:
# we use the same hyperparameters are in the authors original implementation, but train for fewer iterations
training_args = TrainingArguments(
    output_dir=model_name,
    remove_unused_columns=False,
    num_train_epochs=250,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
    use_mps_device=True,
    report_to="none",
    push_to_hub=False,
    hub_model_id=model_name,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()
trainer.save_model()
# trainer.push_to_hub()

  0%|          | 0/4000 [00:00<?, ?it/s]

{'loss': 0.4496, 'learning_rate': 9.722222222222223e-05, 'epoch': 31.25}
{'loss': -0.1207, 'learning_rate': 8.333333333333334e-05, 'epoch': 62.5}
{'loss': -0.189, 'learning_rate': 6.944444444444444e-05, 'epoch': 93.75}
{'loss': -0.2181, 'learning_rate': 5.555555555555556e-05, 'epoch': 125.0}
{'loss': -0.2353, 'learning_rate': 4.166666666666667e-05, 'epoch': 156.25}
{'loss': -0.2516, 'learning_rate': 2.777777777777778e-05, 'epoch': 187.5}
