Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for finite fields #12

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/Problem/SpkProblem.jl
Expand Up @@ -139,11 +139,11 @@ mutable struct Problem{IT<:BlasInt, FT}
end

"""
Problem(nrows::IT, ncols::IT, nnz::IT=2500, z::FT=0.0, info = "") where {IT, FT}
Problem(nrows::IT, ncols::IT, nnz::IT=2500, z::FT=zero(FT), info = "") where {IT, FT}

Construct a problem.
"""
function Problem(nrows::IT, ncols::IT, nnz::IT=2500, z::FT=0.0, info = "") where {IT<:BlasInt, FT}
function _Problem(nrows::IT, ncols::IT, nnz::IT=2500, z::FT=zero(FT), info = "") where {IT<:BlasInt, FT}
lenlink = nnz
lenhead = ncols
lenrhs = nrows
Expand All @@ -158,7 +158,7 @@ function Problem(nrows::IT, ncols::IT, nnz::IT=2500, z::FT=0.0, info = "") where
x = fill(zero(FT), lenhead)
link = fill(zero(IT), lenlink)
rowsubs = fill(zero(IT), lenlink)
values = fill(_BIGGY(), lenlink)
values = fill(_BIGGY(FT), lenlink)
rscales = FT[]
cscales = FT[]

Expand All @@ -167,6 +167,10 @@ function Problem(nrows::IT, ncols::IT, nnz::IT=2500, z::FT=0.0, info = "") where
rscales, cscales, x, rhs)
end

Problem(nrows,ncols) = _Problem(nrows,ncols,2500,zero(Float64))
Problem(nrows,ncols,nnz) = _Problem(nrows,ncols,nnz,zero(Float64))
Problem(nrows,ncols,nnz,z) = _Problem(nrows,ncols,nnz,z)

