You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to update the existing integration of GeometricKernels with GPJax so that it works with newer versions of GPJax. It works okay for GPJax 0.6.9. However, for the current GPJax 0.8.0, I hit two problems.
The first one is exactly #397, which, although quite annoying, can be fixed by downgrading tensorflow to version 2.13.
The second one is illustrated in the Related code section below. I believe it is concenred with plum-dispatch, which we use extensively in GeometricKernels to support multiple backends. GPJax uses cola which in its turn relies on a fork of cola, cola-plum-dispatch. This unmaintained fork uses the same namespace plum (which seems like a terrible sin) and gets overriden by the actual plum that GeometricKernels uses, causing the error below. I believe this is similar to this issue.
Expected behavior:
I am not sure how to fix this, but it seems to be an important problem to fix as otherwise GPJax becomes incompatible with any other libraries that rely on plum-dispatch, which is quite popular.
Steps to reproduce:
See below.
Related code:
It is enough to run this snippet:
# Import a backend, we use jax in this example.importjax.numpyasjnpimportjaximportgpjaxasgpx# Import the geometric_kernels backend.importgeometric_kernelsimportgeometric_kernels.jax
which leads to
---------------------------------------------------------------------------TypeErrorTraceback (mostrecentcalllast)
CellIn[1], line42importjax.numpyasjnp3importjax---->4importgpjaxasgpx6# Import the geometric_kernels backend.7importgeometric_kernelsFile~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/__init__.py:151# Copyright 2022 The GPJax Contributors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");
(...)
13# limitations under the License.14# ==============================================================================--->15fromgpjaximport (
16base,
17decision_making,
18gps,
19integrators,
20kernels,
21likelihoods,
22mean_functions,
23objectives,
24variational_families,
25 )
26fromgpjax.baseimport (
27Module,
28param_field,
29 )
30fromgpjax.citationimportciteFile~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/__init__.py:151# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved.2#3# Licensed under the Apache License, Version 2.0 (the "License");
(...)
13# limitations under the License.14# ==============================================================================--->15fromgpjax.decision_making.decision_makerimport (
16AbstractDecisionMaker,
17UtilityDrivenDecisionMaker,
18 )
19fromgpjax.decision_making.posterior_handlerimportPosteriorHandler20fromgpjax.decision_making.search_spaceimport (
21AbstractSearchSpace,
22ContinuousSearchSpace,
23 )
File~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/decision_maker.py:3229importjax.randomasjr31fromgpjax.datasetimportDataset--->32fromgpjax.decision_making.posterior_handlerimportPosteriorHandler33fromgpjax.decision_making.search_spaceimportAbstractSearchSpace34fromgpjax.decision_making.utility_functionsimport (
35AbstractUtilityFunctionBuilder,
36ThompsonSampling,
37 )
File~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/decision_making/posterior_handler.py:2523importgpjaxasgpx24fromgpjax.datasetimportDataset--->25fromgpjax.gpsimport (
26AbstractLikelihood,
27AbstractPosterior,
28AbstractPrior,
29 )
30fromgpjax.objectivesimportAbstractObjective31fromgpjax.typingimportKeyArrayFile~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/gpjax/gps.py:2618fromtypingimportoverload20frombeartype.typingimport (
21Any,
22Callable,
23Optional,
24Union,
25 )
--->26importcola27fromcola.linalg.decompositions.decompositionsimportCholesky28importjax.numpyasjnpFile~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/__init__.py:119__all__= []
10# for loader, module_name, is_pkg in pkgutil.walk_packages(__path__):--->11import_from_all("fns", globals(), __all__, __name__)
12import_from_all("annotations", globals(), __all__, __name__)
13import_from_all("linalg", globals(), __all__, __name__)
File~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/utils/__init__.py:36, inimport_from_all(module_name, namespace, _all, _name)
32defimport_from_all(module_name, namespace, _all, _name):
33"""Import all functions from module.__all__ into the namespace and add to __all__. 34 example usage: import_every("operators",globals(),__all__,__name__) 35 """--->36module=importlib.import_module('.'+module_name, package=_name)
37ifnothasattr(module, "__all__"):
38logging.debug(f"empty {module_name}.__all__")
File~/anaconda3/envs/gkconda_newjax/lib/python3.10/importlib/__init__.py:126, inimport_module(name, package)
124break125level+=1-->126return_bootstrap._gcd_import(name[level:], package, level)
File~/anaconda3/envs/gkconda_newjax/lib/python3.10/site-packages/cola/fns.py:127122 @dispatch123deftranspose(A: Dense):
124returnDense(A.A.T)
-->127 @dispatch(cond=lambdaA: A.isa(cola.SelfAdjoint))
128deftranspose(A: LinearOperator):
129# dangerous, TODO: fix when A is complex or unify transpose and adjoint130returnA133 @dispatch134deftranspose(A: Triangular):
TypeError: Dispatcher.__call__() gotanunexpectedkeywordargument'cond'
The text was updated successfully, but these errors were encountered:
Bug Report
GPJax version: 0.8.0
Current behavior:
I am trying to update the existing integration of GeometricKernels with GPJax so that it works with newer versions of GPJax. It works okay for GPJax 0.6.9. However, for the current GPJax 0.8.0, I hit two problems.
The first one is exactly #397, which, although quite annoying, can be fixed by downgrading
tensorflow
to version 2.13.The second one is illustrated in the Related code section below. I believe it is concenred with plum-dispatch, which we use extensively in GeometricKernels to support multiple backends. GPJax uses cola which in its turn relies on a fork of
cola
, cola-plum-dispatch. This unmaintained fork uses the same namespaceplum
(which seems like a terrible sin) and gets overriden by the actualplum
that GeometricKernels uses, causing the error below. I believe this is similar to this issue.Expected behavior:
I am not sure how to fix this, but it seems to be an important problem to fix as otherwise GPJax becomes incompatible with any other libraries that rely on
plum-dispatch
, which is quite popular.Steps to reproduce:
See below.
Related code:
It is enough to run this snippet:
which leads to
The text was updated successfully, but these errors were encountered: