In [1]:
import sympy

In [2]:
def constant_folding(expr, state_vars, start_index=0):
    """
    Simplifies an expression by folding constants.

    Returns: (folded_expression, next_constant_index)
    """
    # Expand expression (multiply out)
    expanded_exp = sympy.expand(expr, state_vars)

    # Define a list to store the folded-constant terms and a dictionary for constant mapping
    folded_terms = []
    constant_map = {}
    new_constant_index = start_index

    # Separate out terms at +
    terms_list = sympy.Add.make_args(expanded_exp)

    # --- Perform multiplicative constant folding ---
    for term in terms_list:
        # Isolate constants in term and symbols in term
        constant_part, state_part = term.as_independent(*state_vars)

        # If the constant part is a number, keep it as is.
        if not constant_part.free_symbols:
            updated_constant_part = constant_part

        # Fold the constants if they are symbolic
        elif constant_part in constant_map:
            # Use the existing mapped constant
            updated_constant_part = constant_map[constant_part]
        else:
            # Create a new symbol for the folded constant
            updated_constant_part = sympy.Symbol(f"K{new_constant_index}")

            # Store the constant mapping
            constant_map[constant_part] = updated_constant_part
            new_constant_index += 1

        # Put the folded constant and the state part back together
        folded_term = updated_constant_part * state_part

        # Add the new folded term to the list
        folded_terms.append(folded_term)

    # Put all the folded terms back together
    folded_exp = sum(folded_terms)

    # --- Perform additive constant folding ---
    coeff_accumulator = {}  # Key: state_part, Value: sum of K coefficients

    # Separate out terms at +
    terms_list = sympy.Add.make_args(folded_exp)

    for term in terms_list:
        # Split coefficient from state
        coeff, state_part = term.as_independent(*state_vars)

        # Add to dictionary
        if state_part in coeff_accumulator:
            coeff_accumulator[state_part] += coeff
        else:
            coeff_accumulator[state_part] = coeff

    # Storage for folded terms
    final_folded_terms = []

    for state_part, coefficient in coeff_accumulator.items():

        # If the combined coefficient is purely a number, don't fold it.
        if not coefficient.free_symbols:
            updated_coefficient = coefficient

        # Check if coefficient is already a pure K symbol
        elif coefficient in constant_map.values():
            updated_coefficient = coefficient

        # Create a new K-symbol if new coeff combination
        else:
            while sympy.Symbol(f"K{new_constant_index}") in constant_map.values():
                new_constant_index += 1

            new_K_symbol = sympy.Symbol(f"K{new_constant_index}")

            # Map the sum to the new symbol for tracking
            constant_map[coefficient] = new_K_symbol

            updated_coefficient = new_K_symbol

            # Increment const index
            new_constant_index += 1

        # Recombine the fully folded coefficient with the state part
        final_folded_term = updated_coefficient * state_part
        final_folded_terms.append(final_folded_term)

    # Sort the terms alphabetically by their state variable string (i, r, s...)
    final_folded_terms.sort(key=lambda x: str(x.as_independent(*state_vars)[1]))

    # Create the substitution map based on the sorted order
    rename_map = {}
    final_b_index = start_index

    for term in final_folded_terms:
        # Extract the coefficient (which is a K symbol or a number)
        coeff, state_part = term.as_independent(*state_vars)

        # Rename symbolic constants
        if coeff.free_symbols:
            # If we haven't renamed this K yet, give it the next B number
            if coeff not in rename_map:
                new_b = sympy.Symbol(f"B{final_b_index}")
                rename_map[coeff] = new_b
                final_b_index += 1

    # Construct the final expression
    temp_sum = sum(final_folded_terms)
    final_folded_exp_reindexed = temp_sum.subs(rename_map)

    # Calculate final outputs
    final_consts = list(rename_map.values())

    # Sort B-consts to look nice in the list [B0, B1, B2]
    final_consts.sort(key=lambda s: int(s.name[1:]))

    return final_folded_exp_reindexed, final_b_index, final_consts


