diff --git a/src/abstractdatagraph.jl b/src/abstractdatagraph.jl index e03cfef..f6d46db 100644 --- a/src/abstractdatagraph.jl +++ b/src/abstractdatagraph.jl @@ -1,6 +1,16 @@ using Dictionaries: set!, unset! using Graphs: - Graphs, AbstractEdge, AbstractGraph, IsDirected, add_edge!, edges, ne, nv, vertices + Graphs, + AbstractEdge, + AbstractGraph, + IsDirected, + add_edge!, + a_star, + edges, + ne, + nv, + steiner_tree, + vertices using NamedGraphs.GraphsExtensions: GraphsExtensions, incident_edges, vertextype using NamedGraphs.SimilarType: similar_type using SimpleTraits: SimpleTraits, Not, @traitfn @@ -131,6 +141,20 @@ function outdegree(graph::AbstractDataGraph, vertex::Integer) return outdegree(underlying_graph(graph), vertex) end +# Fix for ambiguity error with `AbstractGraph` version +function Graphs.a_star( + graph::AbstractDataGraph, source::Integer, destination::Integer, args... +) + return a_star(underlying_graph(graph), source, destination, args...) +end + +# Fix for ambiguity error with `AbstractGraph` version +@traitfn function Graphs.steiner_tree( + graph::AbstractDataGraph::(!IsDirected), term_vert::Vector{<:Integer}, args... +) + return steiner_tree(underlying_graph(graph), term_vert, args...) +end + @traitfn GraphsExtensions.directed_graph(graph::AbstractDataGraph::IsDirected) = graph @traitfn function GraphsExtensions.directed_graph(graph::AbstractDataGraph::(!IsDirected)) diff --git a/test/runtests.jl b/test/runtests.jl index d58bf89..d37d69c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using DataGraphs: using Dictionaries: AbstractIndices, Dictionary, Indices, dictionary using Graphs: add_edge!, + a_star, bfs_tree, connected_components, degree, @@ -27,6 +28,7 @@ using Graphs: outdegree, path_graph, src, + steiner_tree, vertices using Graphs.SimpleGraphs: SimpleDiGraph, SimpleEdge, SimpleGraph using GraphsFlows: GraphsFlows @@ -411,6 +413,17 @@ using DataGraphs: is_arranged @test ps.parents == dictionary([1 => 1, 2 => 1, 3 => 2, 4 => 3]) @test ps.pathcounts == dictionary([1 => 1.0, 2 => 1.0, 3 => 1.0, 4 => 1.0]) end + @testset "a_star" begin + g = DataGraph(named_grid(4)) + path = a_star(g, 1, 3) + @test path == NamedEdge.([1 => 2, 2 => 3]) + end + @testset "steiner_tree" begin + g = DataGraph(named_grid(5)) + t = steiner_tree(g, [2, 4]) + @test nv(t) == 3 + @test ne(t) == 2 + end @testset "GraphsFlows.mincut (vertextype=$(eltype(verts))" for verts in ( [1, 2, 3, 4], ["A", "B", "C", "D"] )