# A quick demo for the analysis of KL divergence and Mahalanobis distance cost functions

This script compares the differences in properties between our newly proposed KL divergence cost function (Eq. (S11)) and the Mahalanobis distance cost function (Eq. (S12)).

In fact, they share the same fundamental objective: ensuring that the predictive function $f(\cdot ; \boldsymbol{\theta})$ is close to a uniform vector $\mathcal{U}$ over the context input space.

Although the experiments results in Section 4.2~4.5 have verified that these two cost functions can improve model's performance, we also found that, in some experimental settings, such as the TinyImageNet detection task in Table 3,  KL divergence consistently outperforms the another one.

We analyze that this is because the Mahalanobis distance may excessively increase or decrease the regularization strength of certain context samples, thereby further affecting the optimization process.

Next, we will analyze this based on Appendix Fig. S5:

![Alts](../figs/FigS5.png)

Here, consider a binary classification task, where a training data includes positive (red) and negative (blue) classes are shown. The context set $\mathcal{C}=\left\{\widehat{\mathbf{x}}_1, \widehat{\mathbf{x}}_2, \widehat{\mathbf{x}}_3, \widehat{\mathbf{x}}_4\right\}$ (green) consists of four points, with input feature $\hat{\mathbf{x}}_m \in \mathbb{R}^2$ in a 2D plane. The classifier is denoted as $\boldsymbol{\theta}$. For simplicity, its predictive function $f\left(\hat{\mathbf{x}}_m ; \boldsymbol{\theta}\right)$ is represented as $f\left(\hat{\mathbf{x}}_m\right)$. In binary classification, $[1,0]$ and $[0,1]$ indicate a negative and positive predictions, respectively, and $\mathcal{U}=[0.5,0.5]$ is a uniform output. Ideally, we expect all context samples to be predicted as the uniform vector, which corresponds to a zero cost. Now, from left to right, the outputs of different samples change, resulting in the cost increasement.

For (a) Mahalanobis distance cost function (Eq. (S12)):

$$
\operatorname{cost}(\mathcal{C}, \boldsymbol{\theta})=\beta \sum_{k=1}^K\left(f(\widehat{\mathbf{X}} ; \boldsymbol{\theta})_k-\mathcal{U}\right)^{\top} C^{-1}(\widehat{\mathbf{X}})\left(f(\widehat{\mathbf{X}} ; \boldsymbol{\theta})_k-\mathcal{U}\right), \text { where } C(\widehat{\mathbf{X}})=h(\widehat{\mathbf{X}}) h^{\top}(\widehat{\mathbf{X}})+s \mathbf{I}
$$

- The changes caused by $\widehat{\mathbf{x}}_2$ and $\widehat{\mathbf{x}}_4$ are significant (5.79 and 8.95), corresponding to a great regularization strength. Since they are relatively close to the training distribution, encouraging these inputs to output uniform vectors may affect prediction accuracy.
- On the other hand, the changes in $\widehat{\mathbf{x}}_1$ and $\widehat{\mathbf{x}}_3$ are small (2.69). As they are relatively far from the training distribution, their weak influence may limit the model's ability to generalize OOD discrimination capabity to the whole OOD space.

In contrast, for (b) our proposed KL divergence cost function (Eq. (S11)):

$$
\operatorname{cost}(\mathcal{C}, \boldsymbol{\theta})=\beta \sum_{m=1}^M \operatorname{KL}\left(f\left(\hat{\mathbf{x}}_m ; \boldsymbol{\theta}\right) \| \mathcal{U}\right),  \text { where } \operatorname{KL}\left(f\left(\widehat{\mathbf{x}}_m ; \boldsymbol{\theta}\right) \| \mathcal{U}\right)=\sum_{k=1}^K f\left(\widehat{\mathbf{x}}_m ; \boldsymbol{\theta}\right)_k \log \frac{f\left(\widehat{\mathbf{x}}_m ; \boldsymbol{\theta}\right)_k}{\mathcal{U}_k}
$$

- This influence remains consistent (4.23), thus providing a more stable optimization process.

**The following code provides an example of calculating these two types of cost functions when different sample points change.**

In [75]:
import torch

Note that the Mahalanobis distance cost function here is different from that in the script `helpers.py`, but they are equivalent. Here, the Mahalanobis distance is calculated using Appendix Eq. (S12), while in `helpers.py`, it is approximated as a _Gaussian process (GP)_ to accelerate computation process.

In [76]:
def mahalanobis_distance_cost_function(preds_f, preds_feature, beta=1):
    """
    :param preds_f: Model's logits output (without activation) [bts, K]
    :param preds_feature:  Feature extractor output [bts, ndim]
    :param beta: Cost function weight scalar parameter
    """
    K = preds_f.size(1)  # K, number of classes
    y = torch.ones_like(preds_f) / 2  # [bts, K] context set labels (we want all positions of the predicted output to be close to 0)
    cov = preds_feature @ preds_feature.T  # [bts, bts]
    cov += torch.eye(cov.size(0)).to(cov.device) * 0.05  # Add a small noise term to ensure positive definiteness
    coov_inv = torch.inverse(cov)  # [bts, bts], the inverse of the covariance matrix

    # Calculate the Mahalanobis distance cost function using Appendix Eq. (S12)
    cost = torch.tensor(0.0, device=preds_f.device)  # 初始化代价函数
    for k in range(K):
        f_k = preds_f[:, k]  # [bts], the k-th class's logits output
        y_k = y[:, k]  # [bts], the k-th class's target output
        cost += beta * (f_k - y_k).T @ coov_inv @ (f_k - y_k)  # Appendix Eq. (S12)

    return cost