"""
inaij!(p::Problem{IT,FT}, rnum, cnum, aij=zero(FT)) where {IT,FT}

Expand All @@ -184,7 +188,7 @@ function inaij!(p::Problem{IT,FT}, rnum, cnum, aij=zero(FT)) where {IT<:BlasInt,
p.lenlink = max(2 * p.lenlink, 3 * p.ncols)
p.link = __extend(p.link, p.lenlink)
p.rowsubs = __extend(p.rowsubs, p.lenlink)
p.values = __extend(p.values, p.lenlink, _BIGGY())
p.values = __extend(p.values, p.lenlink, _BIGGY(FT))
end

p.nrows = max(rnum, p.nrows)
Expand All @@ -208,7 +212,7 @@ function inaij!(p::Problem{IT,FT}, rnum, cnum, aij=zero(FT)) where {IT<:BlasInt,
break
end
if (p.rowsubs[ptr] == rnum)
if (p.values[ptr] == _BIGGY())
if (p.values[ptr] == _BIGGY(FT))
p.nnz = p.nnz + 1
p.values[ptr] = aij
if (rnum == cnum)
Expand Down Expand Up @@ -411,7 +415,7 @@ function makerhs!(p::Problem, x::Vector{FT}, mtype = "T") where {FT}
p.x .= FT.(1:p.ncols)
end

p.rhs .= 0.0
p.rhs .= zero(FT)
res = deepcopy(p.rhs)

computeresidual(p, res, p.x, mtype)
Expand Down Expand Up @@ -480,7 +484,7 @@ function computeresidual(p::Problem, res::Vector{FT}, xin::Vector{FT} = FT[], mt
ptr = p.head[cnum]; t = x[cnum]
while (ptr > 0)
rnum = p.rowsubs[ptr]; temp = p.values[ptr]
if (temp != _BIGGY())
if (temp != _BIGGY(FT))
r[rnum] -= t * temp
if (rnum != cnum && flag == 1)
u = x[rnum]
Expand Down
10 changes: 5 additions & 5 deletions src/SparseMethod/SpkLUFactor.jl
Expand Up @@ -65,7 +65,7 @@ function _lufactor!(n::IT, nsuper::IT, xsuper::Vector{IT}, snode::Vector{IT}, xl
@assert length(xunz) == (n + 1)
@assert length(ipvt) == n

ONE = FT(1.0)
ONE = one(FT)

link = fill(zero(IT), nsuper)
lngth = fill(zero(IT), nsuper)
Expand Down Expand Up @@ -278,8 +278,8 @@ function _lulsolve!(nsuper::IT, xsuper::Vector{IT}, xlindx::Vector{IT}, lindx::V
# - - - - - - - - - -
# constants.
# - - - - - - - - - -
ONE = FT(1.0)
ZERO = FT(0.0)
ONE = one(FT)
ZERO = zero(FT)

if (nsuper <= 0) return false; end

Expand Down Expand Up @@ -331,8 +331,8 @@ function _luusolve!(n::IT, nsuper::IT, xsuper::Vector{IT}, xlindx::Vector{IT}, l
# integer :: length, maxlength
# real(double), dimension(:), allocatable :: temp

ONE = FT(1.0)
ZERO = FT(0.0)
ONE = one(FT)
ZERO = zero(FT)

if (nsuper <= 0) return; end

Expand Down
35 changes: 30 additions & 5 deletions src/Utilities/GenericBlasLapackFragments.jl
Expand Up @@ -42,27 +42,45 @@ Base.getindex(A::StridedReshape,i)= @inbounds A.v[i]
Base.setindex!(A::StridedReshape,v,i)= @inbounds A.v[i]=v


@static if VERSION< v"1.7"
abstract type PivotingStrategy end
struct RowNonZero <: PivotingStrategy end
struct RowMaximum <: PivotingStrategy end
struct NoPivot <: PivotingStrategy end
lupivottype(::Type{T}) where {T} = RowMaximum()
end


@static if VERSION >= v"1.7" && VERSION< v"1.9"
struct RowNonZero <: LinearAlgebra.PivotingStrategy end
lupivottype(::Type{T}) where {T} = RowMaximum()
end


@static if VERSION >= v"1.9"
import LinearAlgebra: lupivottype
end

#
# LU factorization adapted from generic_lufact! (https://github.com/JuliaLang/LinearAlgebra.jl/blob/main/src/lu.jl).
# LU factorization adapted from generic_lufact! (https://github.com/JuliaLang/LinearAlgebra.jl/blob/main/src/lu.jl)
# with support of RowNonZero pivoting for finite fields etc.
# Originally it is (like many other LA operators) defined for StridedMatrix which is a union and not an abstract type,
# so we cannot use that code directly. See https://github.com/JuliaLang/julia/issues/2345 for some discussion about this.
#
# Modifications:
# - Use ipiv passed instead of creating one
# - No need to return LU object
# - Remove unused parameters - always do pivoting anyway
#
#
function ggetrf!(m,n,a::AbstractVector{FT},lda,ipiv) where FT
function ggetrf!(m,n,a::AbstractVector{FT},lda,ipiv; pivot::Union{NoPivot,RowMaximum,RowNonZero}=lupivottype(FT)) where FT

minmn = min(m,n)
A=StridedReshape(a,lda)

begin
for k = 1:minmn
# find index max
kp = k
if k < m # pivot === RowMaximum() &&
if pivot === RowMaximum() && k < m
amax = abs(A[k,k])
for i = k+1:m
absi = abs(A[i,k])
Expand All @@ -71,6 +89,13 @@ function ggetrf!(m,n,a::AbstractVector{FT},lda,ipiv) where FT
amax = absi
end
end
elseif pivot === RowNonZero()
for i = k:m
if !iszero(A[i,k])
kp = i
break
end
end
end
ipiv[k] = kp
if !iszero(A[kp,k])
Expand Down
2 changes: 1 addition & 1 deletion src/Utilities/SpkUtilities.jl
Expand Up @@ -4,7 +4,7 @@ See the comments below in the interface section of the module.
"""
module SpkUtilities

