Skip to content

Commit

Permalink
Solve linear equations with #solve
Browse files Browse the repository at this point in the history
Conflicts:
	spec/math_spec.rb
  • Loading branch information
v0dro authored and Colin J. Fuller committed Jan 28, 2015
1 parent 58d3a00 commit 4241d24
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 6 deletions.
59 changes: 59 additions & 0 deletions ext/nmatrix/math.cpp
Expand Up @@ -230,6 +230,49 @@ namespace nm {
} }
} }


/*
* Solve a system of linear equations using forward-substution followed by
* back substution from the LU factorization of the matrix of co-efficients.
* Replaces x_elements with the result. Works only with non-integer, non-object
* data types.
*
* args - r -> The number of rows of the matrix.
* lu_elements -> Elements of the LU decomposition of the co-efficients
* matrix, as a contiguos array.
* b_elements -> Elements of the the right hand sides, as a contiguous array.
* x_elements -> The array that will contain the results of the computation.
* pivot -> Positions of permuted rows.
*/
template <typename DType>
void solve(const int r, const void* lu_elements, const void* b_elements, void* x_elements, const int* pivot) {
int ii = 0, ip;
DType sum;

const DType* matrix = reinterpret_cast<const DType*>(lu_elements);
const DType* b = reinterpret_cast<const DType*>(b_elements);
DType* x = reinterpret_cast<DType*>(x_elements);

for (int i = 0; i < r; ++i) { x[i] = b[i]; }
for (int i = 0; i < r; ++i) { // forward substitution loop
ip = pivot[i];
sum = x[ip];
x[ip] = x[i];

if (ii != 0) {
for (int j = ii - 1;j < i; ++j) { sum = sum - matrix[i * r + j] * x[j]; }
}
else if (sum != 0.0) {
ii = i + 1;
}
x[i] = sum;
}

for (int i = r - 1; i >= 0; --i) { // back substitution loop
sum = x[i];
for (int j = i + 1; j < r; j++) { sum = sum - matrix[i * r + j] * x[j]; }
x[i] = sum/matrix[i * r + i];
}
}


/* /*
* Calculates in-place inverse of A_elements. Uses Gauss-Jordan elimination technique. * Calculates in-place inverse of A_elements. Uses Gauss-Jordan elimination technique.
Expand Down Expand Up @@ -1735,6 +1778,22 @@ void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dty
ttable[dtype](M, elements, lda, result); ttable[dtype](M, elements, lda, result);
} }


/*
* C accessor for solving a system of linear equations.
*/
void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv) {
int* pivot = new int[RARRAY_LEN(ipiv)];

for (int i = 0; i < RARRAY_LEN(ipiv); ++i) {
pivot[i] = FIX2INT(rb_ary_entry(ipiv, i));
}

NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::solve, void, const int, const void*, const void*, void*, const int*);

ttable[NM_DTYPE(x)](NM_SHAPE0(b), NM_STORAGE_DENSE(lu)->elements,
NM_STORAGE_DENSE(b)->elements, NM_STORAGE_DENSE(x)->elements, pivot);
}

