Skip to content

Commit

Permalink
UDT/Monoid/Bugfix Update (#86)
Browse files Browse the repository at this point in the history
* UnaryOp rework, remove need to @Unop

* Removal of BinaryOp, Monoid update in progress

* monoids working on the surface

* passing all except random map tests

* fix promotion

* rm debug prints

* rename

* fix #85, fix #83

* fix #82

* work towards #81

* fixes #80, fixes #76

* fix #77
  • Loading branch information
Will Kimmerer committed Aug 16, 2022
1 parent 7aeaf5b commit ba33380
Show file tree
Hide file tree
Showing 39 changed files with 613 additions and 431 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.7.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
HyperSparseMatrices = "c7efdb1c-7caa-4c7d-9b5e-9093f9323c7c"
KLU = "ef3ab10e-7fda-4108-b977-705223b18434"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -21,7 +20,6 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
ChainRulesCore = "1"
HyperSparseMatrices = "0.2"
MacroTools = "0.5"
Preferences = "1"
SSGraphBLAS_jll = "6.2.1"
Expand Down
4 changes: 1 addition & 3 deletions docs/src/binaryops.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ Internally functions are lowered like this:
```@repl
using SuiteSparseGraphBLAS
op = BinaryOp(+)
typedop = op(Int64, Int64)
typedop = binaryop(+, Int64, Int64)
eadd(GBVector([1,2]), GBVector([3,4]), typedop)
```
Expand Down
3 changes: 1 addition & 2 deletions docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ Operators are lowered from a Julia function to a container like `BinaryOp` or `S
using SuiteSparseGraphBLAS
```
```@repl operators
b = BinaryOp(+)
b(Int32)
b = binaryop(+, Int32)
s = Semiring(max, +)
s(Float64)
Expand Down
4 changes: 2 additions & 2 deletions docs/src/udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ GraphBLAS supports users to supply functions as operators. Constructors exported

- `UnaryOp(name::String, fn::Function, [type | types | ztype, xtype | ztypes, xtypes])`
- `BinaryOp(name::String, fn::Function, [type | types | ztype, xtype | ztypes, xtypes])`
- `Monoid(name::String, binop::Union{AbstractBinaryOp, GrB_BinaryOp}, id::T, terminal::T = nothing)`: all types must be the same.
- `Semiring(name::String, add::[GrB_Monoid | AbstractMonoid], mul::[GrB_BinaryOp | AbstractBinaryOp])`
- `Monoid(name::String, binop::Union{GrB_BinaryOp}, id::T, terminal::T = nothing)`: all types must be the same.
- `Semiring(name::String, add::[GrB_Monoid | AbstractMonoid], mul::GrB_BinaryOp)`

`GrB_` prefixed arguments are typed operators, such as the result of `UnaryOps.COS[Float64]`.
Type arguments may be single types or vectors of types.
Expand Down
4 changes: 1 addition & 3 deletions docs/src/unaryops.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ Internally functions are lowered like this:
```@repl
using SuiteSparseGraphBLAS
op = UnaryOp(sin)
typedop = op(Float64)
op = unaryop(sin, Float64)
map(typedop, GBVector([1.5, 0, pi]))
```
Expand Down
4 changes: 1 addition & 3 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ using Serialization
using StorageOrders

export ColMajor, RowMajor, storageorder #reexports from StorageOrders

using HyperSparseMatrices
include("abstracts.jl")
include("libutils.jl")

Expand Down Expand Up @@ -101,7 +99,7 @@ include("oriented.jl")
export SparseArrayCompat
export LibGraphBLAS
# export UnaryOps, BinaryOps, Monoids, Semirings #Submodules
export UnaryOp, BinaryOp, Monoid, Semiring #UDFs
export unaryop, binaryop, Monoid, semiring #UDFs
export Descriptor #Types
export gbset, gbget # global and object specific options.
# export xtype, ytype, ztype #Determine input/output types of operators
Expand Down
42 changes: 21 additions & 21 deletions src/abstractgbarray.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# AbstractGBArray functions:
function SparseArrays.nnz(A::AbsGBArrayOrTranspose)
function SparseArrays.nnz(A::GBArrayOrTranspose)
nvals = Ref{LibGraphBLAS.GrB_Index}()
@wraperror LibGraphBLAS.GrB_Matrix_nvals(nvals, gbpointer(parent(A)))
return Int64(nvals[])
end

Base.eltype(::Type{AbstractGBArray{T}}) where{T} = T
Base.eltype(::Type{GBArrayOrTranspose{T}}) where{T} = T

"""
empty!(v::GBVector)
Expand All @@ -14,32 +14,32 @@ Base.eltype(::Type{AbstractGBArray{T}}) where{T} = T
Clear all the entries from the GBArray.
Does not modify the type or dimensions.
"""
function Base.empty!(A::AbsGBArrayOrTranspose)
function Base.empty!(A::GBArrayOrTranspose)
@wraperror LibGraphBLAS.GrB_Matrix_clear(gbpointer(parent(A)))
return A
end

function Base.Matrix(A::AbstractGBMatrix)
function Base.Matrix(A::GBArrayOrTranspose)
sparsity = sparsitystatus(A)
T = copy(A) # We copy here to 1. avoid densifying A, and 2. to avoid destroying A.
return unpack!(T, Dense())
end

function Base.Vector(v::AbstractGBVector)
function Base.Vector(v::GBVectorOrTranspose)
sparsity = sparsitystatus(v)
T = copy(v) # avoid densifying v and destroying v.
return unpack!(T, Dense())
end

function SparseArrays.SparseMatrixCSC(A::AbstractGBArray)
function SparseArrays.SparseMatrixCSC(A::GBArrayOrTranspose)
sparsity = sparsitystatus(A)
T = copy(A) # avoid changing sparsity of A and destroying it.
return unpack!(T, SparseMatrixCSC)
end

function SparseArrays.SparseVector(v::AbstractGBVector)
function SparseArrays.SparseVector(v::GBVectorOrTranspose)
sparsity = sparsitystatus(v)
T = copy(A) # avoid changing sparsity of v and destroying it.
T = copy(v) # avoid changing sparsity of v and destroying it.
return unpack!(T, SparseVector)
end

Expand Down Expand Up @@ -94,7 +94,7 @@ for T ∈ valid_vec
function build(A::AbstractGBMatrix{$T}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector{$T};
combine = +
)
combine = BinaryOp(combine)($T)
combine = binaryop(combine, $T)
I isa Vector || (I = collect(I))
J isa Vector || (J = collect(J))
X isa Vector || (X = collect(X))
Expand Down Expand Up @@ -181,7 +181,7 @@ function build(
A::AbstractGBMatrix{T}, I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}, X::AbstractVector{T};
combine = +
) where {T}
combine = BinaryOp(combine)(T)
combine = binaryop(combine, T)
I isa Vector || (I = collect(I))
J isa Vector || (J = collect(J))
X isa Vector || (X = collect(X))
Expand Down Expand Up @@ -314,7 +314,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`assign!`](@ref) except that
# Keywords
- `mask::Union{Nothing, GBMatrix} = nothing`: mask where
`size(M) == size(A)`.
- `accum::Union{Nothing, Function, AbstractBinaryOp} = nothing`: binary accumulator operation
- `accum::Union{Nothing, Function} = nothing`: binary accumulator operation
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
- `desc::Union{Nothing, Descriptor} = nothing`
Expand All @@ -325,7 +325,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`assign!`](@ref) except that
- `GrB_DIMENSION_MISMATCH`: If `size(A) != (max(I), max(J))` or `size(A) != size(mask)`.
"""
function subassign!(
C::AbstractGBArray, A::AbstractGBArray, I, J;
C::AbstractGBArray, A::GBArrayOrTranspose, I, J;
mask = nothing, accum = nothing, desc = nothing
)
I, ni = idx(I)
Expand Down Expand Up @@ -397,7 +397,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`subassign!`](@ref) except that
# Keywords
- `mask::Union{Nothing, GBMatrix} = nothing`: mask where
`size(M) == size(C)`.
- `accum::Union{Nothing, Function, AbstractBinaryOp} = nothing`: binary accumulator operation
- `accum::Union{Nothing, Function} = nothing`: binary accumulator operation
where `C[i,j] = accum(C[i,j], T[i,j])` where T is the result of this function before accum is applied.
- `desc::Union{Nothing, Descriptor} = nothing`
Expand All @@ -408,7 +408,7 @@ Assign a submatrix of `A` to `C`. Equivalent to [`subassign!`](@ref) except that
- `GrB_DIMENSION_MISMATCH`: If `size(A) != (max(I), max(J))` or `size(C) != size(mask)`.
"""
function assign!(
C::AbstractGBMatrix, A::AbstractGBVector, I, J;
C::AbstractGBMatrix, A::GBArrayOrTranspose, I, J;
mask = nothing, accum = nothing, desc = nothing
)
I, ni = idx(I)
Expand All @@ -417,16 +417,16 @@ function assign!(
I = decrement!(I)
J = decrement!(J)
# we know A isn't adjoint/transpose on input
desc = _handledescriptor(desc)
@wraperror LibGraphBLAS.GrB_Matrix_assign(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(A), I, ni, J, nj, desc)
desc = _handledescriptor(desc; in1=A)
@wraperror LibGraphBLAS.GrB_Matrix_assign(gbpointer(C), mask, getaccum(accum, eltype(C)), gbpointer(parent(A)), I, ni, J, nj, desc)
increment!(I)
increment!(J)
return A
end

function assign!(C::AbstractGBArray, x, I, J;
function assign!(C::AbstractGBArray{T}, x, I, J;
mask = nothing, accum = nothing, desc = nothing
)
) where T
x = typeof(x) === T ? x : convert(T, x)
I, ni = idx(I)
J, nj = idx(J)
Expand Down Expand Up @@ -467,7 +467,7 @@ end
Base.eltype(::Type{AbstractGBVector{T}}) where{T} = T

function Base.deleteat!(v::AbstractGBVector, i)
@wraperror LibGraphBLAS.GrB_Matrix_removeElement(gbpointer(v), decrement!(i), 1)
@wraperror LibGraphBLAS.GrB_Matrix_removeElement(gbpointer(v), decrement!(i), 0)
return v
end

Expand Down Expand Up @@ -520,7 +520,7 @@ for T ∈ valid_vec
I isa Vector || (I = collect(I))
X isa Vector || (X = collect(X))
length(X) == length(I) || DimensionMismatch("I and X must have the same length")
combine = BinaryOp(combine)($T)
combine = binaryop(combine, $T)
decrement!(I)
@wraperror LibGraphBLAS.$func(
Ptr{LibGraphBLAS.GrB_Vector}(gbpointer(v)),
Expand Down Expand Up @@ -606,7 +606,7 @@ function build(v::AbstractGBVector{T}, I::Vector{<:Integer}, X::Vector{T}; combi
I isa Vector || (I = collect(I))
X isa Vector || (X = collect(X))
length(X) == length(I) || DimensionMismatch("I and X must have the same length")
combine = BinaryOp(combine)(T)
combine = binaryop(combine, T)
decrement!(I)
@wraperror LibGraphBLAS.GrB_Matrix_build_UDT(
Ptr{LibGraphBLAS.GrB_Vector}(gbpointer(v)),
Expand Down
3 changes: 0 additions & 3 deletions src/abstracts.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
abstract type AbstractGBType end
abstract type AbstractDescriptor end
abstract type AbstractOp end
abstract type AbstractUnaryOp <: AbstractOp end
abstract type AbstractBinaryOp <: AbstractOp end
abstract type AbstractSelectOp <: AbstractOp end
abstract type AbstractMonoid <: AbstractOp end
abstract type AbstractSemiring <: AbstractOp end
abstract type AbstractTypedOp{Z} end

abstract type AbstractGBArray{T, N, F} <: AbstractSparseArray{T, UInt64, N} end
Expand Down
2 changes: 1 addition & 1 deletion src/chainrules/chainruleutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ using ChainRulesCore
const RealOrComplex = Union{Real, Complex}

# LinearAlgebra.norm doesn't like the nothings.
LinearAlgebra.norm(A::GBArray, p::Real=2) = norm(nonzeros(A), p)
LinearAlgebra.norm(A::GBVecOrMat, p::Real=2) = norm(nonzeros(A), p)
20 changes: 10 additions & 10 deletions src/chainrules/ewiserules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
function frule(
(_, ΔA, ΔB, _)::Tuple,
::typeof(emul),
A::GBArray,
B::GBArray,
A::AbstractGBArray,
B::AbstractGBArray,
::typeof(*)
)
Ω = emul(A, B, *)
∂Ω = emul(unthunk(ΔA), B, *) + emul(unthunk(ΔB), A, *)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB)::Tuple, ::typeof(emul), A::GBArray, B::GBArray)
function frule((_, ΔA, ΔB)::Tuple, ::typeof(emul), A::AbstractGBArray, B::AbstractGBArray)
return frule((nothing, ΔA, ΔB, nothing), emul, A, B, *)
end

function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(*))
function rrule(::typeof(emul), A::AbstractGBArray, B::AbstractGBArray, ::typeof(*))
function timespullback(ΔΩ)
∂A = emul(unthunk(ΔΩ), B)
∂B = emul(unthunk(ΔΩ), A)
Expand All @@ -23,7 +23,7 @@ function rrule(::typeof(emul), A::GBArray, B::GBArray, ::typeof(*))
return emul(A, B, *), timespullback
end

function rrule(::typeof(emul), A::GBArray, B::GBArray)
function rrule(::typeof(emul), A::AbstractGBArray, B::AbstractGBArray)
Ω, fullpb = rrule(emul, A, B, *)
emulpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, emulpb
Expand All @@ -39,19 +39,19 @@ end
function frule(
(_, ΔA, ΔB, _)::Tuple,
::typeof(eadd),
A::GBArray,
B::GBArray,
A::AbstractGBArray,
B::AbstractGBArray,
::typeof(+)
)
Ω = eadd(A, B, +)
∂Ω = eadd(unthunk(ΔA), unthunk(ΔB), +)
return Ω, ∂Ω
end
function frule((_, ΔA, ΔB)::Tuple, ::typeof(eadd), A::GBArray, B::GBArray)
function frule((_, ΔA, ΔB)::Tuple, ::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray)
return frule((nothing, ΔA, ΔB, nothing), eadd, A, B, +)
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(+))
function rrule(::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray, ::typeof(+))
function pluspullback(ΔΩ)
return (
NoTangent(),
Expand All @@ -63,7 +63,7 @@ function rrule(::typeof(eadd), A::GBArray, B::GBArray, ::typeof(+))
return eadd(A, B, +), pluspullback
end

function rrule(::typeof(eadd), A::GBArray, B::GBArray)
function rrule(::typeof(eadd), A::AbstractGBArray, B::AbstractGBArray)
Ω, fullpb = rrule(eadd, A, B, +)
eaddpb(ΔΩ) = fullpb(ΔΩ)[1:3]
return Ω, eaddpb
Expand Down
8 changes: 4 additions & 4 deletions src/chainrules/maprules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ macro scalarapplyrule(func, derivative)
(_, _, $(esc(:ΔA)))::Tuple,
::typeof(apply),
::typeof($(func)),
$(esc(:A))::GBArray
$(esc(:A))::AbstractGBArray
)
$(esc()) = apply($(esc(func)), $(esc(:A)))
return $(esc()), $(esc(derivative)) .* unthunk($(esc(:ΔA)))
end
function ChainRulesCore.rrule(
::typeof(apply),
::typeof($(func)),
$(esc(:A))::GBArray
$(esc(:A))::AbstractGBArray
)
$(esc()) = apply($(esc(func)), $(esc(:A)))
function applyback($(esc(:ΔA)))
Expand Down Expand Up @@ -75,10 +75,10 @@ function frule(
(_, _, ΔA)::Tuple,
::typeof(apply),
::typeof(identity),
A::GBArray
A::AbstractGBArray
)
return (A, ΔA)
end
function rrule(::typeof(apply), ::typeof(identity), A::GBArray)
function rrule(::typeof(apply), ::typeof(identity), A::AbstractGBArray)
return A, (ΔΩ) -> (NoTangent(), NoTangent(), ΔΩ)
end
Loading

0 comments on commit ba33380

Please sign in to comment.