In [1]:
# Copyright (c) 2023 Qualcomm Technologies, Inc.
# All Rights Reserved.

"""Batch-Shaping example notebook"""

from functools import partial
from typing import Optional

import numpy as np
import seaborn as sns
import torch
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation

from batchshaping.batch_shaping_loss import Prior, gbas_loss
from batchshaping.utils import set_seed, to_numpy
from batchshaping.viz_utils import init_anim_plot, plot_gt_cdf, plot_gt_pdf

sns.set_style("darkgrid")

# Batch-Shaping

The **Batch-Shaping loss** (BaS) is a probability distribution matching tool derived from the Cramér–von Mises goodness of fit criterion. It minimizes the difference between the cumulative distribution function (CDF) of the target distribution, and the empirical CDF of the current data points: 

$$
\mathcal{L}(x) = \frac{1}{N} \sum_{i=1}^N \left( \hat{F}(x_i) - F^\ast(x_{i}; \phi) \right)^2
$$

where $i$ is the sample index, $N$ is the number of samples, $F^\ast$ is the CDF of the target prior distribution with parameters $\phi$, and $\hat{F}$ is the empirical CDF of the data, which can be estimated as $\hat{F}(x_i) = \frac{i}{N}$, assuming that  $x$ is sorted.




See [Batch-Shaping Loss for Learning Conditional-Channel Gated Networks, Bejnordi et al, ICLR 2020](https://arxiv.org/abs/1907.06627) and [MSViT: Dynamic Mixed-Scale Tokenization for Vision Transformers, Havtorn et al, arXiv 2023](http://arxiv.org/abs/2307.02321) for more details.


## A. The Batch-Shaping loss in practice

In the following examples, we illustrate how the Batch-Shaping loss can be used to match set of data points $x \in [0, 1]$ to a Relaxed Bernoulli distribution (aka. Binary Concrete) or the Beta distribution. Each plot depicts the evolution of the data points (on the $x$ axis with a random ordinate value for readability) and the empirical CDF. The two bottom plots represent the evolution of the training loss and empirical probability density function (PDF) at the same time.


In [2]:
set_seed(42)
NUM_POINTS = 500

In [3]:
def train_bas(
    num_points: int,
    num_epochs: int,
    prior: Prior,
    prior_param1: float,
    prior_param2: Optional[float] = None,
    lr: float = 0.1,
    init_range: float = 4,
) -> FuncAnimation:
    """Train Batch-Shaping loss example"""
    # Init plot
    fig, data_ax, loss_ax, pdf_ax, data_plot, cdf_pred_plot, loss_plot = init_anim_plot(num_epochs)
    losses = []

    def init_func() -> None:
        plot_gt_cdf(data_ax, prior, prior_param1, prior_param2)
        data_ax.legend(loc="upper center", ncol=3, bbox_to_anchor=(0.5, 1.09))
        out = torch.sigmoid(data)
        loss = loss_fn(out[None, :])
        loss_ax.set_ylim([-0.005, loss.item() + 0.05])

    # Init data points (will be fed as input to sigmoid for support of [0, 1])
    data = (
        torch.rand(num_points, device=torch.device("cpu"), requires_grad=True) * 2 - 1
    ) * init_range
    data_ys = np.linspace(0.0, 1.0, data.shape[0])
    data = torch.nn.Parameter(data, requires_grad=True)

    # Init optimizer
    optimizer = torch.optim.Adam([data], lr=lr)
    loss_fn = partial(
        gbas_loss, prior=prior, prior_param1=prior_param1, prior_param2=prior_param2, dim=None
    )

    # Train update
    def one_step(i: int) -> None:
        # Optimize
        optimizer.zero_grad()
        out = torch.sigmoid(data)
        loss = loss_fn(out)
        print(
            f"\r[step {i + 1:04d} / {num_epochs:04d}] loss = {loss.item():.2e}",
            end="" if i < num_epochs - 1 else "\n",
        )
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

        # update plots
        loss_plot.set_data(np.arange(i + 1), losses[: i + 1])
        data_plot.set_offsets(np.vstack((to_numpy(out), data_ys)).T)
        cdf_x = np.sort(to_numpy(out))
        cdf_y = np.arange(cdf_x.shape[0]) / cdf_x.shape[0]
        cdf_pred_plot.set_data(cdf_x, cdf_y)
        pdf_ax.cla()
        plot_gt_pdf(pdf_ax, prior, prior_param1, prior_param2)
        pdf_ax.hist(
            cdf_x,
            label="PDF (data)",
            color="xkcd:turquoise",
            rwidth=0.9,
            bins=20,
            density=True,
            stacked=True,
        )
        pdf_ax.legend(loc="upper center")

    ani = FuncAnimation(fig, one_step, frames=num_epochs, init_func=init_func)
    plt.close()
    return ani

#### Example 1: Relaxed Bernoulli with mean 0.2 and temperature 0.25

In [4]:
%%time
NUM_EPOCHS = 100
LR = 0.1
PRIOR = Prior.RELAXED_BERNOULLI
MEAN = 0.2
TEMPERATURE = 0.25

ANI = train_bas(
    NUM_POINTS,
    num_epochs=NUM_EPOCHS,
    lr=LR,
    prior=PRIOR,
    prior_param1=MEAN,
    prior_param2=TEMPERATURE,
)
ANI.save(f"outputs/batch_shaping_{PRIOR.name.lower()}_p1={MEAN}_p2={TEMPERATURE}.gif", fps=10)

MovieWriter ffmpeg unavailable; using Pillow instead.


[step 0100 / 0100] loss = 2.38e-02
CPU times: user 28.5 s, sys: 721 ms, total: 29.2 s
Wall time: 28.8 s


![batch_shaping_relaxed_bernoulli_p1=0.2_p2=0.25](outputs/batch_shaping_relaxed_bernoulli_p1=0.2_p2=0.25.gif)

#### Example 2:  (Symmetric) Beta distribution with parameters a = 0.9 and b = 1 - a = 0.1

In [5]:
%%time
NUM_EPOCHS = 120
LR = 0.1
PRIOR = Prior.SYMBETA
MEAN = 0.9

ANI = train_bas(NUM_POINTS, num_epochs=NUM_EPOCHS, lr=LR, prior=PRIOR, prior_param1=MEAN)
ANI.save(f"outputs/batch_shaping_{PRIOR.name.lower()}_p1={MEAN}.gif", fps=10)

MovieWriter ffmpeg unavailable; using Pillow instead.


[step 0120 / 0120] loss = 1.86e-02
CPU times: user 33.8 s, sys: 609 ms, total: 34.4 s
Wall time: 33.9 s


![batch_shaping_symbeta_p1=0.9.gif](outputs/batch_shaping_symbeta_p1=0.9.gif)

#### Example 3: Beta distribution with parameters a = 2 and b = 6

In [6]:
%%time
NUM_EPOCHS = 100
LR = 0.1
PRIOR = Prior.BETA
BETA_A = 2.0
BETA_B = 6.0

ANI = train_bas(
    NUM_POINTS, num_epochs=NUM_EPOCHS, lr=LR, prior=PRIOR, prior_param1=BETA_A, prior_param2=BETA_B
)
ANI.save(f"outputs/batch_shaping_{PRIOR.name.lower()}_p1={BETA_A}_p2={BETA_B}.gif", fps=10)

MovieWriter ffmpeg unavailable; using Pillow instead.


[step 0100 / 0100] loss = 3.14e-06
CPU times: user 28.4 s, sys: 577 ms, total: 29 s
Wall time: 28.5 s


![batch_shaping_beta_p1=2_p2=6](outputs/batch_shaping_beta_p1=2_p2=6.gif)

## B. Comparison to the KL-divergence

The KL divergence (or its symmetric extension, the Jensen-Shannon divergence) is a popular tool for matching density functions.

$$
KL(p, q) = \sum_x p(x) \log \left(\frac{p(x)}{q(x)}\right)
$$

However, it requires estimating the empirical PDF of the data, which is non trivial for continuous data; Instead, the most common use cases of KL assumes a known parametric distribution on the data, typically the same distribution family as the target PDF ([see Pytorch's implementation of KL divergence for all supported distributions pairs](https://pytorch.org/docs/stable/distributions.html#module-torch.distributions.kl)). In contrast, the Batch-Shaping loss directly estimates the empirical CDF of the data, which is easy to compute for both continuous or discrete data in practice.


## C. Batch-Shaping for discrete data and finite support distribution

Batch-Shaping can also be used to match discrete distributions. However, directly trying to match the correponsding CDF may sometimes lead to vanishing gradients: In fact, in parts of the space where the target probability density function is close or equal to zero (i.e., outside of the support), the gradient flowing to the data will have very low magnitudes, leading to slow training speeds. 
Therefore, it is important to tune training hyperparameters accordingly in some cases. We illustrate this potential issue in two examples below.


#### Example 1: Edges of a Gaussian distribution
We first take the example of a Gaussian distribution: Data points at the edges of the support where the density is almost zero do move towards where the mass of the distribution lies, but rather slowly. 
In the example, we improve training speed with a better choice of the initial clipping range to avoid regions with low density of the prior.

In [7]:
%%time
NUM_EPOCHS = 200
LR = 0.1
PRIOR = Prior.NORMAL
BETA_A = 0.6
BETA_B = 0.1

ANI = train_bas(
    NUM_POINTS,
    num_epochs=NUM_EPOCHS,
    lr=LR,
    prior=PRIOR,
    prior_param1=BETA_A,
    prior_param2=BETA_B,
    init_range=4,
)
ANI.save("outputs/test_outside_support.gif", fps=10)

ANI = train_bas(
    NUM_POINTS,
    num_epochs=NUM_EPOCHS,
    lr=LR,
    prior=PRIOR,
    prior_param1=BETA_A,
    prior_param2=BETA_B,
    init_range=1,
)
ANI.save("outputs/test_inside_support.gif", fps=10)

MovieWriter ffmpeg unavailable; using Pillow instead.


[step 0200 / 0200] loss = 1.29e-03


MovieWriter ffmpeg unavailable; using Pillow instead.


[step 0200 / 0200] loss = 1.79e-08
CPU times: user 1min 52s, sys: 2.03 s, total: 1min 54s
Wall time: 1min 53s


**Points lying in low density regions**

![test_outside_support](outputs/test_outside_support.gif)

**Better initialization to avoid low density regions**

![test_inside_support](outputs/test_inside_support.gif)

#### Example 2: Relaxed Bernouli with low temperature

Another example is the Relaxed Bernoulli distribution (aka Binary Concrete): When the temperature $t \rightarrow 0$, the distribution becomes closer to the discrete Bernoulli distribution. In this setting, we do not observe any particular issue regarding training speed, even at low temperatures

In [8]:
%%time
NUM_EPOCHS = 200
LR = 0.1
PRIOR = Prior.RELAXED_BERNOULLI
MEAN = 0.8

T1 = 0.4
ANI = train_bas(
    NUM_POINTS, num_epochs=NUM_EPOCHS, lr=LR, prior=PRIOR, prior_param1=MEAN, prior_param2=T1
)
ANI.save("outputs/test_high_temperature.gif", fps=10)

T2 = 0.01
ANI = train_bas(
    NUM_POINTS, num_epochs=NUM_EPOCHS, lr=LR, prior=PRIOR, prior_param1=MEAN, prior_param2=T2
)
ANI.save("outputs/test_low_temperature.gif", fps=10)

MovieWriter ffmpeg unavailable; using Pillow instead.


[step 0200 / 0200] loss = 2.70e-03


MovieWriter ffmpeg unavailable; using Pillow instead.


[step 0200 / 0200] loss = 1.65e-01
CPU times: user 1min 51s, sys: 1.98 s, total: 1min 53s
Wall time: 1min 51s


### With higher temperature $t = 0.4$

![test_hig_temperature](outputs/test_high_temperature.gif)

### With lower temperature $t = 0.01$

![test_low_temperature](outputs/test_low_temperature.gif)