/* /*
* C accessor for calculating an in-place inverse. * C accessor for calculating an in-place inverse.
*/ */
Expand Down
1 change: 1 addition & 0 deletions ext/nmatrix/math/math.h
Expand Up @@ -103,6 +103,7 @@ extern "C" {
/* /*
* C accessors. * C accessors.
*/ */
void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv);
void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result); void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype); void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype);
void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype); void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype);
Expand Down
24 changes: 24 additions & 0 deletions ext/nmatrix/ruby_nmatrix.c
Expand Up @@ -155,6 +155,7 @@ static VALUE matrix_multiply_scalar(NMATRIX* left, VALUE scalar);
static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right); static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right);
static VALUE nm_multiply(VALUE left_v, VALUE right_v); static VALUE nm_multiply(VALUE left_v, VALUE right_v);
static VALUE nm_det_exact(VALUE self); static VALUE nm_det_exact(VALUE self);
static VALUE nm_solve(VALUE self, VALUE lu, VALUE b, VALUE x, VALUE ipiv);
static VALUE nm_inverse(VALUE self, VALUE inverse, VALUE bang); static VALUE nm_inverse(VALUE self, VALUE inverse, VALUE bang);
static VALUE nm_inverse_exact(VALUE self, VALUE inverse, VALUE lda, VALUE ldb); static VALUE nm_inverse_exact(VALUE self, VALUE inverse, VALUE lda, VALUE ldb);
static VALUE nm_complex_conjugate_bang(VALUE self); static VALUE nm_complex_conjugate_bang(VALUE self);
Expand Down Expand Up @@ -263,6 +264,7 @@ void Init_nmatrix() {
rb_define_method(cNMatrix, "supershape", (METHOD)nm_supershape, 0); rb_define_method(cNMatrix, "supershape", (METHOD)nm_supershape, 0);
rb_define_method(cNMatrix, "offset", (METHOD)nm_offset, 0); rb_define_method(cNMatrix, "offset", (METHOD)nm_offset, 0);
rb_define_method(cNMatrix, "det_exact", (METHOD)nm_det_exact, 0); rb_define_method(cNMatrix, "det_exact", (METHOD)nm_det_exact, 0);
rb_define_private_method(cNMatrix, "__solve__", (METHOD)nm_solve, 4);
rb_define_protected_method(cNMatrix, "__inverse__", (METHOD)nm_inverse, 2); rb_define_protected_method(cNMatrix, "__inverse__", (METHOD)nm_inverse, 2);
rb_define_protected_method(cNMatrix, "__inverse_exact__", (METHOD)nm_inverse_exact, 3); rb_define_protected_method(cNMatrix, "__inverse_exact__", (METHOD)nm_inverse_exact, 3);
rb_define_method(cNMatrix, "complex_conjugate!", (METHOD)nm_complex_conjugate_bang, 0); rb_define_method(cNMatrix, "complex_conjugate!", (METHOD)nm_complex_conjugate_bang, 0);
Expand Down Expand Up @@ -2961,7 +2963,29 @@ static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right) {
return to_return; return to_return;
} }


