In [3]:
# Notebook support or argpase
import sys; sys.argv=['']; del sys

# Prototyping datsaet reading for the BC experimetns
import random
import numpy as np

import os
import uuid
import datetime
import pickle as pkl
import compress_pickle as cpkl

import torch
import torch as th
import torch.nn as nn
from torch import Tensor

# General config related
from configurator import get_arg_dict, generate_args

# Env config related
from ss_baselines.av_nav.config import get_config
from ss_baselines.savi.config.default import get_config as get_savi_config
from ss_baselines.common.env_utils import construct_envs
from ss_baselines.common.environments import get_env_class

In [5]:
# region: Generating additional hyparams
CUSTOM_ARGS = [
    # General hyper parameters
    get_arg_dict("seed", int, 111),
    get_arg_dict("total-steps", int, 1_000_000),

    # SS env config
    get_arg_dict("config-path", str, "env_configs/audiogoal_rgb_nocont.yaml"),

    # PPO Hyper parameters
    get_arg_dict("num-envs", int, 10), # Number of parallel envs. 10 by default
    get_arg_dict("num-steps", int, 150), # For each env, how many steps are collected to form PPO Agent rollout.
    get_arg_dict("num-minibatches", int, 1), # Number of mini-batches the rollout data is split into to make the updates
    get_arg_dict("update-epochs", int, 4), # Number of gradient step for the policy and value networks
    get_arg_dict("gamma", float, 0.99),
    get_arg_dict("gae-lambda", float, 0.95),
    get_arg_dict("norm-adv", bool, True, metatype="bool"),
    get_arg_dict("clip-coef", float, 0.1), # Surrogate loss clipping coefficient
    get_arg_dict("clip-vloss", bool, True, metatype="bool"),
    get_arg_dict("ent-coef", float, 0.2), # Entropy loss coef; 0.2 in SS baselines
    get_arg_dict("vf-coef", float, 0.5), # Value loss coefficient
    get_arg_dict("max-grad-norm", float, 0.5),
    get_arg_dict("target-kl", float, None),
    get_arg_dict("lr", float, 2.5e-4), # Learning rate
    get_arg_dict("optim-wd", float, 0), # weight decay for adam optim
    ## Agent network params
    get_arg_dict("agent-type", str, "ss-default", metatype="choice",
        choices=["ss-default", "deep-etho",
                    "perceiver-gwt-gwwm", "perceiver-gwt-attgru"]),
    get_arg_dict("hidden-size", int, 512), # Size of the visual / audio features and RNN hidden states 
    ## Perceiver / PerceiverIO params: TODO: num_latnets, latent_dim, etc...
    get_arg_dict("pgwt-latent-type", str, "randn", metatype="choice",
        choices=["randn", "zeros"]), # Depth of the Perceiver
    get_arg_dict("pgwt-latent-learned", bool, True, metatype="bool"),
    get_arg_dict("pgwt-depth", int, 1), # Depth of the Perceiver
    get_arg_dict("pgwt-num-latents", int, 8),
    get_arg_dict("pgwt-latent-dim", int, 64),
    get_arg_dict("pgwt-cross-heads", int, 1),
    get_arg_dict("pgwt-latent-heads", int, 4),
    get_arg_dict("pgwt-cross-dim-head", int, 64),
    get_arg_dict("pgwt-latent-dim-head", int, 64),
    get_arg_dict("pgwt-weight-tie-layers", bool, False, metatype="bool"),
    get_arg_dict("pgwt-ff", bool, False, metatype="bool"),
    get_arg_dict("pgwt-num-freq-bands", int, 6),
    get_arg_dict("pgwt-max-freq", int, 10.),
    get_arg_dict("pgwt-use-sa", bool, False, metatype="bool"),
    ## Peceiver Modality Embedding related
    get_arg_dict("pgwt-mod-embed", int, 0), # Learnable modality embeddings
    ## Additional modalities
    get_arg_dict("pgwt-ca-prev-latents", bool, True, metatype="bool"), # if True, passes the prev latent to CA as KV input data

    # Logging params
    # NOTE: While supported, video logging is expensive because the RGB generation in the
    # envs hogs a lot of GPU, especially with multiple envs 
    get_arg_dict("save-videos", bool, False, metatype="bool"),
    get_arg_dict("save-model", bool, True, metatype="bool"),
    get_arg_dict("log-sampling-stats-every", int, int(1.5e3)), # Every X frames || steps sampled
    get_arg_dict("log-training-stats-every", int, int(10)), # Every X model update
    get_arg_dict("logdir-prefix", str, "./logs/") # Overrides the default one
]
args = generate_args(CUSTOM_ARGS)

# Additional PPO overrides
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)

# Load environment config
is_SAVi = str.__contains__(args.config_path, "savi")
if is_SAVi:
    env_config = get_savi_config(config_paths=args.config_path)
else:
    env_config = get_config(config_paths=args.config_path)
# endregion: Generating additional hyparams

# TODO: make it not require the environemtn instantiation

# Overriding some envs parametes from the .yaml env config
env_config.defrost()
env_config.NUM_PROCESSES = args.num_envs # Corresponds to number of envs, makes script startup faster for debugs
# env_config.USE_SYNC_VECENV = True
# env_config.USE_VECENV = False
# env_config.CONTINUOUS = args.env_continuous
## In caes video saving is enabled, make sure there is also the rgb videos
env_config.freeze()
# print(env_config)

# Environment instantiation
# envs = construct_envs(env_config, get_env_class(env_config.ENV_NAME))
# single_observation_space = envs.observation_spaces[0]
# single_action_space = envs.action_spaces[0]

