Permalink
Browse files

Solve linear equations with #solve

Conflicts:
	spec/math_spec.rb
  • Loading branch information...
v0dro authored and cjfuller committed Jan 26, 2015
1 parent 58d3a00 commit 4241d241ca7744ca2ca5e090782588581160d42b
Showing with 159 additions and 6 deletions.
  1. +59 −0 ext/nmatrix/math.cpp
  2. +1 −0 ext/nmatrix/math/math.h
  3. +24 −0 ext/nmatrix/ruby_nmatrix.c
  4. +46 −6 lib/nmatrix/math.rb
  5. +10 −0 lib/nmatrix/nmatrix.rb
  6. +19 −0 spec/math_spec.rb
View
@@ -230,6 +230,49 @@ namespace nm {
}
}
/*
* Solve a system of linear equations using forward-substution followed by
* back substution from the LU factorization of the matrix of co-efficients.
* Replaces x_elements with the result. Works only with non-integer, non-object
* data types.
*
* args - r -> The number of rows of the matrix.
* lu_elements -> Elements of the LU decomposition of the co-efficients
* matrix, as a contiguos array.
* b_elements -> Elements of the the right hand sides, as a contiguous array.
* x_elements -> The array that will contain the results of the computation.
* pivot -> Positions of permuted rows.
*/
template <typename DType>
void solve(const int r, const void* lu_elements, const void* b_elements, void* x_elements, const int* pivot) {
int ii = 0, ip;
DType sum;
const DType* matrix = reinterpret_cast<const DType*>(lu_elements);
const DType* b = reinterpret_cast<const DType*>(b_elements);
DType* x = reinterpret_cast<DType*>(x_elements);
for (int i = 0; i < r; ++i) { x[i] = b[i]; }
for (int i = 0; i < r; ++i) { // forward substitution loop
ip = pivot[i];
sum = x[ip];
x[ip] = x[i];
if (ii != 0) {
for (int j = ii - 1;j < i; ++j) { sum = sum - matrix[i * r + j] * x[j]; }
}
else if (sum != 0.0) {
ii = i + 1;
}
x[i] = sum;
}
for (int i = r - 1; i >= 0; --i) { // back substitution loop
sum = x[i];
for (int j = i + 1; j < r; j++) { sum = sum - matrix[i * r + j] * x[j]; }
x[i] = sum/matrix[i * r + i];
}
}
/*
* Calculates in-place inverse of A_elements. Uses Gauss-Jordan elimination technique.
@@ -1735,6 +1778,22 @@ void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dty
ttable[dtype](M, elements, lda, result);
}
/*
* C accessor for solving a system of linear equations.
*/
void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv) {
int* pivot = new int[RARRAY_LEN(ipiv)];
for (int i = 0; i < RARRAY_LEN(ipiv); ++i) {
pivot[i] = FIX2INT(rb_ary_entry(ipiv, i));
}
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::solve, void, const int, const void*, const void*, void*, const int*);
ttable[NM_DTYPE(x)](NM_SHAPE0(b), NM_STORAGE_DENSE(lu)->elements,
NM_STORAGE_DENSE(b)->elements, NM_STORAGE_DENSE(x)->elements, pivot);
}
/*
* C accessor for calculating an in-place inverse.
*/
View
@@ -103,6 +103,7 @@ extern "C" {
/*
* C accessors.
*/
void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv);
void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype);
void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype);
View
@@ -155,6 +155,7 @@ static VALUE matrix_multiply_scalar(NMATRIX* left, VALUE scalar);
static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right);
static VALUE nm_multiply(VALUE left_v, VALUE right_v);
static VALUE nm_det_exact(VALUE self);
static VALUE nm_solve(VALUE self, VALUE lu, VALUE b, VALUE x, VALUE ipiv);
static VALUE nm_inverse(VALUE self, VALUE inverse, VALUE bang);
static VALUE nm_inverse_exact(VALUE self, VALUE inverse, VALUE lda, VALUE ldb);
static VALUE nm_complex_conjugate_bang(VALUE self);
@@ -263,6 +264,7 @@ void Init_nmatrix() {
rb_define_method(cNMatrix, "supershape", (METHOD)nm_supershape, 0);
rb_define_method(cNMatrix, "offset", (METHOD)nm_offset, 0);
rb_define_method(cNMatrix, "det_exact", (METHOD)nm_det_exact, 0);
rb_define_private_method(cNMatrix, "__solve__", (METHOD)nm_solve, 4);
rb_define_protected_method(cNMatrix, "__inverse__", (METHOD)nm_inverse, 2);
rb_define_protected_method(cNMatrix, "__inverse_exact__", (METHOD)nm_inverse_exact, 3);
rb_define_method(cNMatrix, "complex_conjugate!", (METHOD)nm_complex_conjugate_bang, 0);
@@ -2961,7 +2963,29 @@ static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right) {
return to_return;
}
/*
* Solve the system of linear equations when passed the LU factorized matrix
* of the matrix of co-effcients and the column-matrix of right hand sides.
* Does no error checking of its own. Expects it all to be done in Ruby. See
* #solve in math.rb for details. Modifies x.
*
* == Arguments
*
* self - The NMatrix object calling this function
* lu - LU Decomoposition of self. Values never change.
* b - The vector of right hand sides. Values never change.
* x - The vector of variables to found. The computed values are stored in this.
* ipiv - The pivot array of the LU factorized matrix.
*
* == Notes
*
* LAPACK free.
*/
static VALUE nm_solve(VALUE self, VALUE lu, VALUE b, VALUE x, VALUE ipiv) {
nm_math_solve(lu, b, x, ipiv);
return x;
}
/*
* Calculate the inverse of a matrix with in-place Gauss-Jordan elimination.
* Inverse will fail if the largest element in any column in zero.
View
@@ -180,18 +180,58 @@ def factorize_cholesky
# call-seq:
# factorize_lu -> ...
#
# LU factorization of a matrix.
#
# LU factorization of a matrix. Optionally return the permutation matrix.
# Note that computing the permutation matrix will introduce a slight memory
# and time overhead.
#
# == Arguments
#
# +with_permutation_matrix+ - If set to *true* will return the permutation
# matrix alongwith the LU factorization as a second return value.
#
# FIXME: For some reason, getrf seems to require that the matrix be transposed first -- and then you have to transpose the
# FIXME: result again. Ideally, this would be an in-place factorize instead, and would be called nm_factorize_lu_bang.
#
def factorize_lu
def factorize_lu with_permutation_matrix=nil
raise(NotImplementedError, "only implemented for dense storage") unless self.stype == :dense
raise(NotImplementedError, "matrix is not 2-dimensional") unless self.dimensions == 2
t = self.transpose
NMatrix::LAPACK::clapack_getrf(:row, t.shape[0], t.shape[1], t, t.shape[1])
t.transpose
t = self.transpose
pivot = NMatrix::LAPACK::clapack_getrf(:row, t.shape[0], t.shape[1], t, t.shape[1])
return t.transpose unless with_permutation_matrix
[t.transpose, FactorizeLUMethods.permutation_matrix_from(pivot)]
end
# Solve a system of linear equations where *self* is the matrix of co-efficients
# and *b* is the vertical vector of right hand sides. Only works with dense
# matrices and non-integer, non-object data types.
#
# == Arguments
#
# +b+ - Vector of Right Hand Sides.
#
# == Usage
#
# a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
# b = NMatrix.new [2,1], [9,8], dtype: dtype
# a.solve(b)
def solve b
raise ArgumentError, "b must be a column vector" if b.shape[1] != 1
raise ArgumentError, "number of rows of b must equal number of cols of self" if
self.shape[1] != b.shape[0]
raise ArgumentError, "only works with dense matrices" if self.stype != :dense
raise ArgumentError, "only works for non-integer, non-object dtypes" if
integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
x = b.clone_structure
clone = self.clone
t = clone.transpose # transpose because of the getrf anomaly described above.
pivot = t.lu_decomposition!
t = t.transpose
__solve__(t, b, x, pivot)
x
end
#
View
@@ -298,6 +298,16 @@ def rational_dtype?
[:rational32, :rational64, :rational128].include?(self.dtype)
end
##
# call-seq:
#
# object_dtype?() -> Boolean
#
# Checks if dtype is a ruby object
def object_dtype?
dtype == :object
end
#
# call-seq:
View
@@ -379,4 +379,23 @@
end
end
end
context "#solve" do
NON_INTEGER_DTYPES.each do |dtype|
next if dtype == :object # LU factorization doesnt work for :object yet
it "solves linear equation for dtype #{dtype}" do
a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
b = NMatrix.new [2,1], [9,8], dtype: dtype
expect(a.solve(b)).to eq(NMatrix.new [2,1], [2,3], dtype: dtype)
end
it "solves linear equation for #{dtype} (non-symmetric matrix)" do
a = NMatrix.new [3,3], [1,2,3,5,6,7,3,5,3], dtype: dtype
b = NMatrix.new [3,1], [2,3,4], dtype: dtype
expect(a.solve(b)).to be_within(0.01).of(NMatrix.new([3,1], [-1.437,1.62,0.062], dtype: dtype))
end unless [:rational32, :rational64, :rational128].include?(dtype)
end
end
end

0 comments on commit 4241d24

Please sign in to comment.