# Learning Optimal Linear Regularizers

## Matthew Streeter

ICML 2019

https://arxiv.org/abs/1902.07234

---
# Introduction

### Loss Function in Training

empirical training loss + regularization penalty

### Regularizer

Explicit: L1, L2, ...

Implicit: dropout, early stopping, ...

### Optimal Regularizer

The one that provides the tightest possible bound on **generalization gap** (i.e., difference between test and training loss).

### Model of this Paper

<img src="model.jpg" width="550">

---
# Notation

1. $\theta$: parameters of a model
2. $\Theta$: a set of models
3. $D$: an unknown distribution
4. $z$: an instance, a <$feature, label$> pair, drawn from $D$
5. $l$: loss function
6. $L(\theta) = \mathbb{E}_{z \sim D}[l(z,\theta)]$: expected value of loss
7. $\hat{L}(\theta) = \frac{1}{n}\sum_{i=1}^{n}l(z_i,\theta)$: average training loss
8. $R$: regularizer
9. $f(\theta) = \hat{L}(\theta) + R(\theta)$: objective function to minimize
10. $\hat{\theta} = argmin_{\theta \in \Theta}\{f(\theta)\}$
11. $\theta^{*} = argmin_{\theta \in \Theta}\{L(\theta)\}$
12. $L(\hat{\theta}) - L(\theta^{*})$: excess test loss
13. choose $R$ so that the excess test loss is as small as possible

---
# Regularizers and Generalization Bounds

If $R(\theta) = L(\theta^{*}) - \hat{L}(\theta)$, then $min\{f(\theta)\} = L(\theta^{*})$ and $L(\hat{\theta}) = L(\theta^{*})$ 

(according to the definition of $f$)

因为 $L$ 按照定义是对**所有**数据的 loss 的期望值，因此这个定义里面就包含了**未知**的数据。如果我们找到了 optimal regularizer $R$ ，那么我们在 training 时 minimize $f$ ，就相当于同时 minimize 未知数据的 loss 了。换个说法就是，我们直接就能够在 training set 上面去 minimize test set loss 。

Use a regularizer that accurately **estimates** the generalization gap.

> *A good regularizer is one that provides an upper bound on the generalization gap that is tight at near-optimal points.*

Training 的时候，一般都是选定一个 $R$，然后在这个固定的 $R$ 的情况下，从一堆 $\theta$ 里面找到一个最好的 $\hat{\theta}$

For a fixed $R$, define the *slack* of $R$ at a point $\theta$ as
$$\Delta(\theta) \equiv f(\theta) - L(\theta)$$

在固定的 $R$ 的情况下，会有一个最好的 $\hat{\theta}$，任何其他 $\theta$ 都与 $\hat{\theta}$ 有一定差距

For any $\theta$, define the *suboptimality* as
$$S(\theta) \equiv f(\theta) - f(\hat{\theta})$$

<img src="fig1.jpg" width="500">

前面定义了 excess test loss = $L(\hat{\theta}) - L(\theta^{*})$，结合 $\Delta$ 和 $S$ 的定义，于是有
<img src="SAS.jpg">

We refer to the quantity $SAS(\theta)$ as *suboptimality-adjusted slack*.

$\hat{L}(\theta) - L(\theta^{*}) = max_{\theta \in \Theta}\{SAS(\theta)\}$

That is, the excess test loss of a hypothesis $\hat{\theta}$ obtained by minimizing $f$ is the **worst-case suboptimality-adjusted slack**. An optimal regularizer is therefore one that minimizes this quantity. 

因此，我们想要找一个 $R$ ，使得 excess test loss 最小，那么我们就要找到一个 $R$ ，使得 worst-case suboptimality-adjusted slack 最小。

**Proposition 1**
$$R^{*} = argmin_{R \in \mathcal{R}}\big\{L(\hat{\theta}(R))\big\} = argmin_{R \in \mathcal{R}}\big\{max_{\theta \in \Theta}\{SAS(\theta;R)\}\big\}$$

现在我们已经有了一个求最优的 $R^{*}$ 的等式了，但实际上我们还是没办法直接解出来，因为 $\Delta$ 的定义里面是用到了 $L$ 的，但我们没办法知道每一个模型 $\theta \in \Theta$ 对应的 $L$ （虽然理论上，只要给无限时间与算力，总有一天能把所有的都算出来），所以我们没有办法准确的得到 $max_{\theta \in \Theta}\{SAS(\theta)\}$。

但是，我们是可以计算出少量模型 $\theta \in \Theta_0 \subset \Theta$ 的 validation loss。于是我们可以计算一个 approximately optimal regularizer
$$R^{*} \approx argmin_{R \in \mathcal{R}}\big\{max_{\theta \in \Theta_0}\{\hat{SAS}(\theta;R)\}\big\}$$
where $\hat{SAS}$ is an estimate of $SAS$ that uses validation loss
as a proxy for test loss, and uses $\Theta_0$ as a proxy for $\Theta$

---
# Learning Linear Regularizers

### Linear Regularizers

$R(\theta; \lambda) = \lambda \cdot q(\theta)$

$q$ is a function mapping the model to a feature vector

#### Example: Coin Flips

$p_i \sim Beta(\alpha, \beta)$

$R = LogitBeta(\theta) = -\frac{1}{n}\sum_{i}\alpha log(\theta_i) + \beta log(1-\theta_i)$

$q(\theta) = <-\sum_{i}\alpha log(\theta_i),  -\sum_{i}\beta log(1-\theta_i)>$

### Algorithm

$V(\theta)$: average loss on validation set

Wish to find the regularizer $\hat{R} \equiv argmin_{R \in \mathcal{R}} \{V(\hat{\theta}_{0}(R))\}$

<img src="LearnLinReg.jpg" width=500>

#### Theorm 1
<img src="Theorm1.jpg" width=500>

#### Hyperparameter Tuning
<img src="TuneReg.jpg" width=500>