# single_observation_space, single_action_space

# # Loading pretrained agent
# import models
# from models import ActorCritic, Perceiver_GWT_GWWM_ActorCritic

In [6]:
DATASET_DIR_PATH = f"ppo_gru_dset_2022_09_21__1000000_STEPS"

In [23]:
ep_filenames = os.listdir(DATASET_DIR_PATH)
len(ep_filenames)
ep_filename = np.random.choice(ep_filenames)

ep_filepath = os.path.join(DATASET_DIR_PATH, ep_filename)
with open(ep_filepath, "rb") as f:
    ep_data_dict = cpkl.load(f)
ep_data_dict.keys()

for k, v in ep_data_dict.items():
    if isinstance(v, dict):
        print(f"{k} dict:")
        for kk, vv in v.items():
            print(f"\t {kk} -> {np.shape(vv)}; type: {type(vv)}")
    elif isinstance(v, list):
        print(f"{k} -> {np.shape(v)}; type: list")
    elif isinstance(v, int):
        print(f"{k} -> {v}; type: int")

obs_list dict:
	 rgb -> (34, 128, 128, 3); type: <class 'list'>
	 audiogoal -> (34, 2, 16000); type: <class 'list'>
	 spectrogram -> (34, 65, 26, 2); type: <class 'list'>
action_list -> (34, 1); type: list
done_list -> (34,); type: list
reward_list -> (34,); type: list
info_list -> (34,); type: list
ep_length -> 34; type: int


In [94]:
from torch.utils.data import IterableDataset, DataLoader

## Shape of the dave ep_data_dict:, for reference
# obs_list dict:
# 	 rgb -> (94, 128, 128, 3)
# 	 audiogoal -> (94, 2, 16000)
# 	 spectrogram -> (94, 65, 26, 2)
# action_list -> (94, 1)
# done_list -> (94,)
# reward_list -> (94,)
# info_list -> (94,)
# ep_length -> 94

class BCIterableDataset(IterableDataset):
    def __init__(self, dataset_path, batch_length, seed=111):
        self.seed = seed
        self.batch_length = batch_length
        self.dataset_path = dataset_path

        # Read episode filenames in the dataset path
        self.ep_filenames = os.listdir(dataset_path)
        print(f"Initialized IterDset with {len(self.ep_filenames)} episodes.")
    
    def __iter__(self):
        batch_length = self.batch_length
        while True:
            # Sample epsiode data until there is enough for one trajectory
            # Hardcoded for now, make flexible later
            # Done later to recover the 
            obs_list = {
                "rgb": np.zeros([batch_length, 128, 128, 3]),
                "audiogoal": np.zeros([batch_length, 2, 16000]),
                "spectrogram": np.zeros([batch_length, 65, 26, 2])
            }
            action_list, reward_list, done_list = \
                np.zeros([batch_length, 1]), \
                np.zeros([batch_length, 1]), \
                np.zeros([batch_length, 1])
            
            ssf = 0 # Step affected so far
            while ssf < batch_length:
                idx = th.randint(len(self.ep_filenames), ())
                print(f"Sampled traj idx: {idx}")
                ep_filename = self.ep_filenames[idx]
                ep_filepath = os.path.join(DATASET_DIR_PATH, ep_filename)
                with open(ep_filepath, "rb") as f:
                    edd = cpkl.load(f)

                # Append the data to the bathc trjectory
                rs = batch_length - ssf # Reamining steps
                horizon = ssf + min(rs, edd["ep_length"])

                for k, v in edd["obs_list"].items():
                    obs_list[k][ssf:horizon] = v[:rs]
                action_list[ssf:horizon] = edd["action_list"][:rs]
                reward_list[ssf:horizon] = np.array(edd["reward_list"][:rs])[:, None]
                done_list[ssf:horizon] = np.array(edd["done_list"][:rs])[:, None]

                ssf += edd["ep_length"]

                if ssf >= self.batch_length:
                    break

            yield obs_list, action_list, reward_list, done_list
    
def make_dataloader(dataset_path, batch_size, batch_length, seed=111, num_workers=4):
    def worker_init_fn(worker_id):
        # worker_seed = th.initial_seed() % (2 ** 32)
        worker_seed = 133754134 + worker_id

        random.seed(worker_seed)
        np.random.seed(worker_seed)

    th_seed_gen = th.Generator()
    th_seed_gen.manual_seed(133754134 + seed)

    dloader = iter(
        DataLoader(
            BCIterableDataset(
                dataset_path=dataset_path, batch_length=batch_length),
                batch_size=batch_size, num_workers=num_workers,
                worker_init_fn=worker_init_fn, generator=th_seed_gen
            )
    )

    return dloader

In [99]:

dloader = make_dataloader(DATASET_DIR_PATH, batch_size=1, batch_length=30)
for _ in range(2):
    obs_batch, action_batch, reward_batch, done_batch = next(dloader)

# obs_batch["rgb"].shape, obs_batch["spectrogram"].shape
# done_batch[0][..., 0].tolist()
# reward_batch[0][..., 0].tolist()
# obs_batch["rgb"][0][0, 0, 0]
# action_batch[0][..., 0].tolist()

Initialized IterDset with 4000 episodes.
Sampled traj idx: 74
Sampled traj idx: 2985
Sampled traj idx: 1244
Sampled traj idx: 3641
Sampled traj idx: 739
Sampled traj idx: 3929
Sampled traj idx: 301


Sampled traj idx: 3497
Sampled traj idx: 872
Sampled traj idx: 766
Sampled traj idx: 3926