def collect_partial_state_vars(expr, base_state_vars=None):
    """
    Collect state variables for partial expressions:
      - base state variables: {s, i, r} (if not provided)
      - any M{int} symbols present in the expression
    """
    if base_state_vars is None:
        s_sym, i_sym, r_sym = sympy.symbols("s i r")
        base_state_vars = {s_sym, i_sym, r_sym}
    else:
        base_state_vars = set(base_state_vars)

    # Add all M{int} symbols referenced in expr
    m_syms = {
        sym
        for sym in expr.free_symbols
        if sym.name.startswith("M") and sym.name[1:].isdigit()
    }
    return base_state_vars | m_syms


def constant_folding_partial(expr, state_vars_base=None, start_index=0):
    """
    Constant folding for partial expressions that may contain M0, M1, ... tokens.

    Args:
        expr (sympy.Expr): The partial expression to fold.
        state_vars_base (set[sympy.Symbol] | None):
            Base state variables (defaults to s, i, r if None).
        start_index (int): Starting index for introduced K/B symbols.

    Returns:
        tuple: (folded_expression, next_constant_index, folded_B_constants)
    """
    state_vars = collect_partial_state_vars(expr, state_vars_base)
    return constant_folding(expr=expr, state_vars=state_vars, start_index=start_index)

In [3]:
# Test out partial constant folding

# Define state symbols
s, i, r = sympy.symbols("s i r")
state_vars = {s, i, r}

examples = [
    'M3*(M0*M1 + M2 + i) + M4',
    'M0*(C0 + 2*i) + M1',
    'C0*C1*r**2',
    'M2*(C0*C1 - M0 + M1)'
]

for example in examples:
    expr = sympy.sympify(example)
    print(f"Original expression:                   {expr}")
    folded_expr = constant_folding_partial(expr, state_vars)
    print(f"Final folded expression:               {folded_expr[0]}")
    print()

Original expression:                   M3*(M0*M1 + M2 + i) + M4
Final folded expression:               M0*M1*M3 + M2*M3 + M3*i + M4

Original expression:                   M0*(C0 + 2*i) + M1
Final folded expression:               B0*M0 + 2*M0*i + M1

Original expression:                   C0*C1*r**2
Final folded expression:               B0*r**2

Original expression:                   M2*(C0*C1 - M0 + M1)
Final folded expression:               B0*M2 - M0*M2 + M1*M2



In [2]:
def constant_folding(expr, state_vars):
    # Expand expression (multiply out)
    expanded_exp = sympy.expand(expr, state_vars)

    # Define a list to store the folded-constant terms and a dictionary for constant mapping
    folded_terms = []
    constant_map = {}
    new_constant_index = 0

    # Separate out terms at +
    terms_list = sympy.Add.make_args(expanded_exp)

    # --- Perform multiplicative constant folding ---
    for term in terms_list:
        # Isolate constants in term and symbols in term
        constant_part, state_part = term.as_independent(*state_vars)

        # If the constant part is a number, keep it as is.
        if not constant_part.free_symbols:
            updated_constant_part = constant_part

        # Fold the constants if they are symbolic
        elif constant_part in constant_map:
            # Use the existing mapped constant
            updated_constant_part = constant_map[constant_part]
        else:
            # Create a new symbol for the folded constant
            updated_constant_part = sympy.Symbol(f"K{new_constant_index}")

            # Store the constant mapping
            constant_map[constant_part] = updated_constant_part
            new_constant_index += 1

        # Put the folded constant and the state part back together
        folded_term = updated_constant_part * state_part

        # Add the new folded term to the list
        folded_terms.append(folded_term)

    # Put all the folded terms back together
    folded_exp = sum(folded_terms)

    # --- Perform additive constant folding ---
    coeff_accumulator = {}  # Key: state_part, Value: sum of K coefficients

    # Separate out terms at +
    terms_list = sympy.Add.make_args(folded_exp)

    for term in terms_list:
        # Split coefficient from state
        coeff, state_part = term.as_independent(*state_vars)

        # Add to dictionary
        if state_part in coeff_accumulator:
            coeff_accumulator[state_part] += coeff
        else:
            coeff_accumulator[state_part] = coeff

    # Storage for folded terms
    final_folded_terms = []

    for state_part, coefficient in coeff_accumulator.items():

        # If the combined coefficient is purely a number, don't fold it.
        if not coefficient.free_symbols:
            updated_coefficient = coefficient

        # Check if coefficient is already a pure K symbol
        elif coefficient in constant_map.values():
            updated_coefficient = coefficient

        # Create a new K-symbol if new coeff combination
        else:
            while sympy.Symbol(f"K{new_constant_index}") in constant_map.values():
                new_constant_index += 1

            new_K_symbol = sympy.Symbol(f"K{new_constant_index}")

            # Map the sum to the new symbol for tracking
            constant_map[coefficient] = new_K_symbol

            updated_coefficient = new_K_symbol

        # Recombine the fully folded coefficient with the state part
        final_folded_term = updated_coefficient * state_part
        final_folded_terms.append(final_folded_term)

    # Put the final folded terms together
    final_folded_exp = sum(final_folded_terms)

    print(f"Original expression:                   {expr}")
    print(f"Expression after multiplicative step : {folded_exp}")
    print(f"Final folded expression:               {final_folded_exp}")

