Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 74 additions & 14 deletions src/Probabilistic/ProbCircuits.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#####################
# Probabilistic circuits
#####################

abstract type ProbΔNode{O} <: DecoratorΔNode{O} end
abstract type ProbLeafNode{O} <: ProbΔNode{O} end
abstract type ProbInnerNode{O} <: ProbΔNode{O} end
Expand All @@ -10,15 +9,15 @@ mutable struct ProbLiteral{O} <: ProbLeafNode{O}
origin::O
data
bit::Bool
ProbLiteral{O}(o::O) where O = new{O}(o, nothing, false)
ProbLiteral(n) = new{node_type(n)}(n, nothing, false)
end

mutable struct Prob⋀{O} <: ProbInnerNode{O}
origin::O
children::Vector{<:ProbΔNode{<:O}}
data
bit::Bool
Prob⋀{O}(o::O,c) where O = new{O}(o,c, nothing, false)
Prob⋀(n, children) = new{node_type(n)}(n, children, nothing, false)
end

mutable struct Prob⋁{O} <: ProbInnerNode{O}
Expand All @@ -27,7 +26,7 @@ mutable struct Prob⋁{O} <: ProbInnerNode{O}
log_thetas::Vector{Float64}
data
bit::Bool
Prob⋁{O}(o::O,c,lt) where O = new{O}(o,c,lt, nothing, false)
Prob⋁(n, children) = new{node_type(n)}(n, children, some_vector(Float64, length(children)), nothing, false)
end

const ProbΔ{O} = AbstractVector{<:ProbΔNode{<:O}}
Expand All @@ -39,39 +38,47 @@ Base.eltype(::Type{ProbΔ{O}}) where {O} = ProbΔNode{<:O}
#####################

import LogicCircuits.GateType # make available for extension
import LogicCircuits.node_type

@inline GateType(::Type{<:ProbLiteral}) = LiteralGate()
@inline GateType(::Type{<:Prob⋀}) = ⋀Gate()
@inline GateType(::Type{<:Prob⋁}) = ⋁Gate()

@inline node_type(::ProbΔNode) = ProbΔNode

#####################
# constructors and conversions
#####################

# for some unknown reason, making the type parameter O be part of this outer constructer as `Prob⋁{O}` does not work. It gives `UndefVarError: O not defined`. Hence pass it as an argument...
function Prob⋁(::Type{O}, origin::O, children::Vector{<:ProbΔNode{<:O}}) where {O}
Prob⋁{O}(origin, children, some_vector(Float64, length(children)))
end
const ProbCache = Dict{ΔNode, ProbΔNode}

function ProbΔ2(circuit::Δ)::ProbΔ
node2dag(ProbΔ2(circuit[end]))
end

const ProbCache = Dict{ΔNode, ProbΔNode}
function ProbΔ2(circuit::ΔNode)::ProbΔNode
f_con(n) = error("Cannot construct a probabilistic circuit from constant leafs: first smooth and remove unsatisfiable branches.")
f_lit(n) = ProbLiteral(n)
f_a(n, cn) = Prob⋀(n, cn)
f_o(n, cn) = Prob⋁(n, cn)
foldup_aggregate(circuit, f_con, f_lit, f_a, f_o, ProbΔNode{node_type(circuit)})
end

function ProbΔ(circuit::Δ, cache::ProbCache = ProbCache())

O = grapheltype(circuit) # type of node in the origin
sizehint!(cache, length(circuit)*4÷3)

pc_node(::LiteralGate, n::ΔNode) = ProbLiteral{O}(n)
pc_node(::LiteralGate, n::ΔNode) = ProbLiteral(n)
pc_node(::ConstantGate, n::ΔNode) = error("Cannot construct a probabilistic circuit from constant leafs: first smooth and remove unsatisfiable branches.")

pc_node(::⋀Gate, n::ΔNode) = begin
children = map(c -> cache[c], n.children)
Prob⋀{O}(n, children)
Prob⋀(n, children)
end

