Skip to content

Commit

Permalink
Fix boolean type infer, fix infer type stability
Browse files Browse the repository at this point in the history
  • Loading branch information
Wimmerer committed Jul 6, 2021
1 parent 8e45498 commit e19133f
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 42 deletions.
7 changes: 4 additions & 3 deletions src/SuiteSparseGraphBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ include("descriptors.jl")
include("indexutils.jl")


const GBVecOrMat = Union{GBVector, GBMatrix}
const GBMatOrTranspose = Union{GBMatrix, Transpose{<:Any, <:GBMatrix}}
const GBArray = Union{GBVector, GBMatOrTranspose}
const GBVecOrMat{T} = Union{GBVector{T}, GBMatrix{T}}
const GBMatOrTranspose{T} = Union{GBMatrix{T}, Transpose{<:Any, GBMatrix{T}}}
const GBArray{T} = Union{GBVector{T}, GBMatOrTranspose{T}}
const ptrtogbtype = Dict{Ptr, AbstractGBType}()

const GrBOp = Union{
Expand Down Expand Up @@ -125,4 +125,5 @@ function __init__()
end
end

include("operators/ztypes.jl")
end #end of module
24 changes: 4 additions & 20 deletions src/operations/ewise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,7 @@ function emul(
desc = nothing
)
op = _handlectx(op, ctxop, BinaryOps.TIMES)
if op isa GrBOp
t = ztype(op)
else
t = optype(u, v)
end
t = inferoutputtype(u, v, op)
w = GBVector{t}(size(u))
return emul!(w, u, v, op; mask , accum, desc)
end
Expand Down Expand Up @@ -136,11 +132,7 @@ function emul(
desc = nothing
)
op = _handlectx(op, ctxop, BinaryOps.TIMES)
if op isa GrBOp
t = ztype(op)
else
t = optype(A, B)
end
t = inferoutputtype(A, B, op)
C = GBMatrix{t}(size(A))
return emul!(C, A, B, op; mask, accum, desc)
end
Expand Down Expand Up @@ -235,11 +227,7 @@ function eadd(
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
if op isa GrBOp
t = ztype(op)
else
t = optype(eltype(u), eltype(v))
end
t = inferoutputtype(u, v, op)
w = GBVector{t}(size(u))
return eadd!(w, u, v, op; mask, accum, desc)
end
Expand Down Expand Up @@ -282,11 +270,7 @@ function eadd(
desc = nothing
)
op, mask, accum, desc = _handlectx(op, mask, accum, desc, BinaryOps.PLUS)
if op isa GrBOp
t = ztype(op)
else
t = optype(A, B)
end
t = inferoutputtype(A, B, op)
C = GBMatrix{t}(size(A))
return eadd!(C, A, B, op; mask, accum, desc)
end
Expand Down
6 changes: 1 addition & 5 deletions src/operations/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,7 @@ function LinearAlgebra.kron(
desc = nothing
)
op = _handlectx(op, ctxop, BinaryOps.TIMES)
if op isa GrBOp
t = ztype(op)
else
t = optype(A, B)
end
t = inferoutputtype(A, B, op)
C = GBMatrix{t}(size(A,1) * size(B, 1), size(A, 2) * size(B, 2))
kron!(C, A, B, op; mask, accum, desc)
return C
Expand Down
9 changes: 6 additions & 3 deletions src/operations/map.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ function Base.map(
op::UnaryUnion, A::GBArray;
mask = nothing, accum = nothing, desc = nothing
)
return map!(op, similar(A), A; mask, accum, desc)
t = inferoutputtype(A, op)
return map!(op, similar(A, t), A; mask, accum, desc)
end

function Base.map!(
Expand Down Expand Up @@ -54,7 +55,8 @@ function Base.map(
op::BinaryUnion, x, A::GBArray;
mask = nothing, accum = nothing, desc = nothing
)
return map!(op, similar(A), x, A; mask, accum, desc)
t = inferoutputtype(A, op)
return map!(op, similar(A, t), x, A; mask, accum, desc)
end

function Base.map!(
Expand Down Expand Up @@ -83,7 +85,8 @@ function Base.map(
op::BinaryUnion, A::GBArray, x;
mask = nothing, accum = nothing, desc = nothing
)
return map!(op, similar(A), A, x; mask, accum, desc)
t = inferoutputtype(A, op)
return map!(op, similar(A, t), A, x; mask, accum, desc)
end

function Base.broadcasted(::typeof(+), u::GBArray, x::valid_union;
Expand Down
6 changes: 1 addition & 5 deletions src/operations/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,7 @@ function mul(
desc = nothing
)
op = _handlectx(op, ctxop, Semirings.PLUS_TIMES)
if op isa libgb.GrB_Semiring
t = ztype(op)
else
t = optype(A, B)
end
t = inferoutputtype(A, B, op)
if A isa GBVector && B isa GBMatOrTranspose
C = GBVector{t}(size(B, 2))
elseif A isa GBMatOrTranspose && B isa GBVector
Expand Down
15 changes: 14 additions & 1 deletion src/operations/operationutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,21 @@ function optype(atype, btype)
end
end

optype(A::GBArray, B::GBArray) = optype(eltype(A), eltype(B))
optype(::GBArray{T}, ::GBArray{U}) where {T, U} = optype(T, U)

function inferoutputtype(A::GBArray{T}, B::GBArray{U}, op::AbstractOp) where {T, U}
t = optype(A, B)
return ztype(op, t)
end
function inferoutputtype(::GBArray{T}, op::AbstractOp) where {T}
return ztype(op, T)
end
function inferoutputtype(::GBArray{T}, op) where {T}
return ztype(op)
end
function inferoutputtype(::GBArray{T}, ::GBArray{U}, op) where {T, U}
return ztype(op)
end
function _handlectx(ctx, ctxvar, default = nothing)
if ctx === nothing || ctx === missing
ctx2 = get(ctxvar)
Expand Down
6 changes: 3 additions & 3 deletions src/operators/binaryops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ end

Base.show(io::IO, ::MIME"text/plain", u::libgb.GrB_BinaryOp) = gxbprint(io, u)

xtype(op::BinaryUnion) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_xtype(op)])
ytype(op::BinaryUnion) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ytype(op)])
ztype(op::BinaryUnion) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ztype(op)])
xtype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_xtype(op)])
ytype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ytype(op)])
ztype(op::libgb.GrB_BinaryOp) = tojuliatype(ptrtogbtype[libgb.GxB_BinaryOp_ztype(op)])