In [3]:
# Define state symbols
s, i, r = sympy.symbols("s i r")
state_vars = {s, i, r}

examples = [
    "C0 + C1",
    "C0*s + C1*s",
    "C0*s + s",
    "C0*s*r + C1*(C2*r*s + C3)",
    "C0*r*s + C1*(C2*r*s + C3)+C4*s",
    "C0*s*i*C1(C2+C3)",
    "C1*r + i - r",
    "C1*r - 2*i - r",
    "C1*(C2 + i) - C3",
    "(C2 + C3)*(C4 + r)",
    "s*(C0 - C1 + i**2)",
    "C0*s**2 + C1*s**2 + 5*s**2",
    "C0*i**3 - C1*i**3 + i**3",
    "C0*s**2 + C1*s + C2*s**2 + s",
    "C0*s*i + C1*i*s",
    "C0*s*i*r - 2*s*i*r + C1*s*r*i",
    "s*i*(C0 + C1) + s*i*C2",
    "C0 * (s + C1 * (i - C2 * r))",
    "(C0*s + C1*i) * (C2*s + C3*i)",
    "C0 * (s + 1) * (i + 1)",
    "s + i + r",
    "2*s - 2*s",
    "C0*s + s",
    "2*i + C1*i - 5*i",
]

for example in examples:
    expr = sympy.sympify(example)
    constant_folding(expr, state_vars)
    print()

Original expression:                   C0 + C1
Expression after multiplicative step : K0 + K1
Final folded expression:               K2

Original expression:                   C0*s + C1*s
Expression after multiplicative step : K0*s + K1*s
Final folded expression:               K2*s

Original expression:                   C0*s + s
Expression after multiplicative step : K0*s + s
Final folded expression:               K1*s

Original expression:                   C0*r*s + C1*(C2*r*s + C3)
Expression after multiplicative step : K0 + K1*r*s + K2*r*s
Final folded expression:               K0 + K3*r*s

Original expression:                   C0*r*s + C1*(C2*r*s + C3) + C4*s
Expression after multiplicative step : K0 + K1*s + K2*r*s + K3*r*s
Final folded expression:               K0 + K1*s + K4*r*s

Original expression:                   C0*i*s*C1(C2 + C3)
Expression after multiplicative step : K0*i*s
Final folded expression:               K0*i*s

Original expression:                   C1*r + i -

In [4]:
# Define state symbols
s, i, r = sympy.symbols("s i r")
state_vars = {s, i, r}

expr = sympy.simplify("C0*r*s + C1*(C2*r*s + C3)+C4*s - i")

# Expand expression (multiply out)
expanded_exp = sympy.expand(expr, state_vars)
print("Expanded exp", expanded_exp)

