In [1]:
using LinearAlgebra
using JuMP
using HiGHS
using ProgressMeter
import Base: show

# Region

In [3]:
mutable struct Region
    # Region attributes
    qlw::Vector{Int}                # Activation pattern
    q_tilde::Vector{Int}            # Active bits indices
    bounded::Bool
    
    # Inequalities and projections
    Alw::Matrix{Float64}            # Slope projection matrix
    clw::Vector{Float64}            # Intercept projection matrix
    
    Dlw::Matrix{Float64}            # Slopes of inequalities
    glw::Vector{Float64}            # Intercept of inequalities 
    
    Dlw_active::Matrix{Float64}     # Active slopes
    glw_active::Vector{Float64}     # Active intercepts
    
    # Tree attributes
    parent::Union{Region, Nothing}
    children::Vector{Region}
    
    # Utility attributes
    layer_number::Int
    
    # -- Constructors --
    function Region(; input_dim::Int)
        this = new()
        this.qlw = Int[]
        this.q_tilde = Int[]
        this.bounded = true
        
        # Identity projection for root
        this.Alw = Matrix{Float64}(I, input_dim, input_dim)
        this.clw = zeros(Float64, input_dim)

        this.Dlw = Matrix{Float64}(undef, 0, input_dim)
        this.glw = Float64[]
        
        this.Dlw_active = this.Dlw
        this.glw_active = this.glw
        
        this.parent = nothing
        this.children = Region[]
        this.layer_number = 0
        return this
    end

    function Region(activation::Vector{Int})
        this = new()
        this.qlw = activation
        this.children = Region[]
        return this
    end
end

function add_child!(parent::Region, child::Region)
    child.parent = parent
    push!(parent.children, child)
end

function get_children(r::Region)
    return r.children
end

# Matches Python: get_path_inequalities(self)
function get_path_inequalities(r::Region)
    D_list = Matrix{Float64}[]
    g_list = Vector{Float64}[]
    
    node = r
    while node.parent !== nothing
        push!(D_list, node.Dlw_active)
        push!(g_list, node.glw_active)
        node = node.parent
    end
    
    if isempty(D_list)
        return Matrix{Float64}(undef, 0, 0), Float64[]
    end
    
    # Reverse to get Root -> Leaf order
    D_path = reduce(vcat, reverse(D_list))
    g_path = reduce(vcat, reverse(g_list))
    
    return D_path, g_path
end

function Base.show(io::IO, r::Region)
    dim = isdefined(r, :Alw) ? size(r.Alw, 2) : "N/A"
    print(io, "\nRegion (L$(r.layer_number)) | Act: $(r.qlw) | Children: $(length(r.children))")
end

# Tree

In [6]:
mutable struct Tree
    weights::Vector{Matrix{Float64}}
    biases::Vector{Vector{Float64}}
    input_dim::Int
    L::Int
    root::Region

    function Tree(state_dict::Dict{String, Any})
        weights, biases = find_hyperplanes(state_dict)
        input_dim = size(weights[1], 2)
        L = length(weights)
        root = Region(input_dim=input_dim)
        new(weights, biases, input_dim, L, root)
    end
end

function construct_tree!(tree::Tree; verbose::Bool=false)
    current_layer_nodes = [tree.root]
    
    for i in 1:tree.L
        Wl = tree.weights[i]
        bl = tree.biases[i]
        layer = i 
        
        # Thread-safe storage for next layer
        next_layer_nodes = Region[]
        nodes_lock = ReentrantLock()

        if verbose
            println("Layer $layer: Processing $(length(current_layer_nodes)) regions...")
        end
        
        # Parallel loop over regions
        # Threads.@threads for parent in current_layer_nodes
        for parent in current_layer_nodes
            
            # Solver Step (Expensive)
            new_nodes_info = find_next_layer_region_info(
                parent.Dlw_active, parent.glw_active, 
                parent.Alw, parent.clw, 
                Wl, bl, layer
            )
            
            # Create Julia structs (Cheap)
            local_children = Region[]
            for (act, info) in new_nodes_info
                child = Region(act)
                child.q_tilde = info["q_tilde"]
                child.bounded = info["bounded"]
                child.Dlw = info["Dlw"]
                child.glw = info["glw"]
                
                # Slicing active constraints
                child.Dlw_active = info["Dlw"][info["q_tilde"], :]
                child.glw_active = info["glw"][info["q_tilde"]]
                
                child.Alw = info["Alw"]
                child.clw = info["clw"]
                child.layer_number = layer
                
                add_child!(parent, child)
                push!(local_children, child)
            end
            
            # Write to shared array
            lock(nodes_lock) do
                append!(next_layer_nodes, local_children)
            end
        end
        
        current_layer_nodes = next_layer_nodes
        if isempty(current_layer_nodes)
            break
        end
    end
end

function get_regions_at_layer(tree::Tree, layer::Int)
    regions = Region[]
    queue = [tree.root]
    
    while !isempty(queue)
        current_region = popfirst!(queue)
        if current_region.layer_number == layer
            push!(regions, current_region)
        elseif current_region.layer_number < layer
            append!(queue, current_region.children)
        end
    end
    return regions
end

get_regions_at_layer (generic function with 1 method)

# Solvers