In [1]:
# Copyright (c) 2023 Qualcomm Technologies, Inc.
# All Rights Reserved.

"""Generalized Batch-Shaping example notebook"""

from functools import partial
from typing import Optional

import numpy as np
import seaborn as sns
import torch
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation

from batchshaping.batch_shaping_loss import Prior, gbas_loss
from batchshaping.utils import set_seed, to_numpy, warmup_factory
from batchshaping.viz_utils import init_gbas_anim_plot, plot_gt_cdf, plot_gt_pdf

sns.set_style("darkgrid")

# Generalized Batch-Shaping


### Introduction
The original Batch-Shaping loss only acts on 1-dimensional inputs $x$ which can be a limiting factor: For instance in the case of gated networks, we may want to impose a different sparsity prior (hence a different CDF in the Batch-Shaping loss) for each layer. In such cases, we can instead minimize a sum of CDF losses:

$$
\mathcal{L}_{BaS,multi}(x^1, \dots x^T) = \sum_{k=1}^{T} \frac{1}{N} \sum_{i=1}^N \left( \hat{F}(x^k_i) - F^{\ast, k}(x^k_{i}; \phi^k) \right)^2
$$

However, defining the target prior parameters for each dimension, ($\phi^k$) is cumbersome and assumes a lot of prior knowledge. Instead, we introduce the **generalized Batch-Shaping loss** (gBaS) in which the priors' parameters are learned, and controlled by an additional hyperprior: 

$$
\mathcal{L}_{hyperprior}(\phi) = \frac{1}{T} \sum_{k=1}^T \left( \hat{F}(\phi^k) - \mathcal{F}^{\ast}(\phi^k; \psi) \right)^2
$$

where $\hat{F}$ still indicates the empirical CDF, and $\mathcal{F}$ is a prior over the parameters $\phi$, with its own parameters $\psi$.