# Define a list to store the folded-constant terms and a dictionary for constant mapping
folded_terms = []
constant_map = {}
new_constant_index = 0

# Separate out terms at +
terms_list = sympy.Add.make_args(expanded_exp)

# --- Perform multiplicative constant folding ---
for term in terms_list:
    print("Term", term)
    # Isolate constants in term and symbols in term
    constant_part, state_part = term.as_independent(*state_vars)
    print("Const", constant_part, "State", state_part)

    # If the constant part is a number, keep it as is.
    if not constant_part.free_symbols:
        updated_constant_part = constant_part

    # Fold the constants if they are symbolic
    elif constant_part in constant_map:
        # Use the existing mapped constant
        updated_constant_part = constant_map[constant_part]
    else:
        # Create a new symbol for the folded constant
        updated_constant_part = sympy.Symbol(f"K{new_constant_index}")

        # Store the constant mapping
        constant_map[constant_part] = updated_constant_part
        new_constant_index += 1

    print("Updated constant part", updated_constant_part)

    # Put the folded constant and the state part back together
    folded_term = updated_constant_part * state_part

    # Add the new folded term to the list
    folded_terms.append(folded_term)

# Put all the folded terms back together
folded_exp = sum(folded_terms)

# --- Perform additive constant folding ---
coeff_accumulator = {}  # Key: state_part, Value: sum of K coefficients

# Separate out terms at +
terms_list = sympy.Add.make_args(folded_exp)
print("terms_list", terms_list)

for term in terms_list:
    # Split coefficient from state
    coeff, state_part = term.as_independent(*state_vars)

    # Add to dictionary
    if state_part in coeff_accumulator:
        coeff_accumulator[state_part] += coeff
    else:
        coeff_accumulator[state_part] = coeff

# Storage for folded terms
final_folded_terms = []

print("coeff_accumulator", coeff_accumulator)

for state_part, coefficient in coeff_accumulator.items():

    # If the combined coefficient is purely a number, don't fold it.
    if not coefficient.free_symbols:
        updated_coefficient = coefficient

    # Check if coefficient is already a pure K symbol
    elif coefficient in constant_map.values():
        updated_coefficient = coefficient

    # Create a new K-symbol if new coeff combination
    else:
        while sympy.Symbol(f"K{new_constant_index}") in constant_map.values():
            new_constant_index += 1

        new_K_symbol = sympy.Symbol(f"K{new_constant_index}")

        # Map the sum to the new symbol for tracking
        constant_map[coefficient] = new_K_symbol

        updated_coefficient = new_K_symbol
        new_constant_index += 1

    # Recombine the fully folded coefficient with the state part
    final_folded_term = updated_coefficient * state_part
    final_folded_terms.append(final_folded_term)

# Put the final folded terms together
final_folded_exp = sum(final_folded_terms)

print(f"Original expression:                   {expr}")
print(f"Expression after multiplicative step : {folded_exp}")
print(f"Final folded expression:               {final_folded_exp}")

print(constant_map)
print(final_folded_exp.free_symbols - state_vars)

Expanded exp C0*r*s + C1*C2*r*s + C1*C3 + C4*s - i
Term -i
Const -1 State i
Updated constant part -1
Term C1*C3
Const C1*C3 State 1
Updated constant part K0
Term C4*s
Const C4 State s
Updated constant part K1
Term C0*r*s
Const C0 State r*s
Updated constant part K2
Term C1*C2*r*s
Const C1*C2 State r*s
Updated constant part K3
terms_list (K0, -i, K1*s, K2*r*s, K3*r*s)
coeff_accumulator {1: K0, i: -1, s: K1, r*s: K2 + K3}
Original expression:                   C0*r*s + C1*(C2*r*s + C3) + C4*s - i
Expression after multiplicative step : K0 + K1*s + K2*r*s + K3*r*s - i
Final folded expression:               K0 + K1*s + K4*r*s - i
{C1*C3: K0, C4: K1, C0: K2, C1*C2: K3, K2 + K3: K4}
{K0, K1, K4}


