Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling kernels to use PyTree inputs #288

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions docs/examples/pytree_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# %%
import jax.numpy as jnp
import datasets as ds
import gpjax as gpx
from jax import jit
import optax as ox
import jax.random as jr
from jaxtyping import PyTree
import matplotlib.pyplot as plt

# %% [markdown]
# Now load a graph dataset and pad it

# %%
gd = ds.load_dataset("graphs-datasets/AQSOL")

gd = gd.map(
lambda x: {
"num_edges": len(x["edge_index"][0]),
}
)
gd.set_format("jax")

max_num_edges = max([gd[i]["num_edges"].max() for i in gd])
max_num_nodes = max([gd[i]["num_nodes"].max() for i in gd])

small_gd = (
gd["train"]
.select(range(100))
.map(
lambda x: {
"num_edges": len(x["edge_index"][0]),
}
)
)


def pad_edge_attr_node_feat(x):
nf = (
jnp.zeros(max_num_nodes).at[: len(x["node_feat"])].set(x["node_feat"].squeeze())
)
ea = (
jnp.zeros(max_num_edges).at[: len(x["edge_attr"])].set(x["edge_attr"].squeeze())
)
return {"node_feat": nf, "edge_attr": ea}


small_gd = small_gd.map(pad_edge_attr_node_feat)

# prepare the dataset for GPjax
D = gpx.Dataset(X={i: small_gd[i] for i in ("node_feat", "edge_attr")}, y=small_gd["y"])

# %% [markdown]
# Now define a naive Graph kernel that takes node and edge features


# %%
class GraphKern(gpx.AbstractKernel):
def __call__(self, x1: PyTree, x2: PyTree, **kwargs):
return gpx.kernels.RBF()(x1["node_feat"], x2["node_feat"]) + gpx.kernels.RBF()(
x1["edge_attr"], x2["edge_attr"]
)


# %% [markdown]
# And we're ready to fit a model!

# %%
meanf = gpx.mean_functions.Zero()
prior = gpx.Prior(mean_function=meanf, kernel=GraphKern())
likelihood = gpx.Gaussian(num_datapoints=D.n)
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))
likelihood = gpx.Gaussian(num_datapoints=D.n)
posterior = prior * likelihood

opt_posterior, mll_history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=600,
safe=True,
key=jr.PRNGKey(0),
)

# %%
plt.plot(mll_history)

# %%
42 changes: 35 additions & 7 deletions gpjax/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
# ==============================================================================

from dataclasses import dataclass
from typing import (
Callable,
TypeVar,
Union,
)

from beartype.typing import Optional
import jax
import jax.numpy as jnp
from jaxtyping import Num
from simple_pytree import Pytree
Expand Down Expand Up @@ -43,7 +49,7 @@ def __post_init__(self) -> None:
def __repr__(self) -> str:
r"""Returns a string representation of the dataset."""
repr = (
f"- Number of observations: {self.n}\n- Input dimension:"
f"- Number of observations: {self.n}\n- Input dimension (sum over PyTree):"
f" {self.in_dim}\n- Output dimension: {self.out_dim}"
)
return repr
Expand Down Expand Up @@ -72,12 +78,14 @@ def __add__(self, other: "Dataset") -> "Dataset":
@property
def n(self) -> int:
r"""Number of observations."""
return self.X.shape[0]
return jax.tree_util.tree_leaves(self.X)[0].shape[0]

@property
def in_dim(self) -> int:
r"""Dimension of the inputs, $`X`$."""
return self.X.shape[1]
return jax.tree_util.tree_reduce(
lambda a, b: a + b, jax.tree_map(lambda a: a.shape[1], self.X), 0
)

@property
def out_dim(self) -> int:
Expand All @@ -89,15 +97,17 @@ def _check_shape(
X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
) -> None:
r"""Checks that the shapes of $`X`$ and $`y`$ are compatible."""
if X is not None and y is not None and X.shape[0] != y.shape[0]:
len_ok, X_length = _check_all_leaves_const(lambda a: len(a), len(y), X)
if X is not None and y is not None and not len_ok:
raise ValueError(
"Inputs, X, and outputs, y, must have the same number of rows."
f" Got X.shape={X.shape} and y.shape={y.shape}."
f" Got len(y)={len(y)} and len(X)={X_length}."
)

if X is not None and X.ndim != 2:
dim_ok, X_dim = _check_all_leaves_const(lambda a: a.ndim, 2, X)
if X is not None and not dim_ok:
raise ValueError(
f"Inputs, X, must be a 2-dimensional array. Got X.ndim={X.ndim}."
f"Inputs, X, must be a 2-dimensional array. Got X.ndim={X_dim}."
)

if y is not None and y.ndim != 2:
Expand All @@ -106,6 +116,24 @@ def _check_shape(
)


T = TypeVar("T")


def _check_all_leaves_const(
extract_value: Callable[[any], T],
equal_to: T,
X: Optional[Union[Pytree, Num[Array, "..."]]],
) -> bool:
values = jax.tree_map(extract_value, X)

return (
jax.tree_util.tree_reduce(
lambda a, b: a and b, jax.tree_map(lambda a: a == equal_to, values), True
),
values,
)


__all__ = [
"Dataset",
]
5 changes: 3 additions & 2 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Callable,
Optional,
)
import jax
import jax.numpy as jnp
from jax.random import (
PRNGKey,
Expand Down Expand Up @@ -481,7 +482,7 @@ def predict(
x, y, n = train_data.X, train_data.y, train_data.n

# Unpack test inputs
t, n_test = test_inputs, test_inputs.shape[0]
t, n_test = test_inputs, jax.tree_util.tree_leaves(test_inputs)[0].shape[0]

# Observation noise o²
obs_noise = self.likelihood.obs_noise
Expand Down Expand Up @@ -655,7 +656,7 @@ def predict(
Lx = Kxx.to_root()

# Unpack test inputs
t, n_test = test_inputs, test_inputs.shape[0]
t, n_test = test_inputs, jax.tree_util.tree_leaves(test_inputs)[0].shape[0]

# Compute terms of the posterior predictive distribution
Ktx = kernel.cross_covariance(t, x)
Expand Down
3 changes: 2 additions & 1 deletion gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
List,
Union,
)
import jax
import jax.numpy as jnp
from jaxtyping import (
Float,
Expand Down Expand Up @@ -147,7 +148,7 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
-------
Float[Array, "1"]: The evaluated mean function.
"""
return jnp.ones((x.shape[0], 1)) * self.constant
return jnp.ones((jax.tree_util.tree_leaves(x)[0].shape[0], 1)) * self.constant


@dataclasses.dataclass
Expand Down
Loading