In [None]:
import sympy as sp
from moment import *
from latex_helper import *

# Running Example (example 1)

In [None]:
import time
import sympy as sp

###############################################################################
# Running Example
###############################################################################

def example_running():
    # SDE in 2D:
    #   dX_t = -X_t dt + dW_t^(1)
    #   dY_t = (-2 Y_t + X_t + X_t^2) dt + X_t dW_t^(2)
    #
    # State variables: x1 = X, x2 = Y
    x1, x2 = sp.symbols('x1 x2')
    vars_ = (x1, x2)

    b_vec = [
        -x1,
        -sp.Integer(2)*x2 + x1 + x1**2
    ]

    sigma_mat = [
        [sp.Integer(1), sp.Integer(0)],
        [sp.Integer(0), x1]
    ]

    # Target moment: m_{0,2}(t) = E[Y_t^2] = E[x2^2]
    alpha = (0, 2)

    # ---- Timing: closure (Algorithm 1) ----
    t0 = time.perf_counter()
    S, A_mat, c_vec = moment_closure_algorithm(b_vec, sigma_mat, alpha, vars_)
    t1 = time.perf_counter()
    closure_time = t1 - t0

    print("\n=== running example ===")
    print(f"Size of S: |S| = {len(S)}")
    print(f"Time for obtaining closure S: {closure_time:.6f} seconds")

    print("Index set S (multi-indices and corresponding monomials):")
    for s in pretty_print_S(S, vars_):
        print("  ", s)

    # print("\nA matrix (size {}x{}):".format(*A_mat.shape))
    # sp.pprint(A_mat)
    # print("\nc vector:")
    # sp.pprint(c_vec)

    t = sp.symbols('t', real=True)

    # Initial condition (X_0, Y_0) = (0, 0) ⇒ all moments in S are 0 at t=0
    m0_vec = sp.zeros(len(S), 1)

    # ---- Timing: solving ODE system ----
    t2 = time.perf_counter()
    m_t = solve_moment_system(A_mat, c_vec, m0_vec, t)
    t3 = time.perf_counter()
    ode_time = t3 - t2

    print(f"\nTime for solving ODE system: {ode_time:.6f} seconds")

    # m_{0,2}(t) is the first component of m(t) because we started S with alpha=(0,2)
    m_0_2_t = sp.simplify(m_t[0, 0])

    print("\nSolution m_{0,2}(t) = E[Y_t^2]:")
    sp.pprint(m_0_2_t)

    # -------- LaTeX output --------
    print("\nLaTeX: index set S")
    print(latex_index_set(S))

    print("\nLaTeX: ODE system")
    print(latex_moment_ode_system(S, A_mat, c_vec))

    print("\nLaTeX: solution for all moments in S")
    print(latex_moment_solutions(S, m_t))

    print("\nLaTeX: single moment m_{0,2}(t)")
    print("$" + latex_single_moment(S, m_t, alpha) + "$")


example_running()


In [None]:
import time
import sympy as sp

###############################################################################
# ou-env Example
###############################################################################

