diff --git a/src/Arpack.jl b/src/Arpack.jl index ebc3f3a..674428e 100644 --- a/src/Arpack.jl +++ b/src/Arpack.jl @@ -222,10 +222,10 @@ struct AtA_or_AAt{T,S} <: AbstractArray{T, 2} buffer::Vector{T} end -function AtA_or_AAt(A::AbstractMatrix{T}) where T +function AtA_or_AAt(A) + T = eltype(A) Tnew = typeof(zero(T)/sqrt(one(T))) - Anew = convert(AbstractMatrix{Tnew}, A) - AtA_or_AAt{Tnew,typeof(Anew)}(Anew, Vector{Tnew}(undef, max(size(A)...))) + return AtA_or_AAt{Tnew,typeof(A)}(A, Vector{Tnew}(undef, max(size(A)...))) end function LinearAlgebra.mul!(y::StridedVector{T}, A::AtA_or_AAt{T}, x::StridedVector{T}) where T diff --git a/test/runtests.jl b/test/runtests.jl index 4f43a68..0797101 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -281,3 +281,17 @@ end @test_throws MethodError eigs(big.(rand(1:10, 10, 10)), rand(1:10, 10, 10)) @test_throws MethodError svds(big.(rand(1:10, 10, 8))) end + +struct MyOp{S} + mat::S +end +Base.size(A::MyOp) = size(A.mat) +Base.size(A::MyOp, i::Integer) = size(A.mat, i) +Base.eltype(A::MyOp) = Float64 +Base.:*(A::MyOp, B::AbstractMatrix) = A.mat*B +LinearAlgebra.mul!(y::AbstractVector, A::MyOp, x::AbstractVector) = mul!(y, A.mat, x) +LinearAlgebra.adjoint(A::MyOp) = MyOp(adjoint(A.mat)) +@testset "svds for non-AbstractMatrix" begin + A = MyOp(randn(10, 9)) + @test svds(A, v0 = ones(9))[1].S == svds(A.mat, v0 = ones(9))[1].S +end