# SciRuby/nmatrix forked from mohawkjohn/nmatrix

Solve linear equations with #solve

```Conflicts:
spec/math_spec.rb```
v0dro authored and cjfuller committed Jan 26, 2015
1 parent 58d3a00 commit 4241d241ca7744ca2ca5e090782588581160d42b
Showing with 159 additions and 6 deletions.
1. +59 −0 ext/nmatrix/math.cpp
2. +1 −0 ext/nmatrix/math/math.h
3. +24 −0 ext/nmatrix/ruby_nmatrix.c
4. +46 −6 lib/nmatrix/math.rb
5. +10 −0 lib/nmatrix/nmatrix.rb
6. +19 −0 spec/math_spec.rb
 @@ -230,6 +230,49 @@ namespace nm { } } /* * Solve a system of linear equations using forward-substution followed by * back substution from the LU factorization of the matrix of co-efficients. * Replaces x_elements with the result. Works only with non-integer, non-object * data types. * * args - r -> The number of rows of the matrix. * lu_elements -> Elements of the LU decomposition of the co-efficients * matrix, as a contiguos array. * b_elements -> Elements of the the right hand sides, as a contiguous array. * x_elements -> The array that will contain the results of the computation. * pivot -> Positions of permuted rows. */ template void solve(const int r, const void* lu_elements, const void* b_elements, void* x_elements, const int* pivot) { int ii = 0, ip; DType sum; const DType* matrix = reinterpret_cast(lu_elements); const DType* b = reinterpret_cast(b_elements); DType* x = reinterpret_cast(x_elements); for (int i = 0; i < r; ++i) { x[i] = b[i]; } for (int i = 0; i < r; ++i) { // forward substitution loop ip = pivot[i]; sum = x[ip]; x[ip] = x[i]; if (ii != 0) { for (int j = ii - 1;j < i; ++j) { sum = sum - matrix[i * r + j] * x[j]; } } else if (sum != 0.0) { ii = i + 1; } x[i] = sum; } for (int i = r - 1; i >= 0; --i) { // back substitution loop sum = x[i]; for (int j = i + 1; j < r; j++) { sum = sum - matrix[i * r + j] * x[j]; } x[i] = sum/matrix[i * r + i]; } } /* * Calculates in-place inverse of A_elements. Uses Gauss-Jordan elimination technique. @@ -1735,6 +1778,22 @@ void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dty ttable[dtype](M, elements, lda, result); } /* * C accessor for solving a system of linear equations. */ void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv) { int* pivot = new int[RARRAY_LEN(ipiv)]; for (int i = 0; i < RARRAY_LEN(ipiv); ++i) { pivot[i] = FIX2INT(rb_ary_entry(ipiv, i)); } NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::solve, void, const int, const void*, const void*, void*, const int*); ttable[NM_DTYPE(x)](NM_SHAPE0(b), NM_STORAGE_DENSE(lu)->elements, NM_STORAGE_DENSE(b)->elements, NM_STORAGE_DENSE(x)->elements, pivot); } /* * C accessor for calculating an in-place inverse. */
 @@ -103,6 +103,7 @@ extern "C" { /* * C accessors. */ void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv); void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result); void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype); void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype);
 @@ -155,6 +155,7 @@ static VALUE matrix_multiply_scalar(NMATRIX* left, VALUE scalar); static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right); static VALUE nm_multiply(VALUE left_v, VALUE right_v); static VALUE nm_det_exact(VALUE self); static VALUE nm_solve(VALUE self, VALUE lu, VALUE b, VALUE x, VALUE ipiv); static VALUE nm_inverse(VALUE self, VALUE inverse, VALUE bang); static VALUE nm_inverse_exact(VALUE self, VALUE inverse, VALUE lda, VALUE ldb); static VALUE nm_complex_conjugate_bang(VALUE self); @@ -263,6 +264,7 @@ void Init_nmatrix() { rb_define_method(cNMatrix, "supershape", (METHOD)nm_supershape, 0); rb_define_method(cNMatrix, "offset", (METHOD)nm_offset, 0); rb_define_method(cNMatrix, "det_exact", (METHOD)nm_det_exact, 0); rb_define_private_method(cNMatrix, "__solve__", (METHOD)nm_solve, 4); rb_define_protected_method(cNMatrix, "__inverse__", (METHOD)nm_inverse, 2); rb_define_protected_method(cNMatrix, "__inverse_exact__", (METHOD)nm_inverse_exact, 3); rb_define_method(cNMatrix, "complex_conjugate!", (METHOD)nm_complex_conjugate_bang, 0); @@ -2961,7 +2963,29 @@ static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right) { return to_return; } /* * Solve the system of linear equations when passed the LU factorized matrix * of the matrix of co-effcients and the column-matrix of right hand sides. * Does no error checking of its own. Expects it all to be done in Ruby. See * #solve in math.rb for details. Modifies x. * * == Arguments * * self - The NMatrix object calling this function * lu - LU Decomoposition of self. Values never change. * b - The vector of right hand sides. Values never change. * x - The vector of variables to found. The computed values are stored in this. * ipiv - The pivot array of the LU factorized matrix. * * == Notes * * LAPACK free. */ static VALUE nm_solve(VALUE self, VALUE lu, VALUE b, VALUE x, VALUE ipiv) { nm_math_solve(lu, b, x, ipiv); return x; } /* * Calculate the inverse of a matrix with in-place Gauss-Jordan elimination. * Inverse will fail if the largest element in any column in zero.
 @@ -180,18 +180,58 @@ def factorize_cholesky # call-seq: # factorize_lu -> ... # # LU factorization of a matrix. # # LU factorization of a matrix. Optionally return the permutation matrix. # Note that computing the permutation matrix will introduce a slight memory # and time overhead. # # == Arguments # # +with_permutation_matrix+ - If set to *true* will return the permutation # matrix alongwith the LU factorization as a second return value. # # FIXME: For some reason, getrf seems to require that the matrix be transposed first -- and then you have to transpose the # FIXME: result again. Ideally, this would be an in-place factorize instead, and would be called nm_factorize_lu_bang. # def factorize_lu def factorize_lu with_permutation_matrix=nil raise(NotImplementedError, "only implemented for dense storage") unless self.stype == :dense raise(NotImplementedError, "matrix is not 2-dimensional") unless self.dimensions == 2 t = self.transpose NMatrix::LAPACK::clapack_getrf(:row, t.shape[0], t.shape[1], t, t.shape[1]) t.transpose t = self.transpose pivot = NMatrix::LAPACK::clapack_getrf(:row, t.shape[0], t.shape[1], t, t.shape[1]) return t.transpose unless with_permutation_matrix [t.transpose, FactorizeLUMethods.permutation_matrix_from(pivot)] end # Solve a system of linear equations where *self* is the matrix of co-efficients # and *b* is the vertical vector of right hand sides. Only works with dense # matrices and non-integer, non-object data types. # # == Arguments # # +b+ - Vector of Right Hand Sides. # # == Usage # # a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype # b = NMatrix.new [2,1], [9,8], dtype: dtype # a.solve(b) def solve b raise ArgumentError, "b must be a column vector" if b.shape[1] != 1 raise ArgumentError, "number of rows of b must equal number of cols of self" if self.shape[1] != b.shape[0] raise ArgumentError, "only works with dense matrices" if self.stype != :dense raise ArgumentError, "only works for non-integer, non-object dtypes" if integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype? x = b.clone_structure clone = self.clone t = clone.transpose # transpose because of the getrf anomaly described above. pivot = t.lu_decomposition! t = t.transpose __solve__(t, b, x, pivot) x end #
 @@ -298,6 +298,16 @@ def rational_dtype? [:rational32, :rational64, :rational128].include?(self.dtype) end ## # call-seq: # # object_dtype?() -> Boolean # # Checks if dtype is a ruby object def object_dtype? dtype == :object end # # call-seq:
 @@ -379,4 +379,23 @@ end end end context "#solve" do NON_INTEGER_DTYPES.each do |dtype| next if dtype == :object # LU factorization doesnt work for :object yet it "solves linear equation for dtype #{dtype}" do a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype b = NMatrix.new [2,1], [9,8], dtype: dtype expect(a.solve(b)).to eq(NMatrix.new [2,1], [2,3], dtype: dtype) end it "solves linear equation for #{dtype} (non-symmetric matrix)" do a = NMatrix.new [3,3], [1,2,3,5,6,7,3,5,3], dtype: dtype b = NMatrix.new [3,1], [2,3,4], dtype: dtype expect(a.solve(b)).to be_within(0.01).of(NMatrix.new([3,1], [-1.437,1.62,0.062], dtype: dtype)) end unless [:rational32, :rational64, :rational128].include?(dtype) end end end