# Mean Formatting Demonstration

This notebook demonstrates the mean function formatting capabilities with `str()`. Examples progress from simple to complex.

In [1]:
import sys

sys.path.insert(0, "..")

import jax.numpy as jnp

from kernax import AffineMean, ConstantMean, LinearMean, ZeroMean
from kernax.wrappers import ExpModule, LogModule

## 1. Simple Means with Scalar Parameters

Simple means display their hyperparameters in a readable format.

In [2]:
# Zero mean (no parameters)
m_zero = ZeroMean()
print(f"ZeroMean: {m_zero}")
print()

ZeroMean: ZeroMean()



In [3]:
# Constant mean with scalar
m_const = ConstantMean(constant=2.5)
print(f"ConstantMean: {m_const}")
print()

ConstantMean: ConstantMean(constant=2.50)



In [4]:
# Linear mean with scalar slope
m_linear = LinearMean(slope=1.5)
print(f"LinearMean: {m_linear}")
print()

LinearMean: LinearMean(slope=1.50)



In [5]:
# Affine mean with slope and intercept
m_affine = AffineMean(slope=2.0, intercept=-1.0)
print(f"AffineMean: {m_affine}")
print()

AffineMean: AffineMean(slope=2.00, intercept=-1.00)



## 2. Means with Multidimensional Parameters

When a parameter is an array, the format displays `[mean ± std]_shape` with subscript indices.

In [6]:
# Linear mean with vector slope (for multi-dimensional inputs)
m_linear_vec = LinearMean(slope=jnp.array([1.0, 2.0, 0.5]))
print(f"LinearMean (vector slope): {m_linear_vec}")
print()

LinearMean (vector slope): LinearMean(slope=[1.17 ± 0.62]₃)



In [7]:
# Affine mean with vector slope and scalar intercept
m_affine_vec = AffineMean(slope=jnp.array([1.0, 0.5]), intercept=3.0)
print(f"AffineMean (vector slope): {m_affine_vec}")
print()

AffineMean (vector slope): AffineMean(slope=[0.75 ± 0.25]₂, intercept=3.00)



In [8]:
# Constant mean with 1D array
m_const_array = ConstantMean(constant=jnp.array([1.0, 2.0, 3.0]))
print(f"ConstantMean (1D array): {m_const_array}")
print()

ConstantMean (1D array): ConstantMean(constant=[2.00 ± 0.82]₃)



## 3. Mean Operators - Addition

The `+` operator creates a `SumModule` that displays intuitively.

In [9]:
# Addition of two simple means
m_sum = LinearMean(slope=1.0) + ConstantMean(constant=0.5)
print(f"Linear + Constant: {m_sum}")
print()

Linear + Constant: LinearMean(slope=1.00) + ConstantMean(constant=0.50)



In [10]:
# radd: scalar + mean auto-converts scalar to ConstantMean
m_radd = 3.0 + LinearMean(slope=1.0)
print(f"3.0 + Linear: {m_radd}")
print()

3.0 + Linear: ConstantMean(constant=3.00) + LinearMean(slope=1.00)



In [11]:
# Addition of multiple means
m_sum_multi = ZeroMean() + LinearMean(slope=2.0) + ConstantMean(constant=1.0)
print(f"Zero + Linear + Constant: {m_sum_multi}")
print()

Zero + Linear + Constant: ZeroMean() + LinearMean(slope=2.00) + ConstantMean(constant=1.00)



## 4. Mean Operators - Subtraction

Subtraction uses `NegModule` internally, but displays as natural subtraction.

In [12]:
# Mean subtraction
m_diff = AffineMean(slope=2.0, intercept=1.0) - ConstantMean(constant=0.5)
print(f"Affine - Constant: {m_diff}")
print()

Affine - Constant: AffineMean(slope=2.00, intercept=1.00) - ConstantMean(constant=0.50)



In [13]:
# Mean negation
m_neg = -LinearMean(slope=1.0)
print(f"Negation: {m_neg}")
print()

Negation: - LinearMean(slope=1.00)



## 5. Mean Operators - Multiplication

The `*` operator creates a `ProductModule`.

