In [1]:
import sys
import os
dir = os.path.abspath('')
while not dir.endswith('ardt'): dir = os.path.dirname(dir)
if not dir in sys.path: sys.path.append(dir)

In [2]:
import os
import wandb

from datasets import load_dataset, load_from_disk
from huggingface_hub import login, list_models
from transformers import DecisionTransformerConfig, Trainer, TrainingArguments

from model.trainable_dt import TrainableDT
from model.ardt_vanilla import SingleAgentRobustDT
from model.ardt_full import TwoAgentRobustDT
from model.ardt_utils import DecisionTransformerGymDataCollator

from utils.config_utils import find_root_dir

from access_tokens import HF_WRITE_TOKEN, WANDB_TOKEN

In [None]:
login(token=HF_WRITE_TOKEN)
wandb.login(key=WANDB_TOKEN)
os.environ["WANDB_PROJECT"] = 'ARDT-Project'
ARDT_DIR = find_root_dir()

## Configs

In [None]:
envs = {
    0: "walker2d-expert-v2",
    1: "halfcheetah-expert-v2",
}

chosen_env = envs[1]

In [None]:
agent = {
    0: SingleAgentRobustDT,
    1: TwoAgentRobustDT,
    2: TrainableDT
}

chosen_agent = agent[1]
model_name_prefix = "ardt-full" if chosen_agent == TwoAgentRobustDT else ("ardt-vanilla-" if chosen_agent == SingleAgentRobustDT else "dt")

## Loading and exploring the dataset

In [None]:
# from local
# dataset = load_from_disk(f"{ARDT_DIR}/datasets/rarl_halfcheetah_v1")
# dataset_name = "rarl"

dataset = load_from_disk(f"{ARDT_DIR}/datasets/d4rl_expert_halfcheetah")
dataset_name = "d4rl"

In [None]:
# # from hf
# dataset = load_dataset(f"afonsosamarques/rarl_halfcheetah_v1", use_auth_token=True)

In [None]:
print("Dataset elements: ", dataset[0].keys())
print("Number of steps: ", len(dataset[0]['observations']))
print("Size of state representation: ", len(dataset[0]['observations'][0]))
print("Size of action representation: ", len(dataset[0]['pr_actions'][0]))
print("Size of action representation: ", len(dataset[0]['adv_actions'][0]))
print("Reward type: ", type(dataset[0]['rewards'][0]))
print("Done flag: ", type(dataset[0]['dones'][0]))
print("Rewards len: ", len(dataset[0]['rewards']))
print("Dones len: ", len(dataset[0]['dones']))

## Setting up the model

In [None]:
MAX_RETURN = 15000.0
RETURNS_SCALE = 1000.0
BATCH_SIZE = 32
CONTEXT_SIZE = 20
N_EPOCHS = 300
WARMUP_STEPS = int(RETURNS_SCALE/BATCH_SIZE) * 25
EVAL_ITERS = 1
WANDB_PROJECT = "ARDT-Project"
TRACEBACK = False
SUFFIX = '-test'

LAMBDA1 = 0.05
LAMBDA2 = 0.8

In [None]:
collator = DecisionTransformerGymDataCollator(
                    dataset=dataset, 
                    context_size=CONTEXT_SIZE, 
                    returns_scale=RETURNS_SCALE
                )
                
config = DecisionTransformerConfig(
    state_dim=collator.state_dim, 
    pr_act_dim=collator.pr_act_dim,
    adv_act_dim=collator.adv_act_dim,
    max_ep_len=collator.max_ep_len,
    context_size=collator.context_size,
    state_mean=list(collator.state_mean),
    state_std=list(collator.state_std),
    scale=collator.scale,
    lambda1=LAMBDA1,
    lambda2=LAMBDA2,
    warmup_steps=WARMUP_STEPS,
    returns_scale=RETURNS_SCALE,
    max_return=MAX_RETURN
)

model = chosen_agent(config)
model

## Training the model

In [None]:
# my_env_name = model_name_prefix + chosen_env.split("-")[0]
# models = sorted([m.modelId.split("/")[-1] for m in list_models(author="afonsosamarques")])
# models = [m for m in models if my_env_name in m]
# if len(models) > 0:
#     latest_version = [m.split("-")[-3 if "lambda" in m else -1][1:] for m in models][-1]
#     new_version = "v" + str(int(latest_version) + 1)
# else:
#     new_version = "v0"
# model_name = my_env_name + "-" + new_version + "-rarl"
# print(model_name)

model_name = ""
full_model_name = model_name + "-" + dataset_name

In [None]:
if model_name is None:
    raise Exception("Please provide a model name")

# we use the same hyperparameters as in the authors original implementation, but train for fewer iterations
training_args = TrainingArguments(
    output_dir=f"{ARDT_DIR}/agents{SUFFIX}/" + full_model_name,
    remove_unused_columns=False,
    num_train_epochs=N_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    optim="adamw_torch",
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.0, # FIXME just changed it
    max_grad_norm=0.25,
    use_mps_device=True,
    push_to_hub=True,
    dataloader_num_workers=1,
    log_level="info",
    logging_steps=1,
    report_to="wandb",
    run_name=full_model_name,
    hub_model_id=full_model_name,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=collator,
)

trainer.train()
trainer.save_model()