Skip to content

Commit

Permalink
add @inline annotations to [getindex|setindex!](::ACSet, ...)
Browse files Browse the repository at this point in the history
These functions can be significantly optimized with constant information
available within their arguments, so it would be beneficial to enforce
constant propagation for them.
This can be achieved using `Base.@constprop :aggressive` or `@inline`,
and since these functions are likely to be used frequently, I think
promoting inline expansion with `@inline` is a good approach,
so I chose to use `@inline`.

- fixes AlgebraicJulia#121
  • Loading branch information
aviatesk committed May 19, 2024
1 parent 9ac3b78 commit 5e74b8b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
19 changes: 9 additions & 10 deletions src/ACSetInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export ACSet, acset_schema, acset_name, dom_parts, codom_parts, subpart_type,
add_part!, add_parts!, set_subpart!, set_subparts!, clear_subpart!,
rem_part!, rem_parts!, cascading_rem_part!, cascading_rem_parts!, gc!,
copy_parts!, copy_parts_only!, disjoint_union, tables, pretty_tables,
@acset, constructor, undefined_subparts, PartsType, DenseParts, MarkAsDeleted,
@acset, constructor, undefined_subparts, PartsType, DenseParts, MarkAsDeleted,
rem_free_vars!, parts_type

using MLStyle: @match
Expand Down Expand Up @@ -170,9 +170,8 @@ function subpart(acs, part, names::AbstractVector{Symbol})
end
end

Base.getindex(acs::ACSet, part, name) = subpart(acs, part, name)
Base.getindex(acs::ACSet, name) = subpart(acs, name)

@inline Base.getindex(acs::ACSet, part, name) = subpart(acs, part, name)
@inline Base.getindex(acs::ACSet, name) = subpart(acs, name)

