In [None]:
import sys

sys.path.insert(0, "..")
sys.path.insert(0, "../disc_cop")
sys.path.insert(0, "../avg_corr")

import _pickle as pickle
import gymnasium as gym
import numpy as np
import os
import torch

from tqdm.notebook import tqdm

from disc_cop.utils import Buffer

from disc_cop.envs import ENV_ID_TO_NAME

In [None]:
class PPOBuffer:
    """
    A buffer for storing trajectories experienced by a PPO agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """

    def __init__(self, obs_dim, act_dim, size, fold, gamma=0.99):
        self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.next_obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32)
        self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.tim_buf = np.zeros(size, dtype=np.int32)
        self.logtarg_buf = np.zeros(size, dtype=np.float32)
        self.prod_buf = np.zeros(size, dtype=np.float32)
        self.logbev_buf = np.zeros(size, dtype=np.float32)
        self.gamma = gamma
        self.fold = fold
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size

    def store(self, obs, next_obs,act, rew, tim, logbev, logtarg):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size  # buffer has to have room so you can store
        self.obs_buf[self.ptr] = obs
        self.next_obs_buf[self.ptr] = next_obs
        self.act_buf[self.ptr] = act
        self.rew_buf[self.ptr] = rew
        self.tim_buf[self.ptr] = tim
        self.logbev_buf[self.ptr] = logbev
        self.logtarg_buf[self.ptr] = logtarg
        self.ptr += 1

    def finish_path(self):
        """
        Call this at the end of a trajectory, or when one gets cut off
        by an epoch ending. This looks back in the buffer to where the
        trajectory started, and uses rewards and value estimates from
        the whole trajectory to compute advantage estimates with GAE-Lambda,
        as well as compute the rewards-to-go for each state, to use as
        the targets for the value function.

        The "last_val" argument should be 0 if the trajectory ended
        because the agent reached a terminal state (died), and otherwise
        should be V(s_T), the value function estimated for the last state.
        This allows us to bootstrap the reward-to-go calculation to account
        for timesteps beyond the arbitrary episode horizon (or epoch cutoff).
        """

        path_slice = slice(self.path_start_idx, self.ptr)

        # the next two lines implement GAE-Lambda advantage calculation
        deltas = self.logtarg_buf - self.logbev_buf
        self.prod_buf[path_slice] = np.append(0,core.discount_cumsum(deltas[path_slice], 1)[:-1])

        self.path_start_idx = self.ptr

    def sample(self,batch_size,fold_num):
        """
        Call this at the end of an epoch to get all of the data from
        the buffer, with advantages appropriately normalized (shifted to have
        mean zero and std one). Also, resets some pointers in the buffer.
        """
        interval = int(self.ptr / self.fold)
        if self.fold>1:
            ind = np.random.randint(self.ptr-interval, size=batch_size)
            ind = ind + np.where(ind >= fold_num * interval, 1, 0) * interval
        else:
            ind = np.random.randint(self.ptr, size=batch_size)

        data = dict(obs=self.obs_buf[ind], act=self.act_buf[ind],
                    prod=self.prod_buf[ind],tim=self.tim_buf[ind],
                    logbev=self.logbev_buf[ind], logtarg=self.logtarg_buf[ind])
        return {k: torch.as_tensor(v, dtype=torch.float32) for k, v in data.items()}

    def delete_last_traj(self):
        self.ptr =self.path_start_idx

In [None]:
split = "train"
datasets_dir = "/Users/chanb/research/ualberta/Avg_OPE/local/orig_{}_dataset".format(split)
save_datasets_dir = "/Users/chanb/research/ualberta/Avg_OPE/local/datasets/"
num_seeds = 10

In [None]:
for dataset_path in tqdm(os.listdir(datasets_dir)):
    variant = dataset_path[:-4]

    key_val_pairs = dict()

    key = None
    val = None

    entries = variant.split("-")
    for entry_i, entry in enumerate(entries):
        if (entry_i + 1) % 2 == 0:
            val = entry

            if key == "env":
                val = "-".join(entries[entry_i:])

            assert key not in key_val_pairs
            key_val_pairs[key] = val

        else:
            key = entry

    # print(key_val_pairs)

    complete_path = os.path.join(datasets_dir, dataset_path)
    try:
        with open(complete_path, "rb") as f:
            f.seek(0)
            orig_buffer = pickle.load(f)
    except:
        print("{} failed to load".format(complete_path))
        continue

    buffer_size = int(key_val_pairs["buffer_size"])
    max_len = int(key_val_pairs["max_length"])

    if buffer_size != 16000:
        continue

    new_buffer = Buffer(
        [orig_buffer.obs_buf.shape[-1]],
        orig_buffer.act_buf.shape[1:],
        buffer_size // max_len,
        max_len,
        1,
    )

    for sample_i in range(orig_buffer.max_size):
        obs = orig_buffer.obs_buf[sample_i]
        next_obs = orig_buffer.next_obs_buf[sample_i]
        act = orig_buffer.act_buf[sample_i]
        rew = orig_buffer.rew_buf[sample_i]
        tim = orig_buffer.tim_buf[sample_i]
        logbev = orig_buffer.logbev_buf[sample_i]
        logtarg = orig_buffer.logtarg_buf[sample_i]

        new_buffer.store(
            obs,
            act,
            rew,
            next_obs,
            tim,
            logbev,
            logtarg,
        )

        if tim + 1 == max_len:
            new_buffer.finish_path()

    seed = int(key_val_pairs["seed"])
    if split == "test":
        seed = seed % num_seeds
    os.makedirs(
        os.path.join(
            save_datasets_dir,
            "{}-seed_{}".format(
                ENV_ID_TO_NAME[key_val_pairs["env"]],
                seed,
            )
        ),
        exist_ok=True,
    )
    pickle.dump(
        new_buffer,
        open(
            os.path.join(
                save_datasets_dir,
                "{}-seed_{}".format(
                    ENV_ID_TO_NAME[key_val_pairs["env"]],
                    seed,
                ),
                "{}-random_weight_{}-buffer_size_{}-max_len_{}.pkl".format(
                    split,
                    key_val_pairs["random_weight"],
                    key_val_pairs["buffer_size"],
                    key_val_pairs["max_length"],
                )
            ), "wb"
        )
    )
