diff --git a/src/graphinfo.jl b/src/graphinfo.jl index 571826f..4c65fc6 100644 --- a/src/graphinfo.jl +++ b/src/graphinfo.jl @@ -20,11 +20,11 @@ adjacency matrix and topologically ordered vertex list and stored. GraphInfo is instantiated using the `Model` constctor. """ -struct GraphInfo{T} <: AbstractModelTrace - input::NamedTuple{T} - value::NamedTuple{T} - eval::NamedTuple{T} - kind::NamedTuple{T} +struct GraphInfo{Tnames, Tinput, Tvalue, Teval, Tkind} <: AbstractModelTrace + input::NamedTuple{Tnames, Tinput} + value::NamedTuple{Tnames, Tvalue} + eval::NamedTuple{Tnames, Teval} + kind::NamedTuple{Tnames, Tkind} A::SparseMatrixCSC sorted_vertices::Vector{Symbol} end @@ -55,8 +55,8 @@ y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic) ``` """ -struct Model{T} <: AbstractProbabilisticProgram - g::GraphInfo{T} +struct Model{Tnames, Tinput, Tvalue, Teval, Tkind} <: AbstractProbabilisticProgram + g::GraphInfo{Tnames, Tinput, Tvalue, Teval, Tkind} end function Model(;kwargs...) diff --git a/test/graphinfo.jl b/test/graphinfo.jl index 02c6e2e..d40527f 100644 --- a/test/graphinfo.jl +++ b/test/graphinfo.jl @@ -24,11 +24,11 @@ model = ( m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor # test the type of the model is correct -@test typeof(m) <: Model +@test m isa Model sorted_vertices = get_sorted_vertices(m) -@test typeof(m) == Model{Tuple(sorted_vertices)} -@test typeof(m.g) <: GraphInfo <: AbstractModelTrace -@test typeof(m.g) == GraphInfo{Tuple(sorted_vertices)} +@test m isa Model{Tuple(sorted_vertices)} +@test m.g isa GraphInfo <: AbstractModelTrace +@test m.g isa GraphInfo{Tuple(sorted_vertices)} # test the dag is correct A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0]) @@ -37,11 +37,18 @@ A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0]) @test length(m) == 5 @test eltype(m) == valtype(m) + # check the values from the NamedTuple match the values in the fields of GraphInfo vals, evals, kinds = AbstractPPL.GraphPPL.getvals(NamedTuple{Tuple(sorted_vertices)}(model)) inputs = (s2 = (), xmat = (), β = (), μ = (:xmat, :β), y = (:μ, :s2)) for (i, vn) in enumerate(keys(m)) + @inferred m[vn] + @inferred get_node_value(m, vn) + @inferred get_node_eval(m, vn) + @inferred get_nodekind(m, vn) + @inferred get_node_input(m, vn) + @test vn isa VarName @test get_node_value(m, vn) == vals[i] @test get_node_eval(m, vn) == evals[i] @@ -50,16 +57,16 @@ for (i, vn) in enumerate(keys(m)) end for node in m - @test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]} + @test node isa NamedTuple{fieldnames(GraphInfo)[1:4]} end # test Model constructor for model with single parent node single_parent_m = Model(μ = (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic)) -@test typeof(single_parent_m) == Model{(:μ, :y)} -@test typeof(single_parent_m.g) == GraphInfo{(:μ, :y)} +@test single_parent_m isa Model{(:μ, :y)} +@test single_parent_m.g isa GraphInfo{(:μ, :y)} -# test setindex +# test setindex @test_throws AssertionError set_node_value!(m, @varname(s2), [0.0]) @test_throws AssertionError set_node_value!(m, @varname(s2), (1.0,)) set_node_value!(m, @varname(s2), 1.0)