In [1]:
from deap import gp
import numpy as np
import os
import sys
sys.path.append(os.path.abspath('../..'))  # or the full path to the "project" directory. This hack should be really fixed
from gpbr.gp.funcs import pow2, sqrtabs, expplusone


pset = gp.PrimitiveSet("main", 1)
pset.addPrimitive(np.add, 2)
# pset.addPrimitive(np.subtract, 2)
pset.addPrimitive(np.multiply, 2)
pset.addPrimitive(np.cos, 1)
pset.addPrimitive(np.sin, 1)
pset.addPrimitive(sqrtabs, 1)
# pset.addPrimitive(pow2, 1)
pset.addPrimitive(expplusone, 1)
pset.addEphemeralConstant('rand', (np.random.rand, 1)[0])
# pset.addTerminal(np.pi, 'pi')

pset.renameArguments(ARG0="s")

In [2]:
from deap import creator, base, tools, algorithms
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMin)

In [3]:
toolbox = base.Toolbox()
toolbox.register('expr', gp.genHalfAndHalf, pset=pset, min_=3, max_=6)
toolbox.register('individual', tools.initIterate, creator.Individual, toolbox.expr)
toolbox.register('population', tools.initRepeat, list, toolbox.individual)
toolbox.register('compile', gp.compile, pset=pset)

In [95]:
ind = toolbox.individual()
print(ind)

expplusone(add(cos(sin(add(0.5041008352677898, s))), sin(multiply(sin(0.9508309438650218), expplusone(s)))))


In [59]:
test_set = np.linspace(0, 2*np.pi, 10)
test_set

array([0.        , 0.6981317 , 1.3962634 , 2.0943951 , 2.7925268 ,
       3.4906585 , 4.1887902 , 4.88692191, 5.58505361, 6.28318531])

In [60]:
ind_compiled =toolbox.compile(expr=ind)
print(ind_compiled(0.5))  # example of use

0.9878138859066259


In [61]:
ind_compiled(test_set)

array([0.99858852, 0.59359999, 0.55574055, 0.94666834, 0.61315501,
       0.56489499, 0.9129983 , 0.97851449, 0.69337964, 0.99858852])

In [88]:
import deap.gp as gp

def evaluate_subtrees(tree, pset, **kwargs):
    """
    Evaluate the GP tree and compute results for every subtree using a stack-based approach,
    inspired by the __str__ method in PrimitiveTree.
    
    :param tree: deap.gp.PrimitiveTree, the tree to evaluate.
    :param pset: deap.gp.PrimitiveSet, the primitive set used to define functions and terminals.
    :param args: Variable arguments representing the inputs for ARG0, ARG1, etc.
    :return: A tuple (root_value, subtree_values), where root_value is the result of the entire tree,
             and subtree_values is a list where subtree_values[i] is the result of the subtree rooted at tree[i].
    """
    subtree_values = [None] * len(tree)
    stack = []
    
    for i in range(len(tree)):
        stack.append((i, []))
        
        while stack and len(stack[-1][1]) == tree[stack[-1][0]].arity:
            idx, child_vals = stack.pop()
            node = tree[idx]
            
            if node.arity == 0:  # Terminal
                if node.name.startswith('ARG'):
                    val = kwargs[node.value]
                else:
                    val = node.value
            else:  # Primitive
                func = pset.context[node.name]
                val = func(*child_vals)
            
            subtree_values[idx] = val
            
            if stack:
                stack[-1][1].append(val)
    
    root_value = subtree_values[0]
    return root_value, subtree_values

In [89]:
val, subtree_vals =evaluate_subtrees(ind, pset, s=test_set)

In [90]:
val

array([0.99858852, 0.59359999, 0.55574055, 0.94666834, 0.61315501,
       0.56489499, 0.9129983 , 0.97851449, 0.69337964, 0.99858852])

In [91]:
subtree_vals

[array([0.99858852, 0.59359999, 0.55574055, 0.94666834, 0.61315501,
        0.56489499, 0.9129983 , 0.97851449, 0.69337964, 0.99858852]),
 array([ 0.05313782, -0.93527145,  0.98154286,  0.32806269,  0.91074802,
        -0.97049034, -0.42022223, -0.20766746, -0.80462758,  0.05313782]),
 array([ 4.76555184,  9.06300689, 12.75879828, 11.32982632,  6.70888981,
         3.38513301,  2.00448654,  1.77998616,  2.50584432,  4.76555184]),
 array([ 0.56141334,  1.20420095,  1.54622109,  1.42743875,  0.90343348,
         0.2193932 , -0.30461206, -0.42339441, -0.08137427,  0.56141334]),
 array([ 0.00000000e+00,  6.42787610e-01,  9.84807753e-01,  8.66025404e-01,
         3.42020143e-01, -3.42020143e-01, -8.66025404e-01, -9.84807753e-01,
        -6.42787610e-01, -2.44929360e-16]),
 array([0.        , 0.6981317 , 1.3962634 , 2.0943951 , 2.7925268 ,
        3.4906585 , 4.1887902 , 4.88692191, 5.58505361, 6.28318531]),
 0.5614133415411786,
 0.5960927048090293]

