In [None]:
import sys
import os
dir = os.path.abspath('')
while not dir.endswith('ardt'): dir = os.path.dirname(dir)
if not dir in sys.path: sys.path.append(dir)

In [None]:
import numpy as np
import json

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from collections import defaultdict

from datasets import Dataset

from utils.helpers import find_root_dir

from access_tokens import HF_WRITE_TOKEN

In [None]:
ARDT_DIR = find_root_dir()

DIR = ARDT_DIR + "/datasets-raw/arrl-sgld-raw-data/"
MDP_VERSION = "arrl_sgld_nrmdp"
ENV = "halfcheetah"
TYPE = "train"

DS_VERSION = "v3"

In [None]:
filepath = f"{DIR}/" + "arrl_sgld_raw_dataset-HalfCheetah-v4-1408_2255" + ".json"

with (open(filepath, "rb")) as f:
    data = json.load(f)

max([int(k) for k in data.keys()])

In [None]:
transformed_data = []

for k, v in data.items():
    if len(v) > 1000:
        v = v[-1000:]
    elif len(v) < 1000:
        # pad with all zeros to the left
        pad = 1000 - len(v)
        for _ in range(pad):
            v.insert(0, {
                "state": np.zeros_like(v[-1]['state']).tolist(),
                "reward": 0,
                "pr_action": np.zeros_like(v[-1]['pr_action']).tolist(),
                "adv_action": np.zeros_like(v[-1]['adv_action']).tolist(),
                "done": False,
                "info": v[-1]['info']  # not a problem, is ignored
            })
    sublists = [v[i:i + 1000] for i in range(0, len(v), 1000)]
    transformed_data.extend(sublists)

len(transformed_data)

In [None]:
# convert list of dictionaries into dictionaries of lists
trajectories = []
for t in transformed_data:
    traj = defaultdict(list)
    for p in t:
        traj['observations'].append(p['state'])
        traj['pr_actions'].append(p['pr_action'])
        traj['adv_actions'].append(p['adv_action'])
        traj['rewards'].append(float(p['reward']))
        traj['dones'].append(bool(p['done']))
    trajectories.append(traj)

len(trajectories)

In [None]:
d = defaultdict(list)
for t in trajectories:
    d['observations'].append(t['observations'])
    d['pr_actions'].append(t['pr_actions'])
    d['adv_actions'].append(t['adv_actions'])
    d['rewards'].append(t['rewards'])
    d['dones'].append(t['dones'])

ds = Dataset.from_dict(d)
print(ds)

In [None]:
def compute_sum(ds):
    return {'returns': sum(ds['rewards'])}

ds_to_vis = ds.map(compute_sum)
sns.displot(ds_to_vis['returns'], kind="kde", bw_adjust=0.5);

In [None]:
ds.save_to_disk(f'{ARDT_DIR}/datasets/{MDP_VERSION}_{TYPE}_{ENV}_{DS_VERSION}')