Skip to content

Commit

Permalink
Merge pull request #886 from AlgebraicJulia/fix_overlap_iterator
Browse files Browse the repository at this point in the history
Overlap iterator output span order bugfix + ignoring of specific attrtypes
  • Loading branch information
epatters committed Feb 3, 2024
2 parents 303d0b2 + 1f64c8d commit 819b5f5
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 39 deletions.
7 changes: 4 additions & 3 deletions src/categorical_algebra/CSets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1322,12 +1322,13 @@ preimage(f::ACSetTransformation,Y::StructACSet) =

"""
For any ACSet, X, a canonical map A→X where A has distinct variables for all
subparts.
attributes valued in attrtypes present in `abstract` (by default: all attrtypes)
"""
function abstract_attributes(X::ACSet)
function abstract_attributes(X::ACSet, abstract=nothing)
S = acset_schema(X)
abstract = isnothing(abstract) ? attrtypes(S) : abstract
A = copy(X)
comps = Dict{Any, Any}(map(attrtypes(S)) do at
comps = Dict{Any, Any}(map(abstract) do at
rem_parts!(A, at, parts(A, at))
comp = Union{AttrVar, attrtype_type(X, at)}[]
for (f, d, _) in attrs(S; to=at)
Expand Down
63 changes: 36 additions & 27 deletions src/categorical_algebra/HomSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -606,13 +606,18 @@ function Base.iterate(Sub::SubobjectIterator, state=SubobjectIteratorState())
end
end


struct OverlapIterator
top::ACSet
others::Vector{ACSet}
function OverlapIterator(Xs::Vector{T}) where T<:ACSet
t, o... = sort(Xs, by=total_parts)
new(t, o)
acsets::Vector{ACSet}
top::Int
abstract::Vector{Symbol}
function OverlapIterator(Xs::Vector{T}; abstract=true) where T<:ACSet
S = acset_schema(first(Xs))
abstract_attrs = if abstract isa Bool
abstract ? attrtypes(S) : Symbol[]
else
abstract
end
new(Xs, argmin(total_parts.(Xs)), abstract_attrs)
end
end
Base.eltype(::Type{OverlapIterator}) = Multispan
Expand Down Expand Up @@ -641,7 +646,7 @@ independently. This is the maps from A into all the other objects as well as the
automorphisms of A.
"""
function Base.iterate(Sub::OverlapIterator, state=nothing)
state = isnothing(state) ? OverlapIteratorState(Sub.top) : state
state = isnothing(state) ? OverlapIteratorState(Sub.acsets[Sub.top]) : state
# if we are not computing overlaps from a particular subobj,
if isnothing(state.curr_subobj) # pick the next subobj
isnothing(state.maps) || error("Inconsistent overlapiterator state")
Expand All @@ -653,7 +658,7 @@ function Base.iterate(Sub::OverlapIterator, state=nothing)
end
elseif isnothing(state.maps) # compute all the maps out of curr subobj
subobj = state.curr_subobj
abs_subobj = abstract_attributes(dom(subobj)) subobj
abs_subobj = abstract_attributes(dom(subobj), Sub.abstract) subobj
Y = dom(abs_subobj)
# don't repeat work if already computed syms/maps for something iso to Y
for res in state.seen
Expand All @@ -663,26 +668,29 @@ function Base.iterate(Sub::OverlapIterator, state=nothing)
return (Multispan(map(m->σm, res)), state)
end
end
maps = Vector{ACSetTransformation}[[abs_subobj]]
# Compute the automorphisms so that we can remove spurious symmetries
syms = isomorphisms(Y, Y)
# Get monic maps from Y into each of the objects. The first comes for free
maps = Vector{ACSetTransformation}[[abs_subobj]]
for X in Sub.others
fs = homomorphisms(Y, X; monic=ob(acset_schema(Y)))
real_fs = Set() # quotient fs via automorphisms of Y
for f in fs
if all(rf->all-> forcef) != force(rf), syms), real_fs)
push!(real_fs, f)
maps = Vector{ACSetTransformation}[]
for (i, X) in enumerate(Sub.acsets)
if i == Sub.top
push!(maps, [abs_subobj])
else
fs = homomorphisms(Y, X; monic=ob(acset_schema(Y)))
real_fs = Set() # quotient fs via automorphisms of Y
for f in fs
if all(rf->all-> forcef) != force(rf), syms), real_fs)
push!(real_fs, f)
end
end
if isempty(real_fs)
break # this subobject of Xs[1] does not have common overlap w/ all Xs
else
push!(maps, collect(real_fs))
end
end
if isempty(real_fs)
break # this subobject of Xs[1] does not have common overlap w/ all Xs
else
push!(maps,collect(real_fs))
end
end
if length(maps) == length(Sub.others) + 1
if length(maps) == length(Sub.acsets)
state.maps = Iterators.Stateful(Iterators.product(maps...))
else
state.curr_subobj = nothing
Expand All @@ -696,8 +704,8 @@ function Base.iterate(Sub::OverlapIterator, state=nothing)
end
end

partial_overlaps(Xs::Vector{T}) where T<:ACSet = OverlapIterator(Xs)
partial_overlaps(Xs::ACSet...) = Xs |> collect |> partial_overlaps
partial_overlaps(Xs::Vector{T}; abstract=true) where T<:ACSet = OverlapIterator(Xs; abstract)
partial_overlaps(Xs::ACSet...; abstract=true) = partial_overlaps(collect(Xs); abstract)

""" Compute the Maximimum Common C-Sets from a vector of C-Sets.
Expand All @@ -711,8 +719,8 @@ these are all returned.
If there are attributes, we ignore these and use variables in the apex of the
overlap.
"""
function maximum_common_subobject(Xs::Vector{T}) where T <: ACSet
it = partial_overlaps(Xs)
function maximum_common_subobject(Xs::Vector{T}; abstract=true) where T <: ACSet
it = partial_overlaps(Xs; abstract)
osize = -1
res = DefaultDict(()->[])
for overlap in it
Expand All @@ -725,7 +733,8 @@ function maximum_common_subobject(Xs::Vector{T}) where T <: ACSet
return res
end

maximum_common_subobject(Xs::T...) where T <: ACSet = maximum_common_subobject(collect(Xs))
maximum_common_subobject(Xs::T...; abstract=true) where T <: ACSet =
maximum_common_subobject(collect(Xs); abstract)


end # module
46 changes: 37 additions & 9 deletions test/categorical_algebra/HomSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,39 +162,63 @@ subG, subobjs = subobject_graph(G) |> collect
@test length(incident(subG, 1, :src)) == 1 # ⊤ is terminal

# Graph and ReflexiveGraph should have same subobject structure
subG = subobject_graph(path_graph(Graph, 2)) |> first
subG, _ = subobject_graph(path_graph(Graph, 2))
subRG, sos = subobject_graph(path_graph(ReflexiveGraph, 2))
@test all(is_natural, hom.(sos))
@test is_isomorphic(subG, subRG)

# Partial overlaps
G,H = path_graph.(Graph, 2:3)
os = collect(partial_overlaps(G,G))
G, H = path_graph.(Graph, 2:3)
os = collect(partial_overlaps(G, G))
@test length(os) == 7 # ⊤, ••, 4× •, ⊥

po = partial_overlaps([G,H])
po = partial_overlaps([G, H])
@test length(collect(po))==12 # 2×⊤, 3ו•, 6× •, ⊥
@test all(m -> apex(m) == G, Iterators.take(po, 2)) # first two are •→•
@test all(m -> apex(m) == Graph(2),
Iterators.take(Iterators.drop(po, 2), 3)) # next three are • •

# Partial overlaps with attributes

@present SchVELabeledGraph <: SchGraph begin
VL::AttrType; EL::AttrType; vlabel::Attr(V,VL); elabel::Attr(E,EL)
end

@acset_type VELabeledGraph(SchVELabeledGraph) <: AbstractGraph
const LGraph = VELabeledGraph{Bool,Bool}

G = @acset LGraph begin
V=3; E=2; src=[1,2]; tgt=[2,3]; vlabel=Bool[0,1,1]; elabel=Bool[0,1]
end
H = @acset LGraph begin
V=3; E=2; src=[1,2]; tgt=[2,3]; vlabel=Bool[0,0,1]; elabel=Bool[0,0]
end
os = partial_overlaps(G, H); # abstract=true
@test count(apx->nparts(apx,:E)==2, apex.(os)) == 1
os = partial_overlaps(G, H; abstract=[:VL]);
@test count(apx->nparts(apx,:E)==2, apex.(os)) == 0
@test count(apx->nparts(apx,:E)==1, apex.(os)) == 4
os = partial_overlaps(G, H; abstract=false);
@test count(apx->nparts(apx,:E)==2, apex.(os)) == 0
@test count(apx->nparts(apx,:E)==1, apex.(os)) == 1

# Maximum Common C-Set
######################

const WG = WeightedGraph{Bool}
"""
Searching for overlaps: •→•→•↺ vs ↻•→•→•
Two results: •→•→• || •↺ •→•
"""
g1 = @acset WeightedGraph{Bool} begin
g1 = @acset WG begin
V=3; E=3; src=[1,1,2]; tgt=[1,2,3]; weight=[true,false,false]
end
g2 = @acset WeightedGraph{Bool} begin
g2 = @acset WG begin
V=3; E=3; src=[1,2,3]; tgt=[2,3,3]; weight=[true,false,false]
end
apex1 = @acset WeightedGraph{Bool} begin
apex1 = @acset WG begin
V=3; E=2; Weight=2; src=[1,2]; tgt=[2,3]; weight=AttrVar.(1:2)
end
apex2 = @acset WeightedGraph{Bool} begin
apex2 = @acset WG begin
V=3; E=2; Weight=2; src=[1,3]; tgt=[2,3]; weight=AttrVar.(1:2)
end

Expand All @@ -213,4 +237,8 @@ results = first(is_iso1) ? results : reverse(results)
@test collect(R2[:V]) == [3,1,2]
@test L2(apx2) == Subobject(g1, V=[1,2,3], E=[1,3])

# If we demand equality on attributes, max overlap is one false edge and all vertices.
exp = @acset WG begin V=3; E=1; src=1; tgt=2; weight=[false] end
@test first(first(maximum_common_subobject(g1, g2; abstract=false))) == exp

end # module

0 comments on commit 819b5f5

Please sign in to comment.