## Data generation

- Parameters
- Prefix list (const placeholder)
- Prefix list (with consts)
- DE.jl expr tree; Infix str (for debug)?

Julia structuring: 
- No OOP
- Throw all functions into a module?
-> create functions here for now
- Try to use off-the-shelf DE.jl functions

In [3]:
using DynamicExpressions: OperatorEnum

In [4]:
ops = OperatorEnum((+, -, *), (sin, exp))
nuna = length(ops.unaops)
nbin = length(ops.binops)
ops
all_ops = (ops.unaops..., ops.binops...)

(sin, exp, +, -, *)

In [5]:
ops.binops[1]

+ (generic function with 239 methods)

In [6]:
nl = 1
p1 = 1
p2 = 1

ops = OperatorEnum((+, -, *), (sin, exp))

function _generate_ubi_dist(max_ops::Int)
    """
    `max_ops`: maximum number of operators
    Enumerate the number of possible unary-binary trees that can be generated from empty nodes.
    D[e][n] represents the number of different binary trees with n nodes that
    can be generated from e empty nodes, using the following recursion:
        D(0, n) = 0
        D(e, 0) = L ** e
        D(e, n) = L * D(e - 1, n) + p_1 * D(e, n - 1) + p_2 * D(e + 1, n - 1)
    """

    # enumerate possible trees
    # first generate the transposed version of D, then transpose it
    D = Vector{Vector{Int}}()
    push!(D, [0; [nl^i for i in 1:(2*max_ops+1)]])
    
    for n in 1:(2*max_ops)  # number of operators
        s = [0]
        for e in 1:(2*max_ops-n+1)  # number of empty nodes
            push!(s, 
                nl * s[e] +
                p1 * D[n][e+1] +
                p2 * D[n][e+2]
            )
        end
        push!(D, s)
    end
    
    @assert all(length(D[i]) >= length(D[i+1]) for i in 1:length(D)-1)
    
    D = [
        [D[j][i] for j in 1:length(D) if i <= length(D[j])]
        for i in 1:maximum(length(x) for x in D)
    ]
    
    return D
end

ubi_dist = _generate_ubi_dist(5)

12-element Vector{Vector{Int64}}:
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 [1, 2, 6, 22, 90, 394, 1806, 8558, 41586, 206098, 1037718]
 [1, 4, 16, 68, 304, 1412, 6752, 33028, 164512, 831620]
 [1, 6, 30, 146, 714, 3534, 17718, 89898, 461010]
 [1, 8, 48, 264, 1408, 7432, 39152, 206600]
 [1, 10, 70, 430, 2490, 14002, 77550]
 [1, 12, 96, 652, 4080, 24396]
 [1, 14, 126, 938, 6314]
 [1, 16, 160, 1296]
 [1, 18, 198]
 [1, 20]
 [1]

In [8]:
params = ...

self.ubi_dist = _generate_ubi_dist(params.max_ops)

Base.Meta.ParseError: ParseError:
# Error @ /Users/luis/Desktop/Cranmer 2024/Workplace/smallMutations/similar-expressions/src/datagen_jl/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W5sZmlsZQ==.jl:1:10
params = ...
#        └─┘ ── invalid identifier

In [7]:
using Random
using Distributions

function _sample_next_pos_ubi(nb_empty::Int, nb_ops::Int, rng = Random.MersenneTwister(123))
    """
    Sample the position of the next node (unary-binary case).
    Sample a position in {0, ..., `nb_empty` - 1}, along with an arity.
    """
    @assert nb_empty > 0
    @assert nb_ops > 0
    probs = Float64[]
    for i in 0:(nb_empty-1)
        push!(probs, 
            (nl ^ i) * p1 * ubi_dist[nb_empty - i + 1][nb_ops]
        )
    end
    for i in 0:(nb_empty-1)
        push!(probs, 
            (nl ^ i) * p2 * ubi_dist[nb_empty - i + 2][nb_ops]
        )
    end
    probs ./= ubi_dist[nb_empty + 1][nb_ops + 1]
    e = rand(rng, Categorical(probs)) - 1
    arity = e < nb_empty ? 1 : 2
    e = e % nb_empty
    return e, arity
end

_sample_next_pos_ubi(5, 2)

(2, 2)

In [8]:
struct OperatorProbEnum
    binops_probs::Vector{Float64}
    unaops_probs::Vector{Float64}
end

op_probs = OperatorProbEnum(
    [0.5, 0.5],
    [0.5, 0.5]
)

OperatorProbEnum([0.5, 0.5], [0.5, 0.5])

In [9]:
using DynamicExpressions:
    AbstractExpressionNode,
    AbstractExpression,
    AbstractNode,
    NodeSampler,
    Node,
    get_contents,
    with_contents,
    constructorof,
    copy_node,
    set_node!,
    count_nodes,
    has_constants,
    has_operators

using Random: default_rng, AbstractRNG

const DATA_TYPE = Number

function make_random_leaf(
    nfeatures::Int,
    ::Type{T},
    ::Type{N},
    rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE,N<:AbstractExpressionNode}
