Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SparseArrayDOKs] Add setindex_maybe_grow! and macro @maybe_grow #1434

Merged
merged 10 commits into from
May 14, 2024
6 changes: 4 additions & 2 deletions NDTensors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -34,20 +35,20 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[extensions]
NDTensorsAMDGPUExt = "AMDGPU"
NDTensorsCUDAExt = "CUDA"
NDTensorscuTENSORExt = "cuTENSOR"
NDTensorsHDF5Ext = "HDF5"
NDTensorsMetalExt = "Metal"
NDTensorsOctavianExt = "Octavian"
NDTensorsTBLISExt = "TBLIS"
NDTensorscuTENSORExt = "cuTENSOR"

[compat]
Accessors = "0.1.33"
Expand All @@ -65,6 +66,7 @@ HDF5 = "0.14, 0.15, 0.16, 0.17"
HalfIntegers = "1"
InlineStrings = "1"
LinearAlgebra = "1.6"
MacroTools = "0.5"
mtfishman marked this conversation as resolved.
Show resolved Hide resolved
MappedArrays = "0.4"
PackageExtensionCompat = "1"
Random = "1.6"
Expand Down
44 changes: 43 additions & 1 deletion NDTensors/src/lib/SparseArrayDOKs/src/sparsearraydok.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using Accessors: @set
using Dictionaries: Dictionary, set!
using MacroTools: @capture
using ..SparseArrayInterface:
SparseArrayInterface, AbstractSparseArray, getindex_zero_function

# TODO: Parametrize by `data`?
struct SparseArrayDOK{T,N,Zero} <: AbstractSparseArray{T,N}
mutable struct SparseArrayDOK{T,N,Zero} <: AbstractSparseArray{T,N}
emstoudenmire marked this conversation as resolved.
Show resolved Hide resolved
data::Dictionary{CartesianIndex{N},T}
dims::NTuple{N,Int}
zero::Zero
Expand Down Expand Up @@ -104,3 +105,44 @@ SparseArrayDOK{T}(a::AbstractArray) where {T} = SparseArrayDOK{T,ndims(a)}(a)
function SparseArrayDOK{T,N}(a::AbstractArray) where {T,N}
return SparseArrayInterface.sparse_convert(SparseArrayDOK{T,N}, a)
end

function Base.resize!(a::SparseArrayDOK{<:Any,N}, new_size::NTuple{N,Integer}) where {N}
a.dims = new_size
return a
end

function setindex_maybe_grow!(a::SparseArrayDOK{<:Any,N}, value, I::Vararg{Int,N}) where {N}
if any(I .> size(a))
resize!(a, max.(I, size(a)))
end
a[I...] = value
return a
end

function check_siteindex!_expr(expr, macroname=:macro)
try
@assert expr.head == :(=) && expr.args[1] isa Expr && expr.args[1].head == :(ref)
catch
end
end
emstoudenmire marked this conversation as resolved.
Show resolved Hide resolved

function is_setindex!_expr(expr::Expr)
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
end
is_setindex!_expr(x) = false

is_getindex_expr(expr::Expr) = (expr.head === :ref)
is_getindex_expr(x) = false

is_assignment_expr(expr::Expr) = (expr.head === :(=))
is_assignment_expr(expr) = false

macro maybe_grow(expr)
if !is_setindex!_expr(expr)
error(
"@maybe_grow must be used with setindex! syntax (as @maybe_grow a[i,j,...] = value)"
)
end
@capture(expr, array_[indices__] = value_)
return :(setindex_maybe_grow!($(esc(array)), $value, $indices...))
end
27 changes: 26 additions & 1 deletion NDTensors/src/lib/SparseArrayDOKs/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# Slicing

using Test: @test, @testset, @test_broken
using NDTensors.SparseArrayDOKs: SparseArrayDOK, SparseMatrixDOK
using NDTensors.SparseArrayDOKs:
SparseArrayDOKs, SparseArrayDOK, SparseMatrixDOK, @maybe_grow
using NDTensors.SparseArrayInterface: storage_indices, nstored
using SparseArrays: SparseMatrixCSC, nnz
@testset "SparseArrayDOK (eltype=$elt)" for elt in
Expand Down Expand Up @@ -94,5 +95,29 @@ using SparseArrays: SparseMatrixCSC, nnz
end
end
end
@testset "Maybe Grow Feature" begin
a = SparseArrayDOK{elt,2}((0, 0))
SparseArrayDOKs.setindex_maybe_grow!(a, 230, 2, 3)
@test size(a) == (2, 3)
@test a[2, 3] == 230
# Test @maybe_grow macro
@maybe_grow a[5, 5] = 550
@test size(a) == (5, 5)
@test a[2, 3] == 230
@test a[5, 5] == 550
# Test that size remains same
# if we set at an index smaller than
# the maximum size:
@maybe_grow a[3, 4] = 340
@test size(a) == (5, 5)
@test a[2, 3] == 230
@test a[5, 5] == 550
@test a[3, 4] == 340
# Test vector case
v = SparseArrayDOK{elt,1}((0,))
@maybe_grow v[5] = 50
@test size(v) == (5,)
@test v[5] == 50
end
end
end
Loading