In [50]:
for el in ind:
    print(el.name)

sin
sqrtabs
expplusone
add
ARG0
ARG0


In [66]:
def is_2pi_periodic(values, tolerance=1e-5):
    if isinstance(values, (np.integer, int, float)):
        return True
    return np.allclose(values[0], values[-1], atol=tolerance)

In [70]:
for el in subtree_vals:
    print(is_2pi_periodic(el))

True
True
True
True
True
False
True
True


In [85]:
str(ind)

'cos(cos(expplusone(add(sin(s), sin(0.5960927048090293)))))'

In [87]:
[str(gp.PrimitiveTree(ind[ind.searchSubtree(i)])) for i in range(len(ind))]

['cos(cos(expplusone(add(sin(s), sin(0.5960927048090293)))))',
 'cos(expplusone(add(sin(s), sin(0.5960927048090293))))',
 'expplusone(add(sin(s), sin(0.5960927048090293)))',
 'add(sin(s), sin(0.5960927048090293))',
 'sin(s)',
 's',
 'sin(0.5960927048090293)',
 '0.5960927048090293']

In [79]:
ind.searchSubtree(2)
ind[ind.searchSubtree(4)]

[<deap.gp.Primitive at 0x1bd57264180>, <deap.gp.Terminal at 0x1bd57253240>]

In [None]:
# Test: compare DEAP compiled subtree outputs with custom evaluate_subtrees
import numpy as np
import deap.gp as gp
import traceback

# import project helpers
from gpbr.gp.funcs import sqrtabs, expplusone
from gpbr.gp.evaluators import evaluate_subtrees

# Build primitive set similar to the notebook
pset = gp.PrimitiveSet("main", 1)
pset.addPrimitive(np.add, 2)
pset.addPrimitive(np.multiply, 2)
pset.addPrimitive(np.cos, 1)
pset.addPrimitive(np.sin, 1)
pset.addPrimitive(sqrtabs, 1)
pset.addPrimitive(expplusone, 1)
pset.addEphemeralConstant('rand', lambda: np.random.rand())
pset.renameArguments(ARG0='s')

# Test inputs
test_set = np.linspace(0, 2*np.pi, 10)

# Example expressions (use argument name 's' because pset was renamed)
examples = [
    "add(sin(s), 1.0)",
    "multiply(s, cos(s))",
    "expplusone(s)",
    "add(multiply(s, s), 0.5)",
    "sqrtabs(s)",
    "add(add(sin(s), cos(s)), multiply(s, 0.1))",
]

print(f"Running {len(examples)} examples; test_set shape: {test_set.shape}")

for expr in examples:
    print('\nExpression:', expr)
    try:
        tree = gp.PrimitiveTree.from_string(expr, pset)
    except Exception as e:
        print('  Failed to parse into PrimitiveTree:', e)
        traceback.print_exc()
        continue

    # Evaluate all subtrees using the custom evaluator
    try:
        root_val, subtree_vals = evaluate_subtrees(tree, pset, s=test_set)
    except Exception as e:
        print('  evaluate_subtrees failed:', e)
        traceback.print_exc()
        continue

    all_ok = True
    for i in range(len(tree)):
        node = tree[i]
        sl = tree.searchSubtree(i)
        subtree = gp.PrimitiveTree(tree[sl])

        # Compile subtree with DEAP
        try:
            compiled_fn = gp.compile(subtree, pset)
        except Exception as e:
            print(f"  Failed to compile subtree {i} ({node.name}):", e)
            traceback.print_exc()
            all_ok = False
            continue

        # Evaluate compiled subtree
        try:
            compiled_val = compiled_fn(test_set)
        except Exception as e:
            # Some compiled functions might expect scalar input — try elementwise
            try:
                compiled_val = np.array([compiled_fn(x) for x in test_set])
            except Exception as e2:
                print(f"  Failed to run compiled subtree {i} ({node.name}):", e)
                traceback.print_exc()
                all_ok = False
                continue

        eval_val = subtree_vals[i]

        # Normalize to numpy arrays
        a = np.asarray(compiled_val)
        b = np.asarray(eval_val)

        # Compare shapes and numeric closeness
        shapes_match = a.shape == b.shape
        nums_match = False
        if shapes_match:
            try:
                nums_match = np.allclose(a, b, equal_nan=True)
            except Exception:
                nums_match = False

        if not (shapes_match and nums_match):
            all_ok = False
            print(f"  MISMATCH at node {i} ({node.name}): compiled shape {a.shape}, evaluator shape {b.shape}")
            # show short samples
            def sample(arr):
                try:
                    arr = np.asarray(arr)
                    if arr.size > 5:
                        return arr.flatten()[:5]
                    return arr
                except Exception:
                    return str(arr)
            print('    compiled sample =', sample(a))
            print('    evaluator sample=', sample(b))

    print('  OK' if all_ok else '  FAIL')

print('\nDone.')