Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement distributional anchor regression in glum #573

Open
mlondschien opened this issue Sep 28, 2022 · 0 comments
Open

Implement distributional anchor regression in glum #573

mlondschien opened this issue Sep 28, 2022 · 0 comments

Comments

@mlondschien
Copy link
Contributor

I am interested in domain generalization (DG, also "external validity) of statistical / machine learning models.
Anchor regression [1] is a recent idea interpolating between OLS and IV. [3] give ideas to generalize anchor regression to more general distributions (including classification). [2] is a "nice-to-read" summary, including ideas on how to extend to non-linear settings.

To my knowledge, no efficient implementations for anchor regression or classification exist. I'd be interested to contribute this to my favorite GLM library but would need some guidance.

What is Anchor Regression?

Anchor regression improves the DG / external validity of OLS by adding a regularization term penalizing the correlation between a so-called anchor variable and the regression's residuals. The anchor variable is assumed to be exogenous to the system, i.e., not directly causally affected by covariates, the outcome, or relevant hidden variables. See the following causal graph:

graph LR
A --> U & X & Y
U --> X & Y
X --> Y

What is an anchor?: Say we are interested to predict health outcomes in the ICU. Possibly valid anchor variables would be hospital id (one-hot encoded) or some transformation of time of year. The choice of anchor depends on the application. If we would like to predict out of time but on the same hospitals as seen in training, using time of year as anchor suffices. The hospital id should be included in the covariates (X). If we however would like to generalize across hospitals (i.e., predict on unseen hospitals), we need to include hospital id as an anchor (and exclude it from covariates). A similar example would be insurance with geographical location and time of year.

Write $P_A$ for the $\ell_2$-projection onto the column-space of $A$ (i.e., $P_A(\cdot) = \mathbb{E}[\cdot \mid A]$) and let $\gamma>0$. In a regression setting, the anchor regression solution is given by:

$$
b^\gamma = \underset{b}{\arg\min} \mathbb{E}\textrm{train}[((\mathrm{Id} - P_A)(Y - X^T b))^2] + \gamma \mathbb{E}\textrm{train}[(P_A(Y - X^T b))^2].
$$

Given samples from $P_\mathrm{train}$, write $\Pi_A$ for the projection onto the column space of $A$, this can be estimated as

$$ \hat b^\gamma = \underset{b}{\arg\min} |((\mathrm{Id} - \Pi_A)(Y - X^T b))|_2^2 + \gamma | \Pi_A (Y - X^T b)|_2^2. $$

[1] show that the anchor regression solution protects against the worst-case risk with respect to distribution shifts induced through the anchor variable. Here $\gamma$ controls the size of the set of distributions the method protects against, which is generated by $\sqrt{\gamma}$-times the shifts as seen in the training data [1, Theorem 1].

In an instrumental variable (IV) setting (no direct causal effect $A \to U$, $A \to Y$, "sufficient" effect $A \to X$), anchor regression interpolates between OLS and IV regression, with $\hat b^\gamma$ converging to the IV solution for $\gamma \to \infty$. This is because the IV solution can be written as

$$ \hat b^\textrm{IV} = \underset{b \colon \mathrm{Cor}(A, X^T b - Y)=0}{\arg\min} |Y - X^T b|_2^2. $$

In low-dimensional settings, (1) can be optimized using the transformation

$$ \tilde X := (\mathrm{Id} - \Pi_A) X + \sqrt{\gamma} \Pi_A X \ \ \textrm{ and }\ \ \tilde Y := (\mathrm{Id} - \Pi_A)Y + \sqrt{\gamma} \Pi_A Y, $$

where $\Pi_A = A (A^T A)^{-1} A^T$ (this needs not to be calculated though).

What is Distributional Anchor Regression?

[2] present ideas on how to generalize anchor regression from OLS to GLMs. In particular, if $f$ are raw scores, they propose to use residuals

$$ r = \frac{d}{d f} \ell(f, y). $$

For $f = X^T \beta$ and $\ell(f, y) = \frac{1}{2}(y - f)^2$ this reduces to anchor regression. For logistic regression, with $Y \in {-1, 1}$ and

$$ \ell(f, y) = - \sum_i \log(1 + \exp(-y_i f_i)), $$

this yields residuals

$$ r = \frac{d}{d f} \ell(f, y) = y (1 + \exp(y_i f_i))^{-1} = \tilde y - p_i, $$

where $\tilde y = \frac{y}{2} + 0.5 \in {0, 1}$ and $p_i = (1 + \exp(-f_i))^{-1}$.

Define $\ell^\gamma(y, f) := \ell(f, y) + (\gamma - 1) | \Pi_A r |_2^2$. The gradient of the anchor loss is given as

$$ \frac{d}{d f_i} \ell^\gamma(f, y) = y_i (1 + \exp(y_i f_i))^{-1} - 2 (\gamma - 1) (\Pi_A r)_i p_i (1 - p_i). $$

The Hessian is (not pretty)

$$
\frac{d}{d f_i f_j}
\ell^\gamma(f, y) = -\mathbb{1}_{{i = j}} p_i ( 1 - p_i) \left(1 + 2(\gamma - 1) (1 - 2p_i) (\Pi_A r)i \right) + 2 (\gamma - 1) p_i (1 - p_i) p_j (1 - p_j) (\Pi_A){i, j}
$$

If $f = X^T \beta$, then (here, $\cdot$ is matrix multiplication)

$$ \frac{d}{d \beta} \ell^\gamma(X^T\beta, y) = y(1 + \exp(yf)^{-1}) \cdot X + 2(\gamma - 1) p (1 - p) \cdot \Pi_A X $$

and

$$ \frac{d}{d^2 \beta} \ell^\gamma(X^T\beta, y) = X^T \cdot \textrm{diag}(p (1 - p) (1 + 2(\gamma - 1)(1 - 2p)\Pi_A r)) X + X^T \cdot \mathrm{diag}(p (1-p)) \cdot \Pi_A \cdot \mathrm{diag}(p (1-p))\cdot X $$

Computational considerations

Here is some numpy code calculating and testing the above derivatives:

import numpy as np
import pytest
from scipy.optimize import approx_fprime


def predictions(f):
    return 1 / (1 + np.exp(-f))


def proj(A, f):
    return np.dot(A, np.linalg.lstsq(A, f, rcond=None)[0])


def proj_matrix(A):
    return np.dot(np.dot(A, np.linalg.inv(A.T @ A)), A.T)


def loss(X, beta, y, A, gamma):
    f = X @ beta
    r = (y / 2 + 0.5) - predictions(f)
    return -np.sum(np.log1p(np.exp(-y * f))) + (gamma - 1) * np.sum(proj(A, r) ** 2)


def grad(X, beta, y, A, gamma):
    f = X @ beta
    p = predictions(f)
    r = (y / 2 + 0.5) - p

    return (r - 2 * (gamma - 1) * proj(A, r) * p * (1 - p)) @ X


def hess(X, beta, y, A, gamma):
    f = X @ beta
    p = predictions(f)
    r = (y / 2 + 0.5) - p
    diag = -np.diag(p * (1 - p) * (1 + 2 * (gamma - 1) * (1 - 2 * p) * proj(A, r)))
    dense = proj_matrix(A) * p * (1 - p)[np.newaxis, :] * (p * (1 - p))[:, np.newaxis]

    return X.T @ (diag + 2 * (gamma - 1) * dense) @ X


@pytest.mark.parametrize("gamma", [0, 0.1, 0.8, 1, 5])
def test_grad_hess(gamma):
    rng = np.random.default_rng(0)
    n = 100
    p = 10
    q = 3

    X = rng.normal(size=(n, p))
    beta = rng.normal(size=p)

    y = 2 * rng.binomial(1, 0.5, n) - 1

    A = rng.normal(size=(n, q))

    approx_grad = approx_fprime(beta, lambda b: loss(X, b, y, A, gamma))
    np.testing.assert_allclose(approx_grad, grad(X, beta, y, A, gamma), 1e-5)

    approx_hess = approx_fprime(beta, lambda b: grad(X, b, y, A, gamma), 1e-7)
    np.testing.assert_allclose(approx_hess, hess(X, beta, y, A, gamma), 1e-5)

I understand that glum implements different solvers. As $\ell_1$-regularization is popular in the robustness community, the irls solver is most interesting.

To my understanding, the computation of the full projection matrix above can be skipped using a QR decomposition of $A$. However, in your implementation, you never actually compute the Hessian, but rather an approximation. And your implementation appears to depend heavily on the Hessian being of the form $X^T D X$ for some diagonal $D$, which is no longer the case here.

Summary

Anchor regression interpolates between OLS and IV regression to improve the models' robustness to distribution shifts.
Distributional anchor regression is a generalization to GLMs. To my knowledge, no efficient solver for distributional anchor regression exists.

Is this something you would be interested to integrate into glum? How complex would this be? Are there any hurdles (e.g., dense Hessian) that prohibit the use of existing methods?

References

[1] Rothenhäusler, D., N. Meinshausen, P. Bühlmann, and J. Peters (2021). Anchor regression: Heterogeneous data meet causality. Journal of the Royal Statistical Society Series B (Statistical Methodology) 83(2), 215–246.

[2] Bühlmann, P. (2020). Invariance, causality and robustness. Statistical Science 35(3), 404– 426.

[3] Kook, L., B. Sick, and P. Bühlmann (2022). Distributional anchor regression. Statistics and Computing 32(3), 1–19.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant