# JEPA

> Contains an encoder and a predictor.

In [None]:
#| default_exp models.jepa

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
from dataclasses import dataclass
from typing import Optional, NamedTuple

import torch

from functools import reduce
import operator



In [None]:
#| export
from mawm.core import get_cls
def get_model(model_name: str, type = "vision", params: dict = {}):
    model_cls = get_cls(f"mawm.models.{type}", f"{model_name}")
    model = model_cls(**params)
    return model

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
from mawm.core import get_cls

class JEPA(nn.Module):
    def __init__(
        self,
        cfg,
        input_dim = (42, 42, 3),
        repr_dim = None,
        action_dim = 5,
    ):
        super().__init__()

        self.cfg = cfg
        self.input_dim = input_dim
        self.repr_dim = repr_dim
        self.action_dim = action_dim
        self.backbone, self.dynamics = self.get_models()

    def get_models(self):
        model_cls = get_cls("mawm.models.vision", f"{self.cfg.backbone.arch}")
        backbone = model_cls(config=self.cfg.backbone, input_dim=self.input_dim)
        repr_dim = backbone.repr_dim

        model_cls = get_cls("mawm.models.dynamics", f"{self.cfg.predictor.arch}")
        dynamics = model_cls(config=self.cfg.predictor, repr_dim=repr_dim, action_dim=self.action_dim)
        return backbone, dynamics
        

In [None]:
#| hide
from omegaconf import OmegaConf


In [None]:
#| hide
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")

cfg.model.backbone

{'arch': 'MeNet6', 'backbone_subclass': 'd4rl_a', 'backbone_mlp': None, 'backbone_width_factor': 2, 'input_dim': 4, 'channels': 3, 'position_dim': 2, 'position_encoder_arch': '2-8-16'}

In [None]:
#| hide
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")
model = JEPA(cfg.model, input_dim=(3, 42, 42), action_dim=5)

In [None]:
#| export
@patch
def forward(
        self: JEPA,
        x: torch.Tensor, # [B, T, C, H, W]
        pos: torch.Tensor = None,  # [B, T, 2]
        repr_input: bool = False,
        actions: torch.Tensor = None,# [B, T, 1]
        msgs: torch.Tensor = None,# [B, T, 32]
        T: int = None,
        goal: torch.Tensor = None
    ):

    z0 = self.backbone(x, position=pos) if not repr_input else x
    z0 = torch.einsum('b t c h w->t b c h w', z0)# [T, B, C, H, W]
    actions = torch.einsum('b t d->t b d', actions) # [T, B, D]
    msgs = torch.einsum('b t m->t b m', msgs) # [T, B, M]
    
    Z = self.dynamics.forward_multiple(z0, actions[:-1], msgs[:-1], T) # TODO: check if it should be msgs[1:]

    return z0, Z

In [None]:
#| hide
model.dynamics 

ConvPredictor(
  (layers): Sequential(
    (0): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): GroupNorm(4, 32, eps=1e-05, affine=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): GroupNorm(4, 64, eps=1e-05, affine=True)
    (5): ReLU()
    (6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (action_encoder): Sequential(
    (0): Linear(in_features=5, out_features=32, bias=True)
    (1): Expander2D()
  )
  (msg_encoder): Expander2D()
)

In [None]:
#| hide
model.backbone.repr_dim

(32, 15, 15)

In [None]:
#| hide
B = 16
T = 6
x = torch.randn(B, T, 3, 42, 42)
inp_pos = torch.randint(0, 15, (B, T, 2))
actions = torch.randn(B, T, 5)
msgs = torch.randn(B, T, 32)
z0, Z = model(x, pos=inp_pos, repr_input=False, actions=actions, msgs=msgs, T=actions.size(1) - 1)

In [None]:
#| hide
from torchviz import make_dot
graph = make_dot(Z, params=dict(list(model.named_parameters())), show_saved=True)#.render("jepa_model_graph.png", format="png")


In [None]:
# save graph as png
import os
if not os.path.exists("../graphs"):
    os.makedirs("../graphs")
graph.render("../graphs/jepa_model_graph.png", format="png")

'../graphs/jepa_model_graph.png.png'

In [None]:
#| hide
z0.shape, Z.shape

(torch.Size([6, 16, 32, 15, 15]), torch.Size([6, 16, 32, 15, 15]))

In [None]:
#| shape
from einops import rearrange
print(rearrange(z0[:-1], 't b c h w -> (b t) (c h w)').shape)
Z0 = rearrange(z0[:-1], 't b c h w -> (b t) (c h w)')
Z0.mean(0).shape

torch.Size([80, 7200])


torch.Size([7200])

In [None]:
#| hide
C = torch.randn(B, T-1, 32)

If you tried to execute the following code:

```python
inv_loss = (Z0.mean(0) - C).square().mean()
```

It will throw an error because the variables `Z0` and `C` are of two different embedding shapes. 

The solution is to project both into a shared embedding space before computing the invariance loss, like so:

```python
proj_z = torch.nn.Linear(Z0.size(-1), 128)(Z0)
proj_c = torch.nn.Linear(C.size(-1), 128)(C)
inv_loss = (proj_z.mean(0) - proj_c).square().mean()
```

In [None]:
#| hide
import torch
proj_z = torch.nn.Linear(Z0.size(-1), 128)(Z0)
proj_c = torch.nn.Linear(C.size(-1), 128)(C)
inv_loss = (proj_z.mean(0) - proj_c).square().mean()

In [None]:
#| export
@patch
def update_ema(self: JEPA):
    if self.config.momentum > 0:
        for param, ema_param in zip(
            self.backbone.parameters(), self.backbone_ema.parameters()
        ):
            ema_param.data.mul_(self.config.momentum).add_(
                param.data, alpha=1 - self.config.momentum
            )

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()