In [2]:
# Notebook support or argpase
import sys; sys.argv=['']; del sys

In [3]:
import os
import cv2
import time
import random
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import apex

from collections import deque
from torchinfo import summary

import tools
from configurator import generate_args, get_arg_dict
from th_logger import TBXLogger as TBLogger

# Env deps: Soundspaces and Habitat
from habitat.datasets import make_dataset
from ss_baselines.av_nav.config import get_config
from ss_baselines.savi.config.default import get_config as get_savi_config
from ss_baselines.common.env_utils import construct_envs
from ss_baselines.common.environments import get_env_class
from ss_baselines.common.utils import images_to_video_with_audio

# Custom ActorCritic agent for PPO
from models import ActorCritic, Perceiver_GWT_GWWM_ActorCritic

# Dataset utils
from torch.utils.data import IterableDataset, DataLoader
import compress_pickle as cpkl

# region: Generating additional hyparams
CUSTOM_ARGS = [
    # General hyper parameters
    get_arg_dict("seed", int, 111),
    get_arg_dict("total-steps", int, 1_000_000),
    
    # Behavior cloning gexperiment config
    get_arg_dict("dataset-path", str, "SAVI_Oracle_Dataset_v0"),

    # SS env config
    get_arg_dict("config-path", str, "env_configs/savi/savi_ss1.yaml"),

    # PPO Hyper parameters
    get_arg_dict("num-envs", int, 10), # Number of parallel envs. 10 by default
    get_arg_dict("num-steps", int, 150), # For each env, how many steps are collected to form PPO Agent rollout.
    get_arg_dict("num-minibatches", int, 1), # Number of mini-batches the rollout data is split into to make the updates
    get_arg_dict("update-epochs", int, 4), # Number of gradient step for the policy and value networks
    get_arg_dict("gamma", float, 0.99),
    get_arg_dict("gae-lambda", float, 0.95),
    get_arg_dict("norm-adv", bool, True, metatype="bool"),
    get_arg_dict("clip-coef", float, 0.1), # Surrogate loss clipping coefficient
    get_arg_dict("clip-vloss", bool, True, metatype="bool"),
    get_arg_dict("ent-coef", float, 0.0), # Entropy loss coef; 0.2 in SS baselines
    get_arg_dict("vf-coef", float, 0.5), # Value loss coefficient
    get_arg_dict("max-grad-norm", float, 0.5),
    get_arg_dict("target-kl", float, None),
    get_arg_dict("lr", float, 2.5e-4), # Learning rate
    get_arg_dict("optim-wd", float, 0), # weight decay for adam optim
    ## Agent network params
    get_arg_dict("agent-type", str, "ss-default", metatype="choice",
        choices=["ss-default", "perceiver-gwt-gwwm"]),
    get_arg_dict("hidden-size", int, 512), # Size of the visual / audio features and RNN hidden states 
    ## Perceiver / PerceiverIO params: TODO: num_latnets, latent_dim, etc...
    get_arg_dict("pgwt-latent-type", str, "randn", metatype="choice",
        choices=["randn", "zeros"]), # Depth of the Perceiver
    get_arg_dict("pgwt-latent-learned", bool, True, metatype="bool"),
    get_arg_dict("pgwt-depth", int, 1), # Depth of the Perceiver
    get_arg_dict("pgwt-num-latents", int, 8),
    get_arg_dict("pgwt-latent-dim", int, 64),
    get_arg_dict("pgwt-cross-heads", int, 1),
    get_arg_dict("pgwt-latent-heads", int, 4),
    get_arg_dict("pgwt-cross-dim-head", int, 64),
    get_arg_dict("pgwt-latent-dim-head", int, 64),
    get_arg_dict("pgwt-weight-tie-layers", bool, False, metatype="bool"),
    get_arg_dict("pgwt-ff", bool, False, metatype="bool"),
    get_arg_dict("pgwt-num-freq-bands", int, 6),
    get_arg_dict("pgwt-max-freq", int, 10.),
    get_arg_dict("pgwt-use-sa", bool, False, metatype="bool"),
    ## Peceiver Modality Embedding related
    get_arg_dict("pgwt-mod-embed", int, 0), # Learnable modality embeddings
    ## Additional modalities
    get_arg_dict("pgwt-ca-prev-latents", bool, False, metatype="bool"), # if True, passes the prev latent to CA as KV input data

    ## Special BC
    get_arg_dict("prev-actions", bool, False, metatype="bool"),
    get_arg_dict("burn-in", int, 0), # Steps used to init the latent state for RNN component
    get_arg_dict("batch-chunk-length", int, 0), # For gradient accumulation
    get_arg_dict("dataset-ce-weights", bool, True, metatype="bool"), # If True, will read CEL weights based on action dist. from the 'dataset_statistics.bz2' file.
    get_arg_dict("ce-weights", float, None, metatype="list"), # Weights for the Cross Entropy loss

    ## SSL Support
    get_arg_dict("obs-center", bool, False, metatype="bool"), # Centers the rgb_observations' range to [-0.5,0.5]
    get_arg_dict("ssl-tasks", str, None, metatype="list"), # Expects something like ["rec-rgb-vis", "rec-depth", "rec-spectr"]
    get_arg_dict("ssl-task-coefs", float, None, metatype="list"), # For each ssl-task, specifies the loss coeff. during computation

    # Eval protocol
    get_arg_dict("eval", bool, False, metatype="bool"),
    get_arg_dict("eval-every", int, int(1.5e4)), # Every X frames || steps sampled
    get_arg_dict("eval-n-episodes", int, 5),

    # Logging params
    # NOTE: While supported, video logging is expensive because the RGB generation in the
    # envs hogs a lot of GPU, especially with multiple envs 
    get_arg_dict("save-videos", bool, False, metatype="bool"),
    get_arg_dict("save-model", bool, True, metatype="bool"),
    get_arg_dict("log-sampling-stats-every", int, int(1.5e3)), # Every X frames || steps sampled
    get_arg_dict("log-training-stats-every", int, int(10)), # Every X model update
    get_arg_dict("logdir-prefix", str, "./logs/") # Overrides the default one
]
args = generate_args(CUSTOM_ARGS)
# endregion: Generating additional hyparams

