In [2]:
# Notebook support or argpase
import sys; sys.argv=['']; del sys

In [3]:
# Custom PPO implementation with Soundspaces 2.0
# Borrows from 
## - CleanRL's PPO LSTM: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py
## - SoundSpaces AudioNav Baselines: https://github.com/facebookresearch/sound-spaces/tree/main/ss_baselines/av_nav

import time
import random
import numpy as np
import torch as th
import torch.nn as nn
from collections import deque, defaultdict

# import tools
from configurator import generate_args, get_arg_dict
from th_logger import TBXLogger as TBLogger

# Env deps: Soundspaces and Habitat
from habitat.datasets import make_dataset
from ss_baselines.savi.config.default import get_config
from ss_baselines.common.env_utils import construct_envs
from ss_baselines.common.environments import get_env_class

# Custom ActorCritic agent for PPO
from models import ActorCritic

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# region: Generating additional hyparams
CUSTOM_ARGS = [
    # General hyepr parameters
    get_arg_dict("seed", int, 111),
    get_arg_dict("total-steps", int, 10_000_000),
    
    # SS env config
    get_arg_dict("config-path", str, "env_configs/savi/savi.yaml"),

    # PPO Hyper parameters
    get_arg_dict("num-envs", int, 10), # Number of parallel envs. 10 by default
    get_arg_dict("num-steps", int, 150), # For each env, how many steps are collected to form PPO Agent rollout.
    get_arg_dict("num-minibatches", int, 1), # Number of mini-batches the rollout data is split into to make the updates
    get_arg_dict("update-epochs", int, 4), # Number of gradient step for the policy and value networks
    get_arg_dict("gamma", float, 0.99),
    get_arg_dict("gae-lambda", float, 0.95),
    get_arg_dict("norm-adv", bool, True, metatype="bool"),
    get_arg_dict("clip-coef", float, 0.1), # Surrogate loss clipping coefficient
    get_arg_dict("clip-vloss", bool, True, metatype="bool"),
    get_arg_dict("ent-coef", float, 0.2), # Entropy loss coef; 0.2 in SS baselines
    get_arg_dict("vf-coef", float, 0.5), # Value loss coefficient
    get_arg_dict("max-grad-norm", float, 0.5),
    get_arg_dict("target-kl", float, None),
    get_arg_dict("lr", float, 2.5e-4), # Learning rate
    ## Agent network params
    get_arg_dict("hidden-size", int, 512), # Size of the visual / audio features and RNN hidden states 

    # Logging params
    # TODO: Eval that has a separate environment and is called eval-every 100K steps to generate a single
    # video to disk / TB / Wandb ?
    get_arg_dict("save-videos", bool, False, metatype="bool"),
    get_arg_dict("save-model", bool, True, metatype="bool"),
    get_arg_dict("log-sampling-stats-every", int, int(1.5e4)), # Every X frames || steps sampled
    get_arg_dict("log-training-stats-every", int, int(10)), # Every X model update
    get_arg_dict("logdir-prefix", str, "./logs/"), # Overrides the default one
]
args = generate_args(CUSTOM_ARGS)

# Load environment config
env_config = get_config(config_paths=args.config_path)

# Overriding some envs parametes from the .yaml env config
env_config.defrost()
env_config.NUM_PROCESSES = 1 # Corresponds to number of envs, makes script startup faster for debugs
env_config.freeze()

# endregion: Generating additional hyparams

In [5]:
print(env_config)
# env_config.ENV_NAME
# env_config.TASK_CONFIG.DATASET

BASE_TASK_CONFIG_PATH: env_configs/savi/base_semantic_audiogoal.yaml
CHECKPOINT_FOLDER: data/models/output/data
CHECKPOINT_INTERVAL: 50
CMD_TRAILING_OPTS: []
CONTINUOUS: True
DEBUG: False
DISPLAY_RESOLUTION: 128
ENV_NAME: AudioNavRLEnv
EVAL:
  SPLIT: val
  USE_CKPT_CONFIG: True
EVAL_CKPT_PATH_DIR: data/models/output/data
EXTRA_RGB: False
LOG_FILE: data/models/output/train.log
LOG_INTERVAL: 10
MODEL_DIR: data/models/output
NUM_PROCESSES: 1
NUM_UPDATES: 20000
RL:
  DDPPO:
    backbone: custom_resnet18
    distrib_backend: GLOO
    num_recurrent_layers: 1
    pretrained: True
    pretrained_weights: data/models/savi/data/ckpt.XXX.pth
    reset_critic: False
    rnn_type: GRU
    sync_frac: 0.6
  DISTANCE_REWARD_SCALE: 1.0
  PPO:
    BELIEF_PREDICTOR:
      audio_only: False
      current_pred_only: False
      lr: 0.001
      normalize_category_distribution: False
      online_training: True
      train_encoder: False
      use_label_belief: True
      use_location_belief: True
      weig

In [7]:
# Environment instantiation
envs = construct_envs(env_config, get_env_class(env_config.ENV_NAME))
single_observation_space = envs.observation_spaces[0]
single_action_space = envs.action_spaces[0]

2023-05-11 10:09:38,935 Initializing dataset SemanticAudioNav
2023-05-11 10:09:38,949 Initializing dataset SemanticAudioNav


In [7]:
list(single_observation_space.keys())

['audiogoal', 'category', 'depth', 'pose', 'rgb', 'spectrogram']

In [9]:
single_observation_space["depth"]
single_observation_space["audiogoal"]

Box(-3.4028235e+38, 3.4028235e+38, (2, 16000), float32)

In [8]:
observations = envs.reset()

In [9]:
for obs_sensor, obs_value in observations[0].items():
    print(f"{obs_sensor} -> {obs_value.shape}")

depth -> (128, 128, 1, 1)
rgb -> (128, 128, 3)
audiogoal -> (2, 16000)
spectrogram -> (65, 26, 2)
category -> (21,)
pose -> (4,)


In [10]:
done = False
outputs = envs.step([0, 0])

In [11]:
outputs[0][0] # env_0, obs_dict
outputs[0][0]["depth"].shape

(128, 128, 1, 1)

In [12]:
list(single_observation_space.keys())

['audiogoal', 'category', 'depth', 'pose', 'rgb', 'spectrogram']