In [1]:
import torch
import torch as th
import torch.nn as nn
from torch import Tensor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
            )

    def forward(self, x: Tensor) -> Tensor:
        return self.model(x)

from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self._features

In [3]:
# from perceiver_gwt_gwwm import Perceiver_GWT_GWWM
from perceiver_gwt_gwwm import SelfAttention, CrossAttention
class Perceiver_GWT_GWWM(nn.Module):
    def __init__(
        self,
        *,
        input_dim,
        latent_type = "randn",
        latent_learned = True,
        num_latents = 8,
        latent_dim = 64,
        cross_heads = 1, # LucidRains's implm. uses 1 by default
        latent_heads = 4, # LucidRains's implm. uses 8 by default
        # cross_dim_head = 64,
        # latent_dim_head = 64,
        # self_per_cross_attn = 1, # Number of self attention blocks per cross attn.
        # Modality embeddings realted
        hidden_size = 512, # Dim of the visual / audio encoder outputs
        mod_embed = 0, # Dimensio of learned modality embeddings
        use_sa = False,
        ca_prev_latents = False
    ):
        super().__init__()
        self.input_dim = input_dim
        self.num_latents = num_latents # N
        self.latent_dim = latent_dim # D
        self.latent_type = latent_type
        self.latent_learned = latent_learned

        self.mod_embed = mod_embed
        self.hidden_size = hidden_size
        self.use_sa = use_sa
        self.ca_prev_latents = ca_prev_latents

        # Cross Attention
        if self.ca_prev_latents:
            assert num_latents * latent_dim == input_dim, \
                f"input_dim=={input_dim} and num_latents * latent_dim=={num_latents * latent_dim} must match"
        self.ca = CrossAttention(latent_dim, input_dim + mod_embed,
            n_heads=cross_heads, skip_q=True) # If not skipping, usually blows
        # Self Attention
        if self.use_sa:
            self.sa = SelfAttention(latent_dim, n_heads=latent_heads)
        # self.decoder = CrossAttention(self.h_size, self.s_size, skip_q=True)

        # Modality embedding
        if self.mod_embed:
            self.modality_embeddings = nn.Parameter(torch.randn(1, 2 + int(ca_prev_latents), mod_embed))
        
        # Latent vector, supposedly equivalent to an RNN's hidden state
        if latent_type == "randn":
            self.latents = torch.randn(1, num_latents, latent_dim)
            # As per original paper
            with th.no_grad():
                self.latents.normal_(0.0, 0.02).clamp_(-2.0,2.0)
        elif latent_type == "zeros":
            self.latents = torch.zeros(1, num_latents, latent_dim)
        
        self.latents = nn.Parameter(self.latents, requires_grad=latent_learned)

        # Hooking


    def seq_forward(self, data, prev_latents, masks):
        # TODO: a more optimal method to process sequences of same length together ?
        x_list, latents_list = [], []

        T_B, feat_dim = data.shape
        B = prev_latents.shape[0]
        T = T_B // B # TODO: assert that T * B == T_B exactly
        latents = prev_latents.clone()

        data = data.reshape(T, B, feat_dim)
        masks = masks.reshape(T, B, 1)

        for t in range(T):
            x, latents = self.single_forward(data[t], latents, masks[t])

            x_list.append(x.clone())
            latents_list.append(latents.clone())
        
        x_list = th.stack(x_list, dim=0).flatten(start_dim=0, end_dim=1) # [B * T, feat_dim]
        latents_list = th.stack(latents_list, dim=0).flatten(start_dim=0, end_dim=1) # [B * T, num_latents, latent_dim]

        return x_list, latents_list

    def single_forward(self, data, prev_latents, masks):
        b = data.shape[0] # Batch size

        if data.dim() == 2:
            data = data.reshape(b, 2, self.hidden_size) # [B,1024] -> [B,2,512]
        
        if self.ca_prev_latents:
            # NOTE: flattened latents dim must equal dim of audio, vision features
            data = th.cat([data, prev_latents.flatten(start_dim=1)[:, None, :]], dim=1) # [B, 2, H] and [B, 1, L * D] -> [B, 3, H == L * D]

        if self.mod_embed:
            data = th.cat([data, self.modality_embeddings.repeat(b, 1, 1)], dim=2)
        
        # If the current step is the start of a new episode,
        # the the mask will contain 0
        prev_latents = masks[:, :, None] * prev_latents + \
            (1. - masks[:, :, None]) * self.latents.repeat(b, 1, 1)

        x = prev_latents
        
        # Cross Attention
        x, _ = self.ca(x, data) # x: [B, N * D], x_weights: [B, ???]

        # Self Attention
        if self.use_sa:
            x, _ = self.sa(x) # x: [B, N * D]
        
        return x.flatten(start_dim=1), x

    def forward(self, data, prev_latents, masks):
        """
            - data: observation features [NUM_ENVS, feat_dim] or [NUM_ENVS, NUM_STEPS, feat_dim]
            - prev_latents: previous latents [B, num_latents, latent_dim]
            - masks: not Perceiver mask, but end-of-episode signaling mask
                - shape of [NUM_ENVS, 1] if single step forward
                - shape of [NUM_ENVS, NUM_STEPS, 1] if sequence forward
        """
        if data.size(0) == prev_latents.size(0):
            return self.single_forward(data, prev_latents, masks)
        else:
            return self.seq_forward(data, prev_latents, masks)