def example_OUenv():
    # SDE in 2D:
    #   dX_t = -X_t dt + dW_t^(1)
    #   dY_t = (-2 Y_t + X_t + X_t^2) dt + X_t dW_t^(2)
    #
    # State variables: x1 = X, x2 = Y
    x1, x2 = sp.symbols('x1 x2')
    vars_ = (x1, x2)

    b_vec = [
        -x1,
        -sp.Integer(2)*x2 + x1 + x1**2
    ]

    sigma_mat = [
        [sp.Integer(1), 0],
        [0, x1]
    ]

    # target multi-index
    # alpha = (0, 2)
    # alpha = (0, 3)
    # alpha = (0, 4)
    # alpha = (0, 5)
    alpha = (0, 10)

    # ---- Timing: closure (Algorithm 1) ----
    t0 = time.perf_counter()
    S, A_mat, c_vec = moment_closure_algorithm(b_vec, sigma_mat, alpha, vars_)
    t1 = time.perf_counter()
    closure_time = t1 - t0

    print("\n=== OU-env example ===")
    print(f"Size of S: |S| = {len(S)}")
    print(f"Time for obtaining closure S: {closure_time:.6f} seconds")

    print("Index set S (multi-indices and corresponding monomials):")
    for s in pretty_print_S(S, vars_):
        print("  ", s)

    # print("\nA matrix (size {}x{}):".format(*A_mat.shape))
    # sp.pprint(A_mat)
    # print("\nc vector:")
    # sp.pprint(c_vec)

    t = sp.symbols('t', real=True)

    # Initial condition (X_0, Y_0) = (0, 0) ⇒ all moments in S are 0 at t=0
    m0_vec = sp.zeros(len(S), 1)

    # ---- Timing: solving ODE system ----
    t2 = time.perf_counter()
    m_t = solve_moment_system(A_mat, c_vec, m0_vec, t)
    t3 = time.perf_counter()
    ode_time = t3 - t2

    print(f"\nTime for solving ODE system: {ode_time:.6f} seconds")

    # target moment
    idx = S.index(alpha)
    m_alpha = sp.simplify(m_t[idx, 0])

    print("\nSolution m_{alpha}(t) = E[X^alpha]:")
    sp.pprint(m_alpha)

    # -------- LaTeX output --------
    print("\nLaTeX: index set S")
    print(latex_index_set(S))

    # print("\nLaTeX: ODE system for the moments")
    # print(latex_moment_ode_system(S, A_mat, c_vec))

    # print("\nLaTeX: solution for all moments in S")
    # print(latex_moment_solutions(S, m_t))

    print("\nLaTeX: E[X^alpha]")
    print("$" + latex_single_moment(S, m_t, alpha) + "$")

example_OUenv()

# Case study 1: Consensus System

In [None]:
import time
import sympy as sp

def example_concensus_system():
    # 2D SDE:
    #   dX1_t = ( - 2 X1_t + X2_t) dt +  X1_t dW_t^(1)
    #   dX2_t = ( X1_t - 2 X2_t) dt +  X2_t dW_t^(2)
    #
    # State variables
    x1, x2 = sp.symbols('x1 x2')
    vars_ = (x1, x2)

    # Drift b(x)
    b_vec = [
        - sp.Integer(2)*x1 + x2,
         x1 - sp.Integer(2)*x2
    ]

    # Diffusion sigma(x): diagonal, independent Brownian motions
    sigma_mat = [
        [ x1, sp.Integer(0)],
        [sp.Integer(0), x2]
    ]

    # Target monomial
    alpha = (1, 1)

    # ---- Timing: closure (Algorithm 1) ----
    t0 = time.perf_counter()
    S, A_mat, c_vec = moment_closure_algorithm(b_vec, sigma_mat, alpha, vars_)
    t1 = time.perf_counter()
    closure_time = t1 - t0

    print("\n=== concensus system example ===")
    print(f"Size of S: |S| = {len(S)}")
    print(f"Time for obtaining closure S: {closure_time:.6f} seconds")

    print("Index set S (multi-indices and corresponding monomials):")
    for s in pretty_print_S(S, vars_):
        print("  ", s)

    print("\nA matrix (size {}x{}):".format(*A_mat.shape))
    sp.pprint(A_mat)
    print("\nc vector:")
    sp.pprint(c_vec)

    t = sp.symbols('t', real=True)

    # # Initial state X_0 = (0,0)  ⇒ all moments in S start from 0
    # m0_vec = sp.zeros(len(S), 1)

    # Initial state X_0 = (1,0)
    x10 = sp.Integer(1)
    x20 = sp.Integer(0)
    m0_list = []
    for beta in S:
        # m_beta(0) = 1^{beta1} * 0^{beta2}
        val = (x10 ** beta[0]) * (x20 ** beta[1])
        m0_list.append(val)
    m0_vec = sp.Matrix(m0_list)

    # ---- Timing: solving ODE system ----
    t2 = time.perf_counter()
    m_t = solve_moment_system(A_mat, c_vec, m0_vec, t)
    t3 = time.perf_counter()
    ode_time = t3 - t2

    print(f"\nTime for solving ODE system: {ode_time:.6f} seconds")

    # m_{0,2}(t) corresponds to monomial x2^2
    idx = S.index(alpha)
    m_0_2_t = sp.simplify(m_t[idx, 0])

    print("\nSolution m_alpha:")
    sp.pprint(m_0_2_t)

    # -------- LaTeX output --------
    print("\nLaTeX: index set S")
    print(latex_index_set(S))

    print("\nLaTeX: ODE system for the moments")
    print(latex_moment_ode_system(S, A_mat, c_vec))

    print("\nLaTeX: solution for all moments in S")
    print(latex_moment_solutions(S, m_t))

    print("\nLaTeX: E[X_alpha]")
    print("$" + latex_single_moment(S, m_t, alpha) + "$")