pc_node(::⋁Gate, n::ΔNode) = begin
children = map(c -> cache[c], n.children)
Prob⋁(O, n, children)
Prob⋁(n, children)
end

map(circuit) do node
Expand Down Expand Up @@ -99,6 +106,59 @@ prob_origin(n::DecoratorΔNode)::ProbΔNode = origin(n, ProbΔNode)
"Return the first origin that is a probabilistic circuit"
prob_origin(c::DecoratorΔ)::ProbΔ = origin(c, ProbΔNode)

function estimate_parameters2(pc::ProbΔ, data::XData{Bool}; pseudocount::Float64)
Logical.pass_up_down2(pc, data)
w = (data isa PlainXData) ? nothing : weights(data)
estimate_parameters_cached2(pc, w; pseudocount=pseudocount)
end

function estimate_parameters_cached2(pc::ProbΔ, w; pseudocount::Float64)
flow(n) = Float64(sum(sum(n.data)))
children_flows(n) = sum.(map(c -> c.data[1] .& n.data[1], children(n)))

if issomething(w)
flow_w(n) = sum(Float64.(n.data[1]) .* w)
children_flows_w(n) = sum.(map(c -> Float64.(c.data[1] .& n.data[1]) .* w, children(n)))
flow = flow_w
children_flows = children_flows_w
end

estimate_parameters_node2(n::ProbΔNode) = ()
function estimate_parameters_node2(n::Prob⋁)
if num_children(n) == 1
n.log_thetas .= 0.0
else
smoothed_flow = flow(n) + pseudocount
uniform_pseudocount = pseudocount / num_children(n)
n.log_thetas .= log.((children_flows(n) .+ uniform_pseudocount) ./ smoothed_flow)
@assert isapprox(sum(exp.(n.log_thetas)), 1.0, atol=1e-6) "Parameters do not sum to one locally"
# normalize away any leftover error
n.log_thetas .- logsumexp(n.log_thetas)
end
end

foreach(estimate_parameters_node2, pc)
end

function log_likelihood_per_instance2(pc::ProbΔ, data::XData{Bool})
Logical.pass_up_down2(pc, data)
log_likelihood_per_instance_cached(pc, data)
end

function log_likelihood_per_instance_cached(pc::ProbΔ, data::XData{Bool})
log_likelihoods = zeros(num_examples(data))
indices = some_vector(Bool, num_examples(data))::BitVector
for n in pc
if n isa Prob⋁ && num_children(n) != 1 # other nodes have no effect on likelihood
foreach(n.children, n.log_thetas) do c, log_theta
indices = n.data[1] .& c.data[1]
view(log_likelihoods, indices::BitVector) .+= log_theta # see MixedProductKernelBenchmark.jl
end
end
end
log_likelihoods
end

function estimate_parameters(pc::ProbΔ, data::XBatches{Bool}; pseudocount::Float64)
estimate_parameters(AggregateFlowΔ(pc, aggr_weight_type(data)), data; pseudocount=pseudocount)
end
Expand Down Expand Up @@ -421,4 +481,4 @@ end
# for child in node.children
# mpe_simulate(child, inst)
# end
# end
# end
2 changes: 1 addition & 1 deletion src/Probabilistic/Probabilistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ initial_mixture_model, estimate_parameters_from_aggregates, compute_ensemble_log
expectation_step, maximization_step, expectation_step_batch, train_mixture_with_structure, check_parameter_integrity,
ll_per_instance_per_component, ll_per_instance_for_ensemble,estimate_parameters_cached,
sample,
MPE, MAP,
MPE, MAP,prob_origin,

# ProbFlowCircuits
marginal_pass_up, marginal_pass_down, marginal_pass_up_down,
Expand Down
16 changes: 8 additions & 8 deletions src/StructureLearner/CircuitBuilder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ function compile_prob_circuit_from_clt(clt::CLT)::ProbΔ
prob_cache = ProbCache()
parent = parent_vector(clt)

