Skip to content

Commit

Permalink
reverting adjacency matrix function and changing test
Browse files Browse the repository at this point in the history
  • Loading branch information
PavanChaggar committed Jan 31, 2022
1 parent 5205ad5 commit d51a91d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 49 deletions.
57 changes: 12 additions & 45 deletions src/simpleppl.jl
Expand Up @@ -82,7 +82,9 @@ end

function Model(;kwargs...)
Model(values(kwargs))
end
end
# add thing here to extract inputs from anon functions then change into different NamedTuple
# add docstring because it will behave differently

function Base.show(io::IO, m::Model)
print(io, "Nodes: \n")
Expand Down Expand Up @@ -118,29 +120,15 @@ function adjacency_matrix(inputs::NamedTuple{nodes}) where {nodes}
col_inds = NamedTuple{nodes}(ntuple(identity, N))
A = spzeros(Bool, N, N)
for (row, node) in enumerate(nodes)
v_inputs = inputs[node]
setinput!(A, row, col_inds, nodes, v_inputs)
end
return A
end

function setinput!(A::SparseMatrixCSC{Bool, Int64}, row, col_inds, nodes, v_input::Symbol)
if v_input nodes
error("Parent node of $(v_input) not found in node set: $(nodes)")
end
col = col_inds[v_input]
A[row, col] = true
end

function setinput!(A::SparseMatrixCSC{Bool, Int64}, row, col_inds, nodes, v_inputs)
for input in v_inputs
if input nodes
error("Parent node of $(input) not found in node set: $(nodes)")
for input in inputs[node]
if input nodes
error("Parent node of $(input) not found in node set: $(nodes)")
end
col = col_inds[input]
A[row, col] = true
end
col = col_inds[input]
A[row, col] = true
end
A
return A
end

adjacency_matrix(m::Model) = adjacency_matrix(m.ModelState.input)
Expand Down Expand Up @@ -226,6 +214,7 @@ Base.length(m::Model) = length(nodes(m))
Base.keytype(m::Model) = eltype(keys(m))
Base.valtype(m::Model) = eltype(m)


"""
dag(m::Model)
Expand All @@ -239,26 +228,4 @@ dag(m::Model) = m.DAG.A
Returns a `Vector{Symbol}` containing the sorted vertices
of the DAG.
"""
nodes(m::Model) = m.DAG.sorted_vertices

# # General eval function
# function evalf(f::Function, m::Model)
# nodes = m.DAG.sorted_vertex_list
# symlist = keys(m.ModelState.input)
# vals = (;)
# for (i, n) in enumerate(nodes)
# node = symlist[n]
# input_nodes = m.ModelState.input[node]
# if m.ModelState.kind[node] == :Stochastic
# if length(input_nodes) == 0
# vals = merge(vals, [node=>f(m.ModelState.eval[node]())])
# elseif length(input_nodes) > 0
# inputs = [vals[n] for n in input_nodes]
# vals = merge(vals, [node=>f(m.ModelState.eval[node](inputs...))])
# end
# else
# vals = merge(vals, [node=>m.ModelState.eval[node]()])
# end
# end
# vals
# end
nodes(m::Model) = m.DAG.sorted_vertices
9 changes: 5 additions & 4 deletions test/simpleppl.jl
Expand Up @@ -17,18 +17,19 @@ model = (
y = (zeros(5), (, :s2), (μ, s2) -> MvNormal(μ, sqrt(s2)), :Stochastic)
)


# test Model constructor for model with single parent node
@test typeof(
Model(
μ = (zeros(5), (), () -> 3, :Logical),
y = (zeros(5), (), (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)
μ = (1.0, (), () -> 3, :Logical),
y = (1.0, (,), (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)
)
) == Model

# test ErrorException for parent node not being found
@test_throws ErrorException Model(
μ = (zeros(5), (), () -> 3, :Logical),
y = (zeros(5), (), (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)
μ = (zeros(5), (,), () -> 3, :Logical),
y = (zeros(5), (,), (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)
)

m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor
Expand Down

0 comments on commit d51a91d

Please sign in to comment.