# Load environment config
is_SAVi = str.__contains__(args.config_path, "savi")
if is_SAVi:
    env_config = get_savi_config(config_paths=args.config_path)
else:
    env_config = get_config(config_paths=args.config_path)

# Additional PPO overrides
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)

# Gradient accumulation support
if args.batch_chunk_length == 0:
    args.batch_chunk_length = args.num_envs

# Seeding
random.seed(args.seed)
np.random.seed(args.seed)
th.manual_seed(args.seed)
th.cuda.manual_seed_all(args.seed)
th.backends.cudnn.deterministic = args.torch_deterministic
# th.backends.cudnn.benchmark = args.cudnn_benchmark

# Set device as GPU
device = tools.get_device(args) if (not args.cpu and th.cuda.is_available()) else th.device("cpu")

# Overriding some envs parametes from the .yaml env config
env_config.defrost()
## Override default seed
env_config.SEED = env_config.TASK_CONFIG.SEED = env_config.TASK_CONFIG.SIMULATOR.SEED = args.seed

env_config.TASK_CONFIG.SIMULATOR.USE_RENDERED_OBSERVATIONS = False
# For smoother video, set CONTINUOUS_VIEW_CHANGE to True, and get the additional frames in obs_dict["intermediate"]
env_config.TASK_CONFIG.SIMULATOR.CONTINUOUS_VIEW_CHANGE = False

env_config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.WIDTH = 256
env_config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.HEIGHT = 256
env_config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH = 256
env_config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.HEIGHT = 256

# NOTE: using less environments for eval to save up system memory -> run more experiment at the same time
env_config.NUM_PROCESSES = 1 # Corresponds to number of envs, makes script startup faster for debugs
# env_config.CONTINUOUS = args.env_continuous
## In caes video saving is enabled, make sure there is also the rgb videos
agent_extra_rgb = False
if args.save_videos:
    # For RGB video sensors
    if "RGB_SENSOR" not in env_config.SENSORS:
        env_config.SENSORS.append("RGB_SENSOR")
        # Indicates to the agent that RGB obs should not be used as observational inputs
        agent_extra_rgb = True
    # For Waveform to generate audio over the videos
    if "AUDIOGOAL_SENSOR" not in env_config.TASK_CONFIG.TASK.SENSORS:
        env_config.TASK_CONFIG.TASK.SENSORS.append("AUDIOGOAL_SENSOR")
# Add support for TOP_DOWN_MAP
# NOTE: it seems to induce "'DummySimulator' object has no attribute 'pathfinder'" error
# If top down map really needed, probably have to run the env without pre-rendered observations ?
# env_config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")

env_config.freeze()

# Environment instantiation
if args.eval:
    # In case there is need for eval, instantiate some environments
    envs = construct_envs(env_config, get_env_class(env_config.ENV_NAME))
    single_observation_space = envs.observation_spaces[0]
    single_action_space = envs.action_spaces[0]
else:
    # Otherwise, just use dummy obs. and act. spaces for agent structure init
    from gym import spaces
    single_action_space = spaces.Discrete(4)
    single_observation_space = spaces.Dict({
        # "rgb": spaces.Box(shape=[128,128,3], low=0, high=255, dtype=np.uint8),
        # "depth": spaces.Box(shape=[128,128,1], low=0, high=255, dtype=np.uint8),
        "audiogoal": spaces.Box(shape=[2,16000], low=-3.4028235e+38, high=3.4028235e+38, dtype=np.float32),
        "spectrogram": spaces.Box(shape=[65,26,2], low=-3.4028235e+38, high=3.4028235e+38, dtype=np.float32)
    })

