Skip to content

Commit

Permalink
Merge pull request #736 from AlgebraicJulia/map-fix
Browse files Browse the repository at this point in the history
Simplified implementation of `map` for acsets
  • Loading branch information
epatters committed Jan 31, 2023
2 parents 55eeaff + eba89ac commit c99e580
Showing 1 changed file with 51 additions and 62 deletions.
113 changes: 51 additions & 62 deletions src/acsets/DenseACSets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,23 @@ function replace_colons(acs::ACSet, parts::NamedTuple{types}) where {types}
end)
end

# Type modification
###################

function empty_with_types(acs::SA, type_assignment) where {S, SA <: StructACSet{S}}
s = acset_schema(acs)
(SA.name.wrapper){[type_assignment[d] for d in attrtypes(s)]...}()
end

function empty_with_types(acs::DynamicACSet, type_assignment)
DynamicACSet(acs.name, acs.schema, type_assignment)
end

function get_type_assignment(acs::StructACSet{S,Ts}) where {S,Ts}
Dict(d => Ts.parameters[i] for (i,d) in enumerate(attrtypes(S)))
end

get_type_assignment(acs::DynamicACSet) = acs.type_assignment

# Printing
##########
Expand Down Expand Up @@ -709,38 +726,11 @@ end
#########

function Base.map(acs::ACSet; kwargs...)
_map(acs, (;kwargs...))
end

function sortunique!(x)
sort!(x)
unique!(x)
x
end

function groupby(f::Function, xs)
d = Dict{typeof(f(xs[1])),Vector{eltype(xs)}}()
for x in xs
k = f(x)
if k in keys(d)
push!(d[k],x)
else
d[k] = [x]
end
end
d
end

# Eventually should translate this to comptime so it will work with dynamic acsets
@generated function _map(acs::AT, fns::NamedTuple{map_over}) where
{S, Ts, AT<:StructACSet{S,Ts}, map_over}
s = Schema(S)
map_over = Symbol[map_over...]

q(s) = Expr(:quote, s)
s = acset_schema(acs)
fns = (;kwargs...)

mapped_attrs = intersect(attrs(s; just_names=true), map_over)
mapped_attrtypes = intersect(attrtypes(s), map_over)
mapped_attrs = intersect(attrs(s; just_names=true), keys(fns))
mapped_attrtypes = intersect(attrtypes(s), keys(fns))
mapped_attrs_from_attrtypes = [a for (a,d,c) in attrs(s) if c mapped_attrtypes]
attrs_accounted_for = sortunique!(Symbol[mapped_attrs; mapped_attrs_from_attrtypes])

Expand All @@ -751,46 +741,45 @@ end
unnaccounted_for_attrs == [] ||
error("not enough functions provided to fully transform ACSet, need functions for: $(unnaccounted_for_attrs)")

fn_applications = map(attrs_accounted_for) do a
qa = q(a)
if a mapped_attrs
:($a = (fns[$qa]).(subpart(acs, $qa)))
new_subparts = Dict(
f => (f keys(fns) ? fns[f] : fns[codom(s, f)]).(subpart(acs, f))
for f in needed_attrs)

type_assignments = get_type_assignment(acs)

new_type_assignments = Dict(map(enumerate(attrtypes(s))) do (i,d)
if d affected_attrtypes
d => mapreduce(eltype, typejoin, [new_subparts[f] for f in attrs(s, to=d, just_names=true)])
else
d = codom(s,a)
:($a = (fns[$(q(d))]).(subpart(acs, $qa)))
d => type_assignments[d]
end
end...)

new_acs = empty_with_types(acs, new_type_assignments)

for ob in objects(s)
add_parts!(new_acs, ob, nparts(acs, ob))
end

abc = groupby(a -> codom(s,a), attrs(s; just_names=true))
for f in homs(s; just_names=true)
set_subpart!(new_acs, :, f, subpart(acs, f))
end

attrtype_instantiations = map(enumerate(attrtypes(s))) do (i,d)
if d affected_attrtypes
:(mapreduce(eltype, typejoin,
$(Expr(:tuple, (:(fn_vals[$(q(a))]) for a in abc[d])...))))
for f in attrs(s; just_names=true)
if f keys(new_subparts)
set_subpart!(new_acs, :, f, new_subparts[f])
else
:($(Ts[i]))
set_subpart!(new_acs, :, f, subpart(acs, f))
end
end

quote
fn_vals = $(Expr(:tuple, fn_applications...))
new_acs = $(AT.name.wrapper){$(attrtype_instantiations...)}()
$(Expr(:block, map(objects(s)) do ob
:(add_parts!(new_acs,$(q(ob)),nparts(acs,$(q(ob)))))
end...))
$(Expr(:block, map(homs(s; just_names=true)) do f
:(set_subpart!(new_acs,$(q(f)),subpart(acs,$(q(f)))))
end...))
$(Expr(:block, map(attrs(s; just_names=true)) do a
qa = Expr(:quote, a)
if a attrs_accounted_for
:(set_subpart!(new_acs,$qa,fn_vals[$qa]))
else
:(set_subpart!(new_acs,$qa,subpart(acs,$qa)))
end
end...))
return new_acs
end
new_acs
end

function sortunique!(x)
sort!(x)
unique!(x)
x
end

end

0 comments on commit c99e580

Please sign in to comment.