state_encoder = Perceiver_GWT_GWWM(
            input_dim = 512,
            # latent_type = "rand"
            # latent_learned = config.pgwt_latent_learned,
            num_latents = 8,
            latent_dim = 64,
            cross_heads = 1,
            latent_heads = 4,
            # cross_dim_head = config.pgwt_cross_dim_head, # Default: 64
            # latent_dim_head = config.pgwt_latent_dim_head, # Default: 64
            # self_per_cross_attn = 1,
            # Modality embedding related
            mod_embed = 0,
            hidden_size = 512,
            use_sa = False,
            ca_prev_latents = False
        ); # list(state_encoder.named_modules())

# for i, nm in enumerate(state_encoder.named_modules()):
#     if i: # Skips the first '' that tcontains the whole model
#         print(nm) # tuple: (name, nn.Module)
    # break 
# print(state_encoder)
_state_encoder = FeatureExtractor(state_encoder, layers=[
    # CrossAttentio block
    # "ca",
        "ca.ln_q",
        "ca.ln_kv",
        # "ca.mha", # attention_value, attention_weight
        # "ca.ff_self",
            "ca.ff_self.0", # Layer Norm
            "ca.ff_self.1", # Linear
            "ca.ff_self.2", # GELU
        #     "ca.ff_self.3"  # Linear
    ])

prev_latents = state_encoder.latents.clone().repeat(3, 1, 1)
obs_feats = th.zeros(3, 512 * 2)
# obs_feats.shape, prev_latents.shape
state_feats, next_latents = state_encoder(obs_feats, prev_latents, th.ones([3, 1]))
print(""); print("")
for k, v in _state_encoder._features.items():
    # print(k, len(v))
    # break
    # print(f"{k} -> {[vv.shape for vv in v]}")
    print(f"{k} -> {v.shape if isinstance(v, th.Tensor) else [vv.shape for vv in v]}")
# state_feats.shape, next_latents.shape



ca.ln_q -> torch.Size([3, 8, 64])
ca.ln_kv -> torch.Size([3, 2, 512])
ca.ff_self.0 -> torch.Size([3, 8, 64])
ca.ff_self.1 -> torch.Size([3, 8, 64])
ca.ff_self.2 -> torch.Size([3, 8, 64])


## Prototype the hooking of a complete agent, distinguish between the various parts

In [4]:
# PPO GRU Agent
from ss_baselines.av_nav.config import get_config
from ss_baselines.common.env_utils import construct_envs
from ss_baselines.common.environments import get_env_class
from models import ActorCritic, Perceiver_GWT_GWWM_ActorCritic

config_path = "env_configs/audiogoal_depth_nocont.yaml"
env_config = get_config(config_paths=config_path)

# Overriding some envs parametes from the .yaml env config
env_config.defrost()
env_config.NUM_PROCESSES = 1 # Corresponds to number of envs, makes script startup faster for debugs
env_config.USE_SYNC_VECENV = True
# env_config.USE_VECENV = False
# env_config.CONTINUOUS = args.env_continuous
## In caes video saving is enabled, make sure there is also the rgb videos
env_config.freeze()
# print(env_config)

# Environment instantiation
envs = construct_envs(env_config, get_env_class(env_config.ENV_NAME))
single_observation_space = envs.observation_spaces[0]
single_action_space = envs.action_spaces[0]

2022-09-12 16:44:30,257 Initializing dataset AudioNav
2022-09-12 16:44:30,266 Initializing dataset AudioNav
2022-09-12 16:45:37,825 initializing sim SoundSpacesSim
2022-09-12 16:45:38,082 Initializing task AudioNav


In [12]:
# PPO GRU Agent

# region: inspect the strcuture of the PPO GRU agent
# agent = ActorCritic(single_observation_space, 
#                     single_action_space,
#                     512, extra_rgb=False); agent

# for i, nm in enumerate(agent.named_modules()):
#     if i: # Skips the first '' that contains the whole model
#         print(nm) # tuple: (name, nn.Module)

