-
Notifications
You must be signed in to change notification settings - Fork 1
Define a generic dense
function
#72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/abstractsparsearrayinterface.jl b/src/abstractsparsearrayinterface.jl
index f36dfba..586f00a 100644
--- a/src/abstractsparsearrayinterface.jl
+++ b/src/abstractsparsearrayinterface.jl
@@ -41,15 +41,15 @@ unstoredsimilar(a::AbstractArray) = a
# about the array (such as which device it is on).
using TypeParameterAccessors: unspecify_type_parameters, unwrap_array, unwrap_array_type
function densetype(arraytype::Type{<:AbstractArray})
- return unspecify_type_parameters(unwrap_array_type(arraytype))
+ return unspecify_type_parameters(unwrap_array_type(arraytype))
end
# TODO: Ideally this would be defined as `densetype(typeof(a))` but that is less general right now since `unwrap_array_type` is defined on fewer arrays, since it is based on `parentype` rather than `parent`.
function densetype(a::AbstractArray)
- return unspecify_type_parameters(typeof(unwrap_array(a)))
+ return unspecify_type_parameters(typeof(unwrap_array(a)))
end
using GPUArraysCore: @allowscalar
function dense(a::AbstractArray)
- return @allowscalar convert(densetype(a), a)
+ return @allowscalar convert(densetype(a), a)
end
# Minimal interface for `SparseArrayInterface`.
diff --git a/test/test_dense.jl b/test/test_dense.jl
index 6211cb5..3acc147 100644
--- a/test/test_dense.jl
+++ b/test/test_dense.jl
@@ -6,36 +6,36 @@ using Test: @test, @testset
elts = (Float32, ComplexF64)
arrayts = (Array, JLArray)
@testset "dense (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts, elt in elts
- dev(x) = adapt(arrayt, x)
+ dev(x) = adapt(arrayt, x)
- @testset "SparseArrayDOK" begin
- s = sparsezeros(elt, 3, 4)
- s[1, 2] = 2
- d = dense(s)
- @test d isa Matrix{elt}
- @test d == [0 2 0 0; 0 0 0 0; 0 0 0 0]
- end
-
- @testset "Custom sparse array" begin
- struct MySparseArrayDOK{T,N,S<:AbstractVector{T}} <: AbstractArray{T,N}
- storedvalues::S
- storedindices::Dict{CartesianIndex{N},Int}
- size::NTuple{N,Int}
- end
- Base.size(a::MySparseArrayDOK) = a.size
- function Base.getindex(a::MySparseArrayDOK{<:Any,N}, I::Vararg{Int,N}) where {N}
- storageindex = get(a.storedindices, CartesianIndex(I), nothing)
- isnothing(storageindex) && return zero(eltype(a))
- return a.storedvalues[storageindex]
+ @testset "SparseArrayDOK" begin
+ s = sparsezeros(elt, 3, 4)
+ s[1, 2] = 2
+ d = dense(s)
+ @test d isa Matrix{elt}
+ @test d == [0 2 0 0; 0 0 0 0; 0 0 0 0]
end
- Base.parent(a::MySparseArrayDOK) = a.storedvalues
- s = MySparseArrayDOK(
- dev(elt[2, 4]), Dict([CartesianIndex(1, 2) => 1, CartesianIndex(3, 4) => 2]), (3, 4)
- )
- d = dense(s)
- @show typeof(d)
- @test d isa arrayt{elt,2}
- @test d == dev(elt[0 2 0 0; 0 0 0 0; 0 0 0 4])
- end
+ @testset "Custom sparse array" begin
+ struct MySparseArrayDOK{T, N, S <: AbstractVector{T}} <: AbstractArray{T, N}
+ storedvalues::S
+ storedindices::Dict{CartesianIndex{N}, Int}
+ size::NTuple{N, Int}
+ end
+ Base.size(a::MySparseArrayDOK) = a.size
+ function Base.getindex(a::MySparseArrayDOK{<:Any, N}, I::Vararg{Int, N}) where {N}
+ storageindex = get(a.storedindices, CartesianIndex(I), nothing)
+ isnothing(storageindex) && return zero(eltype(a))
+ return a.storedvalues[storageindex]
+ end
+ Base.parent(a::MySparseArrayDOK) = a.storedvalues
+
+ s = MySparseArrayDOK(
+ dev(elt[2, 4]), Dict([CartesianIndex(1, 2) => 1, CartesianIndex(3, 4) => 2]), (3, 4)
+ )
+ d = dense(s)
+ @show typeof(d)
+ @test d isa arrayt{elt, 2}
+ @test d == dev(elt[0 2 0 0; 0 0 0 0; 0 0 0 4])
+ end
end |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #72 +/- ##
==========================================
+ Coverage 75.86% 75.93% +0.06%
==========================================
Files 9 9
Lines 692 698 +6
==========================================
+ Hits 525 530 +5
- Misses 167 168 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
dense
function
No description provided.