In [14]:
# Product of two means
m_prod = ConstantMean(constant=2.0) * LinearMean(slope=1.5)
print(f"Constant × Linear: {m_prod}")
print()

Constant × Linear: ConstantMean(constant=2.00) * LinearMean(slope=1.50)



In [15]:
# rmul: scalar * mean auto-converts scalar to ConstantMean
m_rmul = 2.0 * LinearMean(slope=1.0)
print(f"2.0 × Linear: {m_rmul}")
print()

2.0 × Linear: ConstantMean(constant=2.00) * LinearMean(slope=1.00)



In [16]:
# Product with negation (adds parentheses)
m_prod_neg = LinearMean(slope=1.0) * (-ConstantMean(constant=2.0))
print(f"Linear × (-Constant): {m_prod_neg}")
print()

Linear × (-Constant): LinearMean(slope=1.00) * (- ConstantMean(constant=2.00))



## 6. Wrapper Means

Wrappers apply transformations to mean outputs and display using functional notation.

In [17]:
# Exp wrapper
m_exp = ExpModule(LinearMean(slope=1.0))
print(f"Exp wrapper: {m_exp}")
print()

Exp wrapper: Exp(LinearMean(slope=1.00))



In [18]:
# Log wrapper
m_log = LogModule(ConstantMean(constant=2.0))
print(f"Log wrapper: {m_log}")
print()

Log wrapper: Log(ConstantMean(constant=2.00))



In [19]:
# Exp on an affine mean
m_exp_affine = ExpModule(AffineMean(slope=1.0, intercept=0.5))
print(f"Exp(Affine): {m_exp_affine}")
print()

Exp(Affine): Exp(AffineMean(slope=1.00, intercept=0.50))



## 7. Simple Combinations

Combinations of operators and wrappers.

In [20]:
# Wrapper on a sum
m_exp_sum = ExpModule(LinearMean(slope=1.0) + ConstantMean(constant=0.5))
print(f"Exp(Linear + Constant): {m_exp_sum}")
print()

Exp(Linear + Constant): Exp(LinearMean(slope=1.00) + ConstantMean(constant=0.50))



In [21]:
# Sum of wrappers
m_sum_wrapped = ExpModule(LinearMean(slope=1.0)) + LogModule(ConstantMean(constant=2.0))
print(f"Exp(Linear) + Log(Constant): {m_sum_wrapped}")
print()

Exp(Linear) + Log(Constant): Exp(LinearMean(slope=1.00)) + Log(ConstantMean(constant=2.00))



In [22]:
# LinearMean + ConstantMean is equivalent to AffineMean
m_linear_const = LinearMean(slope=2.0) + ConstantMean(constant=3.0)
m_affine_ref = AffineMean(slope=2.0, intercept=3.0)
print(f"LinearMean + ConstantMean: {m_linear_const}")
print(f"AffineMean (reference):    {m_affine_ref}")
print()

LinearMean + ConstantMean: LinearMean(slope=2.00) + ConstantMean(constant=3.00)
AffineMean (reference):    AffineMean(slope=2.00, intercept=3.00)



## 8. Complex Nested Structures

Demonstration of deeply nested mean structures.

In [23]:
# Double wrapping
m_double_wrap = ExpModule(LogModule(LinearMean(slope=1.0)))
print(f"Double wrapper Exp(Log(Linear)): {m_double_wrap}")
print()

Double wrapper Exp(Log(Linear)): Exp(Log(LinearMean(slope=1.00)))



In [24]:
# Complex combination: 2 * Linear + Constant
m_complex1 = 2.0 * LinearMean(slope=1.0) + ConstantMean(constant=1.0)
print(f"2 × Linear + Constant: {m_complex1}")
print()

2 × Linear + Constant: ConstantMean(constant=2.00) * LinearMean(slope=1.00) + ConstantMean(constant=1.00)



In [25]:
# Complex: Exp(Linear) * Constant - ZeroMean
m_complex2 = ExpModule(LinearMean(slope=1.0)) * ConstantMean(constant=2.0) - ZeroMean()
print(f"Exp(Linear) × Constant - Zero: {m_complex2}")
print()

Exp(Linear) × Constant - Zero: Exp(LinearMean(slope=1.00)) * ConstantMean(constant=2.00) - ZeroMean()



