Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions ext/nmatrix/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ extern "C" {
// Math Functions //
////////////////////

namespace nm {
namespace nm {
namespace math {

/*
Expand Down Expand Up @@ -335,18 +335,18 @@ namespace nm {
for (int row = k + 1; row < M; ++row) {
typename MagnitudeDType<DType>::type big;
big = magnitude( matrix[M*row + k] ); // element below the temp pivot

if ( big > akk ) {
interchange = row;
akk = big;
akk = big;
}
}
}

if (interchange != k) { // check if rows need flipping
DType temp;

for (int col = 0; col < M; ++col) {
NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp);
NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp);
}
}

Expand All @@ -360,7 +360,7 @@ namespace nm {
DType pivot = matrix[k * (M + 1)];
matrix[k * (M + 1)] = (DType)(1); // set diagonal as 1 for in-place inversion

for (int col = 0; col < M; ++col) {
for (int col = 0; col < M; ++col) {
// divide each element in the kth row with the pivot
matrix[k*M + col] = matrix[k*M + col] / pivot;
}
Expand All @@ -369,7 +369,7 @@ namespace nm {
if (kk == k) continue;

DType dum = matrix[k + M*kk];
matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion
matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion
for (int col = 0; col < M; ++col) {
matrix[M*kk + col] = matrix[M*kk + col] - matrix[M*k + col] * dum;
}
Expand All @@ -384,7 +384,7 @@ namespace nm {

for (int row = 0; row < M; ++row) {
NM_SWAP(matrix[row * M + row_index[k]], matrix[row * M + col_index[k]],
temp);
temp);
}
}
}
Expand All @@ -410,14 +410,14 @@ namespace nm {
DType sum_of_squares, *p_row, *psubdiag, *p_a, scale, innerproduct;
int i, k, col;

// For each column use a Householder transformation to zero all entries
// For each column use a Householder transformation to zero all entries
// below the subdiagonal.
for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1,
for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1,
col++) {
// Calculate the signed square root of the sum of squares of the
// elements below the diagonal.

for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows;
for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows;
p_a += nrows, i++) {
sum_of_squares += *p_a * *p_a;
}
Expand Down Expand Up @@ -447,7 +447,7 @@ namespace nm {
*p_a -= u[k] * innerproduct;
}
}

// Postmultiply QA by Q
for (p_row = a, i = 0; i < nrows; p_row += nrows, i++) {
for (innerproduct = 0.0, k = col + 1; k < nrows; k++) {
Expand Down Expand Up @@ -485,7 +485,7 @@ namespace nm {
B[0] = A[lda+1] / det;
B[1] = -A[1] / det;
B[ldb] = -A[lda] / det;
B[ldb+1] = -A[0] / det;
B[ldb+1] = A[0] / det;

} else if (M == 3) {
// Calculate the exact determinant.
Expand Down Expand Up @@ -1313,7 +1313,7 @@ void nm_math_hessenberg(VALUE a) {
NULL, NULL, // does not support Complex
NULL // no support for Ruby Object
};

ttable[NM_DTYPE(a)](NM_SHAPE0(a), NM_STORAGE_DENSE(a)->elements);
}
/*
Expand Down
58 changes: 57 additions & 1 deletion lib/nmatrix/math.rb
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,62 @@ def invert
end
alias :inverse :invert

#
# call-seq:
# invert_exact! -> NMatrix
#
# Calulates inverse_exact of a matrix of size 2 or 3.
# Only works on dense matrices.
#
# * *Raises* :
# - +StorageTypeError+ -> only implemented on dense matrices.
# - +ShapeError+ -> matrix must be square.
# - +DataTypeError+ -> cannot invert an integer matrix in-place.
# - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3
#
def invert_exact!
raise(StorageTypeError, "invert only works on dense matrices currently") unless self.dense?
raise(ShapeError, "Cannot invert non-square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
raise(DataTypeError, "Cannot invert an integer matrix in-place") if self.integer_dtype?
#No internal implementation of getri, so use this other function
n = self.shape[0]
if n>3
raise(NotImplementedError, "Cannot find exact inverse of matrix of size greater than 3")
else
clond=self.clone
__inverse_exact__(clond, n, n)
end
end

#
# call-seq:
# invert_exact -> NMatrix
#
# Make a copy of the matrix, then invert using exact_inverse
#
# * *Returns* :
# - A dense NMatrix. Will be the same type as the input NMatrix,
# except if the input is an integral dtype, in which case it will be a
# :float64 NMatrix.
#
# * *Raises* :
# - +StorageTypeError+ -> only implemented on dense matrices.
# - +ShapeError+ -> matrix must be square.
# - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3
#
def invert_exact
#write this in terms of invert_exact! so plugins will only have to overwrite
#invert_exact! and not invert_exact
if self.integer_dtype?
cloned = self.cast(dtype: :float64)
cloned.invert_exact!
else
cloned = self.clone
cloned.invert_exact!
end
end
alias :inverse_exact :invert_exact

#
# call-seq:
# adjugate! -> NMatrix
Expand Down Expand Up @@ -1393,4 +1449,4 @@ def dtype_for_floor_or_ceil
self.__dense_map__ { |l| l.send(op,rhs) }
end
end
end
end
17 changes: 17 additions & 0 deletions spec/math_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,23 @@

expect(a.invert).to be_within(err).of(b)
end

it "should correctly find exact inverse" do
pending("not yet implemented for NMatrix-JRuby") if jruby?
a = NMatrix.new(:dense, 3, [1,2,3,0,1,4,5,6,0], dtype)
b = NMatrix.new(:dense, 3, [-24,18,5,20,-15,-4,-5,4,1], dtype)

expect(a.invert_exact).to be_within(err).of(b)
end

it "should correctly find exact inverse" do
pending("not yet implemented for NMatrix-JRuby") if jruby?
a = NMatrix.new(:dense, 2, [1,3,3,8,], dtype)
b = NMatrix.new(:dense, 2, [-8,3,3,-1], dtype)

expect(a.invert_exact).to be_within(err).of(b)
end

end
end

Expand Down