# List of layers of which to capture the intermediate outputs
layer_names = [
    *[f"visual_encoder.cnn.{i}" for i in range(8)], # Shared arch. for GRU and PGWT
    *[f"audio_encoder.cnn.{i}" for i in range(8)], # Shared arch. for GRU and PGWT
    "action_distribution.linear", # Shared arch. for any type of agent
    "critic.fc", # Shared arch. for any type of agent
    
    "state_encoder", # Either GRU or PGWT based. Shape different based on nature of cell though.
]
# layer_names
# endregion: inspect the strcuture of the PPO GRU agent

# region: Prototyping the PPO GRU agent with layer output recording
agent = ActorCritic(single_observation_space, 
                    single_action_space,
                    512, extra_rgb=False,
                    analysis_layers=layer_names)
                    # analysis_layers=[])

batch_size = 2
prev_states = th.zeros(1, batch_size, 512)
masks = th.ones(batch_size, 1)
obs_dict = { k: th.randn([batch_size, *v.shape]) for k,v in single_observation_space.items()}
# for k, v in obs_dict.items():
#     print(k, v.shape)

actions, action_logprobs, dist_ent, values, latents = agent.act(obs_dict, prev_states, masks)
actions.shape, action_logprobs.shape, dist_ent.shape, values.shape, latents.shape

for k, v in agent._features.items():
    print(f"{k} -> {[vv.shape for vv in v] if isinstance(v[0], th.Tensor) else v[0].shape}")
# [v.shape for v in agent._features["state_encoder"]]
# endregion: Prototyping the PPO GRU agent with layer output recording