prob_children(n)::Vector{<:ProbΔNode{<:LogicalΔNode}} =
copy_with_eltype(map(c -> prob_cache[c], n.children), ProbΔNode{<:LogicalΔNode})
prob_children(n)::Vector{<:ProbΔNode{<:node_type(n)}} =
copy_with_eltype(map(c -> prob_cache[c], n.children), ProbΔNode{<:node_type(n)})

"default order of circuit node, from left to right: +/1 -/0"

Expand All @@ -39,8 +39,8 @@ function compile_prob_circuit_from_clt(clt::CLT)::ProbΔ
neg = LiteralNode(-var2lit(ln))
node_cache[var2lit(ln)] = pos
node_cache[-var2lit(ln)] = neg
pos2 = ProbLiteral{LiteralNode}(pos)
neg2 = ProbLiteral{LiteralNode}(neg)
pos2 = ProbLiteral(pos)
neg2 = ProbLiteral(neg)
push!(lin, pos2)
push!(lin, neg2)
prob_cache[pos] = pos2
Expand All @@ -56,7 +56,7 @@ function compile_prob_circuit_from_clt(clt::CLT)::ProbΔ
#build logical ciruits
temp = ⋁Node([node_cache[lit] for lit in [var2lit(c), - var2lit(c)]])
push!(logical_nodes, temp)
n = Prob⋁(LogicalΔNode,temp, prob_children(temp))
n = Prob⋁(temp, prob_children(temp))
prob_cache[temp] = n
n.log_thetas = zeros(Float64, 2)
cpt = get_prop(clt, c, :cpt)
Expand All @@ -73,7 +73,7 @@ function compile_prob_circuit_from_clt(clt::CLT)::ProbΔ
leaf = node_cache[indicator]
temp = ⋀Node(vcat([leaf], children))
node_cache[indicator] = temp
n = Prob⋀{LogicalΔNode}(temp, prob_children(temp))
n = Prob⋀(temp, prob_children(temp))
prob_cache[temp] = n
push!(lin, n)
end
Expand All @@ -90,7 +90,7 @@ function compile_prob_circuit_from_clt(clt::CLT)::ProbΔ
"compile root, add another disjunction node"
function compile_root(root::Var)
temp = ⋁Node([node_cache[s] for s in [var2lit(root), -var2lit(root)]])
n = Prob⋁(LogicalΔNode, temp, prob_children(temp))
n = Prob⋁(temp, prob_children(temp))
prob_cache[temp] = n
n.log_thetas = zeros(Float64, 2)
cpt = get_prop(clt, root, :cpt)
Expand All @@ -102,7 +102,7 @@ function compile_prob_circuit_from_clt(clt::CLT)::ProbΔ

function compile_independent_roots(roots::Vector{ProbΔNode})
temp = ⋀Node([c.origin for c in roots])
n = Prob⋀{LogicalΔNode}(temp, prob_children(temp))
n = Prob⋀(temp, prob_children(temp))
prob_cache[temp] = n
push!(lin, n)
temp = ⋁Node([temp])
Expand Down
12 changes: 6 additions & 6 deletions src/StructureLearner/PSDDInitializer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,16 @@ end
#####################

prob_children(n, prob_cache) =
copy_with_eltype(map(c -> prob_cache[c], n.children), ProbΔNode{<:StructLogicalΔNode{PlainVtreeNode}})
copy_with_eltype(map(c -> prob_cache[c], n.children), ProbΔNode{<:StructLogicalΔNode})

