# `mdmm` tutorial
The purpose of this notebook is to guide you to train a model with a loss function of several parameters in a proper mathematical way. The idea comes from the paper [Constrained Differential Optimization](https://papers.nips.cc/paper/1987/file/a87ff679a2f3e71d9181a67b7542122c-Paper.pdf), and the implementation from the [mdmm package Github](https://github.com/crowsonkb/mdmm)

The guide is best illustrated through the **[VICReg](https://arxiv.org/abs/2105.04906)** example where your input is split into two views and you are asked to minimize three loss terms: `variance`, `invariance` and `covariance`.

In [1]:
! pip install mdmm

Collecting mdmm
  Downloading mdmm-0.1.3-py3-none-any.whl (5.7 kB)
Installing collected packages: mdmm
Successfully installed mdmm-0.1.3


In [2]:
import warnings

import mdmm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.conv import GravNetConv

warnings.simplefilter("ignore", UserWarning)

# define the global base device
if torch.cuda.device_count():
    device = torch.device("cuda:0")
    print(f"Will use {torch.cuda.get_device_name(device)}")
else:
    device = "cpu"
    print("Will use cpu")

Will use NVIDIA GeForce GTX 1080 Ti


Load a processed `.pt` clic file.

In [3]:
data = torch.load("/../ssl-jet-vol-v2/toptagging/train/processed/data_0.pt")
print(f"num of clic events {len(data)}")

num of clic events 100001


In [7]:
# build a data loader
batch_size = 256

loader = DataLoader(data, batch_size, shuffle=True)
for batch in loader:
    print(f"A single event: \n {batch}")
    break

input_dim = batch.x.shape[-1]

A single event: 
 DataBatch(x=[12533, 7], y=[256], batch=[12533], ptr=[257])


# Preparation

In [8]:
def translate_jets(batch, width=1.0, device='cpu'):
    width = 1.0
    batch = batch.to("cpu")
    X = batch.x.numpy()
    mask = (X[:,2] > 0).reshape(X.shape[0],1)
    ptp_eta = np.ptp(X[:,1], axis=-1, keepdims=True)
    ptp_phi = np.ptp(X[:,0], axis=-1, keepdims=True)
    low_eta  = -width*ptp_eta
    high_eta = +width*ptp_eta
    low_phi = np.maximum(-width*ptp_phi, -np.pi-np.min(X[:,2]).reshape(ptp_phi.shape))
    high_phi = np.minimum(+width*ptp_phi, +np.pi-np.max(X[:,2]).reshape(ptp_phi.shape))
    shift_eta = mask*np.random.uniform(low=low_eta, high=high_eta, size=(X.shape[0], 1))
    shift_phi = mask*np.random.uniform(low=low_phi, high=high_phi, size=(X.shape[0], 1))
    shift = np.hstack((shift_eta, shift_phi, np.zeros((X.shape[0], 5))))
    X = X + shift
    X = torch.tensor(X).to(device)
    batch.x = X.float()
    return batch.to(device)

In [5]:
def event_augmentation(batch):
    """
    Takes events of the form Batch() and splits them into two Batch() objects representing the two views.

    In this example, the first view is tracks and the second view is clusters.
    """

    

    view1 = translate_jets(batch)
    view2 = translate_jets(batch)

    return view1, view2

In [6]:
view1, view2 = event_augmentation(batch)
print(f"view1: {view1}")
print(f"view2: {view2}")

view1: Batch(x=[2428, 17], ygen=[2428, 6], ygen_id=[2428], ycand=[2428, 6], ycand_id=[2428], batch=[2428])
view2: Batch(x=[4503, 17], ygen=[4503, 6], ygen_id=[4503], ycand=[4503, 6], ycand_id=[4503], batch=[4503])


# Setup the VICReg model (GravNet-based)

In [7]:
class VICReg(nn.Module):
    def __init__(self, encoder, decoder):
        super(VICReg, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

        self.augmentation = event_augmentation

    def forward(self, event):
        # seperate the two views
        view1, view2 = self.augmentation(event)

        # encode to retrieve the representations
        view1_representations, view2_representations = self.encoder(view1, view2)

        # simple MLP decoder
        view1_embeddings = self.decoder(view1_representations)
        view2_embeddings = self.decoder(view2_representations)

        # global pooling to be able to compute a loss between views of different dimensionalities
        view1_embeddings = global_mean_pool(view1_embeddings, view1.batch)
        view2_embeddings = global_mean_pool(view2_embeddings, view2.batch)

        return view1_embeddings, view2_embeddings


class ENCODER(nn.Module):
    """The Encoder part of VICReg which attempts to learn useful latent representations of the two views."""

    def __init__(
        self,
        input_dim,
        width=126,
        embedding_dim=34,
        num_convs=2,
    ):
        super(ENCODER, self).__init__()

        self.act = nn.ELU

        # 1. different MLP for each view
        self.nn1 = nn.Sequential(
            nn.Linear(input_dim, width),
            self.act(),
            nn.Linear(width, width),
            self.act(),
            nn.Linear(width, width),
            self.act(),
            nn.Linear(width, embedding_dim),
        )
        self.nn2 = nn.Sequential(
            nn.Linear(17, width),
            self.act(),
            nn.Linear(width, width),
            self.act(),
            nn.Linear(width, width),
            self.act(),
            nn.Linear(width, embedding_dim),
        )

        # 2. same GNN for each view
        self.convs = nn.ModuleList()
        for i in range(num_convs):
            self.convs.append(
                GravNetConv(
                    embedding_dim,
                    embedding_dim,
                    space_dimensions=4,
                    propagate_dimensions=22,
                    k=8,
                )
            )

    def forward(self, view1, view2):
        view1_representations = self.nn1(view1.x.float())
        view2_representations = self.nn2(view2.x.float())

        # perform a series of graph convolutions
        for num, conv in enumerate(self.convs):
            view1_representations = conv(view1_representations, view1.batch)
            view2_representations = conv(view2_representations, view2.batch)

        return view1_representations, view2_representations


class DECODER(nn.Module):
    """The Decoder part of VICReg which attempts to expand the learned latent representations
    of the two views into a space where a loss can be computed."""

    def __init__(
        self,
        embedding_dim=34,
        width=126,
        output_dim=200,
    ):
        super(DECODER, self).__init__()

        self.act = nn.ELU

        # DECODER
        self.expander = nn.Sequential(
            nn.Linear(embedding_dim, width),
            self.act(),
            nn.Linear(width, width),
            self.act(),
            nn.Linear(width, width),
            self.act(),
            nn.Linear(width, output_dim),
        )

    def forward(self, x):
        return self.expander(x)


In [8]:
vicreg_encoder = ENCODER(input_dim, embedding_dim=34)
vicreg_decoder = DECODER(embedding_dim=34, output_dim=200)

vicreg = VICReg(vicreg_encoder, vicreg_decoder)
vicreg.to(device)

VICReg(
  (encoder): ENCODER(
    (nn1): Sequential(
      (0): Linear(in_features=17, out_features=126, bias=True)
      (1): ELU(alpha=1.0)
      (2): Linear(in_features=126, out_features=126, bias=True)
      (3): ELU(alpha=1.0)
      (4): Linear(in_features=126, out_features=126, bias=True)
      (5): ELU(alpha=1.0)
      (6): Linear(in_features=126, out_features=34, bias=True)
    )
    (nn2): Sequential(
      (0): Linear(in_features=17, out_features=126, bias=True)
      (1): ELU(alpha=1.0)
      (2): Linear(in_features=126, out_features=126, bias=True)
      (3): ELU(alpha=1.0)
      (4): Linear(in_features=126, out_features=126, bias=True)
      (5): ELU(alpha=1.0)
      (6): Linear(in_features=126, out_features=34, bias=True)
    )
    (convs): ModuleList(
      (0): GravNetConv(34, 34, k=8)
      (1): GravNetConv(34, 34, k=8)
    )
  )
  (decoder): DECODER(
    (expander): Sequential(
      (0): Linear(in_features=34, out_features=126, bias=True)
      (1): ELU(alpha=1.0)
  

# Setup the loss terms

In [9]:
def off_diagonal(x):
    """Copied from VICReg paper github https://github.com/facebookresearch/vicreg/"""
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class CovLoss(nn.Module):
    def forward(self, view1, view2):
        N = view1.size(0)  # batch size
        D = view1.size(1)  # dim of representations

        view1 = view1 - view1.mean(dim=0)
        view2 = view2 - view2.mean(dim=0)

        cov_view1 = (view1.T @ view1) / (N - 1)
        cov_view2 = (view2.T @ view2) / (N - 1)

        loss = off_diagonal(cov_view1).pow_(2).sum().div(D) + off_diagonal(cov_view2).pow_(2).sum().div(D)
        return loss


class VarLoss(nn.Module):
    def forward(self, view1, view2):
        view1 = view1 - view1.mean(dim=0)
        view2 = view2 - view2.mean(dim=0)

        # variance loss
        std_view1 = torch.sqrt(view1.var(dim=0) + 1e-10)
        std_view2 = torch.sqrt(view2.var(dim=0) + 1e-10)

        loss = torch.mean(F.relu(1 - std_view1)) / 2 + torch.mean(F.relu(1 - std_view2)) / 2
        return loss

In [10]:
crit_invar = nn.MSELoss()
crit_var = VarLoss()
crit_cov = CovLoss()

max_var = 1e-5 * batch_size
max_cov = 50 * batch_size

constraints = []
constraints.append(mdmm.MaxConstraint(lambda: crit_var(view1_embeddings, view2_embeddings), max_var))
constraints.append(mdmm.MaxConstraint(lambda: crit_cov(view1_embeddings, view2_embeddings), max_cov, scale=1e4))

mdmm_module = mdmm.MDMM(constraints)
optimizer = mdmm_module.make_optimizer(vicreg.parameters(), lr=1e-4)

# Run a training loop

In [11]:
for i, batch in enumerate(loader):
    # run VICReg forward pass to get the embeddings
    view1_embeddings, view2_embeddings = vicreg(batch.to(device))

    # compute the invariance loss which is contrained by the other loss terms
    loss = batch_size * crit_invar(view1_embeddings, view2_embeddings)
    mdmm_return = mdmm_module(loss)

    # backprop
    for param in vicreg.parameters():
        param.grad = None
    mdmm_return.value.backward()

    optimizer.step()
    print(f"constrained invariance loss: {loss.detach():.2f}")

    if i == 10:
        break

constrained invariance loss: 5.43
constrained invariance loss: 3.92
constrained invariance loss: 2.88
constrained invariance loss: 2.14
constrained invariance loss: 1.79
constrained invariance loss: 1.47
constrained invariance loss: 1.18
constrained invariance loss: 0.98
constrained invariance loss: 0.82
constrained invariance loss: 0.67
constrained invariance loss: 0.59