In [26]:
# Very nested structure
m_very_complex = ExpModule(
    (LinearMean(slope=2.0) + ConstantMean(constant=1.0))
    * (AffineMean(slope=1.0, intercept=-0.5) - LogModule(ConstantMean(constant=3.0)))
)
print("Very complex structure:")
print(f"{m_very_complex}")
print()

Very complex structure:
Exp(LinearMean(slope=2.00) + ConstantMean(constant=1.00) * AffineMean(slope=1.00, intercept=-0.50) - Log(ConstantMean(constant=3.00)))



## 9. Expressions with Multidimensional Parameters

In [27]:
# Sum of means with vector parameters
m_vec_params = LinearMean(slope=jnp.array([1.0, 2.0, 3.0])) + ConstantMean(
    constant=jnp.array([0.5, 0.5, 0.5])
)
print("Linear (vector) + Constant (vector):")
print(f"{m_vec_params}")
print()

Linear (vector) + Constant (vector):
LinearMean(slope=[2.00 ± 0.82]₃) + ConstantMean(constant=[0.50 ± 0.00]₃)



In [28]:
# Exp wrapper of a mean with vector slope
m_exp_vec = ExpModule(AffineMean(slope=jnp.array([1.0, 0.5]), intercept=0.0))
print("Exp(Affine with vector slope):")
print(f"{m_exp_vec}")
print()

Exp(Affine with vector slope):
Exp(AffineMean(slope=[0.75 ± 0.25]₂, intercept=0.00))



## 10. Edge Cases and Long Expressions

In [29]:
# Long chain of additions
m_long_chain = (
    LinearMean(slope=1.0)
    + LinearMean(slope=2.0)
    + LinearMean(slope=3.0)
    + ConstantMean(constant=1.0)
    + ConstantMean(constant=2.0)
)
print("Long chain:")
print(f"{m_long_chain}")
print()

Long chain:
LinearMean(slope=1.00) + LinearMean(slope=2.00) + LinearMean(slope=3.00) + ConstantMean(constant=1.00) + ConstantMean(constant=2.00)



In [30]:
# Expression with multiple levels of parentheses
m_parentheses = (
    (LinearMean(slope=1.0) + ConstantMean(constant=1.0))
    * (AffineMean(slope=2.0, intercept=-1.0) - ConstantMean(constant=0.5))
) + ExpModule(LinearMean(slope=0.5))
print("Expression with multiple parentheses:")
print(f"{m_parentheses}")
print()

Expression with multiple parentheses:
LinearMean(slope=1.00) + ConstantMean(constant=1.00) * AffineMean(slope=2.00, intercept=-1.00) - ConstantMean(constant=0.50) + Exp(LinearMean(slope=0.50))



In [31]:
# Maximum complexity expression
m_max_complexity = ExpModule(
    LogModule(
        (LinearMean(slope=jnp.array([1.0, 2.0])) * ConstantMean(constant=3.0))
        + (
            AffineMean(slope=1.0, intercept=0.5)
            - ExpModule(LinearMean(slope=0.5))
        )
    )
) * ConstantMean(constant=jnp.array([[1.0, 2.0], [3.0, 4.0]]))

print("Maximum complexity expression:")
print(f"{m_max_complexity}")
print()

Maximum complexity expression:
Exp(Log(LinearMean(slope=[1.50 ± 0.50]₂) * ConstantMean(constant=3.00) + AffineMean(slope=1.00, intercept=0.50) - Exp(LinearMean(slope=0.50)))) * ConstantMean(constant=[2.50 ± 1.12]₂ₓ₂)



## Conclusion

This notebook demonstrated the mean function formatting capabilities:

1. **Simple means**: Clear display of hyperparameters (`ConstantMean`, `LinearMean`, `AffineMean`)
2. **Multidimensional parameters**: `[mean ± std]_shape` format with subscript indices
3. **Operators**: Intuitive notation for `+`, `-`, and `*`
4. **Wrappers**: Functional notation (`Exp(...)`, `Log(...)`)
5. **Combinations**: Correct handling of priorities and parentheses
6. **Nesting**: Support for arbitrarily complex structures

The formatting system makes mean expressions readable while preserving their mathematical structure.