example_concensus_system()


# Vehicle platoon

In [None]:
import time
import sympy as sp

def example_vehicle_platoon_simple():
    # State variables ordered as (p1, v1, p2, v2)
    p1, v1, p2, v2 = sp.symbols('p1 v1 p2 v2')
    vars_ = (p1, v1, p2, v2)

    # Parameters (instantiated as in the text)
    d_des = 0

    # Spacing error s = (p1 - p2)
    s = p1 - p2 

    # Drift b(x):
    #   dp1 = v1 dt
    #   dv1 = (-a1 v1 + u1) dt + sigma1 dW^1, with a1=1, u1=1, sigma1=1
    #   dp2 = v2 dt
    #   dv2 = (-a2 v2 + (v1 - 1)^2) dt + sigma2 dW^2, a2=1, k=1, sigma2=1
    b_vec = [
        v1,                           # dp1
        -v1 + 1,                      # dv1
        v2,                           # dp2
        -v2 + sp.Rational(1, 2) + (v1 - 1)**2 ,     # dv2
        # -v2 + s + (v1 - 1)**2  
    ]

    # Diffusion matrix sigma(x) (4×2):
    # noise only in v1 (W^1) and v2 (W^2), both with coefficient 1
    sigma_mat = [
        [0, 0],   # dp1
        [0, 0],   # dv1
        [0, 0],   # dp2
        [0, sp.Rational(1, 10)],   # dv2
    ]

    # Initial state (p1, v1, p2, v2) = (1, 0, 0, 0)
    p1_0, v1_0, p2_0, v2_0 = 1, 0, 0, 0

    # We want moments of p1 and p2:
    targets = [
        # ((1, 0, 0, 0), "p_1"),  # monomial p1
        # ((0, 0, 1, 0), "p_2"),  # monomial p2
        # ((2, 0, 0, 0), "p_1^2"),  # monomial p1^2
        ((0, 0, 2, 0), "p_2^2"),  # monomial p2^2
        # ((1, 0, 1, 0), "p_1p_2"),  # monomial p1p2
    ]

    t = sp.symbols('t', real=True)

    for alpha, label in targets:
        # --------- closure S (Algorithm 1) ----------
        t0 = time.perf_counter()
        S, A_mat, c_vec = moment_closure_algorithm(b_vec, sigma_mat, alpha, vars_)
        t1 = time.perf_counter()

        print(f"\n=== Nonlinear vehicle platoon ({label}) ===")
        print(f"Target monomial: {label}")
        print(f"Size of S: |S| = {len(S)}")
        print(f"Time for obtaining closure S: {t1 - t0:.6f} seconds")

        print("Index set S (multi-indices and corresponding monomials):")
        for s_str in pretty_print_S(S, vars_):
            print("  ", s_str)

        print("\nA matrix (size {}x{}):".format(*A_mat.shape))
        sp.pprint(A_mat)
        print("\nc vector:")
        sp.pprint(c_vec)

        # Initial moment vector from deterministic initial state
        m0_list = []
        for beta in S:
            val = (p1_0 ** beta[0]) * (v1_0 ** beta[1]) \
                  * (p2_0 ** beta[2]) * (v2_0 ** beta[3])
            m0_list.append(val)
        m0_vec = sp.Matrix(m0_list)

        # --------- solve ODE system ----------
        t2 = time.perf_counter()
        m_t = solve_moment_system(A_mat, c_vec, m0_vec, t)
        t3 = time.perf_counter()
        print(f"\nTime for solving ODE system: {t3 - t2:.6f} seconds")

        # Extract desired moment E[p1(t)] or E[p2(t)]
        idx = S.index(alpha)
        moment_expr = sp.simplify(m_t[idx, 0])

        print(f"\nSolution for E[{label}(t)]:")
        sp.pprint(moment_expr)

        # --------- optional LaTeX output ----------
        print("\nLaTeX: index set S")
        print(latex_index_set(S))

        print("\nLaTeX: ODE system")
        print(latex_moment_ode_system(S, A_mat, c_vec))

        print("\nLaTeX: solution for all moments in S")
        print(latex_moment_solutions(S, m_t))

        print(f"\nLaTeX: E[{label}(t)]")
        print("$" + latex_single_moment(S, m_t, alpha) + "$")

# To run:
example_vehicle_platoon_simple()


# Gene example

In [None]:
import time
import sympy as sp

def example_gene_expression_network():
    # 5D gene regulatory network SDE (from the figure in the paper)
    x1, x2, x3, x4, x5 = sp.symbols('x1 x2 x3 x4 x5')
    vars_ = (x1, x2, x3, x4, x5)

    # Drift b(x)
    b_vec = [
        -x1 + 1,
        sp.Rational(12, 10)* x1 - sp.Rational(8, 10) * x2,
        x2 - sp.Rational(7, 10) * x3 + sp.Rational(2, 10) * x1**2,
        sp.Rational(9, 10) * x3 - sp.Rational(6, 10) * x4 + sp.Rational(1, 10) * x1 * x2,
        sp.Rational(8, 10)* x4 - sp.Rational(5, 10) * x5 + sp.Rational(15, 100) * x3**2 + sp.Rational(5, 100) * x1**3,
    ]

    # Diffusion sigma(x): diagonal, one independent Brownian motion per component
    sigma_mat = [
        [sp.Rational(1, 2), sp.Integer(0), sp.Integer(0), sp.Integer(0), sp.Integer(0)],
        [sp.Integer(0), sp.Rational(3, 10) * x1 +sp.Rational(2, 5), 0, 0, 0],
        [sp.Integer(0),sp.Integer(0), sp.Rational(1, 2) * x2 + sp.Rational(1, 10) * x1**2, sp.Integer(0), sp.Integer(0)],
        [sp.Integer(0), sp.Integer(0), sp.Integer(0), sp.Rational(2, 5) * x3 +sp.Rational(1, 5)* x2**2, sp.Integer(0)],
        [sp.Integer(0), sp.Integer(0),sp.Integer(0), sp.Integer(0), sp.Rational(3, 10) * x4 + sp.Rational(1, 10) * x3**2 + sp.Rational(5, 100) * x1**3],
    ]

    # Target moment
    # alpha = (1, 0, 0, 0, 1)
    # alpha = (0, 0, 0, 0, 2)
    alpha = (1, 0, 0, 0, 2)

    # ---- Timing: closure (Algorithm 1) ----
    t0 = time.perf_counter()
    S, A_mat, c_vec = moment_closure_algorithm(b_vec, sigma_mat, alpha, vars_)
    t1 = time.perf_counter()
    closure_time = t1 - t0

    print("\n=== Gene expression network example ===")
    print(f"Size of S: |S| = {len(S)}")
    print(f"Time for obtaining closure S: {closure_time:.6f} seconds")

    print("Index set S (multi-indices and corresponding monomials):")
    for s in pretty_print_S(S, vars_):
        print("  ", s)

    # print("\nA matrix (size {}x{}):".format(*A_mat.shape))
    # sp.pprint(A_mat)
    # print("\nc vector:")
    # sp.pprint(c_vec)

    t = sp.symbols('t', real=True)

    # Initial state X_0 = 0 => all moments in S start from 0
    m0_vec = sp.zeros(len(S), 1)

    # ---- Timing: solving ODE system ----
    t2 = time.perf_counter()
    m_t = solve_moment_system(A_mat, c_vec, m0_vec, t)
    t3 = time.perf_counter()
    ode_time = t3 - t2

    print(f"\nTime for solving ODE system: {ode_time:.6f} seconds")

    # target moment
    idx = S.index(alpha)
    m_alpha = sp.simplify(m_t[idx, 0])

    print("\nSolution m_{alpha}(t) = E[X^alpha]:")
    sp.pprint(m_alpha)

    # -------- LaTeX output --------
    print("\nLaTeX: index set S")
    print(latex_index_set(S))

    # print("\nLaTeX: ODE system for the moments")
    # print(latex_moment_ode_system(S, A_mat, c_vec))

    # print("\nLaTeX: solution for all moments in S")
    # print(latex_moment_solutions(S, m_t))

    print("\nLaTeX: E[X^alpha]")
    print("$" + latex_single_moment(S, m_t, alpha) + "$")