**Case (i):** Changing the predictive function to $f(\cdot ; \boldsymbol{\theta}) = [1, 0]$ for the first context point results in:
- (a) Mahalanobis distance cost of $\operatorname{cost}=2.69$
- (b) KL divergence cost of $\operatorname{cost}=4.23$.

In [77]:
X = torch.tensor([
    [2.0, 4.0],
    [4.0, 4.0],
    [4.0, 2.0],
    [2.0, 2.0]
],
    dtype=torch.float32)
Y = torch.tensor([
    [1, 0],
    [0.5, 0.5],
    [0.5, 0.5],
    [0.5, 0.5]
],
    dtype=torch.float32)
uniform_dist = torch.Tensor(Y.size(0), 2).fill_((1. / 2))
print(f"Changing the first context point to [1, 0] results in:")
print(f"Mahalanobis distance cost = {mahalanobis_distance_cost_function(Y, X) :.2f}")
print(f"KL divergence cost = {torch.nn.functional.kl_div(torch.log(Y+1e-30), uniform_dist) :.2f}")


Changing the first context point to [1, 0] results in:
Mahalanobis distance cost = 2.69
KL divergence cost = 4.23


**Case (ii):** Changing the predictive function to $f(\cdot ; \boldsymbol{\theta}) = [1, 0]$ for the second context point results in:
- (a) Mahalanobis distance cost of $\operatorname{cost}=5.79$
- (b) KL divergence cost of $\operatorname{cost}=4.23$.

In [78]:
X = torch.tensor([
    [2.0, 4.0],
    [4.0, 4.0],
    [4.0, 2.0],
    [2.0, 2.0]
],
    dtype=torch.float32)
Y = torch.tensor([
    [0.5, 0.5],
    [1, 0],
    [0.5, 0.5],
    [0.5, 0.5]
],
    dtype=torch.float32)
uniform_dist = torch.Tensor(Y.size(0), 2).fill_((1. / 2))
print(f"Changing the second context point to [1, 0] results in:")
print(f"Mahalanobis distance cost = {mahalanobis_distance_cost_function(Y, X) :.2f}")
print(f"KL divergence cost = {torch.nn.functional.kl_div(torch.log(Y+1e-30), uniform_dist) :.2f}")

Changing the second context point to [1, 0] results in:
Mahalanobis distance cost = 5.79
KL divergence cost = 4.23


**Case (iii):** Changing the predictive function to $f(\cdot ; \boldsymbol{\theta}) = [1, 0]$ for the third context point results in:
- (a) Mahalanobis distance cost of $\operatorname{cost}=2.69$
- (b) KL divergence cost of $\operatorname{cost}=4.23$.

In [79]:
X = torch.tensor([
    [2.0, 4.0],
    [4.0, 4.0],
    [4.0, 2.0],
    [2.0, 2.0]
],
    dtype=torch.float32)
Y = torch.tensor([
    [0.5, 0.5],
    [0.5, 0.5],
    [1, 0],
    [0.5, 0.5],
],
    dtype=torch.float32)
uniform_dist = torch.Tensor(Y.size(0), 2).fill_((1. / 2))
print(f"Changing the third context point to [1, 0] results in:")
print(f"Mahalanobis distance cost = {mahalanobis_distance_cost_function(Y, X) :.2f}")
print(f"KL divergence cost = {torch.nn.functional.kl_div(torch.log(Y+1e-30), uniform_dist) :.2f}")

Changing the third context point to [1, 0] results in:
Mahalanobis distance cost = 2.69
KL divergence cost = 4.23


**Case (iv):** Changing the predictive function to $f(\cdot ; \boldsymbol{\theta}) = [1, 0]$ for the fourth context point results in:
- (a) Mahalanobis distance cost of $\operatorname{cost}=8.95$
- (b) KL divergence cost of $\operatorname{cost}=4.23$.

In [80]:
X = torch.tensor([
    [2.0, 4.0],
    [4.0, 4.0],
    [4.0, 2.0],
    [2.0, 2.0]
],
    dtype=torch.float32)
Y = torch.tensor([
    [0.5, 0.5],
    [0.5, 0.5],
    [0.5, 0.5],
    [1, 0],
],
    dtype=torch.float32)
uniform_dist = torch.Tensor(Y.size(0), 2).fill_((1. / 2))
print(f"Changing the third context point to [1, 0] results in:")
print(f"Mahalanobis distance cost = {mahalanobis_distance_cost_function(Y, X) :.2f}")
print(f"KL divergence cost = {torch.nn.functional.kl_div(torch.log(Y+1e-30), uniform_dist) :.2f}")

Changing the third context point to [1, 0] results in:
Mahalanobis distance cost = 8.95
KL divergence cost = 4.23


Through the above experiments, we found that the Mahalanobis distance cost function actually assigns different regularization strengths to different sample pairs, thereby affecting the model's optimization process.

In contrast, the stable regularization strength of our newly proposed KL divergence cost function can provide more effective FSVI learning.

The script `fsviContextSpecification.ipybn` provides a more detailed experimental evidence of this point.