$L_x = 2x$

In [5]:
list(final_folded_exp.free_symbols - state_vars)

[K0, K1, K4]

In [6]:
def count_vars(expr):
    # Base case: if the node is a Symbol, check if it's one we want to count
    if expr.is_Symbol:
        return 1 if expr in expr.free_symbols else 0

    # Recursive step: Sum the counts of all arguments (children)
    # This replaces the explicit loop of preorder_traversal
    return sum(count_vars(arg) for arg in expr.args)

In [7]:
# count number of rules
terminal_rules = count_vars(final_folded_exp)
non_terminal_rules = sympy.count_ops(final_folded_exp)
print(f"Terminal rules = {terminal_rules}")
print(f"Non-terminal rules = {non_terminal_rules}")
print(f"Total rules {terminal_rules+non_terminal_rules}")

Terminal rules = 7
Non-terminal rules = 6
Total rules 13


In [8]:
# count number of rules
terminal_rules = count_vars(expr)
non_terminal_rules = sympy.count_ops(expr)
print(f"Terminal rules = {terminal_rules}")
print(f"Non-terminal rules = {non_terminal_rules}")
print(f"Total rules {terminal_rules+non_terminal_rules}")

Terminal rules = 11
Non-terminal rules = 10
Total rules 21


In [9]:
# Define state symbols
s, i, r = sympy.symbols("s i r")
state_vars = {s, i, r}

expr = sympy.simplify("C0*r*s**3 + C1 + C4*s - i**2")

# count number of rules
print("Original expression")
terminal_rules = count_vars(expr)
non_terminal_rules = sympy.count_ops(expr)
print(f"Terminal rules = {terminal_rules}")
print(f"Non-terminal rules = {non_terminal_rules}")
print(f"Total rules {terminal_rules+non_terminal_rules}")

Original expression
Terminal rules = 7
Non-terminal rules = 8
Total rules 15


In [10]:
expr_mul = expr.replace(
    lambda x: x.is_Pow and x.exp > 0,
    lambda x: sympy.Mul(*[x.base] * x.exp, evaluate=False),
)

In [11]:
# Define state symbols
s, i, r = sympy.symbols("s i r")
state_vars = {s, i, r}

expr = sympy.simplify("C0*r*s**3 + C1 + C4*s - i**2")
# C0*r*s*s*s + C1 + C4*s - i*i

terms_list = sympy.Add.make_args(expr)

num_operations = 0  # number of non-terminal rules
num_variable_uses = 0  # number of terminal rules

# Count addition operations (number of terms - 1)
if len(terms_list) > 1:
    num_operations += len(terms_list) - 1

for term in terms_list:
    # Isolate constants and symbols in term
    constant_part, state_part = term.as_independent(*state_vars)

    # Count variable use of constants
    if constant_part.free_symbols:
        num_variable_uses += 1

    # Check multiplication operation between constants and states
    if constant_part.free_symbols and state_part.free_symbols:
        num_operations += 1

    # Deal with states which may be composed of various factors
    if state_part.free_symbols:
        # Count multiplication operation between states
        if state_part.is_Mul:
            num_operations += len(state_part.args) - 1

        # Separate the state into separate factors
        state_factors = state_part.args if state_part.is_Mul else [state_part]
        for state_factor in state_factors:
            # Handle power terms
            if state_factor.is_Pow and state_factor.exp > 0:
                num_variable_uses += state_factor.exp
                num_operations += state_factor.exp - 1

            # Count variable use of state factor (non-power term)
            elif state_factor.free_symbols:
                num_variable_uses += 1

print("\nnum_operations", num_operations)
print("num_variable_uses", num_variable_uses)
print("total rules", num_operations + num_variable_uses)


num_operations 9
num_variable_uses 10
total rules 19