example_gene_expression_network()

# Non Pro-solvable System (not use in paper)

In [None]:
import time
import sympy as sp

def example_2d_non_prosolvable_x1_2_x2_2():
    # 2D SDE:
    #   dX1_t = -X1_t^3 dt + sqrt(2) * X1_t^2 dW_t^(1)
    #   dX2_t = -X2_t   dt + 1 * dW_t^(2)
    #
    # Initial condition: (X1_0, X2_0) = (1, 1)

    x1, x2 = sp.symbols('x1 x2')
    vars_ = (x1, x2)

    # Parameters λ = 1, σ = 1; use SymPy sqrt for sqrt(2)
    lam = sp.Integer(1)
    sig = sp.Integer(1)

    # Drift
    b_vec = [
        -x1**3,          # b1(x) = -x1^3
        -lam * x2        # b2(x) = -λ x2, λ=1
    ]

    # Diffusion matrix (2×2) with independent Brownian motions
    sigma_mat = [
        [sp.sqrt(2) * x1**2, 0],   # σ1(x) = sqrt(2) x1^2
        [0, sig]                   # σ2(x) = σ, σ=1
    ]

    # Target monomial x1^2 x2^2  ->  m_{2,2}(t)
    alpha = (2, 2)

    # ---------- closure S (Algorithm 1) ----------
    t0 = time.perf_counter()
    S, A_mat, c_vec = moment_closure_algorithm(b_vec, sigma_mat, alpha, vars_)
    t1 = time.perf_counter()

    print("\n=== 2D non–pro-solvable SDE example (moment x1^2 x2^2) ===")
    print(f"Size of S: |S| = {len(S)}")
    print(f"Time for obtaining closure S: {t1 - t0:.6f} seconds")

    print("Index set S (multi-indices and corresponding monomials):")
    for s_str in pretty_print_S(S, vars_):
        print("  ", s_str)

    print("\nA matrix (size {}x{}):".format(*A_mat.shape))
    sp.pprint(A_mat)
    print("\nc vector:")
    sp.pprint(c_vec)

    t = sp.symbols('t', real=True)

    # Initial condition (X1_0, X2_0) = (1, 1)
    x1_0 = sp.Integer(1)
    x2_0 = sp.Integer(1)

    # Build initial moment vector m(0)
    m0_list = []
    for beta in S:
        val = (x1_0 ** beta[0]) * (x2_0 ** beta[1])
        m0_list.append(val)
    m0_vec = sp.Matrix(m0_list)

    # ---------- solve ODE system ----------
    t2 = time.perf_counter()
    m_t = solve_moment_system(A_mat, c_vec, m0_vec, t)
    t3 = time.perf_counter()
    print(f"\nTime for solving ODE system: {t3 - t2:.6f} seconds")

    # Extract m_{2,2}(t) = E[X1_t^2 X2_t^2]
    idx = S.index(alpha)
    m_2_2_t = sp.simplify(m_t[idx, 0])

    print("\nSolution m_{2,2}(t) = E[X_{1,t}^2 X_{2,t}^2]:")
    sp.pprint(m_2_2_t)

    # ---------- LaTeX outputs (optional) ----------
    print("\nLaTeX: index set S")
    print(latex_index_set(S))

    print("\nLaTeX: ODE system")
    print(latex_moment_ode_system(S, A_mat, c_vec))

    print("\nLaTeX: solution for all moments in S")
    print(latex_moment_solutions(S, m_t))

    print("\nLaTeX: E[X_{1,t}^2 X_{2,t}^2]")
    print("$" + latex_single_moment(S, m_t, alpha) + "$")

