In [None]:
import jax.numpy as jnp
import jax.random as jr
import pytest
from gpjax.covariance_operator import (
    CovarianceOperator,
    DenseCovarianceOperator,
    DiagonalCovarianceOperator,
    I,
)

from gpjax.kernels import (
    RBF, 
    Matern12,
    Matern32,
    Matern52,
)


def test_covariance_operator():
    with pytest.raises(TypeError):
        CovarianceOperator()


@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52])
@pytest.mark.parametrize("n", [1, 10, 100])
def test_dense_covariance_operator(kernel, n):

    key = jr.PRNGKey(seed=42)
    A = jr.normal(key, (n, n))
    dense = A.T @ A # Dense random matrix is positive definite.

    cov = DenseCovarianceOperator(matrix = dense)

    # Test shape:
    assert cov.shape == (n, n)
    
    # Test solve:
    # b = jr.normal(key, (n, 1))
    # x = cov.solve(b)
    # assert jnp.allclose(b, dense @ x)

    # Test to_dense method:
    assert jnp.allclose(dense, cov.to_dense())

    # Test to_diag method:
    assert jnp.allclose(jnp.diag(dense), cov.diagonal())


    # Test log determinant:
    assert jnp.allclose(jnp.linalg.slogdet(dense)[1], cov.log_det())

    # Test trace:
    assert jnp.allclose(jnp.trace(dense), cov.trace())

    # Test lower triangular:
    assert jnp.allclose(jnp.linalg.cholesky(dense), cov.triangular_lower())

    # Test adding diagonal covariance operator to dense linear operator:
    diag = DiagonalCovarianceOperator(jnp.diag(dense))
    cov = cov +  (diag * jnp.pi)
    assert jnp.allclose(dense + jnp.pi * jnp.diag(jnp.diag(dense)), cov.to_dense())

In [None]:
n = 3

key = jr.PRNGKey(seed=42)
A = jr.normal(key, (n, n))
dense = A.T @ A # Dense random matrix is positive definite.

cov = DenseCovarianceOperator(matrix = dense)

In [None]:
jnp.linalg.cholesky(dense)

In [None]:
cov.triangular_lower()