# VicReg Loss Function

> As implemented in the paper.

In [None]:
#| default_exp losses.vicreg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#| export

import torch
from torch.nn import functional as F

from mawm.models.utils import flatten_conv_output
from functools import reduce
import operator
from mawm.models.misc import Projector
from einops import rearrange


In [None]:
#| export
class VICReg(torch.nn.Module):
    def __init__(
        self,
        cfg,
        repr_dim: int = (18, 15, 15),
        pred_attr: str = "state",
        name_prefix: str = "",
    ):
        super().__init__()
        if isinstance(repr_dim, tuple):
            repr_dim = reduce(operator.mul, repr_dim)
        self.cfg = cfg
        self.name_prefix = name_prefix
        self.pred_attr = pred_attr
        # self.projector = Projector(
        #     arch=cfg.loss.vicreg.projector,
        #     embedding=repr_dim,
        #     # random=cfg.random_projector,
        # )#.cuda() #TODO: REMOVE

    

There are multiple losses in the paper:
- The similarity loss between the predicted and target representations.

$$
\mathcal{L}_{\operatorname{sim}}=\sum_{k=1}^K \sum_{t=0}^H \frac{1}{N} \sum_{b=0}^N\left\|\hat{Z}_{t, b}^k-Z_{t, b}\right\|_2^2
$$


- The vicReg losses are other three losses:
  - Invariance loss.
  - Variance loss.
  - Covariance loss.

In [None]:
#| export
@patch
def __call__(self: VICReg, encodings, state_predictions, mask= None):
    device = encodings.device
    
    valid_mask = mask#rearrange(mask, 'b t -> t b')
    valid_mask = valid_mask.to(device)
    transition_mask = valid_mask[1:] * valid_mask[:-1]# (T-1, B)
    

    diff = (encodings[1:] - state_predictions[1:]).pow(2).mean(dim=(2, 3, 4)) # (T-1, B)
    sim_loss = (diff * transition_mask).sum() / transition_mask.sum().clamp_min(1)

    if self.cfg.loss.vicreg.sim_coeff_t:
        diff_t = (encodings[1:] - encodings[:-1]).pow(2).mean(dim=(2, 3, 4))# (T-1, B)
        sim_loss_t = (diff_t * transition_mask).sum() / transition_mask.sum().clamp_min(1)
    else:
        sim_loss_t = torch.zeros([1])

    flat_encodings = flatten_conv_output(encodings) # [T, B, D]
    std_loss = self.std_loss(flat_encodings[:1])

    if self.cfg.loss.vicreg.cov_per_feature:
        T, B, ch, h, w = encodings.shape
        # reshape (1, bs, ch, h, w) --> (h*w, bs, ch)
        per_feature_encodings = (
            encodings[:1].reshape(1, B, ch, h * w).permute(0, 3, 1, 2).squeeze(0)
        )
        cov_loss = self.cov_loss(per_feature_encodings)
    else:
        # reshape (1, bs, ch, h, w) --> (w, bs, ch * h * w)
        cov_loss = self.cov_loss(flat_encodings[:1])

    # flat_encodings: (T, B, D)
    flat_enc = flat_encodings[1:]          # drop t=0
    valid = valid_mask[1:]                 # (T-1, B)

    # reshape to (B, T-1, D)
    flat_enc = flat_enc.permute(1, 0, 2)
    valid = valid.permute(1, 0)

    std_losses, cov_losses = [], []

    for b in range(flat_enc.shape[0]):
        idx = valid[b].bool()
        if idx.sum() > 1:   # must have at least 2 steps
            x = flat_enc[b, idx]   # (T_valid, D)
            std_losses.append(self.std_loss(x.unsqueeze(0), across_time=True))
            cov_losses.append(self.cov_loss(x.unsqueeze(0), across_time=True))

    if len(std_losses) > 0:
        std_loss_t = torch.stack(std_losses).mean()
        cov_loss_t = torch.stack(cov_losses).mean()
    else:
        std_loss_t = cov_loss_t = torch.zeros(1, device=flat_enc.device)

    total_loss = (
        self.cfg.loss.vicreg.sim_coeff * sim_loss
        + self.cfg.loss.vicreg.cov_coeff * cov_loss.mean()
        + self.cfg.loss.vicreg.std_coeff * std_loss.mean()
        + self.cfg.loss.vicreg.cov_coeff_t * cov_loss_t.mean()
        + self.cfg.loss.vicreg.std_coeff_t * std_loss_t.mean()
        + self.cfg.loss.vicreg.sim_coeff_t * sim_loss_t.mean()
    )

    losses = {
        "total_loss": total_loss,
        "sim_loss": sim_loss,
        "std_loss": std_loss.mean(),
        "cov_loss": cov_loss.mean(),
        "sim_loss_t": sim_loss_t.mean(),
        "std_loss_t": std_loss_t.mean(),
        "cov_loss_t": cov_loss_t.mean(),
        
    }
    
    return losses

