# Kernel Formatting Demonstration

This notebook demonstrates the new kernel formatting capabilities with `str()`. Examples progress from simple to complex.

In [None]:
import sys
sys.path.insert(0, '..')

import jax.numpy as jnp
from kernax import SEKernel, LinearKernel, ConstantKernel
from kernax.WrapperKernels import ExpKernel, LogKernel, DiagKernel, NegKernel
from kernax.OperatorKernels import SumKernel, ProductKernel

## 1. Simple Kernels with Scalar Parameters

Simple kernels display their hyperparameters in a readable format.

In [None]:
# Constant kernel with scalar
k_const = ConstantKernel(value=2.5)
print(f"ConstantKernel: {k_const}")
print()

In [None]:
# SE (Squared Exponential) kernel
k_se = SEKernel(length_scale=1.0)
print(f"SEKernel: {k_se}")
print()

In [None]:
# Linear kernel with multiple parameters
k_linear = LinearKernel(variance_b=0.5, variance_v=1.0, offset_c=0.0)
print(f"LinearKernel: {k_linear}")
print()

## 2. Kernels with Multidimensional Parameters

When a parameter is an array, the format displays `[mean ± std]_shape` with subscript indices.

In [None]:
# Constant kernel with 1D array
k_const_array = ConstantKernel(value=jnp.array([1.0, 2.0, 3.0]))
print(f"ConstantKernel (1D array): {k_const_array}")
print()

In [None]:
# Constant kernel with 2D array
k_const_matrix = ConstantKernel(value=jnp.array([[1.0, 2.0], [3.0, 4.0]]))
print(f"ConstantKernel (2D array): {k_const_matrix}")
print()

In [None]:
# SE kernel with vector length_scale (for ARD)
k_se_vector = SEKernel(length_scale=jnp.array([0.5, 1.0, 1.5, 2.0]))
print(f"SEKernel (vector length_scale): {k_se_vector}")
print()

## 3. Kernel Operators - Addition

The `+` operator creates a `SumKernel` that displays intuitively.

In [None]:
# Addition of two simple kernels
k_sum = SEKernel(length_scale=1.0) + ConstantKernel(value=0.5)
print(f"SE + Constant: {k_sum}")
print()

In [None]:
# Addition of multiple kernels
k_sum_multi = SEKernel(length_scale=1.0) + SEKernel(length_scale=2.0) + ConstantKernel(value=1.0)
print(f"SE + SE + Constant: {k_sum_multi}")
print()

## 4. Kernel Operators - Subtraction

Subtraction uses `NegKernel` internally, but displays as natural subtraction.

In [None]:
# Kernel subtraction
k_diff = SEKernel(length_scale=1.0) - ConstantKernel(value=0.5)
print(f"SE - Constant: {k_diff}")
print()

In [None]:
# Kernel negation
k_neg = -SEKernel(length_scale=1.0)
print(f"Negation: {k_neg}")
print()

## 5. Kernel Operators - Multiplication

The `*` operator creates a `ProductKernel`.

In [None]:
# Product of two kernels
k_prod = SEKernel(length_scale=1.0) * LinearKernel(variance_b=0.5, variance_v=1.0, offset_c=0.0)
print(f"SE × Linear: {k_prod}")
print()

In [None]:
# Product with negative kernel (adds parentheses)
k_prod_neg = SEKernel(length_scale=1.0) * (-ConstantKernel(value=2.0))
print(f"SE × (-Constant): {k_prod_neg}")
print()

## 6. Wrapper Kernels

Wrappers apply transformations to kernels and display using functional notation.

In [None]:
# Exp wrapper
k_exp = ExpKernel(SEKernel(length_scale=1.0))
print(f"Exp wrapper: {k_exp}")
print()

In [None]:
# Log wrapper
k_log = LogKernel(SEKernel(length_scale=1.0))
print(f"Log wrapper: {k_log}")
print()

In [None]:
# Diag wrapper (creates diagonal matrix)
k_diag = DiagKernel(ConstantKernel(value=1.0))
print(f"Diag wrapper: {k_diag}")
print()

## 7. Simple Combinations

Combinations of operators and wrappers.

In [None]:
# Wrapper on a sum
k_exp_sum = ExpKernel(SEKernel(length_scale=1.0) + ConstantKernel(value=0.5))
print(f"Exp(SE + Constant): {k_exp_sum}")
print()

In [None]:
# Sum of wrappers
k_sum_wrappers = ExpKernel(SEKernel(length_scale=1.0)) + LogKernel(LinearKernel(variance_b=1.0, variance_v=1.0, offset_c=0.0))
print(f"Exp(SE) + Log(Linear): {k_sum_wrappers}")
print()

