# Dynamics module

> Build predictors from base archs.

In [None]:
#| default_exp models.dynamics.__init__

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn

In [None]:
#| export
from typing import Optional, Union

import torch
from torch import nn
from torch.nn import functional as F
from torch.distributions.normal import Normal
import numpy as np

from mawm.models.misc import build_mlp
from mawm.models.utils import *
from mawm.models.dynamics.enums import PredictorConfig, PredictorOutput


In [None]:
#| export
from mawm.models.dynamics.predictor import MLPPredictor, ConvPredictor, RNNPredictor, RNNPredictorV2, RNNPredictorV3, RNNPredictorBurnin
def build_predictor(
    config: PredictorConfig,
    repr_dim: int,
    action_dim: int,
    pred_propio_dim: Union[int, tuple],
    pred_obs_dim: Union[int, tuple],
    backbone_ln: Optional[torch.nn.Module] = None,
):
    arch = config.predictor_arch
    predictor_subclass = config.predictor_subclass
    rnn_layers = config.rnn_layers
    prior_arch = config.prior_arch
    posterior_arch = config.posterior_arch
    z_dim = config.z_dim
    z_min_std = config.z_min_std
    z_discrete = config.z_discrete
    z_discrete_dists = config.z_discrete_dists
    z_discrete_dim = config.z_discrete_dim
    posterior_drop_p = config.posterior_drop_p
    predictor_ln = config.predictor_ln
    posterior_input_type = config.posterior_input_type
    posterior_input_dim = config.posterior_input_dim
    if arch == "mlp":
        predictor = MLPPredictor(
            config=config,
            repr_dim=repr_dim,
            action_dim=action_dim,
            pred_propio_dim=pred_propio_dim,
            pred_obs_dim=pred_obs_dim,
            backbone_ln=backbone_ln,
        )
    # elif arch == "conv":
    #     predictor = PixelPredictorConv(action_dim=action_dim)
    elif arch == "conv2":
        predictor = ConvPredictor(
            config=config,
            repr_dim=repr_dim,
            predictor_subclass=predictor_subclass,
            z_discrete=z_discrete,
            z_discrete_dists=z_discrete_dists,
            z_discrete_dim=z_discrete_dim,
            z_dim=z_dim,
            z_min_std=z_min_std,
            posterior_drop_p=posterior_drop_p,
            prior_arch=prior_arch,
            posterior_arch=posterior_arch,
            posterior_input_type=posterior_input_type,
            posterior_input_dim=posterior_input_dim,
            action_dim=action_dim,
            pred_propio_dim=pred_propio_dim,
            pred_obs_dim=pred_obs_dim,
        )
    elif arch == "rnn":
        predictor = RNNPredictor(
            hidden_size=repr_dim,
            num_layers=rnn_layers,
            action_dim=action_dim,
            z_dim=z_dim,
        )
    elif arch == "rnnV2":
        predictor = RNNPredictorV2(
            config=config,
            hidden_size=repr_dim,
            num_layers=rnn_layers,
            input_size=action_dim,
            z_dim=z_dim,
            z_min_std=z_min_std,
            posterior_drop_p=posterior_drop_p,
            predictor_ln=predictor_ln,
            prior_arch=prior_arch,
            posterior_arch=posterior_arch,
            posterior_input_type=posterior_input_type,
            posterior_input_dim=posterior_input_dim,
            pred_propio_dim=pred_propio_dim,
            pred_obs_dim=pred_obs_dim,
            backbone_ln=backbone_ln,
            action_dim=action_dim,
        )
    elif arch == "rnnV3":
        predictor = RNNPredictorV3(
            hidden_size=repr_dim,
            num_layers=rnn_layers,
            input_size=action_dim,
        )
    elif arch == "rnn_burnin":
        predictor = RNNPredictorBurnin(
            hidden_size=repr_dim,
            output_size=repr_dim,
            num_layers=rnn_layers,
            action_dim=action_dim,
            z_dim=z_dim,
        )
  
    return predictor

In [None]:
# #| hide
# from omegaconf import OmegaConf

In [None]:
# #| hide
# cfg = OmegaConf.load("../cfgs/check.yaml")
# cfg.hjepa.level1.predictor

{'residual': True, 'action_encoder_arch': 'id', 'predictor_arch': 'conv2', 'predictor_subclass': 'd4rl_b_p', 'rnn_converter_arch': '', 'rnn_layers': 1, 'rnn_state_dim': 512, 'z_dim': 0, 'z_min_std': 0.1}

In [None]:
# #| hide
# predictor: PredictorConfig = PredictorConfig(**cfg.hjepa.level1.predictor)
# model = build_predictor(
#     config=predictor,
#     repr_dim=(18, 15, 15),
#     action_dim=1,
#     pred_propio_dim=None,
#     pred_obs_dim=(3, 42, 42),
#     backbone_ln=None,
# )

In [None]:
# #| hide
# model

ConvPredictor(
  (final_ln): Identity()
  (layers): Sequential(
    (0): Conv2d(19, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): GroupNorm(4, 32, eps=1e-05, affine=True)
    (2): ReLU()
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): GroupNorm(4, 32, eps=1e-05, affine=True)
    (5): ReLU()
    (6): Conv2d(32, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (action_encoder): Expander2D()
)

In [None]:
# #| hide
# T = 8
# BS = 16
# state_encs = torch.randn(T, BS, 18, 15, 15)
# act = torch.randn(T-1, BS, 1)

In [None]:
# #| hide
# state_encs.shape[1]

16

In [None]:
# #| hide
# out = model.forward_multiple(state_encs, act, T= 5)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()