# General configuration

In [1]:
# Notebook support or argpase
import sys; sys.argv=['']; del sys
%matplotlib inline

In [2]:
# General config related
import os
import umap
import copy
import time
import random
import rsatoolbox
import numpy as np
import matplotlib as mpl
import compress_pickle as cpkl

# Custom imports
from configurator import get_arg_dict, generate_args

# ML deps
import apex
import torch
import torch as th
import torch.nn as nn
from torch import Tensor
import torch.nn.functional as F

# Plottign deps
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec # TODO: move to the top
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Env config related
from ss_baselines.av_nav.config import get_config
from ss_baselines.savi.config.default import get_config as get_savi_config
from ss_baselines.common.env_utils import construct_envs
from ss_baselines.common.environments import get_env_class
from ss_baselines.common.utils import plot_top_down_map

# Dataset utils
from torch.utils.data import IterableDataset, DataLoader
import compress_pickle as cpkl

# Loading pretrained agent
import tools
import models
from models import ActorCritic, Perceiver_GWT_GWWM_ActorCritic

mpl.rcParams["figure.facecolor"] = "white"
mpl.rcParams["axes.facecolor"] = "white"
mpl.rcParams["savefig.facecolor"] = "white" 

  from .autonotebook import tqdm as notebook_tqdm
2023-06-14 14:28:04.921747: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-14 14:28:04.987921: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-06-14 14:28:04.987944: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [3]:
# region: Generating additional hyparams
CUSTOM_ARGS = [
    # General hyper parameters
    get_arg_dict("seed", int, 111),
    get_arg_dict("total-steps", int, 1_000_000),
    
    # Behavior cloning gexperiment config
    get_arg_dict("dataset-path", str, "SAVI_Oracle_Dataset_v0"),

    # SS env config
    get_arg_dict("config-path", str, "env_configs/savi/savi_ss1.yaml"),

    # PPO Hyper parameters
    get_arg_dict("num-envs", int, 10), # Number of parallel envs. 10 by default
    get_arg_dict("num-steps", int, 150), # For each env, how many steps are collected to form PPO Agent rollout.
    get_arg_dict("num-minibatches", int, 1), # Number of mini-batches the rollout data is split into to make the updates
    get_arg_dict("update-epochs", int, 4), # Number of gradient step for the policy and value networks
    get_arg_dict("gamma", float, 0.99),
    get_arg_dict("gae-lambda", float, 0.95),
    get_arg_dict("norm-adv", bool, True, metatype="bool"),
    get_arg_dict("clip-coef", float, 0.1), # Surrogate loss clipping coefficient
    get_arg_dict("clip-vloss", bool, True, metatype="bool"),
    get_arg_dict("ent-coef", float, 0.0), # Entropy loss coef; 0.2 in SS baselines
    get_arg_dict("vf-coef", float, 0.5), # Value loss coefficient
    get_arg_dict("max-grad-norm", float, 0.5),
    get_arg_dict("target-kl", float, None),
    get_arg_dict("lr", float, 2.5e-4), # Learning rate
    get_arg_dict("optim-wd", float, 0), # weight decay for adam optim
    ## Agent network params
    get_arg_dict("agent-type", str, "ss-default", metatype="choice",
        choices=["ss-default", "perceiver-gwt-gwwm"]),
    get_arg_dict("hidden-size", int, 512), # Size of the visual / audio features and RNN hidden states 
    ## Perceiver / PerceiverIO params: TODO: num_latnets, latent_dim, etc...
    get_arg_dict("pgwt-latent-type", str, "randn", metatype="choice",
        choices=["randn", "zeros"]), # Depth of the Perceiver
    get_arg_dict("pgwt-latent-learned", bool, True, metatype="bool"),
    get_arg_dict("pgwt-depth", int, 1), # Depth of the Perceiver
    get_arg_dict("pgwt-num-latents", int, 8),
    get_arg_dict("pgwt-latent-dim", int, 64),
    get_arg_dict("pgwt-cross-heads", int, 1),
    get_arg_dict("pgwt-latent-heads", int, 4),
    get_arg_dict("pgwt-cross-dim-head", int, 64),
    get_arg_dict("pgwt-latent-dim-head", int, 64),
    get_arg_dict("pgwt-weight-tie-layers", bool, False, metatype="bool"),
    get_arg_dict("pgwt-ff", bool, False, metatype="bool"),
    get_arg_dict("pgwt-num-freq-bands", int, 6),
    get_arg_dict("pgwt-max-freq", int, 10.),
    get_arg_dict("pgwt-use-sa", bool, False, metatype="bool"),
    ## Peceiver Modality Embedding related
    get_arg_dict("pgwt-mod-embed", int, 0), # Learnable modality embeddings
    ## Additional modalities
    get_arg_dict("pgwt-ca-prev-latents", bool, False, metatype="bool"), # if True, passes the prev latent to CA as KV input data

    ## Special BC
    get_arg_dict("prev-actions", bool, False, metatype="bool"),
    get_arg_dict("burn-in", int, 0), # Steps used to init the latent state for RNN component
    get_arg_dict("batch-chunk-length", int, 0), # For gradient accumulation
    get_arg_dict("dataset-ce-weights", bool, True, metatype="bool"), # If True, will read CEL weights based on action dist. from the 'dataset_statistics.bz2' file.
    get_arg_dict("ce-weights", float, None, metatype="list"), # Weights for the Cross Entropy loss

    # Eval protocol
    get_arg_dict("eval", bool, True, metatype="bool"),
    get_arg_dict("eval-every", int, int(1.5e4)), # Every X frames || steps sampled
    get_arg_dict("eval-n-episodes", int, 5),

    # Logging params
    # NOTE: While supported, video logging is expensive because the RGB generation in the
    # envs hogs a lot of GPU, especially with multiple envs 
    get_arg_dict("save-videos", bool, False, metatype="bool"),
    get_arg_dict("save-model", bool, True, metatype="bool"),
    get_arg_dict("log-sampling-stats-every", int, int(1.5e3)), # Every X frames || steps sampled
    get_arg_dict("log-training-stats-every", int, int(10)), # Every X model update
    get_arg_dict("logdir-prefix", str, "./logs/") # Overrides the default one
]
args = generate_args(CUSTOM_ARGS)