"Add leaf nodes to circuit `lin`"
function add_prob_leaf_node(var::Var, vtree::PlainVtreeLeafNode, lit_cache::LitCache, prob_cache::ProbCache, lin)::Tuple{ProbLiteral{<:StructLiteralNode{PlainVtreeNode}}, ProbLiteral{<:StructLiteralNode{PlainVtreeNode}}}
function add_prob_leaf_node(var::Var, vtree::PlainVtreeLeafNode, lit_cache::LitCache, prob_cache::ProbCache, lin)
pos = StructLiteralNode{PlainVtreeNode}( var2lit(var), vtree)
neg = StructLiteralNode{PlainVtreeNode}(-var2lit(var), vtree)
lit_cache[var2lit(var)] = pos
lit_cache[-var2lit(var)] = neg
pos2 = ProbLiteral{StructLiteralNode{PlainVtreeNode}}(pos)
neg2 = ProbLiteral{StructLiteralNode{PlainVtreeNode}}(neg)
pos2 = ProbLiteral(pos)
neg2 = ProbLiteral(neg)
prob_cache[pos] = pos2
prob_cache[neg] = neg2
push!(lin, pos2)
Expand All @@ -169,7 +169,7 @@ end
"Add prob⋀ node to circuit `lin`"
function add_prob⋀_node(children::ProbΔ, vtree::PlainVtreeInnerNode, prob_cache::ProbCache, lin)::Prob⋀
logic = Struct⋀Node{PlainVtreeNode}([c.origin for c in children], vtree)
prob = Prob⋀{StructLogicalΔNode{PlainVtreeNode}}(logic, prob_children(logic, prob_cache))
prob = Prob⋀(logic, prob_children(logic, prob_cache))
prob_cache[logic] = prob
push!(lin, prob)
return prob
Expand All @@ -178,7 +178,7 @@ end
"Add prob⋁ node to circuit `lin`"
function add_prob⋁_node(children::ProbΔ, vtree::PlainVtreeNode, thetas::Vector{Float64}, prob_cache::ProbCache, lin)::Prob⋁
logic = Struct⋁Node{PlainVtreeNode}([c.origin for c in children], vtree)
prob = Prob⋁(StructLogicalΔNode{PlainVtreeNode}, logic, prob_children(logic, prob_cache))
prob = Prob⋁(logic, prob_children(logic, prob_cache))
prob.log_thetas = log.(thetas)
prob_cache[logic] = prob
push!(lin, prob)
Expand Down
36 changes: 36 additions & 0 deletions test/Probabilistic/Benchmark.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using BenchmarkTools
using Test
using LogicCircuits
using ProbabilisticCircuits


function construct_prob_circuit()
circuit = load_smooth_logical_circuit(zoo_psdd_file("plants.psdd"));
@btime pc1 = ProbΔ(circuit);
@btime pc2 = Probabilistic.ProbΔ2(circuit[end]);
nothing
# 101.578 ms (736243 allocations: 46.16 MiB)
# 47.765 ms (369254 allocations: 26.97 MiB)
end

function estimate_parameters_bm()
data = train(dataset(twenty_datasets("plants")));
circuit = load_smooth_logical_circuit(zoo_psdd_file("plants.psdd"));

# construct circuits
@btime pc = Probabilistic.ProbΔ2(circuit);
@btime pc2 = ProbΔ(circuit);

# estimate_parameters
@btime Probabilistic.estimate_parameters2(pc, data; pseudocount=1.0);
@btime estimate_parameters(pc2, convert(XBatches, data); pseudocount=1.0);

# compute log likelihood
@btime lls = Probabilistic.log_likelihood_per_instance_cached(pc, data)
@btime lls2 = log_likelihood_per_instance(pc2, convert(XBatches, data))

# for (l, l2) in zip(lls, lls2)
# @test isapprox(l, l2, atol=1e-9)
# end
end

1 change: 1 addition & 0 deletions test/StructureLearner/PSDDInitializerTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using ProbabilisticCircuits
@test num_variables(vtree) == num_features(data)
@test check_parameter_integrity(pc)
@test num_parameters(pc) == 74

# test below has started to fail -- unclear whether that is a bug or randomness...?
# @test pc[28].log_thetas[1] ≈ -1.1870882896239272 atol=1.0e-7

Expand Down