visual_encoder.cnn.0 -> [torch.Size([32, 31, 31]), torch.Size([32, 31, 31])]
visual_encoder.cnn.1 -> [torch.Size([32, 31, 31]), torch.Size([32, 31, 31])]
visual_encoder.cnn.2 -> [torch.Size([64, 14, 14]), torch.Size([64, 14, 14])]
visual_encoder.cnn.3 -> [torch.Size([64, 14, 14]), torch.Size([64, 14, 14])]
visual_encoder.cnn.4 -> [torch.Size([64, 6, 6]), torch.Size([64, 6, 6])]
visual_encoder.cnn.5 -> [torch.Size([2304]), torch.Size([2304])]
visual_encoder.cnn.6 -> [torch.Size([512]), torch.Size([512])]
visual_encoder.cnn.7 -> [torch.Size([512]), torch.Size([512])]
audio_encoder.cnn.0 -> [torch.Size([32, 31, 11]), torch.Size([32, 31, 11])]
audio_encoder.cnn.1 -> [torch.Size([32, 31, 11]), torch.Size([32, 31, 11])]
audio_encoder.cnn.2 -> [torch.Size([64, 15, 5]), torch.Size([64, 15, 5])]
audio_encoder.cnn.3 -> [torch.Size([64, 15, 5]), torch.Size([64, 15, 5])]
audio_encoder.cnn.4 -> [torch.Size([64, 13, 3]), torch.Size([64, 13, 3])]
audio_encoder.cnn.5 -> [torch.Size([2496]), torch.Size

In [None]:
# PPO PGWT Agent
# Notebook support or argpase
import sys; sys.argv=['']; del sys

from configurator import get_arg_dict, generate_args
from ss_baselines.av_nav.config import get_config
from ss_baselines.savi.config.default import get_config as get_savi_config

# region: Generating additional hyparams
CUSTOM_ARGS = [
    # General hyper parameters
    get_arg_dict("seed", int, 111),
    get_arg_dict("total-steps", int, 1_000_000),

    # SS env config
    get_arg_dict("config-path", str, "env_configs/audiogoal_depth_nocont.yaml"),

    # PPO Hyper parameters
    get_arg_dict("num-envs", int, 10), # Number of parallel envs. 10 by default
    get_arg_dict("num-steps", int, 150), # For each env, how many steps are collected to form PPO Agent rollout.
    get_arg_dict("num-minibatches", int, 1), # Number of mini-batches the rollout data is split into to make the updates
    get_arg_dict("update-epochs", int, 4), # Number of gradient step for the policy and value networks
    get_arg_dict("gamma", float, 0.99),
    get_arg_dict("gae-lambda", float, 0.95),
    get_arg_dict("norm-adv", bool, True, metatype="bool"),
    get_arg_dict("clip-coef", float, 0.1), # Surrogate loss clipping coefficient
    get_arg_dict("clip-vloss", bool, True, metatype="bool"),
    get_arg_dict("ent-coef", float, 0.2), # Entropy loss coef; 0.2 in SS baselines
    get_arg_dict("vf-coef", float, 0.5), # Value loss coefficient
    get_arg_dict("max-grad-norm", float, 0.5),
    get_arg_dict("target-kl", float, None),
    get_arg_dict("lr", float, 2.5e-4), # Learning rate
    get_arg_dict("optim-wd", float, 0), # weight decay for adam optim
    ## Agent network params
    get_arg_dict("agent-type", str, "ss-default", metatype="choice",
        choices=["ss-default", "deep-etho",
                    "perceiver-gwt-gwwm", "perceiver-gwt-attgru"]),
    get_arg_dict("hidden-size", int, 512), # Size of the visual / audio features and RNN hidden states 
    ## Perceiver / PerceiverIO params: TODO: num_latnets, latent_dim, etc...
    get_arg_dict("pgwt-latent-type", str, "randn", metatype="choice",
        choices=["randn", "zeros"]), # Depth of the Perceiver
    get_arg_dict("pgwt-latent-learned", bool, True, metatype="bool"),
    get_arg_dict("pgwt-depth", int, 1), # Depth of the Perceiver
    get_arg_dict("pgwt-num-latents", int, 8),
    get_arg_dict("pgwt-latent-dim", int, 64),
    get_arg_dict("pgwt-cross-heads", int, 1),
    get_arg_dict("pgwt-latent-heads", int, 4),
    get_arg_dict("pgwt-cross-dim-head", int, 64),
    get_arg_dict("pgwt-latent-dim-head", int, 64),
    get_arg_dict("pgwt-weight-tie-layers", bool, False, metatype="bool"),
    get_arg_dict("pgwt-ff", bool, False, metatype="bool"),
    get_arg_dict("pgwt-num-freq-bands", int, 6),
    get_arg_dict("pgwt-max-freq", int, 10.),
    get_arg_dict("pgwt-use-sa", bool, False, metatype="bool"),
    ## Peceiver Modality Embedding related
    get_arg_dict("pgwt-mod-embed", int, 0), # Learnable modality embeddings
    ## Additional modalities
    get_arg_dict("pgwt-ca-prev-latents", bool, False, metatype="bool"), # if True, passes the prev latent to CA as KV input data

    # Logging params
    # NOTE: While supported, video logging is expensive because the RGB generation in the
    # envs hogs a lot of GPU, especially with multiple envs 
    get_arg_dict("save-videos", bool, False, metatype="bool"),
    get_arg_dict("save-model", bool, True, metatype="bool"),
    get_arg_dict("log-sampling-stats-every", int, int(1.5e3)), # Every X frames || steps sampled
    get_arg_dict("log-training-stats-every", int, int(10)), # Every X model update
    get_arg_dict("logdir-prefix", str, "./logs/") # Overrides the default one
]
args = generate_args(CUSTOM_ARGS)

# Load environment config
is_SAVi = str.__contains__(args.config_path, "savi")
if is_SAVi:
    env_config = get_savi_config(config_paths=args.config_path)
else:
    env_config = get_config(config_paths=args.config_path)

# Additional PPO overrides
args.batch_size = int(args.num_envs * args.num_steps)
args.minibatch_size = int(args.batch_size // args.num_minibatches)
# endregion: Generating additional hyparams

# single_observation_space, single_action_space
pgwt_agent = Perceiver_GWT_GWWM_ActorCritic(single_observation_space, 
                    single_action_space,
                    args, extra_rgb=False); pgwt_agent

# for i, nm in enumerate(pgwt_agent.named_modules()):
#     if i: # Skips the first '' that contains the whole model
#         print(nm) # tuple: (name, nn.Module)

# List of layers of which to capture the intermediate outputs
layer_names = [
    *[f"visual_encoder.cnn.{i}" for i in range(8)], # Shared arch. for GRU and PGWT
    *[f"audio_encoder.cnn.{i}" for i in range(8)], # Shared arch. for GRU and PGWT
    "action_distribution.linear", # Shared arch. for any type of agent
    "critic.fc", # Shared arch. for any type of agent

    "state_encoder", # Either GRU or PGWT based. Shape different based on nature of cell though.
    ## PGWT GWWM specific
    "state_encoder.ca", # This will have the output of the residual conn. self.ff_self(attention_value) + attention_value
    "state_encoder.ca.mha", # Note: output here is tuple (attention_value, attention_weight)
    "state_encoder.ca.ln_q",
    "state_encoder.ca.ln_kv",
    "state_encoder.ca.ff_self", # No residual connection output
    *[f"state_encoder.ca.ff_self.{i}" for i in range(4)], # [LayerNorm, Linear, GELU, Linear]
    ## TODO: add support for the SA layers too
]
layer_names