In [None]:
import torch
import torch.nn.functional as F
std_margin = 1.0
x = torch.randn(40, 128, 512)
print(x[:1].shape)
x = x - x.mean(dim=1, keepdim=True)
print(x.shape)
std = torch.sqrt(x.var(dim=1) + 0.0001)
print(std.shape)
std_loss = torch.mean(F.relu(std_margin- std), dim=-1)
print(std_loss.shape)

torch.Size([1, 128, 512])
torch.Size([40, 128, 512])
torch.Size([40, 512])
torch.Size([40])


In [None]:
#| export
@patch
def std_loss(self:VICReg, x: torch.Tensor, across_time=False):
    x = x - x.mean(dim=1, keepdim=True)  # mean for each dim across batch samples

    if (
        not across_time
        and self.cfg.loss.vicreg.std_coeff
        or across_time
        and self.cfg.loss.vicreg.std_coeff_t
    ):
        std = torch.sqrt(x.var(dim=1) + 0.0001)

        std_margin = (
            self.cfg.loss.vicreg.std_margin_t if across_time else self.cfg.loss.vicreg.std_margin
        )
        std_loss = torch.mean(F.relu(std_margin - std), dim=-1)
    else:
        std_loss = torch.zeros([1], device= x.device)

    return std_loss

In [None]:
#| export
@patch
def cov_loss(self: VICReg, x: torch.Tensor, across_time=False):
    batch_size = x.shape[1]
    num_features = x.shape[-1]

    x = x - x.mean(dim=1, keepdim=True)

    if (
        not across_time
        and self.cfg.loss.vicreg.cov_coeff
        or across_time
        and self.cfg.loss.vicreg.cov_coeff_t
    ):
        cov = torch.einsum("bki,bkj->bij", x, x) / (batch_size - 1)
        diagonals = torch.einsum("bii->bi", cov).pow(2).sum(dim=-1)
        # cov shape is TxDxD

        cov_loss = (cov.pow(2).sum(dim=[-1, -2]) - diagonals).div(num_features)
        if self.cfg.loss.vicreg.adjust_cov:
            cov_loss = cov_loss / (
                num_features - 1
            )  # divide by num of elements on off-diagonal.
            # in orig paper they divide by num_features
            # but the correct version is (num_features - 1)*num_features
    else:
        cov_loss = torch.zeros([1], device= x.device)

    return cov_loss

In [None]:
#| hide
from omegaconf import OmegaConf


In [None]:
#| hide
from mawm.models.jepa import JEPA
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")
model = JEPA(cfg.model, input_dim=(3, 42, 42), action_dim=5)

In [None]:
#| hide
import torch
B = 16
T = 6
x = torch.randn(B, T, 3, 42, 42)
pos = torch.randint(0, 15, (B, T, 2))

actions = torch.randn(B, T, 5)
msgs = torch.randn(B, T, 32)
z0, Z = model(x, pos=pos, repr_input=False, actions=actions, msgs=msgs, T=actions.size(1) - 1)

In [None]:
#| hide
Z.shape, z0.shape

(torch.Size([6, 16, 32, 15, 15]), torch.Size([6, 16, 32, 15, 15]))

In [None]:
z0[1:].shape

torch.Size([5, 16, 32, 15, 15])

In [None]:
#| hide
mask = torch.zeros(T, B)
mask[0, 3:] = 1
mask[1, 4:] = 1
mask[2, 2:] = 1
mask.shape

torch.Size([6, 16])

In [None]:
mask

tensor([[0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
# mask = rearrange(mask, 't b -> b t')

In [None]:
mask

tensor([[0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
encodings = z0
state_predictions = Z

In [None]:
#| hide
loss = VICReg(cfg, repr_dim=None, name_prefix="JEPA")


In [None]:
#| hide
loss_dict =loss(encodings, state_predictions, mask=mask)
loss_dict

{'total_loss': tensor(3.0981, grad_fn=<AddBackward0>),
 'sim_loss': tensor(0.3333, grad_fn=<DivBackward0>),
 'std_loss': tensor(0.8302, grad_fn=<MeanBackward0>),
 'cov_loss': tensor(0.0001, grad_fn=<MeanBackward0>),
 'sim_loss_t': tensor(0.0833, grad_fn=<MeanBackward0>),
 'std_loss_t': tensor(0.8620, grad_fn=<MeanBackward0>),
 'cov_loss_t': tensor(0.)}

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()