In [None]:
import sys
import os
dir = os.path.abspath('')
while not dir.endswith('ardt'): dir = os.path.dirname(dir)
if not dir in sys.path: sys.path.append(dir)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd

from datasets import load_dataset, load_from_disk
from huggingface_hub import login

from utils.helpers import find_root_dir

from access_tokens import HF_WRITE_TOKEN

In [None]:
ARDT_DIR = find_root_dir()

datasets_dirname = "datasets-all"
datasets_dirpath = f"{ARDT_DIR}/{datasets_dirname}"

# get names of all directories under dataset_dirname path
dataset_dirpaths = []
dataset_dirnames = []
for name in os.listdir(datasets_dirpath):
    if os.path.isdir(f"{datasets_dirpath}/{name}"):
        dataset_dirpaths.append(f"{datasets_dirpath}/{name}")
        dataset_dirnames.append(name)

dataset_dirpaths = sorted(dataset_dirpaths)
dataset_dirnames = sorted(dataset_dirnames)


In [None]:
def compute_sum(ds):
    return {'returns': sum(ds['rewards'])}

def get_name(dataset_dirname):
    if dataset_dirname.startswith("arrl_sgld"):
        algo = "AR-DDPG-SGLD"
    elif dataset_dirname.startswith("arrl"):
        algo = "AR-DDPG"
    elif dataset_dirname.startswith("ppo"):
        algo = "PPO"
    else:
        algo = dataset_dirname.split("_")[0].capitalize()
    
    env_name = dataset_dirname.split("_")[-2].capitalize().replace("c", "C").replace("d", "D")

    return env_name + " " + algo + (" " + dataset_dirname.split("_")[-1].capitalize() if dataset_dirname.split("_")[-1] == "level" else "")

df = pd.DataFrame(columns=["returns", "Environment"])
for env in ['halfcheetah', 'hopper', 'walker2d']:
    for dataset_dirpath, dataset_dirname in zip(dataset_dirpaths, dataset_dirnames):
        # if "Combo" in get_name(dataset_dirname) or "PPO" in get_name(dataset_dirname):
        #     continue
        # if dataset_dirname != "arrl_sgld_train_halfcheetah_v0" and dataset_dirname != "arrl_sgld_train_hopper_v10" and dataset_dirname != "arrl_sgld_train_walker2d_v5":
        #     continue
        if env in dataset_dirname:
            dataset = load_from_disk(dataset_dirpath)
            ds_to_vis = ds_to_vis.map(compute_sum)
            temp_df = pd.DataFrame({
                'returns': ds_to_vis['returns'],
                'Environment': [get_name(dataset_dirname)] * len(ds_to_vis['returns'])
            })
            df = pd.concat([df, temp_df])

g = sns.FacetGrid(df, col="Environment", col_wrap=4, sharex=False, sharey=False)
g.map_dataframe(sns.histplot, x="returns", bins=100, color='blue')
g.set_axis_labels("Return")
g.set_titles(col_template="{col_name} Dataset")

num_cols = 4  # Number of columns in the grid
for row in range(g.axes.shape[0] // num_cols + 1):
    y_max = 0
    for col in range(num_cols):
        idx = row * num_cols + col
        if idx >= len(g.axes):
            break
        ax = g.axes[idx]
        y_max = max(y_max, ax.get_ylim()[1])

    # Set y_max for each subplot in the same row
    for col in range(num_cols):
        idx = row * num_cols + col
        if idx >= len(g.axes):
            break
        ax = g.axes[idx]
        ax.set_ylim(0, y_max)

plt.show()