In [None]:
from datasets import load_dataset, load_from_disk
from huggingface_hub import login, list_models
from transformers import DecisionTransformerConfig, Trainer, TrainingArguments

from model.ardt_vanilla import SingleAgentRobustDT
from model.ardt_full import TwoAgentRobustDT
from model.ardt_utils import DecisionTransformerGymDataCollator

#
import warnings
warnings.filterwarnings('ignore')

from access_tokens import HF_WRITE_TOKEN

In [None]:
login(token=HF_WRITE_TOKEN)

## Configs

In [None]:
envs = {
    0: "walker2d-expert-v2",
    1: "halfcheetah-expert-v2",
}

chosen_env = envs[1]

In [None]:
agent = {
    0: SingleAgentRobustDT,
    1: TwoAgentRobustDT
}

chosen_agent = agent[1]
model_name_prefix = "ardt-" if chosen_agent == TwoAgentRobustDT else "ardt-vanilla-"

## Loading and exploring the dataset

In [None]:
# from local
dataset = load_from_disk("./datasets/rarl_halfcheetah_v1")

In [None]:
# # from hf
# dataset = load_dataset(f"afonsosamarques/rarl_halfcheetah_v1", use_auth_token=True)

In [None]:
print("Dataset elements: ", dataset[0].keys())
print("Number of steps: ", len(dataset[0]['observations']))
print("Size of state representation: ", len(dataset[0]['observations'][0]))
print("Size of action representation: ", len(dataset[0]['pr_actions'][0]))
print("Size of action representation: ", len(dataset[0]['adv_actions'][0]))
print("Reward type: ", type(dataset[0]['rewards'][0]))
print("Done flag: ", type(dataset[0]['dones'][0]))
print("Rewards len: ", len(dataset[0]['rewards']))
print("Dones len: ", len(dataset[0]['dones']))

## Setting up the model

In [None]:
RETURNS_SCALE = 1000.0
CONTEXT_SIZE = 20
N_EPOCHS = 300
WARMUP_EPOCHS = 50

In [None]:
collator = DecisionTransformerGymDataCollator(dataset, context_size=CONTEXT_SIZE, returns_scale=RETURNS_SCALE)
config = DecisionTransformerConfig(state_dim=collator.state_dim, 
                                   pr_act_dim=collator.pr_act_dim,
                                   adv_act_dim=collator.adv_act_dim,
                                   max_ep_len=collator.max_ep_len,
                                   context_size=collator.context_size,
                                   state_mean=list(collator.state_mean),
                                   state_std=list(collator.state_std),
                                   scale=collator.scale,
                                   lambda1=0.05,
                                   lambda2=10.0,
                                   warmup_epochs=WARMUP_EPOCHS,
                                   max_return=1000) # FIXME completely random, potentially not needed
model = chosen_agent(config)
model

## Training the model

In [None]:
my_env_name = model_name_prefix + chosen_env.split("-")[0]
models = sorted([m.modelId.split("/")[-1] for m in list_models(author="afonsosamarques")])
models = [m for m in models if my_env_name in m]
if len(models) > 0:
    latest_version = [m.split("-")[-3 if "lambda" in m else -1][1:] for m in models][-1]
    new_version = "v" + str(int(latest_version) + 1)
else:
    new_version = "v0"
model_name = my_env_name + "-" + new_version + "-rarl"
print(model_name)

In [None]:
# we use the same hyperparameters as in the authors original implementation, but train for fewer iterations
training_args = TrainingArguments(
    output_dir="./agents/" + model_name,
    remove_unused_columns=False,
    num_train_epochs=N_EPOCHS,
    per_device_train_batch_size=64,
    optim="adamw_torch",
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    max_grad_norm=0.25,
    use_mps_device=True,
    push_to_hub=True,
    report_to="none",
    hub_model_id=model_name,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()
trainer.save_model()
# trainer.push_to_hub()