In [None]:
# Product with wrapper
k_prod_wrapper = SEKernel(length_scale=1.0) * DiagKernel(ConstantKernel(value=2.0))
print(f"SE × Diag(Constant): {k_prod_wrapper}")
print()

## 8. Complex Nested Structures

Demonstration of deeply nested kernel structures.

In [None]:
# Multiple wrapping levels
k_nested_wrapper = ExpKernel(LogKernel(SEKernel(length_scale=1.0)))
print(f"Double wrapper: {k_nested_wrapper}")
print()

In [None]:
# Complex combination: (Exp(SE) + Log(Linear)) × Diag(Constant)
k_complex1 = (ExpKernel(SEKernel(length_scale=1.0)) + LogKernel(LinearKernel(variance_b=1.0, variance_v=1.0, offset_c=0.0))) * DiagKernel(ConstantKernel(value=2.0))
print(f"Complex combination 1:")
print(f"{k_complex1}")
print()

In [None]:
# Expression with nested subtraction and multiplication
k_complex2 = (SEKernel(length_scale=1.0) + SEKernel(length_scale=2.0)) * (LinearKernel(variance_b=1.0, variance_v=1.0, offset_c=0.0) - ConstantKernel(value=0.5))
print(f"Complex combination 2:")
print(f"{k_complex2}")
print()

In [None]:
# Very nested structure with all operator types
k_very_complex = ExpKernel(
    (SEKernel(length_scale=1.0) * ConstantKernel(value=2.0)) +
    DiagKernel(LinearKernel(variance_b=0.5, variance_v=1.0, offset_c=0.0)) -
    LogKernel(SEKernel(length_scale=0.5))
)
print(f"Very complex structure:")
print(f"{k_very_complex}")
print()

## 9. Expressions with Multidimensional Parameters

Complex combinations including vector/matrix parameters.

In [None]:
# Sum of kernels with vector parameters
k_vector_params = SEKernel(length_scale=jnp.array([1.0, 2.0, 3.0])) + ConstantKernel(value=jnp.array([0.5, 0.5, 0.5]))
print(f"SE (vector) + Constant (vector):")
print(f"{k_vector_params}")
print()

In [None]:
# Wrapper of kernel with matrix parameters
k_matrix_wrapper = ExpKernel(ConstantKernel(value=jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])))
print(f"Exp(Constant matrix):")
print(f"{k_matrix_wrapper}")
print()

## 10. Edge Cases and Long Expressions

Demonstration of edge cases to test formatting limits.

In [None]:
# Long chain of additions
k_long_chain = SEKernel(length_scale=1.0) + SEKernel(length_scale=2.0) + SEKernel(length_scale=3.0) + ConstantKernel(value=1.0) + ConstantKernel(value=2.0)
print(f"Long chain:")
print(f"{k_long_chain}")
print()

In [None]:
# Expression with multiple levels of parentheses
k_parentheses = ((SEKernel(length_scale=1.0) + ConstantKernel(value=1.0)) * (LinearKernel(variance_b=1.0, variance_v=1.0, offset_c=0.0) - ConstantKernel(value=0.5))) + ExpKernel(SEKernel(length_scale=0.5))
print(f"Expression with multiple parentheses:")
print(f"{k_parentheses}")
print()

In [None]:
# Maximum complexity expression
k_max_complexity = ExpKernel(
    LogKernel(
        (SEKernel(length_scale=jnp.array([1.0, 2.0])) * DiagKernel(ConstantKernel(value=3.0))) +
        (LinearKernel(variance_b=1.0, variance_v=2.0, offset_c=0.5) - 
         ExpKernel(SEKernel(length_scale=0.5)))
    )
) * ConstantKernel(value=jnp.array([[1.0, 2.0], [3.0, 4.0]]))

print(f"Maximum complexity expression:")
print(f"{k_max_complexity}")
print()

## Conclusion

This notebook demonstrated the kernel formatting capabilities:

1. **Simple kernels**: Clear display of hyperparameters
2. **Multidimensional parameters**: `[mean ± std]_shape` format with subscript indices
3. **Operators**: Intuitive notation for `+`, `-`, and `*`
4. **Wrappers**: Functional notation (Exp, Log, Diag, etc.)
5. **Combinations**: Correct handling of priorities and parentheses
6. **Nesting**: Support for arbitrarily complex structures

The formatting system makes kernel expressions readable while preserving their mathematical structure.