In [1]:
# Notebook support or argpase
import sys; sys.argv=['']; del sys

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt

import os
import time
import random
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import apex

from collections import deque
from torchinfo import summary

import tools
from configurator import generate_args, get_arg_dict
from th_logger import TBXLogger as TBLogger

# Env deps: Soundspaces and Habitat
from habitat.datasets import make_dataset
from ss_baselines.av_nav.config import get_config
from ss_baselines.savi.config.default import get_config as get_savi_config
from ss_baselines.common.env_utils import construct_envs
from ss_baselines.common.environments import get_env_class
from ss_baselines.common.utils import images_to_video_with_audio

# Custom ActorCritic agent for PPO
from models import ActorCritic, ActorCritic_DeepEthologyVirtualRodent, \
    Perceiver_GWT_GWWM_ActorCritic, Perceiver_GWT_AttGRU_ActorCritic

# Dataset utils
from torch.utils.data import IterableDataset, DataLoader
import compress_pickle as cpkl

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# This variant will sample one single (sub) seuqence of an episode as a trajectoyr
# and add zero paddign to the rest
class BCIterableDataset3(IterableDataset):
    def __init__(self, dataset_path, batch_length, seed=111):
        self.seed = seed
        self.batch_length = batch_length
        self.dataset_path = dataset_path

        # Read episode filenames in the dataset path
        self.ep_filenames = os.listdir(dataset_path)
        if "dataset_statistics.bz2" in self.ep_filenames:
            self.ep_filenames.remove("dataset_statistics.bz2")
        
        print(f"Initialized IterDset with {len(self.ep_filenames)} episodes.")
    
    def __iter__(self):
        batch_length = self.batch_length
        while True:
            # Sample one episode file
            idx = th.randint(len(self.ep_filenames), ())
            ep_filename = self.ep_filenames[idx]
            ep_filepath = os.path.join(self.dataset_path, ep_filename)
            with open(ep_filepath, "rb") as f:
                edd = cpkl.load(f)
            is_success = edd["info_list"][-1]["success"]
            last_action = edd["action_list"][-1]
            print(f"Sampled traj idx: {idx}; Length: {edd['ep_length']}; Success: {is_success}; Last act: {last_action}")
            
            # edd_start = th.randint(0, edd["ep_length"]-20, ()).item() # Sample start of sub-squence for this episode
            # NOTE: the following sampling might not leverage long-term trajectories well.
            edd_start = 0 # Given that we have short trajectories, just start at the beginning anyway
            edd_end = min(edd_start + batch_length, edd["ep_length"])
            subseq_len = edd_end - edd_start
            
            horizon = subseq_len

            obs_list = {
                k: np.zeros([batch_length, *np.shape(v)[1:]]) for k,v in edd["obs_list"].items()
            }
            action_list, reward_list, done_list, depad_mask_list = \
                np.zeros([batch_length, 1]), \
                np.zeros([batch_length, 1]), \
                np.zeros([batch_length, 1]), \
                np.zeros((batch_length, 1)).astype(np.bool8)

            for k, v in edd["obs_list"].items():
                obs_list[k][:horizon] = v[edd_start:edd_end]
                # Adjust the shape of obs_list["depth"] from (128,128 -> (128, 128, 1))
            obs_list["depth"] = obs_list["depth"][:, :, None]
            action_list[:horizon] = np.array(edd["action_list"][edd_start:edd_end])[:, None]
            reward_list[:horizon] = np.array(edd["reward_list"][edd_start:edd_end])[:, None]
            done_list[:horizon] = np.array(edd["done_list"][edd_start:edd_end])[:, None]
            depad_mask_list[:horizon] = True

            yield obs_list, action_list, reward_list, done_list, depad_mask_list
    
def make_dataloader3(dataset_path, batch_size, batch_length, seed=111, num_workers=2):
    def worker_init_fn(worker_id):
        # worker_seed = th.initial_seed() % (2 ** 32)
        worker_seed = 133754134 + worker_id

        random.seed(worker_seed)
        np.random.seed(worker_seed)

    th_seed_gen = th.Generator()
    th_seed_gen.manual_seed(133754134 + seed)

    dloader = iter(
        DataLoader(
            BCIterableDataset3(
                dataset_path=dataset_path, batch_length=batch_length),
                batch_size=batch_size, num_workers=num_workers,
                worker_init_fn=worker_init_fn, generator=th_seed_gen
            )
    )

    return dloader

# Tensorize current observation, store to rollout data
def tensorize_obs_dict(obs, device, observations=None, rollout_step=None):
    obs_th = {}
    for obs_field, _ in obs[0].items():
        v_th = th.Tensor(np.array([step_obs[obs_field] for step_obs in obs], dtype=np.float32)).to(device)
        # in SS1.0, the dcepth observations comes as [B, 128, 128, 1, 1], so fix that
        if obs_field == "depth" and v_th.dim() == 5:
            v_th = v_th.squeeze(-1)
        obs_th[obs_field] = v_th
        # Special case when doing the rollout, also stores the 
        if observations is not None:
            observations[obs_field][rollout_step] = v_th
    
    return obs_th

In [4]:
# region: Generating additional hyparams
CUSTOM_ARGS = [
    # General hyper parameters
    get_arg_dict("seed", int, 111),
    get_arg_dict("total-steps", int, 1_000_000),
    
    # Behavior cloning gexperiment config
    get_arg_dict("dataset-path", str, "SAVI_Oracle_Dataset_v0"),

    # SS env config
    get_arg_dict("config-path", str, "env_configs/savi/savi_ss1.yaml"),

    # PPO Hyper parameters
    get_arg_dict("num-envs", int, 10), # Number of parallel envs. 10 by default
    get_arg_dict("num-steps", int, 150), # For each env, how many steps are collected to form PPO Agent rollout.
    get_arg_dict("num-minibatches", int, 1), # Number of mini-batches the rollout data is split into to make the updates
    get_arg_dict("update-epochs", int, 4), # Number of gradient step for the policy and value networks
    get_arg_dict("gamma", float, 0.99),
    get_arg_dict("gae-lambda", float, 0.95),
    get_arg_dict("norm-adv", bool, True, metatype="bool"),
    get_arg_dict("clip-coef", float, 0.1), # Surrogate loss clipping coefficient
    get_arg_dict("clip-vloss", bool, True, metatype="bool"),
    get_arg_dict("ent-coef", float, 0.0), # Entropy loss coef; 0.2 in SS baselines
    get_arg_dict("vf-coef", float, 0.5), # Value loss coefficient
    get_arg_dict("max-grad-norm", float, 0.5),
    get_arg_dict("target-kl", float, None),
    get_arg_dict("lr", float, 2.5e-4), # Learning rate
    get_arg_dict("optim-wd", float, 0), # weight decay for adam optim
    ## Agent network params
    get_arg_dict("agent-type", str, "ss-default", metatype="choice",
        choices=["ss-default", "deep-etho",
                    "perceiver-gwt-gwwm", "perceiver-gwt-attgru"]),
    get_arg_dict("hidden-size", int, 512), # Size of the visual / audio features and RNN hidden states 
    ## Perceiver / PerceiverIO params: TODO: num_latnets, latent_dim, etc...
    get_arg_dict("pgwt-latent-type", str, "randn", metatype="choice",
        choices=["randn", "zeros"]), # Depth of the Perceiver
    get_arg_dict("pgwt-latent-learned", bool, True, metatype="bool"),
    get_arg_dict("pgwt-depth", int, 1), # Depth of the Perceiver
    get_arg_dict("pgwt-num-latents", int, 8),
    get_arg_dict("pgwt-latent-dim", int, 64),
    get_arg_dict("pgwt-cross-heads", int, 1),
    get_arg_dict("pgwt-latent-heads", int, 4),
    get_arg_dict("pgwt-cross-dim-head", int, 64),
    get_arg_dict("pgwt-latent-dim-head", int, 64),
    get_arg_dict("pgwt-weight-tie-layers", bool, False, metatype="bool"),
    get_arg_dict("pgwt-ff", bool, False, metatype="bool"),
    get_arg_dict("pgwt-num-freq-bands", int, 6),
    get_arg_dict("pgwt-max-freq", int, 10.),
    get_arg_dict("pgwt-use-sa", bool, False, metatype="bool"),
    ## Peceiver Modality Embedding related
    get_arg_dict("pgwt-mod-embed", int, 0), # Learnable modality embeddings
    ## Additional modalities
    get_arg_dict("pgwt-ca-prev-latents", bool, False, metatype="bool"), # if True, passes the prev latent to CA as KV input data

    ## Special BC
    get_arg_dict("prev-actions", bool, False, metatype="bool"),
    get_arg_dict("burn-in", int, 0), # Steps used to init the latent state for RNN component
    get_arg_dict("batch-chunk-length", int, 0), # For gradient accumulation
    get_arg_dict("ce-weights", float, None, metatype="list"), # Weights for the Cross Entropy loss
    get_arg_dict("dataset-ce-weights", bool, True, metatype="bool"),

    # Eval protocol
    get_arg_dict("eval", bool, True, metatype="bool"),
    get_arg_dict("eval-every", int, int(1.5e4)), # Every X frames || steps sampled
    get_arg_dict("eval-n-episodes", int, 5),

    # Logging params
    # NOTE: While supported, video logging is expensive because the RGB generation in the
    # envs hogs a lot of GPU, especially with multiple envs 
    get_arg_dict("save-videos", bool, False, metatype="bool"),
    get_arg_dict("save-model", bool, True, metatype="bool"),
    get_arg_dict("log-sampling-stats-every", int, int(1.5e3)), # Every X frames || steps sampled
    get_arg_dict("log-training-stats-every", int, int(10)), # Every X model update
    get_arg_dict("logdir-prefix", str, "./logs/") # Overrides the default one
]
args = generate_args(CUSTOM_ARGS)
# endregion: Generating additional hyparams