In [12]:
def constant_folding(expr, state_vars):
    # Expand expression (multiply out)
    expanded_exp = sympy.expand(expr, state_vars)

    # Define a list to store the folded-constant terms and a dictionary for constant mapping
    folded_terms = []
    constant_map = {}
    new_constant_index = 0

    # Separate out terms at +
    terms_list = sympy.Add.make_args(expanded_exp)

    # --- Perform multiplicative constant folding ---
    for term in terms_list:
        # Isolate constants in term and symbols in term
        constant_part, state_part = term.as_independent(*state_vars)

        # If the constant part is a number, keep it as is.
        if not constant_part.free_symbols:
            updated_constant_part = constant_part

        # Fold the constants if they are symbolic
        elif constant_part in constant_map:
            # Use the existing mapped constant
            updated_constant_part = constant_map[constant_part]
        else:
            # Create a new symbol for the folded constant
            updated_constant_part = sympy.Symbol(f"K{new_constant_index}")

            # Store the constant mapping
            constant_map[constant_part] = updated_constant_part
            new_constant_index += 1

        # Put the folded constant and the state part back together
        folded_term = updated_constant_part * state_part

        # Add the new folded term to the list
        folded_terms.append(folded_term)

    # Put all the folded terms back together
    folded_exp = sum(folded_terms)

    # --- Perform additive constant folding ---
    coeff_accumulator = {}  # Key: state_part, Value: sum of K coefficients

    # Separate out terms at +
    terms_list = sympy.Add.make_args(folded_exp)

    for term in terms_list:
        # Split coefficient from state
        coeff, state_part = term.as_independent(*state_vars)

        # Add to dictionary
        if state_part in coeff_accumulator:
            coeff_accumulator[state_part] += coeff
        else:
            coeff_accumulator[state_part] = coeff

    # Storage for folded terms
    final_folded_terms = []

    for state_part, coefficient in coeff_accumulator.items():

        # If the combined coefficient is purely a number, don't fold it.
        if not coefficient.free_symbols:
            updated_coefficient = coefficient

        # Check if coefficient is already a pure K symbol
        elif coefficient in constant_map.values():
            updated_coefficient = coefficient

        # Create a new K-symbol if new coeff combination
        else:
            while sympy.Symbol(f"K{new_constant_index}") in constant_map.values():
                new_constant_index += 1

            new_K_symbol = sympy.Symbol(f"K{new_constant_index}")

            # Map the sum to the new symbol for tracking
            constant_map[coefficient] = new_K_symbol

            updated_coefficient = new_K_symbol

        # Recombine the fully folded coefficient with the state part
        final_folded_term = updated_coefficient * state_part
        final_folded_terms.append(final_folded_term)

    # Put the final folded terms together
    final_folded_exp = sum(final_folded_terms)

    print(f"Original expression:  {expr}")
    print(f"Final expression:     {final_folded_exp}")

    return final_folded_exp


def rules_count(expr, state_vars) -> int:
    """Count the total number of production rules used to build an expression."""
    terms_list = sympy.Add.make_args(expr)

    num_operations = 0  # number of non-terminal rules
    num_variable_uses = 0  # number of terminal rules

    # Only count unique constants once as a variable use
    constants_seen = set()

    # Count addition operations (number of terms - 1)
    if len(terms_list) > 1:
        num_operations += len(terms_list) - 1

    for term in terms_list:
        # Isolate constants and symbols in term
        constant_part, state_part = term.as_independent(*state_vars)

        # Constant multiplied by states
        coeff = 1  # No repetition (i.e. 3s = s+s+s)
        if constant_part.is_Integer and state_part.free_symbols:
            coeff = abs(int(constant_part))
            if coeff > 1:
                num_operations += coeff - 1

        # Count variable use of constants
        if constant_part.free_symbols and (constant_part not in constants_seen):
            # Add constant to seen set
            constants_seen.add(constant_part)
            num_variable_uses += 1

        # Check multiplication operation between constants and states
        if constant_part.free_symbols and state_part.free_symbols:
            num_operations += 1

        # Deal with states which may be composed of various factors
        if state_part.free_symbols:
            # Count multiplication operation between states
            if state_part.is_Mul:
                # if coeff > 1: 3*i*s = i*s + i*s + i*s
                num_operations += (len(state_part.args) - 1) * coeff

            # Separate the state into separate factors
            state_factors = state_part.args if state_part.is_Mul else [state_part]
            for state_factor in state_factors:
                # Handle power terms
                if state_factor.is_Pow and state_factor.exp > 0:
                    # Apply coeff multiplier to both variable uses and operations
                    num_variable_uses += state_factor.exp * coeff
                    num_operations += (state_factor.exp - 1) * coeff

                # Count variable use of state factor (non-power term)
                elif state_factor.free_symbols:
                    num_variable_uses += 1 * coeff

    print("\nnum_operations", num_operations)
    print("num_variable_uses", num_variable_uses)
    print("total rules", num_operations + num_variable_uses, "\n")

    return num_operations + num_variable_uses