""" Get superparts incident to part in acset.
Expand Down Expand Up @@ -280,8 +279,8 @@ end

@inline set_subparts!(acs, part; kw...) = set_subparts!(acs, part, (;kw...))

Base.setindex!(acs::ACSet, val, part, name) = set_subpart!(acs, part, name, val)
Base.setindex!(acs::ACSet, vals, name) = set_subpart!(acs, name, vals)
@inline Base.setindex!(acs::ACSet, val, part, name) = set_subpart!(acs, part, name, val)
@inline Base.setindex!(acs::ACSet, vals, name) = set_subpart!(acs, name, vals)

"""Clear a subpart in a C-set
Expand Down Expand Up @@ -401,7 +400,7 @@ Garbage collect in an acset.
For some choices of [`IDAllocator`](@ref), this function is a no-op.
"""
function gc! end
function gc! end

"""
Get a nullary callable which constructs an (empty) ACSet of the same type
Expand Down Expand Up @@ -514,13 +513,13 @@ function make_acset end
Remove all AttrType parts that are not in the image of any of the attributes.
"""
function rem_free_vars!(acs::ACSet)
for k in attrtypes(acset_schema(acs))
for k in attrtypes(acset_schema(acs))
rem_free_vars!(acs, k)
end
end

rem_free_vars!(X::ACSet, a::Symbol) = rem_parts!(X, a, filter(parts(X,a)) do p
all(f->isempty(incident(X, AttrVar(p), f)),
rem_free_vars!(X::ACSet, a::Symbol) = rem_parts!(X, a, filter(parts(X,a)) do p
all(f->isempty(incident(X, AttrVar(p), f)),
attrs(acset_schema(X); to=a, just_names=true))
end)

Expand Down
44 changes: 30 additions & 14 deletions test/ACSets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function cascading_delete_is_natural(i::Int)
X = DDS(i)
X′ = copy(X)
d = cascading_rem_part!(X, :X, 1)[:X]
all(enumerate(X[])) do (i, ϕᵢ)
all(enumerate(X[])) do (i, ϕᵢ)
X′[d[i], ] == d[ϕᵢ]
end
end
Expand Down Expand Up @@ -95,7 +95,7 @@ for dds_maker in dds_makers
@test dds[1:2,] isa Vector{Int}
@test dds[] isa Vector{Int}
@test dds[[,]] isa Vector{Int}
v = dds[]
v = dds[]
v[1] = 1
@test v != dds[]
@test dds[1:2, [,]] isa Vector{Int}
Expand All @@ -115,18 +115,18 @@ for dds_maker in dds_makers
else
@test subpart(dds, ) == [2,3]
@test incident(dds, 2, ) == [1]
end
end
rem_part!(dds, :X, 2)
if dds.parts[:X] isa IntParts
@test nparts(dds, :X) == 1
@test subpart(dds, ) == [0]
else
else
@test nparts(dds, :X) == 2
end
end
rem_part!(dds, :X, 1)
if dds.parts[:X] isa IntParts
@test nparts(dds, :X) == 0
else
else
@test nparts(dds, :X) == 1
end

Expand Down Expand Up @@ -264,7 +264,7 @@ for (dgram_maker, ldgram_maker) in dgram_makers
rem_free_vars!(d) # now the added X is AttrVar(1), if IntParts
@test nparts(d, :R) == 1
A = AttrVar(only(parts(d, :R)))

@test nparts(d, :X) == 5
@test subpart(d, 1:3, :parent) == [4,4,4]
@test subpart(d, 4, :parent) == 5
Expand Down Expand Up @@ -414,7 +414,7 @@ for lset_maker in lset_makers
@test subpart_type(lset, :label) == Symbol
@test subpart_type(lset, :Label) == Symbol
@test_throws Exception subpart_type(lset, :abel)

# Labeled set with compound label (tuple).
lset = lset_maker(Tuple{Int,Int})
add_parts!(lset, :X, 2, label=[(1,1), (1,2)])
Expand Down Expand Up @@ -613,7 +613,7 @@ rem_part!(g, :X, 1)
@test g[:tgt] == [2,3]
@test g[:dec] == ["b",AttrVar(2)]

# Densify and sparsify
# Densify and sparsify
g = @acset MadDecGraph{String} begin
V = 4
E = 4
Expand Down Expand Up @@ -760,7 +760,7 @@ end

@test_throws Exception subpart(datcomp, :, (:f,:h))
@test_throws Exception subpart(datcomp, (:f,:h))
@test_throws Exception subpart(datcomp, 1, (:f,:h))
@test_throws Exception subpart(datcomp, 1, (:f,:h))

@test datcomp[:, (:f,:g)] == [3,2,1,3,2]
@test datcomp[1:5, (:f,:g)] == [3,2,1,3,2]
Expand All @@ -784,7 +784,7 @@ datcompdyn = DynamicACSet(datcomp)

@test_throws Exception subpart(datcompdyn, :, (:f,:h))
@test_throws Exception subpart(datcompdyn, (:f,:h))
@test_throws Exception subpart(datcompdyn, 1, (:f,:h))
@test_throws Exception subpart(datcompdyn, 1, (:f,:h))

@test datcompdyn[:, (:f,:g)] == [3,2,1,3,2]
@test datcompdyn[1:5, (:f,:g)] == [3,2,1,3,2]
Expand All @@ -805,7 +805,7 @@ datcompdyn = DynamicACSet(datcomp)
@test incident(datcomp, 3, (:g,)) == [1,4]
@test incident(datcomp, 1:3, (:g,)) == [3,2,1,4]
@test incident(datcomp, 1:3, (:f,:g)) == [3,2,5,1,4]
@test incident(datcomp, 1:3, (:f,:g)) == incident(datcomp, 1:3, [:f,:g])
@test incident(datcomp, 1:3, (:f,:g)) == incident(datcomp, 1:3, [:f,:g])

@test incident(datcomp, [:a,:b,:c], (:f,:g,:zattr)) == incident(datcomp, [:a,:b,:c], [:f,:g,:zattr])

Expand All @@ -824,7 +824,7 @@ datcompdyn = DynamicACSet(datcomp)

# Test @acset_type with type parameters
#--------------------------------------
# Single
# Single
@acset_type IntLabeledSet(SchLabeledSet, index=[:label]){Int}
@test isempty(IntLabeledSet())

Expand Down Expand Up @@ -863,11 +863,27 @@ SchLDDS = BasicSchema([:X], [(:Φ,:X,:X)],[:Y],[(:l,:X,:Y)])
@acset_type AbsLDDS(SchLDDS)
const LDDS = AbsLDDS{StaticVector} # DDS w/ labeled states

X = @acset LDDS begin
X = @acset LDDS begin
X = 2; Φ = [2,2]; l = [SA[:a,:b], SA[:a]]
end

@test X[(,:l)] isa Vector{<:StaticVector}
@test X[1,(,:l)] isa StaticVector

SchInference = BasicSchema([:V,:E], [(:v0,:E,:V)])
@acset_type InferenceTest(SchInference, index=[:v0])
let s = InferenceTest()
@test Base.return_types((typeof(s),)) do s
s[1, :v0]
end |> only === Int
add_part!(s, :V)
add_part!(s, :E)
call_setindex!(s) = s[1, :v0] = 1
@test call_setindex!(s) == 1
@test iszero(@allocated call_setindex!(s))
call_getindex(s) = s[1, :v0]
@test call_getindex(s) == 1
@test iszero(@allocated call_getindex(s))
end

end # module

0 comments on commit 5e74b8b

Please sign in to comment.