Skip to content

Commit

Permalink
add kron tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JaredCrean2 committed Jul 9, 2016
1 parent 292efa9 commit 51b0186
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 17 deletions.
30 changes: 13 additions & 17 deletions src/mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1211,16 +1211,13 @@ end
Users *must* call restore when done with a MatRow, before attempting to
create another one.
"""
type MatRow{T, mtype}
mat::Mat{T, mtype} # the matrix to which the rows belong
immutable MatRow{T}
mat::C.Mat{T} # the matrix to which the rows belong
row::Int
ref_ncols::Ref{PetscInt} # reference to the number of columns
ref_cols::Ref{Ptr{PetscInt}} # reference to the column indices
ref_vals::Ref{Ptr{T}} # reference to the values at the column indices
ncols::Int
cols::Array{PetscInt, 1}
vals::Array{T, 1}

cols_ptr::Ptr{PetscInt}
vals_ptr::Ptr{T}
#=
function MatRow(A::Mat{T}, row::Integer, ref_ncols::Ref{PetscInt}, ref_cols::Ref{Ptr{PetscInt}}, ref_vals::Ref{Ptr{T}})
ncols = ref_ncols[]
cols = pointer_to_array(ref_cols[], ncols)
Expand All @@ -1231,6 +1228,7 @@ type MatRow{T, mtype}
return obj
end
=#
end

"""
Expand All @@ -1242,7 +1240,7 @@ function MatRow{T, mtype}(A::Mat{T, mtype}, row::Integer)
ref_cols = Ref{Ptr{PetscInt}}()
ref_vals = Ref{Ptr{T}}()
chk(C.MatGetRow(A.p, row-1, ref_ncols, ref_cols, ref_vals))
return MatRow{T, mtype}(A, row, ref_ncols, ref_cols, ref_vals)
return MatRow{T}(A.p, row, ref_ncols[], ref_cols[], ref_vals[])
end

"""
Expand All @@ -1260,19 +1258,19 @@ function count_row_nz{T}(A::Mat{T}, row::Integer)
return ncols
end
function restore{T}(row::MatRow{T})
if !PetscFinalized(T) && !isfinalized(row.mat)
chk(C.MatRestoreRow(row.mat.p, row.row-1, row.ref_ncols, row.ref_cols, row.ref_vals))
# if !PetscFinalized(T) && !isfinalized(row.mat)
C.MatRestoreRow(row.mat, row.row-1, Ref(PetscInt(row.ncols)), Ref{Ptr{PetscInt}}(C_NULL), Ref{Ptr{T}}(C_NULL))

return nothing # return type stability
end
# return nothing # return type stability
# end
end

### indexing on a MatRow ###
import Base: length, size
length(A::MatRow) = A.ncols
size(A::MatRow) = (A.ncols,)
getcol(A::MatRow, i) = A.cols[i] + 1
getval(A::MatRow, i) = A.vals[i]
getcol(A::MatRow, i) = unsafe_load(A.cols_ptr, i) + 1
getval(A::MatRow, i) = unsafe_load(A.vals_ptr, i)


import Base.kron
Expand All @@ -1287,8 +1285,6 @@ function kron{T}(A::Mat{T, C.MATSEQAIJ}, B::Mat{T, C.MATSEQAIJ})

Am = size(A, 1); An = size(A, 2)
Bm = size(B, 1); Bn = size(B, 2)
println("size(A) = ", Am, ", ", An)
println("size(B) = ", Bm, ", ", Bm)
# step 1: figure out size, sparsity pattern of result
A_nz = zeros(Int, Am)
B_nz = zeros(Int, Bm)
Expand Down
49 changes: 49 additions & 0 deletions test/mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -546,4 +546,53 @@ end
@test PETSc.count_row_nz(mat, 3) == 3
end

@testset "kron" begin
function testrun(Aj, Bj)
# A and B are julia matrices of some kind
A = Mat(Aj); B = Mat(Bj)
assemble(A); assemble(B)
Cj = kron(Aj, Bj)
C = kron(A, B)
assemble(C)
m, n = size(C)
C2 = C[1:m, 1:n]
@test C2 Cj
end

# case 1
Aj = ST[1. 0 0; 0 1 0; 0 0 1]
testrun(Aj, Aj)

# case 2
Aj = ST[1. 2 0; 3 4 5; 0 6 7]
testrun(Aj, Aj)


n = 10
Aj = rand(ST, n, n)
testrun(Aj, Aj)


m = 3
n = 2
Aj = rand(ST, m, n)
testrun(Aj, Aj)


m1 = 3
n1 = 2
m2 = 4
n2 = 5
Aj = rand(ST, m1, n1)
Bj = rand(ST, m2, n2)
testrun(Aj, Bj)

m1 = 5
m2 = 7
n1 = 8
n2 = 10
Aj = sprand(m1, n1, 0.01)
Bj = sprand(m2, n2, 0.01)
testrun(Aj, Bj)
end
end

0 comments on commit 51b0186

Please sign in to comment.