In [None]:
def constant_folding(expr, state_vars, start_index=0):
    """
    Simplifies an expression by folding constants.

    Returns: (folded_expression, next_constant_index)
    """
    # Expand expression (multiply out)
    expanded_exp = sympy.expand(expr, state_vars)

    # Define a list to store the folded-constant terms and a dictionary for constant mapping
    folded_terms = []
    constant_map = {}
    new_constant_index = start_index

    # Separate out terms at +
    terms_list = sympy.Add.make_args(expanded_exp)

    # --- Perform multiplicative constant folding ---
    for term in terms_list:
        # Isolate constants in term and symbols in term
        constant_part, state_part = term.as_independent(*state_vars)

        # If the constant part is a number, keep it as is.
        if not constant_part.free_symbols:
            updated_constant_part = constant_part

        # Fold the constants if they are symbolic
        elif constant_part in constant_map:
            # Use the existing mapped constant
            updated_constant_part = constant_map[constant_part]
        else:
            # Create a new symbol for the folded constant
            updated_constant_part = sympy.Symbol(f"K{new_constant_index}")

            # Store the constant mapping
            constant_map[constant_part] = updated_constant_part
            new_constant_index += 1

        # Put the folded constant and the state part back together
        folded_term = updated_constant_part * state_part

        # Add the new folded term to the list
        folded_terms.append(folded_term)

    # Put all the folded terms back together
    folded_exp = sum(folded_terms)

    # --- Perform additive constant folding ---
    coeff_accumulator = {}  # Key: state_part, Value: sum of K coefficients

    # Separate out terms at +
    terms_list = sympy.Add.make_args(folded_exp)

    for term in terms_list:
        # Split coefficient from state
        coeff, state_part = term.as_independent(*state_vars)

        # Add to dictionary
        if state_part in coeff_accumulator:
            coeff_accumulator[state_part] += coeff
        else:
            coeff_accumulator[state_part] = coeff

    # Storage for folded terms
    final_folded_terms = []

    for state_part, coefficient in coeff_accumulator.items():

        # If the combined coefficient is purely a number, don't fold it.
        if not coefficient.free_symbols:
            updated_coefficient = coefficient

        # Check if coefficient is already a pure K symbol
        elif coefficient in constant_map.values():
            updated_coefficient = coefficient

        # Create a new K-symbol if new coeff combination
        else:
            while sympy.Symbol(f"K{new_constant_index}") in constant_map.values():
                new_constant_index += 1

            new_K_symbol = sympy.Symbol(f"K{new_constant_index}")

            # Map the sum to the new symbol for tracking
            constant_map[coefficient] = new_K_symbol

            updated_coefficient = new_K_symbol

            # Increment const index
            new_constant_index += 1

        # Recombine the fully folded coefficient with the state part
        final_folded_term = updated_coefficient * state_part
        final_folded_terms.append(final_folded_term)

    # Sort the terms alphabetically by their state variable string (i, r, s...)
    final_folded_terms.sort(key=lambda x: str(x.as_independent(*state_vars)[1]))

    # Create the substitution map based on the sorted order
    rename_map = {}
    final_b_index = start_index

    for term in final_folded_terms:
        # Extract the coefficient (which is a K symbol or a number)
        coeff, state_part = term.as_independent(*state_vars)

        # Rename symbolic constants
        if coeff.free_symbols:
            # If we haven't renamed this K yet, give it the next B number
            if coeff not in rename_map:
                new_b = sympy.Symbol(f"B{final_b_index}")
                rename_map[coeff] = new_b
                final_b_index += 1

    # Construct the final expression
    # Summing the sorted list usually keeps order, but SymPy might shuffle visually.
    # However, B0 is now permanently attached to the alphabetical first term.
    temp_sum = sum(final_folded_terms)
    final_folded_exp_reindexed = temp_sum.subs(rename_map)

    # Calculate final outputs
    final_consts = list(rename_map.values())

    # Sort B-consts to look nice in the list [B0, B1, B2]
    final_consts.sort(key=lambda s: int(s.name[1:]))

    print(f"Original expression:  {expr}")
    print(f"Final expression:     {final_folded_exp_reindexed}")

    return final_folded_exp_reindexed, final_b_index, final_consts