"""From SymbolicRegression.jl"""
    if rand(rng, Bool)  # TODO: Add probs
        return constructorof(N)(; val=randn(rng, T))
    else
        return constructorof(N)(T; feature=rand(rng, 1:nfeatures))
    end
end


nfeatures = 1
T = Float64
N = Node{T}
l=make_random_leaf(nfeatures, T, N, default_rng())
l

-1.2164631331853586

In [102]:
Node{Float64}(feature=999)

x999

In [82]:
Node{Float64}(val=1.0)

1.0

In [87]:
l

-1.2164631331853586

In [95]:
op_node = Node{Float64}(; op=1, l=Node{Float64}(feature=1), r=nothing)

op_node

unary_operator[1](x1)

In [None]:
# Use empty nodes as placeholders for leaves
# Dont know total number of nodes beforehand (depends on unary-binary choices)

function _generate_tree(nb_total_ops::Int, rng::AbstractRNG=default_rng())::QuoteNode
    nb_empty = 1  # number of empty nodes
    l_leaves = 0  # left leaves - nothing states reserved for leaves
    t_leaves = 1  # total number of leaves (just used for sanity check)

    stack = [nothing]

    for nb_ops in nb_total_ops:-1:1
        skipped, arity = _sample_next_pos_ubi(nb_empty, nb_ops)

        if arity == 1
            op_index = StatsBase.sample(1:nuna, StatsBase.Weights(op_probs.unaops_probs))
            r = nothing
        else
            op_index = StatsBase.sample(1:nbin, StatsBase.Weights(op_probs.binops_probs))
            r = Node{Float64}(feature=999)  # placeholder
        end
        l = Node{Float64}(feature=999)  # placeholder
        op_node = Node{Float64}(; op=op_index, l=l, r=r)
    end
end

In [98]:
using StatsBase

"""
Create a tree with exactly `nb_total_ops` operators.

The tree is a list of operators and leaves in prefix notation (each one a Node).

TODO: Add requires_x flag

Steps:
1. Create tree stack (prefix notation, with nothing as leaves)
2. Create leaves list (same size as number of empty leaves)
3. Insert leaves into tree
"""
function _generate_expr2(nb_total_ops::Int, rng::AbstractRNG=default_rng())::Vector{String}
    stack = [nothing]
    nb_empty = 1  # number of empty nodes
    l_leaves = 0  # left leaves - nothing states reserved for leaves
    t_leaves = 1  # total number of leaves (just used for sanity check)

    # 1. Create tree
    for nb_ops in nb_total_ops:-1:1
        skipped, arity = _sample_next_pos_ubi(nb_empty, nb_ops)
        
        if arity == 1
            op_index = StatsBase.sample(1:nuna, StatsBase.Weights(op_probs.unaops_probs))
            r = nothing
        else
            op_index = StatsBase.sample(1:nbin, StatsBase.Weights(op_probs.binops_probs))
            r = Node{Float64}(val=1.0)
        end
        op_node = Node{Float64}(; op=1, l=l, r=r)

        nb_empty += (arity - 1 - skipped)
        t_leaves += arity - 1
        l_leaves += skipped
        
        # update tree
        pos = findall(x -> x === nothing, stack)[l_leaves+1]
        stack = [
            stack[1:pos-1]...,
            op_node,
            [nothing for _ in 1:arity]...,
            stack[pos+1:end]...
        ]
    end
    println(stack)

    # sanity check

    @assert count(x -> x in all_ops, stack) == nb_total_ops
    @assert count(x -> x === nothing, stack) == t_leaves

    # 2. Create leaves
    nfeatures = 1
    T = Float64
    N = Node{T}
    # l=make_random_leaf(nfeatures, T, N, default_rng())

    leaves = [make_random_leaf(nfeatures, T, N, default_rng()) for _ in 1:t_leaves]
    println(leaves)
    println(leaves)

    # 3. Insert leaves into tree
    for pos in length(stack):-1:1
        if stack[pos] === nothing
            stack = [stack[1:pos-1]..., pop!(leaves), stack[pos+1:end]...]
        end
    end
    @assert isempty(leaves)
    return stack
end

stack = _generate_expr(5, )
stack

Union{Nothing, Node{Float64}}[binary_operator[1](-1.2164631331853586, 1.0), binary_operator[1](-1.2164631331853586, 1.0), nothing, binary_operator[1](-1.2164631331853586, 1.0), nothing, binary_operator[1](-1.2164631331853586, 1.0), nothing, binary_operator[1](-1.2164631331853586, 1.0), nothing, nothing, nothing]


AssertionError: AssertionError: count((x->begin
                #= /Users/luis/Desktop/Cranmer 2024/Workplace/smallMutations/similar-expressions/src/datagen_jl/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:50 =#
                x in all_ops
            end), stack) == nb_total_ops

In [53]:
stack[1].degree

0x03

In [52]:
stack[1].degree == 3

true

Conversions:
- Prefix list <> DE.jl tree
- Prefix list: 


        if node.degree == 1
            node.op = rand(rng, 1:(options.nuna))
        else
            node.op = rand(rng, 1:(options.nbin))
        end

        function sample_mutation(w::MutationWeights)
            weights = convert(Vector, w)
            return StatsBase.sample(v_mutations, StatsBase.Weights(weights))
        end