Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev: Decouple kernel computation class initialisation from kernel. #293

Closed
daniel-dodd opened this issue Jun 7, 2023 · 1 comment · Fixed by #328
Closed

dev: Decouple kernel computation class initialisation from kernel. #293

daniel-dodd opened this issue Jun 7, 2023 · 1 comment · Fixed by #328
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@daniel-dodd
Copy link
Member

Currently:

Currently the kernel computation is a dataclass that takes in the kernel as its argument.

@dataclass
class AbstractKernelComputation:
    r"""Abstract class for kernel computations."""

    kernel: "gpjax.kernels.base.AbstractKernel"  # noqa: F821

This means when defining the compute_engine of a kernel, we cannot initialise the Computation before initialising the kernel, so have to pass an uninitialised class through compute_engine = DenseKernelComputation :

@dataclass
class AbstractKernel(Module):
    r"""Base kernel class."""

    compute_engine: Type[AbstractKernelComputation] = static_field(
        DenseKernelComputation
    )

Proposal:

  • Remove the kernel argument on the AbstractKernelComputation so that the class can be initialised independently of the kernel.
  • As such we would instead pass the kernel pytree, through the gram, cross_covariance methods of AbstractKernelComputation.
  • This would then allow us to pass DenseKernelComputation() instead of DenseKernelComputation to the compute_engine.
@daniel-dodd daniel-dodd added enhancement New feature or request good first issue Good for newcomers labels Jun 7, 2023
@frazane frazane self-assigned this Jun 23, 2023
@frazane
Copy link
Contributor

frazane commented Jun 25, 2023

The refactoring proposed by @daniel-dodd is easily implemented (PR coming soon), however it would be even nicer at some point to also avoid problems with circular imports. If we want to have correct typing annotations in gpjax.kernels.computations submodules, we would have to import kernels classes, e.g. AbstractKernel, which creates a circular dependency since compute_engine is part of a kernel.

This would not be a problem when using the typing.TYPE_CHECKING block. It would work for static type checkers, but not for runtime type checking with beartype for some reason I don't fully understand yet. Solutions might include selectively disabling runtime checking or a relatively large refactoring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants