Skip to content

Commit

Permalink
Merge pull request #463 from olynch/acset-macro-fix
Browse files Browse the repository at this point in the history
@acset macro fix for runtime values
  • Loading branch information
epatters committed Jul 14, 2021
2 parents 6658383 + 455fff7 commit aa8032c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
42 changes: 22 additions & 20 deletions src/categorical_algebra/CSetDataStructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using ...Meta, ...Present
using ...Syntax: GATExpr, args
using ...Theories: Schema, FreeSchema, SchemaType,
CatDesc, CatDescType, ob, hom, dom, codom, codom_num,
AttrDesc, AttrDescType, data, attr, adom, acodom, data_num, attrs_by_codom
AttrDesc, AttrDescType, data, attr, attr_num, adom, acodom, data_num, attrs_by_codom

# Data types
############
Expand Down Expand Up @@ -850,8 +850,13 @@ end
```
"""
macro acset(head, body)
expr = :(init_acset($(esc(head)), $(Expr(:quote, body))))
Expr(:call, esc(:eval), expr)
@assert body.head == :block
vals = Expr(:call, :(Dict{Symbol,Any}))
for l in strip_lines(body).args
@assert l.head == :(=)
push!(vals.args, :($(Expr(:quote, l.args[1])) => $(l.args[2])))
end
:(init_acset($(esc(head)), $(esc(vals))))
end

"""
Expand All @@ -860,25 +865,22 @@ TODO: Could also rely on a @generated function that took in a "flat" named tuple
TODO: Alternative syntax for @acset input based on CSV
TODO: Actual CSV input
"""
function init_acset(T::Type{<:ACSet{CD,AD,Ts}},body) where {CD <: CatDesc, AD <: AttrDesc{CD}, Ts <: Tuple}
body = strip_lines(body)
@assert body.head == :block
code = quote
acs = $(T)()
function init_acset(T::Type{<:ACSet{CD,AD,Ts}}, initvals::Dict{Symbol,Any}) where
{CD <: CatDesc, AD <: AttrDesc{CD}, Ts <: Tuple}
acs = T()
ob_specs = filter((kv) -> kv[1] ob(CD), pairs(initvals))
hom_specs = filter((kv) -> kv[1] hom(CD), pairs(initvals))
attr_specs = filter((kv) -> kv[1] attr(AD), pairs(initvals))
for (k,v) in ob_specs
add_parts!(acs, k, Int(v))
end
for elem in body.args
lhs, rhs = @match elem begin
Expr(:(=), lhs, rhs) => (lhs,rhs)
_ => error("Every line of `@acset` must be an assignment")
end
if lhs in ob(CD)
push!(code.args, :(add_parts!(acs, $(Expr(:quote, lhs)), $(rhs))))
elseif lhs in hom(CD) || lhs in attr(AD)
push!(code.args, :(set_subpart!(acs, :, $(Expr(:quote, lhs)), $(rhs))))
end
for (k,v) in hom_specs
set_subpart!(acs, :, k, Vector{Int}(v))
end
for (k,v) in attr_specs
set_subpart!(acs, :, k, Vector{Ts.parameters[data_num(AD,codom(AD,k))]}(v))
end
push!(code.args, :(return acs))
code
acs
end

""" Map over a data type, in the style of Haskell functors
Expand Down
17 changes: 17 additions & 0 deletions test/categorical_algebra/CSetDataStructures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,23 @@ end
@test subpart(g,:,:src) == [1,2,3,4]
@test incident(g,1,:src) == [1]

function path_graph(n::Int)
@acset DecGraph{Float64} begin
V = n
E = (n-1)
src = (1:n-1)
tgt = (2:n)
dec = zeros(n-1)
end
end

pg = path_graph(30)

@test nparts(pg, :V) == 30
@test nparts(pg, :E) == 29
@test incident(pg, 1, :src) == [1]


# Test mapping
#-------------

Expand Down

0 comments on commit aa8032c

Please sign in to comment.