# Additional PPO overrides
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)

# Load environment config
is_SAVi = str.__contains__(args.config_path, "savi")
if is_SAVi:
    env_config = get_savi_config(config_paths=args.config_path)
else:
    env_config = get_config(config_paths=args.config_path)
# endregion: Generating additional hyparams

# Load environment config
is_SAVi = str.__contains__(args.config_path, "savi")
if is_SAVi:
    env_config = get_savi_config(config_paths=args.config_path)
else:
    env_config = get_config(config_paths=args.config_path)

# Additional PPO overrides
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)

# Gradient accumulation support
if args.batch_chunk_length == 0:
    args.batch_chunk_length = args.num_envs

# Seeding
random.seed(args.seed)
np.random.seed(args.seed)
th.manual_seed(args.seed)
th.cuda.manual_seed_all(args.seed)
th.backends.cudnn.deterministic = args.torch_deterministic
# th.backends.cudnn.benchmark = args.cudnn_benchmark

# Set device as GPU
device = tools.get_device(args) if (not args.cpu and th.cuda.is_available()) else th.device("cpu")

# Overriding some envs parametes from the .yaml env config
env_config.defrost()
## Override default seed
env_config.SEED = env_config.TASK_CONFIG.SEED = env_config.TASK_CONFIG.SIMULATOR.SEED = args.seed

env_config.TASK_CONFIG.SIMULATOR.USE_RENDERED_OBSERVATIONS = False
# For smoother video, set CONTINUOUS_VIEW_CHANGE to True, and get the additional frames in obs_dict["intermediate"]
env_config.TASK_CONFIG.SIMULATOR.CONTINUOUS_VIEW_CHANGE = False

env_config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.WIDTH = 256
env_config.TASK_CONFIG.SIMULATOR.RGB_SENSOR.HEIGHT = 256
env_config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.WIDTH = 256
env_config.TASK_CONFIG.SIMULATOR.DEPTH_SENSOR.HEIGHT = 256

# NOTE: using less environments for eval to save up system memory -> run more experiment at the same time
env_config.NUM_PROCESSES = 1 # Corresponds to number of envs, makes script startup faster for debugs
# env_config.CONTINUOUS = args.env_continuous
## In caes video saving is enabled, make sure there is also the rgb videos
agent_extra_rgb = False
if args.save_videos:
    # For RGB video sensors
    if "RGB_SENSOR" not in env_config.SENSORS:
        env_config.SENSORS.append("RGB_SENSOR")
        # Indicates to the agent that RGB obs should not be used as observational inputs
        agent_extra_rgb = True
    # For Waveform to generate audio over the videos
    if "AUDIOGOAL_SENSOR" not in env_config.TASK_CONFIG.TASK.SENSORS:
        env_config.TASK_CONFIG.TASK.SENSORS.append("AUDIOGOAL_SENSOR")
