In [None]:
import sys
import os
dir = os.path.abspath('')
while not dir.endswith('ardt'): dir = os.path.dirname(dir)
if not dir in sys.path: sys.path.append(dir)

In [None]:
import numpy as np
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from collections import defaultdict

from datasets import Dataset

from utils.helpers import find_root_dir

from access_tokens import HF_WRITE_TOKEN

In [None]:
ARDT_DIR = find_root_dir()

DIR = ARDT_DIR + "/datasets-raw/arrl-sgld-raw-data/"
ENV_NAME = "Walker2d-v4"

In [None]:
def compute_sum(ds):
    return {'returns': sum(ds['rewards'])}

filenames = next(os.walk(DIR), (None, None, []))[2]
for filename in filenames:
    if ENV_NAME in filename:
        filepath = DIR + filename
        with (open(filepath, "rb")) as f:
            data = json.load(f)

        print(f"File {filename} has {max([int(k) for k in data.keys()])} episodes.")
        transformed_data = []
        total_number_of_steps = 0

        for k, v in data.items():
            if len(v) > 1000:
                sublists = [v[i:i + 1000] for i in range(0, len(v), 1000)]
                for sl in sublists:
                    total_number_of_steps += len(sl)
            elif len(v) > 20:
                sublists = [v]
                total_number_of_steps += len(v)
            else:
                sublists = []
                pass
            if len(sublists) > 0:
                transformed_data.extend(sublists)

        print("Confirming total episodes: ", len(transformed_data))
        print("Total steps: ", total_number_of_steps)

        trajectories = []
        for t in transformed_data:
            traj = defaultdict(list)
            for p in t:
                traj['observations'].append(p['state'])
                traj['pr_actions'].append(p['pr_action'])
                traj['adv_actions'].append(p['adv_action'])
                traj['rewards'].append(float(p['reward']))
                traj['dones'].append(bool(p['done']))
            trajectories.append(traj)

        d = defaultdict(list)
        for t in trajectories:
            d['observations'].append(t['observations'])
            d['pr_actions'].append(t['pr_actions'])
            d['adv_actions'].append(t['adv_actions'])
            d['rewards'].append(t['rewards'])
            d['dones'].append(t['dones'])

        ds = Dataset.from_dict(d)
        print("Dataset info:\n", ds)

        ds_to_vis = ds.map(compute_sum)
        ds_to_vis = ds_to_vis.filter(lambda x: x['returns'] > 300)

        print("New total episodes: ", len(ds_to_vis['rewards']))
        total_number_of_steps = 0
        for r in ds_to_vis['rewards']:
            total_number_of_steps += len(r)
        print("New total steps: ", total_number_of_steps, " vs expected ", total_number_of_steps - 1000*1000)

        print("Plotting returns distribution...")
        sns.displot(ds_to_vis['returns'], kind="kde", bw_adjust=0.5)
        plt.title(f"Returns distribution for {filename}")
        plt.show()

        print("=================================================================")
        

In [None]:
# MDP_VERSION = "arrl_sgld_nrmdp"
# ENV = "hopper"
# TYPE = "train"
# DS_VERSION = "v1"

# ds.save_to_disk(f'{ARDT_DIR}/datasets/{MDP_VERSION}_{TYPE}_{ENV}_{DS_VERSION}')