/*
* Solve the system of linear equations when passed the LU factorized matrix
* of the matrix of co-effcients and the column-matrix of right hand sides.
* Does no error checking of its own. Expects it all to be done in Ruby. See
* #solve in math.rb for details. Modifies x.
*
* == Arguments
*
* self - The NMatrix object calling this function
* lu - LU Decomoposition of self. Values never change.
* b - The vector of right hand sides. Values never change.
* x - The vector of variables to found. The computed values are stored in this.
* ipiv - The pivot array of the LU factorized matrix.
*
* == Notes
*
* LAPACK free.
*/
static VALUE nm_solve(VALUE self, VALUE lu, VALUE b, VALUE x, VALUE ipiv) {
nm_math_solve(lu, b, x, ipiv);


return x;
}
/* /*
* Calculate the inverse of a matrix with in-place Gauss-Jordan elimination. * Calculate the inverse of a matrix with in-place Gauss-Jordan elimination.
* Inverse will fail if the largest element in any column in zero. * Inverse will fail if the largest element in any column in zero.
Expand Down
52 changes: 46 additions & 6 deletions lib/nmatrix/math.rb
Expand Up @@ -180,18 +180,58 @@ def factorize_cholesky
# call-seq: # call-seq:
# factorize_lu -> ... # factorize_lu -> ...
# #
# LU factorization of a matrix. # LU factorization of a matrix. Optionally return the permutation matrix.
# # Note that computing the permutation matrix will introduce a slight memory
# and time overhead.
#
# == Arguments
#
# +with_permutation_matrix+ - If set to *true* will return the permutation
# matrix alongwith the LU factorization as a second return value.
#
# FIXME: For some reason, getrf seems to require that the matrix be transposed first -- and then you have to transpose the # FIXME: For some reason, getrf seems to require that the matrix be transposed first -- and then you have to transpose the
# FIXME: result again. Ideally, this would be an in-place factorize instead, and would be called nm_factorize_lu_bang. # FIXME: result again. Ideally, this would be an in-place factorize instead, and would be called nm_factorize_lu_bang.
# #
def factorize_lu def factorize_lu with_permutation_matrix=nil
raise(NotImplementedError, "only implemented for dense storage") unless self.stype == :dense raise(NotImplementedError, "only implemented for dense storage") unless self.stype == :dense
raise(NotImplementedError, "matrix is not 2-dimensional") unless self.dimensions == 2 raise(NotImplementedError, "matrix is not 2-dimensional") unless self.dimensions == 2


t = self.transpose t = self.transpose
NMatrix::LAPACK::clapack_getrf(:row, t.shape[0], t.shape[1], t, t.shape[1]) pivot = NMatrix::LAPACK::clapack_getrf(:row, t.shape[0], t.shape[1], t, t.shape[1])
t.transpose return t.transpose unless with_permutation_matrix

[t.transpose, FactorizeLUMethods.permutation_matrix_from(pivot)]
end

# Solve a system of linear equations where *self* is the matrix of co-efficients
# and *b* is the vertical vector of right hand sides. Only works with dense
# matrices and non-integer, non-object data types.
#
# == Arguments
#
# +b+ - Vector of Right Hand Sides.
#
# == Usage
#
# a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
# b = NMatrix.new [2,1], [9,8], dtype: dtype
# a.solve(b)
def solve b
raise ArgumentError, "b must be a column vector" if b.shape[1] != 1
raise ArgumentError, "number of rows of b must equal number of cols of self" if
self.shape[1] != b.shape[0]
raise ArgumentError, "only works with dense matrices" if self.stype != :dense
raise ArgumentError, "only works for non-integer, non-object dtypes" if
integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?

x = b.clone_structure
clone = self.clone
t = clone.transpose # transpose because of the getrf anomaly described above.
pivot = t.lu_decomposition!
t = t.transpose

__solve__(t, b, x, pivot)
x
end end


# #
Expand Down
10 changes: 10 additions & 0 deletions lib/nmatrix/nmatrix.rb
Expand Up @@ -298,6 +298,16 @@ def rational_dtype?
[:rational32, :rational64, :rational128].include?(self.dtype) [:rational32, :rational64, :rational128].include?(self.dtype)
end end


##
# call-seq:
#
# object_dtype?() -> Boolean
#
# Checks if dtype is a ruby object
def object_dtype?
dtype == :object
end



# #
# call-seq: # call-seq:
Expand Down
19 changes: 19 additions & 0 deletions spec/math_spec.rb
Expand Up @@ -379,4 +379,23 @@
end end
end end
end end

context "#solve" do
NON_INTEGER_DTYPES.each do |dtype|
next if dtype == :object # LU factorization doesnt work for :object yet
it "solves linear equation for dtype #{dtype}" do
a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
b = NMatrix.new [2,1], [9,8], dtype: dtype

expect(a.solve(b)).to eq(NMatrix.new [2,1], [2,3], dtype: dtype)
end

it "solves linear equation for #{dtype} (non-symmetric matrix)" do
a = NMatrix.new [3,3], [1,2,3,5,6,7,3,5,3], dtype: dtype
b = NMatrix.new [3,1], [2,3,4], dtype: dtype

expect(a.solve(b)).to be_within(0.01).of(NMatrix.new([3,1], [-1.437,1.62,0.062], dtype: dtype))
end unless [:rational32, :rational64, :rational128].include?(dtype)
end
end
end end

0 comments on commit 4241d24

Please sign in to comment.