_BIGGY() = typemax(Float64)
_BIGGY(T) = typemax(T)


# __extend(v::Vector, newlen::Integer, flagval=zero(eltype(v)))
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Expand Up @@ -2,6 +2,7 @@
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
DataDrop = "aa547a04-dd37-49ab-8e73-656744f8a8fc"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GaloisFields = "8d0d7f98-d412-5cd4-8397-071c807280aa"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5"
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Expand Up @@ -36,3 +36,7 @@ end
include("test_generic.jl")
end

@time @testset "Galois" begin
include("test_galois.jl")
end

48 changes: 48 additions & 0 deletions test/test_galois.jl
@@ -0,0 +1,48 @@
module GaloisTest
using Test
using LinearAlgebra
using SparseArrays
using Sparspak
using Sparspak.SpkProblem: insparse!, outsparse
using Sparspak.SpkSparseSolver: SparseSolver, solve!
using GaloisFields

@static if VERSION < v"1.9"
Sparspak.GenericBlasLapackFragments.lupivottype(::Type{T}) where T<:GaloisFields.AbstractGaloisField= Sparspak.GenericBlasLapackFragments.RowNonZero()
else
LinearAlgebra.lupivottype(::Type{T}) where T<:GaloisFields.AbstractGaloisField= LinearAlgebra.RowNonZero()
end

Sparspak.SpkUtilities._BIGGY(::Type{T}) where T<:GaloisFields.AbstractGaloisField=zero(T)
Base.abs(x::T) where T<:GaloisFields.AbstractGaloisField=x
Base.isless(x::T,y::T) where T<:GaloisFields.AbstractGaloisField=x.n<y.n



function _test(T,n)
# need some scalable (random ?) invertible test matrices here
spm0 = sprand(Int8, n, n, 1/n)
spm0 = -spm0 - spm0' + 40 * LinearAlgebra.I
spm=SparseMatrixCSC(n,n,spm0.colptr,spm0.rowval,T.(spm0.nzval))
@show typeof(spm)
x=rand(T,n)
b=spm*x
p = Sparspak.SpkProblem.Problem(n, n, nnz(spm), zero(T))
Sparspak.SpkProblem.insparse!(p, spm);
Sparspak.SpkProblem.infullrhs!(p, b);
s = SparseSolver(p)
solve!(s)
@test x == p.x
end


const F1013=@GaloisField 1013

_test(F1013, 100)

const F127=@GaloisField 127
_test(F127, 10)


end

4 changes: 2 additions & 2 deletions test/test_problem.jl
Expand Up @@ -44,10 +44,10 @@ function _test()
@test norm(a - a1) / norm(a) < 1.0e-9
a = __extend(a, 2, 2)
a1 = [0.3814043930778628 0.07295808459382358 Inf Inf Inf; 0.5435778459423668 0.018608588657332392 Inf Inf Inf; Inf Inf Inf Inf Inf]
a = __extend(a, 3, 5, SpkUtilities._BIGGY())
a = __extend(a, 3, 5, SpkUtilities._BIGGY(Float64))
@test a == a1
return true
end

_test()
end # module
end # module
4 changes: 2 additions & 2 deletions test/test_utilities.jl
Expand Up @@ -53,7 +53,7 @@ function _test()
@test norm(a - a1) / norm(a) < 1.0e-9
a = __extend(a, 2, 2)
a1 = [0.3814043930778628 0.07295808459382358 Inf Inf Inf; 0.5435778459423668 0.018608588657332392 Inf Inf Inf; Inf Inf Inf Inf Inf]
a = __extend(a, 3, 5, SpkUtilities._BIGGY())
a = __extend(a, 3, 5, SpkUtilities._BIGGY(Float64))
@test a == a1
return true
end
Expand Down Expand Up @@ -107,4 +107,4 @@ end # module
# end

# _test()
# end # module
# end # module