Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NB and ZINB likelihoods #1656

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -107,7 +107,7 @@ We have a public [GPflow slack workspace](https://gpflow.slack.com/). Please use

### Contributing

All constructive input is gratefully received. For more information, see the [notes for contributors](contributing.md).
All constructive input is gratefully received. For more information, see the [notes for contributors](CONTRIBUTING.md).

### Projects using GPflow

Expand Down
@@ -0,0 +1,93 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,.pct.py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.3.3
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---

# %% [markdown]
# # Regression with over-dispersed count data
#
# This notebook demonstrates non-parametric regression modelling of over-dispersed count data, i.e., when the response variable $Y$ does not have an approximate Gaussian distribution, but is non-negative discrete $Y \in \mathbb{N}^0$ and when overdispersion occurs, i.e., when the mean of the conditional distribution of $Y$ is not equal to its mean.

# %%
import gpflow

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

# %matplotlib inline
import matplotlib.pyplot as plt

tfd = tfp.distributions
plt.rcParams["figure.figsize"] = (12, 4)


# %% [markdown]
# The first generate a synthetic data set for demonstration. We sample data using Tensorflow Probability, using an exponential covariance function to parameterize the latent GP.

# %%
def generate_data(n=30, s=20, seed=123):
X = np.linspace(-1, 1, n).reshape(-1, 1)
k = gpflow.kernels.Exponential(lengthscales=0.25, variance=2.0)

f = (
tfd.MultivariateNormalTriL(np.repeat(0.0, X.shape[0]), tf.linalg.cholesky(k(X, X)))
.sample(1, seed=seed)
.numpy()
.reshape((n, 1))
)

y = tfd.NegativeBinomial(logits=f, total_count=s).sample(1, seed=seed).numpy()
Y = y.reshape((n, 1))

return Y, X, f


n, s = 100, 10
Y, X, f = generate_data(n, s)
p = tf.math.sigmoid(f)
mu = s * p / (1 - p)

fig = plt.figure()
plt.plot(X, mu, color="darkred", label="latent function value")
plt.plot(X, Y, "kx", mew=1.5, label="observed data")
plt.legend(bbox_to_anchor=(1.0, 0.5))
plt.show()

# %%
train_idxs = sorted(np.random.choice(n, int(n / 2.0), False))
test_idxs = np.setdiff1d(np.arange(n), train_idxs)

# %% [markdown]
# Next, we specify a covariance function, the likelihood, and a variational GP object. Since, we cannot assume to know the covariance function of the latent GP in practice, we choose to use a Matern 3/2 kernel here.

# %%
kernel = gpflow.kernels.Matern32()
likelihood = gpflow.likelihoods.NegativeBinomial()

m = gpflow.models.VGP(data=(X[train_idxs], Y[train_idxs]), kernel=kernel, likelihood=likelihood)

# %%
opt = gpflow.optimizers.Scipy()
_ = opt.minimize(m.training_loss, m.trainable_variables, options=dict(maxiter=200))

# %%
Y_hat, Y_hat_var = m.predict_y(X[test_idxs])

plt.plot(X, mu, color="darkred", label="latent function value", alpha=0.5)
plt.plot(X[test_idxs], Y_hat, color="darkblue", label="predictive posterior mean", alpha=0.5)
plt.plot(X[test_idxs], Y_hat + 1.5 * np.sqrt(Y_hat_var), "--", lw=2, color="darkblue", alpha=0.5)
plt.plot(X[test_idxs], Y_hat - 1.5 * np.sqrt(Y_hat_var), "--", lw=2, color="darkblue", alpha=0.5)
plt.plot(X[train_idxs], Y[train_idxs], "kx", mew=1.5, label="observed data")
plt.legend(bbox_to_anchor=(1.0, 0.5))
plt.show()
1 change: 1 addition & 0 deletions doc/source/notebooks/intro.md
Expand Up @@ -58,6 +58,7 @@ This section explains the more complex models and features that are available in
- [Inter-domain Variational Fourier features](advanced/variational_fourier_features.ipynb): how to add new inter-domain inducing variables, at the example of representing sparse GPs in the spectral domain.
- [Manipulating kernels](advanced/kernels.ipynb): information on the covariances that are included in the library, and how you can combine them to create new ones.
- [Convolutional GPs](advanced/convolutional.ipynb): how we can use GPs with convolutional kernels for image classification.
- [Regression with over-dispersed count data](advanced/regression_with_overdispersed_count_data.ipynb): how we can use GPs for regression when the data are non-negative counts

### Features

Expand Down
8 changes: 7 additions & 1 deletion gpflow/likelihoods/__init__.py
Expand Up @@ -27,4 +27,10 @@
MultiLatentTFPConditional,
)
from .scalar_continuous import Beta, Exponential, Gamma, Gaussian, StudentT
from .scalar_discrete import Bernoulli, Ordinal, Poisson
from .scalar_discrete import (
Bernoulli,
NegativeBinomial,
Ordinal,
Poisson,
ZeroInflatedNegativeBinomial,
)
83 changes: 83 additions & 0 deletions gpflow/likelihoods/scalar_discrete.py
Expand Up @@ -169,3 +169,86 @@ def _conditional_variance(self, F):
E_y = phi @ Ys
E_y2 = phi @ (Ys ** 2)
return tf.reshape(E_y2 - E_y ** 2, tf.shape(F))