# Add support for TOP_DOWN_MAP
# NOTE: it seems to induce "'DummySimulator' object has no attribute 'pathfinder'" error
# If top down map really needed, probably have to run the env without pre-rendered observations ?
# env_config.TASK_CONFIG.TASK.MEASUREMENTS.append("TOP_DOWN_MAP")

env_config.freeze()

## Compute action coefficient for CEL of BC
dataset_stats_filepath = f"{args.dataset_path}/dataset_statistics.bz2"
# Override dataset statistics if the file already exists
if os.path.exists(dataset_stats_filepath):
    with open(dataset_stats_filepath, "rb") as f:
        dataset_statistics = cpkl.load(f)
    
N_CATEGORIES = len(dataset_statistics["category_counts"].keys())
N_SCENES = len(dataset_statistics["scene_counts"].keys())

In [4]:
# Fake environment instantiation to create the agent models later on

# TODO: add adaptive creation of single_observation_space so that RGB and RGBD based variants
# can be evaluated at thet same time
from gym import spaces
single_action_space = spaces.Discrete(4)
single_observation_space = spaces.Dict({
    "rgb": spaces.Box(shape=[128,128,3], low=0, high=255, dtype=np.uint8),
    "depth": spaces.Box(shape=[128,128,1], low=0, high=255, dtype=np.uint8),
    "audiogoal": spaces.Box(shape=[2,16000], low=-3.4028235e+38, high=3.4028235e+38, dtype=np.float32),
    "spectrogram": spaces.Box(shape=[65,26,2], low=-3.4028235e+38, high=3.4028235e+38, dtype=np.float32)
})
# single_observation_space = envs.observation_spaces[0]
# single_action_space = envs.action_spaces[0]

single_observation_space, single_action_space

(Dict(audiogoal:Box(-3.4028235e+38, 3.4028235e+38, (2, 16000), float32), depth:Box(0, 255, (128, 128, 1), uint8), rgb:Box(0, 255, (128, 128, 3), uint8), spectrogram:Box(-3.4028235e+38, 3.4028235e+38, (65, 26, 2), float32)),
 Discrete(4))

# Probing Analysis Config.

In [5]:
# Define the target of probing
## "category" -> how easy to predict category based on the learned features / inputs
## "scene" -> how easy to predict scene based on the learned features / inputs
PROBING_TARGETS = {
    "category": {"n_classes": 21},
    # "scene": {"n_classes": 10}, # TODO: make this based on the dataset ?
}

# Define which fields of an agent to use for the probes
PROBING_INPUTS = ["state_encoder", "audio_encoder.cnn.7", "visual_encoder.cnn.7"]

# Define the probing "subjects", i.e. which pre-trained BC networks to probe
# also stores info. related to the path to the weights, and pretty names for the plots
MODEL_VARIANTS_TO_STATEDICT_PATH = {
    # region: Random baselines
    # Random GRU Baseline
    "ppo_gru__random": {
        "pretty_name": "GRU Random",
        "state_dict_path": ""
    },
    # Random PGWT Baseline
    # "ppo_pgwt__random": {
    #     "pretty_name": "TransRNN Random",
    #     "state_dict_path": ""
    # },
    # endregion: Random baselines


    # region: SAVi BC variants; trained using RGBD + Spectrogram ; trained up to 5M steps
    # "ppo_bc__rgbd_spectro__gru__SAVi": {
    #     "pretty_name": "[SAVi BC] PPO GRU | RGB Spectro",
    #     "state_dict_path": "/home/rousslan/random/rl/exp-logs/ss-hab-bc/"
    #         "ppo_bc__savi_ss1_rgbd_spectro__gru_seed_111__2023_06_10_16_05_39_999286.musashi"
    #         "/models/ppo_agent.4995001.ckpt.pth"
    # },
    "ppo_bc__rgbd_spectro__pgwt__SAVi": {
        "pretty_name": "[SAVi BC] PPO TransRNN | RGB Spectro",
        "state_dict_path": "/home/rousslan/random/rl/exp-logs/ss-hab-bc/"
            "ppo_bc__savi_ss1_rgbd__spectro__pgwt__dpth_1_nlats_8_latdim_64_noSA_CAnheads_1_SAnheads_4_modembed_0_CAprevlats_seed_111__2023_06_10_16_05_37_098602.musashi"
            "/models/ppo_agent.4995001.ckpt.pth"
    },
    # endregion: SAVi BC variants; trained using RGBD + Spectrogram ; trained up to 5M steps
}

# Indexable instantiated agent models (Torch agents)
MODEL_VARIANTS_TO_AGENTMODEL = {}

for k, v in MODEL_VARIANTS_TO_STATEDICT_PATH.items():
    args_copy = copy.copy(args)
    # Override args depending on the model in use
    if k.__contains__("gru"):
        agent = ActorCritic(single_observation_space, single_action_space, args.hidden_size, extra_rgb=False,
            analysis_layers=models.GRU_ACTOR_CRITIC_DEFAULT_ANALYSIS_LAYER_NAMES)
    elif k.__contains__("pgwt"):
        agent = Perceiver_GWT_GWWM_ActorCritic(single_observation_space, single_action_space, args, extra_rgb=False,
            analysis_layers=models.PGWT_GWWM_ACTOR_CRITIC_DEFAULT_ANALYSIS_LAYER_NAMES + ["state_encoder.ca.mha"])

    agent.eval()
    # Load the model weights
    # TODO: add map location device to use CPU only ?
    if v["state_dict_path"] != "":
        agent_state_dict = th.load(v["state_dict_path"])
        agent.load_state_dict(agent_state_dict)
    
    MODEL_VARIANTS_TO_AGENTMODEL[k] = agent.to(device)

In [6]:
MODEL_VARIANTS_TO_AGENTMODEL.keys() # ['ppo_gru__random', 'ppo_bc__rgbd_spectro__pgwt__SAVi']
MODEL_VARIANTS_TO_AGENTMODEL["ppo_gru__random"]