# Override the observation space for "rgb" and "depth" from (256,256,C) to (128,128,C)
from gym import spaces
if "RGB_SENSOR" in env_config.SENSORS:
    single_observation_space["rgb"] = spaces.Box(shape=[128,128,3], low=0, high=255, dtype=np.uint8)
if "DEPTH_SENSOR" in env_config.SENSORS:
    single_observation_space["rgb"] = spaces.Box(shape=[128,128,1], low=0, high=255, dtype=np.uint8)


In [4]:
single_observation_space

Dict(audiogoal:Box(-3.4028235e+38, 3.4028235e+38, (2, 16000), float32), spectrogram:Box(-3.4028235e+38, 3.4028235e+38, (65, 26, 2), float32), rgb:Box(0, 255, (128, 128, 1), uint8))

In [None]:
class VisualCNN4(nn.Module):
    def __init__(self, obsjervation_space, output_size, extra_rgb, obs_center=False):
        super().__init__()
        if "rgb" in observation_space.spaces and not extra_rgb:
            self._n_input_rgb = observation_space.spaces["rgb"].shape[2]
        else:
            self._n_input_rgb = 0

        if "depth" in observation_space.spaces:
            self._n_input_depth = observation_space.spaces["depth"].shape[2]
        else:
            self._n_input_depth = 0

        self.output_size = output_size
        self.obs_center = obs_center

        # kernel size for different CNN layers
        self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)]

        # strides for different CNN layers
        self._cnn_layers_stride = [(4, 4), (2, 2), (2, 2)]

        if self._n_input_rgb > 0:
            cnn_dims = np.array(
                observation_space.spaces["rgb"].shape[:2], dtype=np.float32
            )
        elif self._n_input_depth > 0:
            cnn_dims = np.array(
                observation_space.spaces["depth"].shape[:2], dtype=np.float32
            )

        if self.is_blind:
            self.cnn = nn.Sequential()
        else:
            for kernel_size, stride in zip(
                self._cnn_layers_kernel_size, self._cnn_layers_stride
            ):
                self.cnn_dims = cnn_dims = conv_output_dim(
                    dimension=cnn_dims,
                    padding=np.array([0, 0], dtype=np.float32),
                    dilation=np.array([1, 1], dtype=np.float32),
                    kernel_size=np.array(kernel_size, dtype=np.float32),
                    stride=np.array(stride, dtype=np.float32),
                )

            self.cnn = nn.Sequential(
                nn.Conv2d(
                    in_channels=self._n_input_rgb + self._n_input_depth,
                    out_channels=32,
                    kernel_size=self._cnn_layers_kernel_size[0],
                    stride=self._cnn_layers_stride[0],
                ),
                nn.ReLU(True),
                nn.Conv2d(
                    in_channels=32,
                    out_channels=64,
                    kernel_size=self._cnn_layers_kernel_size[1],
                    stride=self._cnn_layers_stride[1],
                ),
                nn.ReLU(True),
                nn.Conv2d(
                    in_channels=64,
                    out_channels=64,
                    kernel_size=self._cnn_layers_kernel_size[2],
                    stride=self._cnn_layers_stride[2],
                ),
                #  nn.ReLU(True),
                Flatten(),
                nn.Linear(64 * cnn_dims[0] * cnn_dims[1], 2048),
                nn.ReLU(True),
            )
            self.linear = nn.Sequential(
                nn.Linear(2048, output_size),
                nn.ReLU(True)
            )

        layer_init(self.cnn)

    @property
    def is_blind(self):
        return self._n_input_rgb + self._n_input_depth == 0

    def forward(self, observations):
        cnn_input = []
        if self._n_input_rgb > 0:
            rgb_observations = observations["rgb"]
            # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH]
            rgb_observations = rgb_observations.permute(0, 3, 1, 2)
            rgb_observations = rgb_observations / 255.0  # normalize RGB
            if self.obs_center:
                rgb_observations -= 0.5
            cnn_input.append(rgb_observations)

        if self._n_input_depth > 0:
            depth_observations = observations["depth"]
            # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH]
            depth_observations = depth_observations.permute(0, 3, 1, 2)
            if self.obs_center:
                depth_observations -= 0.5
            cnn_input.append(depth_observations)

        cnn_input = th.cat(cnn_input, dim=1)

        cnn_output = self.cnn(cnn_input)
        features = self.linear(cnn_output)

        return features, cnn_output

encoder = VisualCNN4(single_observation_space, args.hidden_size, extra_rgb=False, obs_center=args.obs_center)