# VicReg Loss Function

> As implemented in the paper.

In [None]:
#| default_exp losses.vicreg

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#| export

import torch
from torch.nn import functional as F

from mawm.models.utils import flatten_conv_output
from functools import reduce
import operator
from mawm.models.misc import Projector



In [None]:
#| export
class VICReg(torch.nn.Module):
    def __init__(
        self,
        cfg,
        repr_dim: int,
        pred_attr: str = "state",
        name_prefix: str = "",
    ):
        super().__init__()
        if isinstance(repr_dim, tuple):
            repr_dim = reduce(operator.mul, repr_dim)
        self.cfg = cfg
        self.name_prefix = name_prefix
        self.pred_attr = pred_attr
        self.projector = Projector(
            arch=cfg.loss.vicreg.projector,
            embedding=repr_dim,
            # random=cfg.random_projector,
        )#.cuda() #TODO: REMOVE

    

In [None]:
#| export
@patch
def __call__(self: VICReg, encodings, state_predictions):
    
    sim_loss = (encodings[1:] - state_predictions[1:]).pow(2).mean()

    if self.cfg.loss.vicreg.sim_coeff_t:
        sim_loss_t = (encodings[1:] - encodings[:-1]).pow(2).mean()
    else:
        sim_loss_t = torch.zeros([1])

    encodings = self.projector(encodings)

    flat_encodings = flatten_conv_output(encodings)

    std_loss = self.std_loss(flat_encodings[:1])

    if self.cfg.loss.vicreg.cov_per_feature:
        T, B, ch, h, w = encodings.shape
        # reshape (1, bs, ch, h, w) --> (h*w, bs, ch)
        per_feature_encodings = (
            encodings[:1].reshape(1, B, ch, h * w).permute(0, 3, 1, 2).squeeze(0)
        )
        cov_loss = self.cov_loss(per_feature_encodings)
    else:
        # reshape (1, bs, ch, h, w) --> (w, bs, ch * h * w)
        cov_loss = self.cov_loss(flat_encodings[:1])

    std_loss_t = self.std_loss(
        flat_encodings[1:].permute(1, 0, 2), across_time=True
    )  # (bs, T, repr)
    cov_loss_t = self.cov_loss(
        flat_encodings[1:].permute(1, 0, 2), across_time=True
    )  # (bs, T, repr)

    total_loss = (
        self.cfg.loss.vicreg.sim_coeff * sim_loss
        + self.cfg.loss.vicreg.cov_coeff * cov_loss.mean()
        + self.cfg.loss.vicreg.std_coeff * std_loss.mean()
        + self.cfg.loss.vicreg.cov_coeff_t * cov_loss_t.mean()
        + self.cfg.loss.vicreg.std_coeff_t * std_loss_t.mean()
        + self.cfg.loss.vicreg.sim_coeff_t * sim_loss_t.mean()
    )

    losses = {
        "total_loss": total_loss,
        "sim_loss": sim_loss,
        "std_loss": std_loss.mean(),
        "cov_loss": cov_loss.mean(),
        "sim_loss_t": sim_loss_t.mean(),
        "std_loss_t": std_loss_t.mean(),
        "cov_loss_t": cov_loss_t.mean(),
        
    }
    
    return losses

In [None]:
#| export
@patch
def std_loss(self:VICReg, x: torch.Tensor, across_time=False):
    x = x - x.mean(dim=1, keepdim=True)  # mean for each dim across batch samples

    if (
        not across_time
        and self.cfg.loss.vicreg.std_coeff
        or across_time
        and self.cfg.loss.vicreg.std_coeff_t
    ):
        std = torch.sqrt(x.var(dim=1) + 0.0001)

        std_margin = (
            self.cfg.loss.vicreg.std_margin_t if across_time else self.cfg.loss.vicreg.std_margin
        )
        std_loss = torch.mean(F.relu(std_margin - std), dim=-1)
    else:
        std_loss = torch.zeros([1])

    return std_loss

In [None]:
#| export
@patch
def cov_loss(self: VICReg, x: torch.Tensor, across_time=False):
    batch_size = x.shape[1]
    num_features = x.shape[-1]

    x = x - x.mean(dim=1, keepdim=True)

    if (
        not across_time
        and self.cfg.loss.vicreg.cov_coeff
        or across_time
        and self.cfg.loss.vicreg.cov_coeff_t
    ):
        cov = torch.einsum("bki,bkj->bij", x, x) / (batch_size - 1)
        diagonals = torch.einsum("bii->bi", cov).pow(2).sum(dim=-1)
        # cov shape is TxDxD

        cov_loss = (cov.pow(2).sum(dim=[-1, -2]) - diagonals).div(num_features)
        if self.cfg.loss.vicreg.adjust_cov:
            cov_loss = cov_loss / (
                num_features - 1
            )  # divide by num of elements on off-diagonal.
            # in orig paper they divide by num_features
            # but the correct version is (num_features - 1)*num_features
    else:
        cov_loss = torch.zeros([1])

    return cov_loss

In [None]:
#| hide
from omegaconf import OmegaConf


In [None]:
#| hide
from mawm.models.jepa import JEPA
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")
model = JEPA(cfg.model, input_dim=(3, 42, 42), action_dim=1)

In [None]:
#| hide
import torch
B = 16
T = 6
x = torch.randn(B, T, 3, 42, 42)
pos = torch.randn(B, T, 2)
actions = torch.randn(B, T, 1)
msgs = torch.randn(B, T, 32)
z0, Z = model(x, pos=pos, repr_input=False, actions=actions, msgs=msgs, T=actions.size(1) - 1)

In [None]:
#| hide
Z.shape, z0.shape

(torch.Size([6, 16, 18, 15, 15]), torch.Size([6, 16, 18, 15, 15]))

In [None]:
z0[1:].shape

torch.Size([5, 16, 18, 15, 15])

In [None]:
#| hide
encodings = z0
state_predictions = Z
sim_loss = (encodings[1:] - state_predictions[1:]).pow(2).mean()
sim_loss_t = (encodings[1:] - encodings[:-1]).pow(2).mean()
sim_loss, sim_loss_t

(tensor(1.6168, grad_fn=<MeanBackward0>),
 tensor(0.2939, grad_fn=<MeanBackward0>))

In [None]:
#| hide
loss = VICReg(cfg, repr_dim=(18, 15, 15), name_prefix="JEPA")
                    

In [None]:
#| hide
loss_dict =loss(encodings, state_predictions)
loss_dict

{'total_loss': tensor(27.5082, grad_fn=<AddBackward0>),
 'sim_loss': tensor(1.6168, grad_fn=<MeanBackward0>),
 'std_loss': tensor(0.6789, grad_fn=<MeanBackward0>),
 'cov_loss': tensor(0.0042, grad_fn=<MeanBackward0>),
 'sim_loss_t': tensor(0.2939, grad_fn=<MeanBackward0>),
 'std_loss_t': tensor(0.6956, grad_fn=<MeanBackward0>),
 'cov_loss_t': tensor(0.)}

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()