In [20]:
# Define state symbols
s, i, r = sympy.symbols("s i r")
state_vars = {s, i, r}

examples = [
    "C0 + C1",
    "C0*s + C1*s",
    "C0*s + s",
    "C0*s*r + C1*(C2*r*s + C3)",
    "C0*r*s + C1*(C2*r*s + C3)+C4*s",
    "C0*s*i*C1(C2+C3)",
    "C1*r + i - r",
    "C1*r - 2*i - r",
    "C1*r - 2*i*s - r",
    "C1*r - 2*i*s**2 - r",
    "C1*r - 2*i*s**2 + 3*i - r",
    "C1*(C2 + i) - C3",
    "(C2 + C3)*(C4 + r)",
    "s*(C0 - C1 + i**2)",
    "C0*s**2 + C1*s**2 + 5*s**2",
    "C0*i**3 - C1*i**3 + i**3",
    "C0*s**2 + C1*s + C2*s**2 + s",
    "C0*s*i + C1*i*s",
    "C0*s*i*r - 2*s*i*r + C1*s*r*i",
    "s*i*(C0 + C1) + s*i*C2",
    "C0 * (s + C1 * (i - C2 * r))",
    "(C0*s + C1*i) * (C2*s + C3*i)",
    "C0 * (s + 1) * (i + 1)",
    "s + i + r",
    "2*i + C1*i - 5*i",
    "C0*i + C1*r + C2*s",
    "C0*r + C1*i + C2*s",
    "C0*s + C1*i + C2*r",
    "C0*s**2 + C1*s + C2*i + C3*r**2 + C4*r",
    "C0*s**2 + C1*r**2 + C2*i + C3*r + C4*s",
]

for example in examples:
    expr = sympy.sympify(example)
    expr = constant_folding(expr, state_vars)[0]
    rules = rules_count(expr, state_vars)

Original expression:  C0 + C1
Final expression:     B0

num_operations 0
num_variable_uses 1
total rules 1 

Original expression:  C0*s + C1*s
Final expression:     B0*s

num_operations 1
num_variable_uses 2
total rules 3 

Original expression:  C0*s + s
Final expression:     B0*s

num_operations 1
num_variable_uses 2
total rules 3 

Original expression:  C0*r*s + C1*(C2*r*s + C3)
Final expression:     B0 + B1*r*s

num_operations 3
num_variable_uses 4
total rules 7 

Original expression:  C0*r*s + C1*(C2*r*s + C3) + C4*s
Final expression:     B0 + B1*r*s + B2*s

num_operations 5
num_variable_uses 6
total rules 11 

Original expression:  C0*i*s*C1(C2 + C3)
Final expression:     B0*i*s

num_operations 2
num_variable_uses 3
total rules 5 

Original expression:  C1*r + i - r
Final expression:     B0*r + i

num_operations 2
num_variable_uses 3
total rules 5 

Original expression:  C1*r - 2*i - r
Final expression:     B0*r - 2*i

num_operations 3
num_variable_uses 4
total rules 7 

Original 