# Inverse Dynamics loss
> Contain various loss functions used for optimization.

In [None]:
#| default_exp losses.idm

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#| export
from mawm.models.utils import *
from mawm.models.misc import MLP
from functools import reduce
import operator



In [None]:
#| export

CONV_LAYERS_CONFIG = {
    "a": [
        (-1, 32, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        (32, 32, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        (32, 32, 3, 1, 1),
        ("fc", -1, 5),
    ],
    "b": [
        (-1, 32, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        (32, 32, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        ("fc", -1, 2),
    ],
}

In [None]:
#| export
from einops import rearrange, repeat, einsum
class IDMLoss(torch.nn.Module):
    """Inverse Dynamics Model (IDM) objective.
    Trains an action predictor to predict the next action given the current
    state and the next state."""

    def __init__(
        self, config, repr_dim, device= "cuda"
    ):
        super().__init__()
        self.config = config

        if config.arch == "conv":
            input_dim = (repr_dim[0] * 2, *repr_dim[1:])
            self.action_predictor = build_conv(
                CONV_LAYERS_CONFIG[config.arch_subclass], input_dim=input_dim
            ).to(device)
        else:
            if isinstance(repr_dim, tuple):
                repr_dim = reduce(operator.mul, repr_dim)
            self.action_predictor = MLP(
                arch=config.arch,
                input_dim=repr_dim * 2,
                output_shape=config.action_dim,
            ).to(device)


In [None]:
#| export
@patch
def __call__(self: IDMLoss, embeddings, predictions, actions):
    actions = rearrange(actions, "b t ... -> t b ...")
    actions = actions[:-1].flatten(start_dim=0, end_dim=1)

    if self.config.use_pred:
        curr_embeds = predictions[:-1]
        next_embeds = embeddings[1:]
    else:
        curr_embeds = embeddings[:-1]
        next_embeds = embeddings[1:]

    if self.config.arch == "conv":
        repr_input = torch.cat([curr_embeds, next_embeds], dim=2)
    else:
        curr_embeds = flatten_conv_output(curr_embeds)
        next_embeds = flatten_conv_output(next_embeds)
        repr_input = torch.cat([curr_embeds, next_embeds], dim=-1)
    
    repr_input = rearrange(repr_input, "t b ... -> (t b) ...")
    actions_pred = self.action_predictor(repr_input)
    print(actions_pred.shape, actions.shape)
    action_loss = F.cross_entropy(
        actions_pred,
        actions.to(actions_pred.device),
        reduction="mean",
    )
    
    # total_loss = action_loss
    return action_loss

In [None]:
#| hide
from omegaconf import OmegaConf

In [None]:
#| hide

cfg = OmegaConf.load("../cfgs/findgoal/mawm/ablations/datasize/mawm_ds_200k.yaml")


In [None]:
#| hide
from mawm.models.jepa import JEPA
model = JEPA(cfg.model, input_dim=(3, 42, 42), action_dim=5)

In [None]:
#| hide
B = 16
T = 40
x = torch.randn(B, T, 3, 42, 42)
inp_pos = torch.randint(0, 15, (B, T, 2))
actions = F.one_hot(torch.randint(low= 0, high= 5, size= (B, T))).float()
msgs = torch.randn(B, T, 32)
z0, Z = model(x, pos=inp_pos, repr_input=False, actions=actions, msgs=msgs, T=actions.size(1) - 1)

In [None]:
#| hide
z0.shape, Z.shape, actions.shape

(torch.Size([40, 16, 32, 15, 15]),
 torch.Size([40, 16, 32, 15, 15]),
 torch.Size([16, 40, 5]))

In [None]:
#| hide
loss_fn = IDMLoss(cfg.loss.idm, (32, 15, 15), device= "cpu")

In [None]:
#| hide
loss_fn.action_predictor

Sequential(
  (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): GroupNorm(8, 32, eps=1e-05, affine=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): GroupNorm(8, 32, eps=1e-05, affine=True)
  (6): ReLU()
  (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): GroupNorm(8, 32, eps=1e-05, affine=True)
  (10): ReLU()
  (11): Flatten(start_dim=1, end_dim=-1)
  (12): Linear(in_features=288, out_features=5, bias=True)
)

In [None]:
#| hide
loss_fn(embeddings= z0, predictions= Z, actions= actions)

torch.Size([624, 5]) torch.Size([624, 5])


tensor(1.6705, grad_fn=<DivBackward1>)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()