# Load environment config
is_SAVi = str.__contains__(args.config_path, "savi")
if is_SAVi:
    env_config = get_savi_config(config_paths=args.config_path)
else:
    env_config = get_config(config_paths=args.config_path)

# Additional PPO overrides
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)

# Gradient accumulation support
if args.batch_chunk_length == 0:
    args.batch_chunk_length = args.num_envs

# Experiment logger
tblogger = TBLogger(exp_name=args.exp_name, args=args)
print(f"# Logdir: {tblogger.logdir}")
should_log_training_stats = tools.Every(args.log_training_stats_every)
should_eval = tools.Every(args.eval_every)

# Seeding
random.seed(args.seed)
np.random.seed(args.seed)
th.manual_seed(args.seed)
th.cuda.manual_seed_all(args.seed)
th.backends.cudnn.deterministic = args.torch_deterministic
# th.backends.cudnn.benchmark = args.cudnn_benchmark

# Set device as GPU
device = tools.get_device(args) if (not args.cpu and th.cuda.is_available()) else th.device("cpu")

# Overriding some envs parametes from the .yaml env config
env_config.defrost()
# NOTE: using less environments for eval to save up system memory -> run more experiment at thte same time
env_config.NUM_PROCESSES = 2 # Corresponds to number of envs, makes script startup faster for debugs
# env_config.CONTINUOUS = args.env_continuous
## In caes video saving is enabled, make sure there is also the rgb videos
agent_extra_rgb = False
if args.save_videos:
    # For RGB video sensors
    if "RGB_SENSOR" not in env_config.SENSORS:
        env_config.SENSORS.append("RGB_SENSOR")
        # Indicates to the agent that RGB obs should not be used as observational inputs
        agent_extra_rgb = True
    # For Waveform to generate audio over the videos
    if "AUDIOGOAL_SENSOR" not in env_config.TASK_CONFIG.TASK.SENSORS:
        env_config.TASK_CONFIG.TASK.SENSORS.append("AUDIOGOAL_SENSOR")
env_config.freeze()
# print(env_config)

# Environment instantiation
envs = construct_envs(env_config, get_env_class(env_config.ENV_NAME))
single_observation_space = envs.observation_spaces[0]
single_action_space = envs.action_spaces[0]

2023-05-18 18:23:02.207129: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-18 18:23:04,295 Initializing dataset SemanticAudioNav
2023-05-18 18:23:04,313 Initializing dataset SemanticAudioNav


# Logdir: ./logs/_seed_111__2023_05_18_18_23_01_823608.musashi


2023-05-18 18:23:23,560 initializing sim SoundSpacesSim
2023-05-18 18:23:23,944 Initializing task SemanticAudioNav
2023-05-18 18:23:24,025 Initializing dataset SemanticAudioNav
2023-05-18 18:23:49,643 initializing sim SoundSpacesSim
2023-05-18 18:23:50,052 Initializing task SemanticAudioNav


In [8]:
# TODO: delete the envrionemtsn / find a more efficient method to do this

# TODO: make the ActorCritic components parameterizable through comand line ?
if args.agent_type == "ss-default":
    agent = ActorCritic(single_observation_space, single_action_space,
        args.hidden_size, extra_rgb=agent_extra_rgb, prev_actions=args.prev_actions).to(device)
elif args.agent_type == "perceiver-gwt-gwwm":
    agent = Perceiver_GWT_GWWM_ActorCritic(single_observation_space, single_action_space,
        args, extra_rgb=agent_extra_rgb).to(device)
elif args.agent_type == "perceiver-gwt-attgru":
    agent = Perceiver_GWT_AttGRU_ActorCritic(single_observation_space, single_action_space,
        args, extra_rgb=agent_extra_rgb).to(device)
