In [None]:
import os

import wandb

from datasets import load_dataset, load_from_disk
from huggingface_hub import login, list_models
from transformers import DecisionTransformerConfig, Trainer, TrainingArguments

from trainable_dt import DecisionTransformerGymDataCollator, TrainableDT

#
import warnings
warnings.filterwarnings('ignore')

from access_tokens import HF_WRITE_TOKEN, WANDB_TOKEN

In [None]:
login(token=HF_WRITE_TOKEN)
wandb.login(key=WANDB_TOKEN)
os.environ["WANDB_PROJECT"] = 'ARDT-Project'

## Configs

In [None]:
envs = {
    0: "walker2d-expert-v2",
    1: "halfcheetah-expert-v2",
}

chosen_env = envs[1]

## Loading and exploring the dataset: halfcheetah (expert)

Some notes:
* This is a multi-dimensional, continuous environment. States are represented by 17 continuous dimensions; actions are represented by 7 continuous dimensions.
* The state space includes the positions and velocities of multiple body parts of the robotic cheetah, which are continuous, unbounded, real-valued quantities.
* The action space consists of torques applied to the joints, which are real-valued and thus continuous. They are however limited to the interval [-1, 1]. 

For more details: https://www.gymlibrary.dev/environments/mujoco/half_cheetah/

In [None]:
dataset = load_dataset("edbeeching/decision_transformer_gym_replay", chosen_env)['train']
# dataset = load_from_disk("./rarl_halfcheetah_v1")

In [None]:
print("Dataset elements: ", dataset[0].keys())
print("Number of steps: ", len(dataset[0]['observations']))
print("Size of state representation: ", len(dataset[0]['observations'][0]))

dataset = dataset.rename_columns({'pr_actions': 'actions'})
print("Size of action representation: ", len(dataset[0]['actions'][0]))
# print("Size of action representation: ", len(dataset[0]['pr_actions'][0]))
# print("Size of action representation: ", len(dataset[0]['adv_actions'][0]))

print("Reward type: ", type(dataset[0]['rewards'][0]))
print("Done flag: ", type(dataset[0]['dones'][0]))
print("Rewards len: ", len(dataset[0]['rewards']))
print("Dones len: ", len(dataset[0]['dones']))

## Processing the dataset

In [None]:
RETURNS_SCALE = 1000.0
CONTEXT_SIZE = 20

While most datasets on the hub are ready to use out of the box, sometimes we wish to perform some additional processing or modification of the dataset. 

In this case we wish to match the author's implementation (from the original paper), that is we need to:
* Normalize each feature by subtracting the mean and dividing by the standard deviation.
* Pre-compute discounted returns for each trajectory.
* Scale the rewards and returns by a factor of 1000.
* Augment the dataset sampling distribution so it takes into account the length of the expert agentâ€™s trajectories.

In order to perform this dataset preprocessing, we will use a custom Data Collator.

In [None]:
# see trainable_dt.py

## Create a trainable Decision Transformer (HF is not trainable by default)

In [None]:
# see trainable_dt.py

## Setting up the model

In [None]:
# putting together the model we just built
collator = DecisionTransformerGymDataCollator(dataset, context_size=CONTEXT_SIZE, returns_scale=RETURNS_SCALE)
config = DecisionTransformerConfig(state_dim=collator.state_dim, 
                                   act_dim=collator.act_dim,
                                   max_ep_len=collator.max_ep_len,
                                   context_size=collator.context_size,
                                   state_mean=list(collator.state_mean),
                                   state_std=list(collator.state_std),
                                   scale=collator.scale,)
model = TrainableDT(config)

## Training the model

In [None]:
model_name = None

In [None]:
if model_name is None:
    raise Exception("Please provide a model name")

# we use the same hyperparameters are in the authors original implementation, but train for fewer iterations
training_args = TrainingArguments(
    output_dir=model_name,
    remove_unused_columns=False,
    num_train_epochs=300,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
    use_mps_device=True,
    report_to="wandb",
    push_to_hub=True,
    run_name=model_name,
    hub_model_id=model_name,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()
trainer.save_model()