# To run:
example_2d_non_prosolvable_x1_2_x2_2()


# Oscillator

In [None]:
import time
import sympy as sp

def example_oscillator():
    # State variables X_t = (x1, x2, x3)
    x1, x2, x3 = sp.symbols('x1 x2 x3')
    vars_ = (x1, x2, x3)

    # Use exact rationals instead of floats:
    # 0.3 = 3/10, 0.8 = 4/5, 0.2 = 1/5, 0.5 = 1/2.
    b_vec = [
        x2,  # dX1_t = X2_t dt
        -sp.Rational(3, 10)*x2 - x1 + sp.Rational(4, 5)*x3**2,  # dX2 drift
        -x3  # dX3 drift
    ]

    # Diffusion matrix sigma(x) (3×2), W^(1) and W^(2) independent
    sigma_mat = [
        [0, 0],                                  # dX1 has no noise
        [sp.Rational(1, 5)*x2, 0],               # 0.2 X2 dW^(1)
        [0, sp.Rational(1, 2)]                   # 0.5 dW^(2)
    ]


    alpha = (0, 1, 2)

    # ---------- closure S (Algorithm 1) ----------
    t0 = time.perf_counter()
    S, A_mat, c_vec = moment_closure_algorithm(b_vec, sigma_mat, alpha, vars_)
    t1 = time.perf_counter()

    print("\n=== osclliator SDE example  ===")
    print(f"Size of S: |S| = {len(S)}")
    print(f"Time for obtaining closure S: {t1 - t0:.6f} seconds")

    print("Index set S (multi-indices and corresponding monomials):")
    for s_str in pretty_print_S(S, vars_):
        print("  ", s_str)

    print("\nA matrix (size {}x{}):".format(*A_mat.shape))
    sp.pprint(A_mat)
    print("\nc vector:")
    sp.pprint(c_vec)

    t = sp.symbols('t', real=True)

    # Initial condition (X1_0, X2_0, X3_0) = (0, 0, 0)
    x1_0 = sp.Integer(0)
    x2_0 = sp.Integer(0)
    x3_0 = sp.Integer(0)

    # Build initial moment vector m(0)
    m0_list = []
    for beta in S:
        val = (x1_0 ** beta[0]) * (x2_0 ** beta[1]) * (x3_0 ** beta[2])
        m0_list.append(val)
    m0_vec = sp.Matrix(m0_list)

    # ---------- solve ODE system ----------
    t2 = time.perf_counter()
    m_t = solve_moment_system(A_mat, c_vec, m0_vec, t)
    t3 = time.perf_counter()
    print(f"\nTime for solving ODE system: {t3 - t2:.6f} seconds")


    idx = S.index(alpha)
    m_alpha = sp.simplify(m_t[idx, 0])

    print("\nSolution m_alpha(t) = E[X_alpha]:")
    sp.pprint(m_alpha)

    # ---------- LaTeX outputs (optional) ----------
    print("\nLaTeX: index set S")
    print(latex_index_set(S))

    print("\nLaTeX: ODE system")
    print(latex_moment_ode_system(S, A_mat, c_vec))

    print("\nLaTeX: solution for all moments in S")
    print(latex_moment_solutions(S, m_t))

    print("\nLaTeX: E[X_alpha]")
    print("$" + latex_single_moment(S, m_t, alpha) + "$")