"""
First argument: `f(x::T,y::T)::T = x`
Expand Down
2 changes: 0 additions & 2 deletions src/operators/operatorutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ function getoperator(op, t)

if op isa AbstractOp
return op[t]
elseif op isa GrBOp
return op
else
return op
end
Expand Down
109 changes: 109 additions & 0 deletions src/operators/ztypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
ztype(::AbstractOp, intype::DataType) = intype

#UnaryOps:
ztype(::Types.ISINF_T, ::DataType) = Bool
ztype(::Types.ISNAN_T, ::DataType) = Bool
ztype(::Types.ISFINITE_T, ::DataType) = Bool

ztype(::Types.CONJ_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
ztype(::Types.ABS_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
ztype(::Types.CREAL_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
ztype(::Types.CIMAG_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]
ztype(::Types.CARG_T, intype::Type{T}) where {T <: Complex} = intype.parameters[1]

ztype(::Types.POSITIONI_T, ::DataType) = Int64
ztype(::Types.POSITIONI1_T, ::DataType) = Int64
ztype(::Types.POSITIONJ_T, ::DataType) = Int64
ztype(::Types.POSITIONJ1_T, ::DataType) = Int64

#BinaryOps:
ztype(::Types.EQ_T, ::DataType) = Bool
ztype(::Types.NE_T, ::DataType) = Bool
ztype(::Types.GT_T, ::DataType) = Bool
ztype(::Types.LT_T, ::DataType) = Bool
ztype(::Types.GE_T, ::DataType) = Bool
ztype(::Types.LE_T, ::DataType) = Bool
ztype(::Types.CMPLX_T, intype::Type{T}) where {T <: AbstractFloat} = Complex{T}

ztype(::Types.FIRSTI_T, ::DataType) = Int64
ztype(::Types.FIRSTI1_T, ::DataType) = Int64
ztype(::Types.FIRSTJ_T, ::DataType) = Int64
ztype(::Types.FIRSTJ1_T, ::DataType) = Int64
ztype(::Types.SECONDI_T, ::DataType) = Int64
ztype(::Types.SECONDI1_T, ::DataType) = Int64
ztype(::Types.SECONDJ_T, ::DataType) = Int64
ztype(::Types.SECONDJ1_T, ::DataType) = Int64

#Semirings:
ztype(::Types.LAND_EQ_T, ::DataType) = Bool
ztype(::Types.LOR_EQ_T, ::DataType) = Bool
ztype(::Types.LXOR_EQ_T, ::DataType) = Bool
ztype(::Types.EQ_EQ_T, ::DataType) = Bool
ztype(::Types.ANY_EQ_T, ::DataType) = Bool
ztype(::Types.LAND_NE_T, ::DataType) = Bool
ztype(::Types.LOR_NE_T, ::DataType) = Bool
ztype(::Types.LXOR_NE_T, ::DataType) = Bool
ztype(::Types.EQ_NE_T, ::DataType) = Bool
ztype(::Types.ANY_NE_T, ::DataType) = Bool
ztype(::Types.LAND_GT_T, ::DataType) = Bool
ztype(::Types.LOR_GT_T, ::DataType) = Bool
ztype(::Types.LXOR_GT_T, ::DataType) = Bool
ztype(::Types.EQ_GT_T, ::DataType) = Bool
ztype(::Types.ANY_GT_T, ::DataType) = Bool
ztype(::Types.LAND_LT_T, ::DataType) = Bool
ztype(::Types.LOR_LT_T, ::DataType) = Bool
ztype(::Types.LXOR_LT_T, ::DataType) = Bool
ztype(::Types.EQ_LT_T, ::DataType) = Bool
ztype(::Types.ANY_LT_T, ::DataType) = Bool
ztype(::Types.LAND_GE_T, ::DataType) = Bool
ztype(::Types.LOR_GE_T, ::DataType) = Bool
ztype(::Types.LXOR_GE_T, ::DataType) = Bool
ztype(::Types.EQ_GE_T, ::DataType) = Bool
ztype(::Types.ANY_GE_T, ::DataType) = Bool
ztype(::Types.LAND_LE_T, ::DataType) = Bool
ztype(::Types.LOR_LE_T, ::DataType) = Bool
ztype(::Types.LXOR_LE_T, ::DataType) = Bool
ztype(::Types.EQ_LE_T, ::DataType) = Bool
ztype(::Types.ANY_LE_T, ::DataType) = Bool


ztype(::Types.MIN_FIRSTI_T, ::DataType) = Int64
ztype(::Types.MAX_FIRSTI_T, ::DataType) = Int64
ztype(::Types.PLUS_FIRSTI_T, ::DataType) = Int64
ztype(::Types.TIMES_FIRSTI_T, ::DataType) = Int64
ztype(::Types.ANY_FIRSTI_T, ::DataType) = Int64
ztype(::Types.MIN_FIRSTI1_T, ::DataType) = Int64
ztype(::Types.MAX_FIRSTI1_T, ::DataType) = Int64
ztype(::Types.PLUS_FIRSTI1_T, ::DataType) = Int64
ztype(::Types.TIMES_FIRSTI1_T, ::DataType) = Int64
ztype(::Types.ANY_FIRSTI1_T, ::DataType) = Int64
ztype(::Types.MIN_FIRSTJ_T, ::DataType) = Int64
ztype(::Types.MAX_FIRSTJ_T, ::DataType) = Int64
ztype(::Types.PLUS_FIRSTJ_T, ::DataType) = Int64
ztype(::Types.TIMES_FIRSTJ_T, ::DataType) = Int64
ztype(::Types.ANY_FIRSTJ_T, ::DataType) = Int64
ztype(::Types.MIN_FIRSTJ1_T, ::DataType) = Int64
ztype(::Types.MAX_FIRSTJ1_T, ::DataType) = Int64
ztype(::Types.PLUS_FIRSTJ1_T, ::DataType) = Int64
ztype(::Types.TIMES_FIRSTJ1_T, ::DataType) = Int64
ztype(::Types.ANY_FIRSTJ1_T, ::DataType) = Int64
ztype(::Types.MIN_SECONDI_T, ::DataType) = Int64
ztype(::Types.MAX_SECONDI_T, ::DataType) = Int64
ztype(::Types.PLUS_SECONDI_T, ::DataType) = Int64
ztype(::Types.TIMES_SECONDI_T, ::DataType) = Int64
ztype(::Types.ANY_SECONDI_T, ::DataType) = Int64
ztype(::Types.MIN_SECONDI1_T, ::DataType) = Int64
ztype(::Types.MAX_SECONDI1_T, ::DataType) = Int64
ztype(::Types.PLUS_SECONDI1_T, ::DataType) = Int64
ztype(::Types.TIMES_SECONDI1_T, ::DataType) = Int64
ztype(::Types.ANY_SECONDI1_T, ::DataType) = Int64
ztype(::Types.MIN_SECONDJ_T, ::DataType) = Int64
ztype(::Types.MAX_SECONDJ_T, ::DataType) = Int64
ztype(::Types.PLUS_SECONDJ_T, ::DataType) = Int64
ztype(::Types.TIMES_SECONDJ_T, ::DataType) = Int64
ztype(::Types.ANY_SECONDJ_T, ::DataType) = Int64
ztype(::Types.MIN_SECONDJ1_T, ::DataType) = Int64
ztype(::Types.MAX_SECONDJ1_T, ::DataType) = Int64
ztype(::Types.PLUS_SECONDJ1_T, ::DataType) = Int64
ztype(::Types.TIMES_SECONDJ1_T, ::DataType) = Int64
ztype(::Types.ANY_SECONDJ1_T, ::DataType) = Int64
1 change: 1 addition & 0 deletions test/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
@test emul(m, n, BinaryOps.POW)[3, 2] == m[3,2] ^ n[3,2]
#check that the (*) op is being picked up from the semiring
@test emul(m, n, Semirings.MAX_PLUS) == emul(m, n, BinaryOps.PLUS)
@test eltype(m .== n) == Bool
end
@testset "kron" begin
m1 = GBMatrix(UInt64[1, 2, 3, 5], UInt64[1, 3, 1, 2], Int8[1, 2, 3, 5])
Expand Down

0 comments on commit e19133f

Please sign in to comment.