ActorCritic(
  (visual_encoder): VisualCNN(
    (cnn): Sequential(
      (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
      (5): Flatten()
      (6): Linear(in_features=2304, out_features=512, bias=True)
      (7): ReLU(inplace=True)
    )
  )
  (audio_encoder): AudioCNN(
    (cnn): Sequential(
      (0): Conv2d(2, 32, kernel_size=(5, 5), stride=(2, 2))
      (1): ReLU(inplace=True)
      (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
      (3): ReLU(inplace=True)
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): Flatten()
      (6): Linear(in_features=2496, out_features=512, bias=True)
      (7): ReLU(inplace=True)
    )
  )
  (state_encoder): RNNStateEncoder(
    (rnn): GRU(1024, 512)
  )
  (action_distribution): CategoricalNet(
    (linear): Linear(in_features=512, 

# Probe network definitions

In [7]:
# TODO
# - consider adding the reference to the network this probe is in charge of ?
class GenericProbeNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        # input_dim: shape of the 
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim, bias=False)
    
    def forward(self, x):
        return self.linear(x)

In [8]:
# Instantiating probes
PROBES = {}
for probe_target_name, probe_target_info in PROBING_TARGETS.items():
    if probe_target_name not in PROBES.keys():
        PROBES[probe_target_name] = {}
    for probe_input in PROBING_INPUTS: # NOTE: maybe switch order with the MODEL_VARIANTS ???

        if probe_input not in PROBES[probe_target_name].keys():
            PROBES[probe_target_name][probe_input] = {}

        for agent_variant in MODEL_VARIANTS_TO_AGENTMODEL.keys():
            probe_input_dim = 512 # TODO: make this adapt to the actual shape of the input's probe
            probe_output_dim = probe_target_info["n_classes"]

            probe_network = GenericProbeNetwork(probe_input_dim, probe_output_dim).to(device)
            if not args.cpu and th.cuda.is_available():
                # TODO: GPU only. But what if we still want to use the default pytorch one ?
                optimizer = apex.optimizers.FusedAdam(probe_network.parameters(), lr=args.lr, eps=1e-5, weight_decay=args.optim_wd)
            else:
                optimizer = th.optim.Adam(probe_network.parameters(), lr=args.lr, eps=1e-5, weight_decay=args.optim_wd)

            PROBES[probe_target_name][probe_input][agent_variant] = {
                "probe_network": probe_network,
                "probe_optimizer": optimizer
            }

In [9]:
PROBES

{'category': {'state_encoder': {'ppo_gru__random': {'probe_network': GenericProbeNetwork(
      (linear): Linear(in_features=512, out_features=21, bias=False)
    ),
    'probe_optimizer': FusedAdam (
    Parameter Group 0
        betas: (0.9, 0.999)
        bias_correction: True
        eps: 1e-05
        lr: 0.00025
        weight_decay: 0
    )},
   'ppo_bc__rgbd_spectro__pgwt__SAVi': {'probe_network': GenericProbeNetwork(
      (linear): Linear(in_features=512, out_features=21, bias=False)
    ),
    'probe_optimizer': FusedAdam (
    Parameter Group 0
        betas: (0.9, 0.999)
        bias_correction: True
        eps: 1e-05
        lr: 0.00025
        weight_decay: 0
    )}},
  'audio_encoder.cnn.7': {'ppo_gru__random': {'probe_network': GenericProbeNetwork(
      (linear): Linear(in_features=512, out_features=21, bias=False)
    ),
    'probe_optimizer': FusedAdam (
    Parameter Group 0
        betas: (0.9, 0.999)
        bias_correction: True
        eps: 1e-05
        lr: 0

In [10]:
# Checking the instantiate PROBES
PROBES["category"]["state_encoder"]

{'ppo_gru__random': {'probe_network': GenericProbeNetwork(
    (linear): Linear(in_features=512, out_features=21, bias=False)
  ),
  'probe_optimizer': FusedAdam (
  Parameter Group 0
      betas: (0.9, 0.999)
      bias_correction: True
      eps: 1e-05
      lr: 0.00025
      weight_decay: 0
  )},
 'ppo_bc__rgbd_spectro__pgwt__SAVi': {'probe_network': GenericProbeNetwork(
    (linear): Linear(in_features=512, out_features=21, bias=False)
  ),
  'probe_optimizer': FusedAdam (
  Parameter Group 0
      betas: (0.9, 0.999)
      bias_correction: True
      eps: 1e-05
      lr: 0.00025
      weight_decay: 0
  )}}

# Loading dataset to be used for probe training

In [11]:
# NOTE / TODO: probe training might benefti from using different batch sizes ?

# This variant will fill each batch trajectory using cat.ed episode data
# There is no empty step in this batch
class BCIterableDataset3(IterableDataset):
    def __init__(self, dataset_path, batch_length, seed=111):
        self.seed = seed
        self.batch_length = batch_length
        self.dataset_path = dataset_path

        # Read episode filenames in the dataset path
        self.ep_filenames = os.listdir(dataset_path)
        if "dataset_statistics.bz2" in self.ep_filenames:
            self.ep_filenames.remove("dataset_statistics.bz2")
        
        print(f"Initialized IterDset with {len(self.ep_filenames)} episodes.")
    
    def __iter__(self):
        batch_length = self.batch_length
        while True:
            # region: Sample episode data until there is enough to fill the hole batch traj
            obs_list = {
                "depth": np.zeros([batch_length, 128, 128]), # NOTE: data was recorded using (128, 128), but ideally we should have (128, 128, 1)
                "rgb": np.zeros([batch_length, 128, 128, 3]),
                "audiogoal": np.zeros([batch_length, 2, 16000]),
                "spectrogram": np.zeros([batch_length, 65, 26, 2]),
                "category": np.zeros([batch_length, 21]),
                "pointgoal_with_gps_compass": np.zeros([batch_length, 2]),
                "pose": np.zeros([batch_length, 4]),
            }

            action_list, reward_list, done_list = \
                np.zeros([batch_length, 1]), \
                np.zeros([batch_length, 1]), \
                np.zeros([batch_length, 1])
            ssf = 0 # Step affected so far
            while ssf < batch_length:
                idx = th.randint(len(self.ep_filenames), ())
                ep_filename = self.ep_filenames[idx]
                ep_filepath = os.path.join(self.dataset_path, ep_filename)
                with open(ep_filepath, "rb") as f:
                    edd = cpkl.load(f)
                # print(f"Sampled traj idx: {idx} ; Len: {edd['ep_length']}")
                
                # Append the data to the bathc trjectory
                rs = batch_length - ssf # Reamining steps
                horizon = ssf + min(rs, edd["ep_length"])
                for k, v in edd["obs_list"].items():
                    obs_list[k][ssf:horizon] = v[:rs]
                action_list[ssf:horizon] = np.array(edd["action_list"][:rs])[:, None]
                reward_list[ssf:horizon] = np.array(edd["reward_list"][:rs])[:, None]
                done_list[ssf:horizon] = np.array(edd["done_list"][:rs])[:, None]

                ssf += edd["ep_length"]

                if ssf >= self.batch_length:
                    break

            # Adjust shape of "depth" to be [T, H, W, 1] instead of [T, H, W]
            obs_list["depth"] = obs_list["depth"][:, :, :, None]
            
            # TODO: add enough data about the scene to be able to do the probing
            # Since the dataset statistics can be accessed here too, we can generate
            # the vector of targets for the scene
            
            yield obs_list, action_list, reward_list, done_list
            # endregion: Sample episode data until there is enough to fill the hole batch traj
    
def make_dataloader3(dataset_path, batch_size, batch_length, seed=111, num_workers=2):
    def worker_init_fn(worker_id):
        # worker_seed = th.initial_seed() % (2 ** 32)
        worker_seed = 133754134 + worker_id

        random.seed(worker_seed)
        np.random.seed(worker_seed)

    th_seed_gen = th.Generator()
    th_seed_gen.manual_seed(133754134 + seed)

    dloader = iter(
        DataLoader(
            BCIterableDataset3(
                dataset_path=dataset_path, batch_length=batch_length),
                batch_size=batch_size, num_workers=num_workers,
                worker_init_fn=worker_init_fn, generator=th_seed_gen
            )
    )

    return dloader

# Instantiate the dataset object
dloader = make_dataloader3(args.dataset_path, batch_size=args.num_envs,
                            batch_length=args.num_steps, seed=args.seed, num_workers=8)

# TODO: consider pre-computing CE weights for categories / scenes to balance the CE loss ?

Initialized IterDset with 29001 episodes.


In [12]:
# Testing iteration over one batch of data for a given variant

# region: Load batch data, and related pre-processing
obs_list, action_list, _, done_list = \
    [ {k: th.Tensor(v).float().to(device) for k,v in b.items()} if isinstance(b, dict) else 
        b.float().to(device) for b in next(dloader)]

# NOTE: RGB are normalized in the VisualCNN module
# PPO networks expect input of shape T,B, ... so doing the permutation first
# then flatten over T x B dimensions. The RNN will reshape it as necessary
for k, v in obs_list.items():
    if k in ["rgb", "spectrogram", "depth"]:
        obs_list[k] = v.permute(1, 0, 2, 3, 4) # BTCHW -> TBCHW
        obs_list[k] = obs_list[k].reshape(-1, *obs_list[k].shape[-3:])
    elif k in ["audiogoal"]:
        obs_list[k] = v.permute(1, 0, 2, 3) # BTCL -> TBCL
        obs_list[k] = obs_list[k].reshape(-1, *obs_list[k].shape[-2:])
    else:
        # TODO: handle other fields like "category", etc...
        pass

action_list = action_list.permute(1, 0, 2)
done_list = done_list.permute(1, 0, 2)
mask_list = 1. - done_list

prev_actions_list = th.zeros_like(action_list)
prev_actions_list[1:] = action_list[:-1]
prev_actions_list = F.one_hot(prev_actions_list.long()[:, :, 0], num_classes=4).float()
prev_actions_list[0] = prev_actions_list[0] * 0.0

# Finally, also flatten across T x B, let the RNN do the unflattening if needs be
action_list = action_list.reshape(-1) # Because it is used for the target later
done_list = done_list.reshape(-1, 1)
mask_list = mask_list.reshape(-1, 1)
prev_actions_list = prev_actions_list.reshape(-1, 1)
# endregion: Load batch data, and related pre-processing


In [13]:
obs_list.keys() # ['depth', 'rgb', 'audiogoal', 'spectrogram', 'category', 'pointgoal_with_gps_compass', 'pose']
obs_list["depth"].shape # torch.Size([1500, 128, 128, 1])
done_list.shape # torch.Size([1500, 1])
mask_list.shape # torch.Size([1500, 1])

torch.Size([1500, 1])

In [14]:
# For each "agent_variant", iterate

agent_variant = "ppo_gru__random"
agent = MODEL_VARIANTS_TO_AGENTMODEL["ppo_gru__random"]

# This will be used to recompute the rnn_hidden_states when computiong the new action logprobs
if agent_variant.__contains__("gru"):
    rnn_hidden_state = th.zeros((1, args.batch_chunk_length, args.hidden_size), device=device)
elif agent_variant.__contains__("pgwt"):
    rnn_hidden_state = agent.state_encoder.latents.repeat(args.batch_chunk_length, 1, 1)
else:
    raise NotImplementedError(f"Unsupported agent-type:{agent_variant}")

# for agent_variant, agent_model in MODEL_VARIANTS_TO_AGENTMODEL.items():
#     if agent_variant.__contains__("gru"):
#         AGENT_RNN_HIDDEN_STATE[agent_variant] = th.zeros((1, args.num_envs, args.hidden_size), device=device)
#     elif agent_variant.__contains__("pgwt"):
#         AGENT_RNN_HIDDEN_STATE[agent_variant] = agent_model.state_encoder.latents.clone()


In [15]:
rnn_hidden_state.shape # torch.Size([1, 10, 512])

torch.Size([1, 10, 512])

In [16]:
# Forward pass through the networks
agent_outputs = agent.act(obs_list, rnn_hidden_state, masks=mask_list) #, prev_actions=prev_actions_list)

In [17]:
args.num_envs, args.num_steps

(10, 150)

In [18]:
category_list = obs_list["category"].reshape(-1, 21).argmax(axis=1)

In [19]:
list(agent._features.keys())

['visual_encoder.cnn.0',
 'visual_encoder.cnn.1',
 'visual_encoder.cnn.2',
 'visual_encoder.cnn.3',
 'visual_encoder.cnn.4',
 'visual_encoder.cnn.5',
 'visual_encoder.cnn.6',
 'visual_encoder.cnn.7',
 'audio_encoder.cnn.0',
 'audio_encoder.cnn.1',
 'audio_encoder.cnn.2',
 'audio_encoder.cnn.3',
 'audio_encoder.cnn.4',
 'audio_encoder.cnn.5',
 'audio_encoder.cnn.6',
 'audio_encoder.cnn.7',
 'action_distribution.linear',
 'critic.fc',
 'state_encoder']

In [20]:
agent._features["state_encoder"] # tuple of shape 2
agent._features["state_encoder"][0].shape # torch.Size([1500, 512]), accumulated T * B, state_features
agent._features["state_encoder"][1].shape # torch.Size([1, 10, 512]), state_features for the next step T+1, unused in our case

torch.Size([1, 10, 512])

In [21]:
agent._features["audio_encoder.cnn.7"].shape

torch.Size([1500, 512])

In [22]:
category__state_encoder__probe = GenericProbeNetwork(512, 21).to(device)

In [23]:
category__state_encoder__logits = category__state_encoder__probe(agent._features["state_encoder"][0])

In [24]:
category__state_encoder__cel = F.cross_entropy(category__state_encoder__logits, category_list)
category__state_encoder__cel

tensor(3.0344, device='cuda:0', grad_fn=<NllLossBackward0>)

In [25]:
AGENT_FEATURES__RAW = {k: {} for k in MODEL_VARIANTS_TO_AGENTMODEL.keys()}
AGENT_RNN_HIDDEN_STATE = {}

In [26]:
# TODO
# Could we maybe pre-cmpute all the foreward passes for all the model variants once,
# then we don't have to re-run those in case we train for more than one epoch ?
# Although even the "epoch" is not a real epoch, since we don't have the guarantee
# that all the steps are sampled exactly once.

# Training start
start_time = time.time()

# NOTE: this time total-steps means how many time .backward() is called on each probe
# One epoch would be equual to "DATASET_SIZE" in steps / (num_envs * num_steps)
n_updates = 0
total_updates = int(500_000 / args.num_envs / args.num_steps) # How many updates expected in total for one epoch ?
print(f"Expected number of updates: {total_updates}")

steps_per_update = args.num_envs * args.num_steps
for global_step in range(1, args.total_steps + steps_per_update, steps_per_update):
    # Load batch data
    obs_list, action_list, _, done_list = \
        [ {k: th.Tensor(v).float().to(device) for k,v in b.items()} if isinstance(b, dict) else 
            b.float().to(device) for b in next(dloader)]
    
    # NOTE: RGB are normalized in the VisualCNN module
    # PPO networks expect input of shape T,B, ... so doing the permutation first
    # then flatten over T x B dimensions. The RNN will reshape it as necessary
    for k, v in obs_list.items():
        if k in ["rgb", "spectrogram", "depth"]:
            obs_list[k] = v.permute(1, 0, 2, 3, 4) # BTCHW -> TBCHW
            obs_list[k] = obs_list[k].reshape(-1, *obs_list[k].shape[-3:])
        elif k in ["audiogoal"]:
            obs_list[k] = v.permute(1, 0, 2, 3) # BTCL -> TBCL
            obs_list[k] = obs_list[k].reshape(-1, *obs_list[k].shape[-2:])
        else:
            # TODO: handle other fields like "category", etc...
            pass
    
    action_list = action_list.permute(1, 0, 2) # TODO: probably uneeded for probing ?
    done_list = done_list.permute(1, 0, 2)
    mask_list = 1. - done_list
    
    prev_actions_list = th.zeros_like(action_list)
    prev_actions_list[1:] = action_list[:-1]
    prev_actions_list = F.one_hot(prev_actions_list.long()[:, :, 0], num_classes=4).float()
    prev_actions_list[0] = prev_actions_list[0] * 0.0

    # Finally, also flatten across T x B, let the RNN do the unflattening if needs be
    action_list = action_list.reshape(-1) # TODO: probably uneeded for probing ?
    done_list = done_list.reshape(-1, 1)
    mask_list = mask_list.reshape(-1, 1)
    prev_actions_list = prev_actions_list.reshape(-1, 1)

    # Holder for the probe losses and accs.
    probe_losses_dict = {}

    # For each "agent_variant", iterate
    # TODO: once we have more than "category", it is more efficient to iterate over the agent first,
    # Do the forward pass, then iterate over the probe target (category, scene, etc...) then
    for agent_variant, agent in MODEL_VARIANTS_TO_AGENTMODEL.items():
        # Forward pass with the agent model to collect the intermediate features
        # Stores in agent_features
        # This will be used to recompute the rnn_hidden_states when computiong the new action logprobs
        if agent_variant.__contains__("gru"):
            rnn_hidden_state = th.zeros((1, args.batch_chunk_length, args.hidden_size), device=device)
        elif agent_variant.__contains__("pgwt"):
            rnn_hidden_state = agent.state_encoder.latents.repeat(args.batch_chunk_length, 1, 1)
        else:
            raise NotImplementedError(f"Unsupported agent-type:{agent_variant}")
        
        with th.no_grad():
            agent_outputs = agent.act(obs_list, rnn_hidden_state, masks=mask_list) #, prev_actions=prev_actions_list)
        
        for probe_target_name, probe_target_dict in PROBES.items():
            # probe_target_name: "category", "scene", more generally the targeted concept of the probing
            # probe_target_dict: { "state_encoder": {"agent_variant": Torch Model} }
            for probe_target_input_name, agent_variant_probes in probe_target_dict.items():
                # probe_target_input_name: the input of the probe, such as "state_encoder", and other
                # agent_variant_probes: dict taht holds {"agent_variant": Torch Model}
                
                probe = agent_variant_probes[agent_variant]["probe_network"]
                probe_optim = agent_variant_probes[agent_variant]["probe_optimizer"]
                probe_optim.zero_grad()

                # Forward pass of the probe network itself
                if probe_target_input_name == "state_encoder":
                    probe_inputs = agent._features["state_encoder"][0]
                elif probe_target_input_name.__contains__("visual_encoder") or \
                     probe_target_input_name.__contains__("audio_encoder"):
                    probe_inputs = agent._features[probe_target_input_name]
                else:
                    raise NotImplementedError(f"Attempt to use {probe_target_input_name} as probe input.")
                
                probe_logits = probe(probe_inputs)
                
                # TODO: generate probe_targets
                if probe_target_name == "category":
                    probe_targets = obs_list["category"].reshape(-1, 21).argmax(axis=1)
                else:
                    raise NotImplementedError(f"Unsupported probe target: {probe_target_name}.")
                
                # Loss
                # TODO: CE weights depending on the probing target and such
                probe_ce_loss = F.cross_entropy(probe_logits, probe_targets)

                probe_ce_loss.backward()
                probe_optim.step()

                # Store the loss valuesl for logging later
                metric_stem = f"{probe_target_name}|{probe_target_input_name}__{agent_variant}"
                loss_name = f"{metric_stem}__probe_loss"
                probe_losses_dict[loss_name] = probe_ce_loss.item()
                acc_name = f"{metric_stem}__probe_acc"
                probe_losses_dict[acc_name] = (F.softmax(probe_logits, dim=1).argmax(1) == probe_targets).float().mean()

    # Tracking the number of NN updates (for all probes)
    n_updates += 1

    if n_updates % 10 == 0:
        for k, v in probe_losses_dict.items():
            print(f"{k}: {round(v,3)}")
        print("")

        break

Expected number of updates: 333


category|state_encoder__ppo_gru__random__probe_loss: 2.994
category|audio_encoder.cnn.7__ppo_gru__random__probe_loss: 3.044
category|visual_encoder.cnn.7__ppo_gru__random__probe_loss: 2.951
category|state_encoder__ppo_bc__rgbd_spectro__pgwt__SAVi__probe_loss: 2.889
category|audio_encoder.cnn.7__ppo_bc__rgbd_spectro__pgwt__SAVi__probe_loss: 3.029
category|visual_encoder.cnn.7__ppo_bc__rgbd_spectro__pgwt__SAVi__probe_loss: 3.01



In [52]:
probe_logits.shape
probe_targets.shape
(F.softmax(probe_logits, dim=1).argmax(1) == probe_targets).float().mean()

tensor(0.1007, device='cuda:0')