# To run:
example_oscillator()


# Coupled3d

In [None]:
import time
import sympy as sp

def example_coupled_3d_box():
    # SDE:
    #   dX1_t = (-1/2 X1_t - X1_t X2_t - 1/2 X1_t X2_t^2) dt + X1_t (1 + X2_t) dW_t^(1)
    #   dX2_t = (-X2_t + X3_t) dt + 0.3 X3_t dW_t^(2)
    #   dX3_t = (X2_t - X3_t) dt + 0.3 X2_t dW_t^(3)
    #
    # Initial state: (X1_0, X2_0, X3_0) = (0, 0, 0)
    # Target moment: x1^2 x2^2 -> m_{2,2,0}(t)

    x1, x2, x3 = sp.symbols('x1 x2 x3')
    vars_ = (x1, x2, x3)

    # Use exact rationals: 1/2, 0.3 = 3/10
    half = sp.Rational(1, 2)
    three_tenths = sp.Rational(3, 10)

    # Drift b(x)
    b_vec = [
        -half * x1 - x1 * x2 - half * x1 * x2**2,  # dX1 drift
        -x2 + x3,                                  # dX2 drift
        x2 - x3                                    # dX3 drift
    ]

    # Diffusion matrix sigma(x): 3×3, W^(1), W^(2), W^(3) independent
    sigma_mat = [
        [x1 * (1 + x2), 0, 0],                 # dW^(1) in X1
        [0, three_tenths * x3, 0],             # dW^(2) in X2
        [0, 0, three_tenths * x2],             # dW^(3) in X3
    ]

    # Target monomial x1^2 x2^2
    alpha = (2, 2, 0)

    # ---------- closure S (Algorithm 1) ----------
    t0 = time.perf_counter()
    S, A_mat, c_vec = moment_closure_algorithm(b_vec, sigma_mat, alpha, vars_)
    t1 = time.perf_counter()

    print("\n=== 3D box system example (moment x1^2 x2^2) ===")
    print(f"Size of S: |S| = {len(S)}")
    print(f"Time for obtaining closure S: {t1 - t0:.6f} seconds")

    print("Index set S (multi-indices and corresponding monomials):")
    for s_str in pretty_print_S(S, vars_):
        print("  ", s_str)

    print("\nA matrix (size {}x{}):".format(*A_mat.shape))
    sp.pprint(A_mat)
    print("\nc vector:")
    sp.pprint(c_vec)

    t = sp.symbols('t', real=True)

    # Initial condition (0,0,0)
    x1_0 = sp.Integer(0)
    x2_0 = sp.Integer(0)
    x3_0 = sp.Integer(0)

    # Build m(0)
    m0_vec = sp.Matrix([
        (x1_0 ** beta[0]) * (x2_0 ** beta[1]) * (x3_0 ** beta[2])
        for beta in S
    ])

    # ---------- solve ODE system ----------
    t2 = time.perf_counter()
    m_t = solve_moment_system(A_mat, c_vec, m0_vec, t)
    t3 = time.perf_counter()
    print(f"\nTime for solving ODE system: {t3 - t2:.6f} seconds")

    # m_{2,2,0}(t) = E[X1_t^2 X2_t^2]
    idx = S.index(alpha)
    m_2_2_0_t = sp.simplify(m_t[idx, 0])

    print("\nSolution m_{2,2,0}(t) = E[X_{1,t}^2 X_{2,t}^2]:")
    sp.pprint(m_2_2_0_t)

    # ---------- optional LaTeX output ----------
    print("\nLaTeX: index set S")
    print(latex_index_set(S))

    print("\nLaTeX: ODE system")
    print(latex_moment_ode_system(S, A_mat, c_vec))

    print("\nLaTeX: solution for all moments in S")
    print(latex_moment_solutions(S, m_t))

    print("\nLaTeX: E[X_{1,t}^2 X_{2,t}^2]")
    print("$" + latex_single_moment(S, m_t, alpha) + "$")

# To run:
example_coupled_3d_box()