class NegativeBinomial(ScalarLikelihood):
"""
A likelihood for count data with overdispersion. The pmf
of this parameterization of the negative binomial distribution is given by

.. math::

NB(y \mid \mu, \psi) =
\frac{\Gamma(y + \psi)}{y! \Gamma(\psi)}
\left( \frac{\mu}{\mu + \psi} \right)^y
\left( \frac{\psi}{\mu + \psi} \right)^\psi

and described by a mean :math:`\mu = \exp(\nu)` parameter, where the predictor
:math:`\nu` is given by a latent GP, and a parameter that controls
overdispersion :math:`\psi`.

The expected value and variance of a negative binomially distributed random
variable :math:`y` is given by :math:`\mathbb{E}[y] = \mu `
and variance :math:`Var[Y] = \mu + \frac{\mu^2}{\psi}`.
"""

def __init__(self, psi=1.0, invlink=tf.exp, **kwargs):
super().__init__(**kwargs)
self.invlink = invlink
self.psi = Parameter(psi, transform=positive())

def _scalar_log_prob(self, F, Y):
mu = self.invlink(F)
mu_psi = mu + self.psi
psi_y = self.psi + Y
f1 = tf.math.lgamma(psi_y) - tf.math.lgamma(Y + 1.0) - tf.math.lgamma(self.psi)
f2 = Y * tf.math.log(mu / mu_psi)
f3 = self.psi * tf.math.log(self.psi / mu_psi)
return f1 + f2 + f3

def _conditional_mean(self, F):
return self.invlink(F)

def _conditional_variance(self, F):
mu = self.invlink(F)
return mu + tf.pow(mu, 2) / self.psi


class ZeroInflatedNegativeBinomial(NegativeBinomial):
"""
A likelihood for count data with overdispersion and inflation of zeros.
The distribution of a zero-inflated negative binomial random variable
arises as a mixture of a negative binomial and a distribution with all mass
at zero.

Its pmf is given by

.. math::

ZINB(y \mid \mu, \psi, \theta) =
\theta * I(y == 0) + (1 - \theta) NB(y \mid \mu, \psi)

where :math:`\mu` and :math:`\psi` are mean and overdispersion parameter
of the negative binomial component, and :math:`\theta` is a parameter to
control how much mass should be used for excess zeros.

with expected value :math:`\mathbb{E}[y] = (1 - \theta) \mu `
and variance :math:`Var[Y] = (1 - \theta) \mu (1 + \theta * \mu + \mu / \psi )`
"""

def __init__(self, theta=0.5, psi=1.0, invlink=tf.exp, **kwargs):
super().__init__(psi, invlink, **kwargs)
self.theta = Parameter(theta, transform=positive())

def _scalar_log_prob(self, F, Y):
yz = tf.cast(Y == 0.0, dtype=default_float())
log_sup = super()._scalar_log_prob(F, Y)
lse = yz * self.theta + (1.0 - self.theta) * tf.math.exp(log_sup)
return tf.math.log(lse)

def _conditional_mean(self, F):
return (1.0 - self.theta) * self.invlink(F)

def _conditional_variance(self, F):
mu = self.invlink(F)
return (1.0 - self.theta) * mu * (1.0 + self.theta * mu + mu / self.psi)
9 changes: 9 additions & 0 deletions tests/gpflow/likelihoods/test_likelihoods.py
Expand Up @@ -42,6 +42,8 @@
Ordinal,
Poisson,
StudentT,
NegativeBinomial,
ZeroInflatedNegativeBinomial,
)

tf.random.set_seed(99012)
Expand Down Expand Up @@ -88,6 +90,13 @@ def __repr__(self):
LikelihoodSetup(
Bernoulli(invlink=tf.sigmoid), Y=tf.random.uniform(Datum.Yshape, dtype=default_float()),
),
LikelihoodSetup(
NegativeBinomial(), Y=tf.random.poisson(Datum.Yshape, 100, dtype=default_float()),
),
LikelihoodSetup(
ZeroInflatedNegativeBinomial(),
Y=tf.random.poisson(Datum.Yshape, 100, dtype=default_float()),
),
]

likelihood_setups = scalar_likelihood_setups + [
Expand Down