elif args.agent_type == "deep-etho":
    raise NotImplementedError(f"Unsupported agent-type:{args.agent_type}")
    # TODO: support for storing the rnn_hidden_statse, so that the policy 
    # that takes in the 'core_modules' 's rnn hidden output can also work.
    agent = ActorCritic_DeepEthologyVirtualRodent(single_observation_space,
            single_action_space, 512).to(device)
else:
    raise NotImplementedError(f"Unsupported agent-type:{args.agent_type}")

if not args.cpu and th.cuda.is_available():
    # TODO: GPU only. But what if we still want to use the default pytorch one ?
    optimizer = apex.optimizers.FusedAdam(agent.parameters(), lr=args.lr, eps=1e-5, weight_decay=args.optim_wd)
else:
    optimizer = th.optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5, weight_decay=args.optim_wd)

optimizer.zero_grad()

# Dataset loading
dloader = make_dataloader3(args.dataset_path, batch_size=args.num_envs,
                            batch_length=args.num_steps, seed=args.seed)

## Compute action coefficient for CEL of BC
dataset_stats_filepath = f"{args.dataset_path}/dataset_statistics.bz2"
# Override dataset statistics if the file already exists
if os.path.exists(dataset_stats_filepath):
    with open(dataset_stats_filepath, "rb") as f:
        dataset_statistics = cpkl.load(f)

# Reads args.ce_weights if passed
ce_weights = args.ce_weights

# In case args.dataset_ce_weights is True,
# override the args.ce_weigths manual setting
if args.dataset_ce_weights:
    # TODO: make some assert on 1) the existence of the "dataset_statistics.bz2" file
    # and 2) that it contains the "action_cel_coefs" of proper dimension
    ce_weights = [dataset_statistics["action_cel_coefs"][a] for a in range(4)]

if ce_weights is not None:
    # TODO: assert it has same shape as action space.
    ce_weights = th.Tensor(ce_weights).to(device)

# Info logging
summary(agent)
print("")
print(agent)
print("")

Initialized IterDset with 2357 episodes.
Sampled traj idx: 786; Length: 14; Success: 1.0; Last act: 0
Sampled traj idx: 1740; Length: 15; Success: 1.0; Last act: 0

ActorCritic(
  (visual_encoder): VisualCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
      (5): Flatten()
      (6): Linear(in_features=2304, out_features=512, bias=True)
      (7): ReLU(inplace=True)
    )
  )
  (audio_encoder): AudioCNN(
    (cnn): Sequential(
      (0): Conv2d(2, 32, kernel_size=(5, 5), stride=(2, 2))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): Flatten()
      (6): Linear(in_features=2496, out_features=512, bias=True)
      (7): ReLU(inpl

Sampled traj idx: 2273; Length: 12; Success: 1.0; Last act: 0
Sampled traj idx: 153; Length: 18; Success: 1.0; Last act: 0
Sampled traj idx: 532; Length: 8; Success: 1.0; Last act: 0
Sampled traj idx: 725; Length: 21; Success: 1.0; Last act: 0
Sampled traj idx: 2309; Length: 11; Success: 1.0; Last act: 0
Sampled traj idx: 1671; Length: 13; Success: 1.0; Last act: 0
Sampled traj idx: 636; Length: 12; Success: 1.0; Last act: 0
Sampled traj idx: 104; Length: 15; Success: 1.0; Last act: 0
Sampled traj idx: 506; Length: 35; Success: 1.0; Last act: 0
Sampled traj idx: 2002; Length: 15; Success: 1.0; Last act: 0
Sampled traj idx: 1598; Length: 25; Success: 1.0; Last act: 0
Sampled traj idx: 76; Length: 11; Success: 1.0; Last act: 0
Sampled traj idx: 725; Length: 21; Success: 1.0; Last act: 0
Sampled traj idx: 291; Length: 37; Success: 1.0; Last act: 0
Sampled traj idx: 808; Length: 19; Success: 1.0; Last act: 0
Sampled traj idx: 1474; Length: 17; Success: 1.0; Last act: 0
Sampled traj idx: 17

In [9]:
ce_weights

tensor([4.6824, 0.3606, 2.1973, 1.7908], device='cuda:0')

In [None]:
obs_list, action_list, _, done_list, depad_mask_list = \
    [ {k: th.Tensor(v).float().to(device) for k,v in b.items()} if isinstance(b, dict) else 
        b.float().to(device) for b in next(dloader)] # NOTE this will not suport "audiogoal" waveform audio, only rgb / depth / spectrogram

In [None]:
action_list.shape, done_list.shape, depad_mask_list.shape, obs_list["rgb"].shape, obs_list["depth"].shape, obs_list["spectrogram"].shape, obs_list["audiogoal"].shape

In [None]:
from pprint import pprint
pprint(dataset_statistics)