Intuitively, each prior $F^{\ast, k}$ controls the token sparsity for a given local position, where the mean and temperature of the prior are learned latent parameters. Then, the hyperprior $\mathcal{F}$ is a Gaussian whose mean controls  the overall sparsity across all positions, and variance controls how far the prior for each position can be from one another. We introduce gBaS in more details in [MSViT: Dynamic Mixed-Scale Tokenization for Vision Transformers, Havtorn et al, arXiv 2023](http://arxiv.org/abs/2307.02321).

### A practical example
Next, we illustrate the generalized Batch-Shaping loss on an example. Say we have $T$ layers, each with their own gate, outputting a $[0, 1]$ value for each of the $N$ samples in the batch. 
For each layer, we want the distribution of the gated outputs, $x^k$ to follow a certain Relaxed Bernoulli (RB) distribution (i.e. the mass of the distribution should be located around the discrete peaks $\{0, 1\}$). 
We do not have a prior estimate for each layer individually, thus we let the model learn the mean of each RB prior. However, we do have a target gate sparsity (in this example, 30%), which we use to define the hyperprior controlling the learned means. 

In summary:
  * The learned data $x^k$'s distribution is controlled by a Relaxed Bernoulli $RB(\mu^k, \tau^k)$ for each layer
  * The distribution of the means $\mu$ is controlled by a hyperprior with a mean equal to the sparsity target and a given variance hyperparameter controlling the spread across layers (in the example, we use a Beta hyperprior $B(3, 7)$).
  * The temperatures $\tau^k$ are also learnable but do not have a hyperprior term: We only add a $L_1$ loss term on $\tau$ to favor more peaky distributions.

In [2]:
set_seed(42)
NUM_DIMENSIONS = 30  # T
NUM_POINTS_PER_DIM = 500  # N

In [3]:
def train_gbas(
    num_dimensions: int,
    num_points_per_dim: int,
    num_epochs: int,
    hyperprior: Prior,
    hyperprior_param1: float,
    hyperprior_param2: Optional[float] = None,
    temperature_l0_lw: float = 0.0,
    num_warmup_epochs: int = 0,
    lr: float = 0.1,
    init_range: float = 4,
) -> FuncAnimation:
    """Train generalized Batch-Shaping loss example"""
    # Init plot
    (
        fig,
        data_axes,
        priors_loss_ax,
        hyperprior_loss_ax,
        hyperprior_pdf_ax,
        priors_loss_plot,
        hyperprior_loss_plot,
    ) = init_gbas_anim_plot(num_epochs)
    priors_losses = []
    hyperprior_losses = []

    # Init data points randomly for each dimension
    data_list = []
    for _ in range(num_dimensions):
        mean = np.random.rand()
        data_list.append(
            (np.clip(np.random.normal(mean, 0.05, size=num_points_per_dim), 0.0, 1.0) * 2 - 1)
            * init_range
        )
    data = np.stack(data_list, axis=1)
    data_ys = np.linspace(0.0, 1.0, data.shape[0])
    data = torch.nn.Parameter(data=torch.Tensor(data), requires_grad=True)

    # Learnable parameters
    # initialize priors mean around 0.5
    # and with a large enough temperature (also 0.5)
    prior = Prior.RELAXED_BERNOULLI
    priors_means = torch.nn.Parameter(
        data=torch.normal(torch.zeros(num_dimensions), 0.01), requires_grad=True
    )
    priors_temperatures = torch.nn.Parameter(data=torch.zeros(num_dimensions), requires_grad=True)

    # Init optimizer
    optimizer = torch.optim.Adam([data, priors_means, priors_temperatures], lr=lr)
    lr_warmup = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=warmup_factory(num_warmup_epochs, lr)
    )

    # control the data points across each distribution with learned means and temperatures
    prior_loss_fn = partial(gbas_loss, prior=prior, dim=1)

    # control the learned means of the priors
    hyperprior_loss_fn = partial(
        gbas_loss,
        prior=hyperprior,
        prior_param1=hyperprior_param1,
        prior_param2=hyperprior_param2,
        dim=None,
    )
    # Plot init

    def init_func() -> None:
        loss = prior_loss_fn(
            torch.sigmoid(data),
            prior_param1=torch.sigmoid(priors_means),
            prior_param2=torch.sigmoid(priors_temperatures),
        )
        priors_loss_ax.set_ylim([-0.005, loss.item() + 0.05])
        loss = hyperprior_loss_fn(torch.sigmoid(priors_means))
        hyperprior_loss_ax.set_ylim([-0.005, loss.item() + 0.05])

    # Train + Plot update
    def one_step(i: int) -> None:
        # Optimize
        optimizer.zero_grad()
        out = torch.sigmoid(data)
        means = torch.sigmoid(priors_means)
        temps = torch.sigmoid(priors_temperatures)
        ploss = prior_loss_fn(out, prior_param1=means, prior_param2=temps)
        hloss = hyperprior_loss_fn(means)
        loss = ploss + hloss + temperature_l0_lw * torch.mean(torch.abs(temps))
        print(
            f"\r[step {i + 1:04d} / {num_epochs:04d}] loss = {loss.item():.2e}",
            end="" if i < num_epochs - 1 else "\n",
        )
        priors_losses.append(ploss.item())
        hyperprior_losses.append(hloss.item())
        loss.backward()
        optimizer.step()
        lr_warmup.step()

        # update plots
        priors_loss_plot.set_data(np.arange(i + 1), priors_losses[: i + 1])
        hyperprior_loss_plot.set_data(np.arange(i + 1), hyperprior_losses[: i + 1])
        out_arr = to_numpy(out)
        cdf_x = np.sort(out_arr, axis=0)
        cdf_y = np.arange(cdf_x.shape[0]) / cdf_x.shape[0]
        for idx, data_ax in enumerate(data_axes):
            data_ax.cla()
            m, tau = means[idx].item(), temps[idx].item()
            data_ax.scatter(
                out_arr[:, idx], data_ys, marker="o", alpha=0.15, label="data", color="orchid"
            )
            data_ax.plot(
                cdf_x[:, idx],
                cdf_y,
                linewidth=2.5,
                linestyle="dashed",
                color="xkcd:turquoise",
                label="CDF (data)",
            )
            plot_gt_cdf(data_ax, prior, m, tau)
            data_ax.set_xlabel(
                f"Dim {idx + 1} " + r"($\mu$=" + f"{m:.2f}," + r" $\tau$=" + f"{tau:.2f})",
                fontsize=14,
            )
            if idx == 0:
                data_ax.legend(loc="upper center", ncol=3, bbox_to_anchor=(1.2, 1.28))
        hyperprior_pdf_ax.cla()
        plot_gt_pdf(hyperprior_pdf_ax, hyperprior, hyperprior_param1, hyperprior_param2)
        hyperprior_pdf_ax.hist(
            to_numpy(means),
            label="PDF (data)",
            color="xkcd:turquoise",
            rwidth=0.9,
            bins=20,
            density=True,
            stacked=True,
        )
        hyperprior_pdf_ax.legend(loc="upper right")
        hyperprior_pdf_ax.set_ylabel("(Hyperprior)\n" + r"PDF of learned $\mu$s", fontsize=16)

    ani = FuncAnimation(fig, one_step, frames=num_epochs, init_func=init_func)
    plt.close()
    return ani

In [4]:
%%time
HYPERPRIOR = Prior.BETA
BETA_A = 3
BETA_B = 7

NUM_EPOCHS = 350
NUM_WARMUP_EPOCHS = 20
LR = 0.3
TEMPERATURE_L0_LW = 0.1

ANI = train_gbas(
    NUM_DIMENSIONS,
    NUM_POINTS_PER_DIM,
    NUM_EPOCHS,
    HYPERPRIOR,
    BETA_A,
    BETA_B,
    temperature_l0_lw=TEMPERATURE_L0_LW,
    num_warmup_epochs=NUM_WARMUP_EPOCHS,
    lr=LR,
)
ANI.save(
    f"outputs/generalized_batch_shaping_hyperprior={HYPERPRIOR.name.lower()}"
    f"_p1={BETA_A}_p2={BETA_B}.gif",
    fps=25,
    dpi=60,
)

MovieWriter ffmpeg unavailable; using Pillow instead.


[step 0350 / 0350] loss = 3.52e-02
CPU times: user 6min 44s, sys: 1min 2s, total: 7min 46s
Wall time: 5min 33s


![generalized_batch_shaping_hyperprior=beta_p1=3_p2=7.gif](outputs/generalized_batch_shaping_hyperprior=beta_p1=3_p2=7.gif)