diff --git a/lib/nmatrix/math.rb b/lib/nmatrix/math.rb index 05c9c981..5c667ca6 100644 --- a/lib/nmatrix/math.rb +++ b/lib/nmatrix/math.rb @@ -112,6 +112,71 @@ def invert end alias :inverse :invert + + # + # call-seq: + # pinv -> NMatrix + # + # Compute the Moore-Penrose pseudo-inverse of a matrix using its + # singular value decomposition (SVD). + # + # This function requires the nmatrix-atlas gem installed. + # + # * *Arguments* : + # - +tolerance(optional)+ -> Cutoff for small singular values. + # + # * *Returns* : + # - Pseudo-inverse matrix. + # + # * *Raises* : + # - +NotImplementedError+ -> If called without nmatrix-atlas or nmatrix-lapacke gem. + # - +TypeError+ -> If called without float or complex data type. + # + # * *Examples* : + # + # a = NMatrix.new([2,2],[1,2, + # 3,4], dtype: :float64) + # a.pinv # => [ [-2.0000000000000018, 1.0000000000000007] + # [1.5000000000000016, -0.5000000000000008] ] + # + # b = NMatrix.new([4,1],[1,2,3,4], dtype: :float64) + # b.pinv # => [ [ 0.03333333, 0.06666667, 0.99999999, 0.13333333] ] + # + # == References + # + # * https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_pseudoinverse + # * G. Strang, Linear Algebra and Its Applications, 2nd Ed., Orlando, FL, Academic Press + # + def pinv(tolerance = 1e-15) + raise DataTypeError, "pinv works only with matrices of float or complex data type" unless + [:float32, :float64, :complex64, :complex128].include?(dtype) + if self.complex_dtype? + u, s, vt = self.complex_conjugate.gesvd # singular value decomposition + else + u, s, vt = self.gesvd + end + rows = self.shape[0] + cols = self.shape[1] + if rows < cols + u_reduced = u + vt_reduced = vt[0..rows - 1, 0..cols - 1].transpose + else + u_reduced = u[0..rows - 1, 0..cols - 1] + vt_reduced = vt.transpose + end + largest_singular_value = s.max.to_f + cutoff = tolerance * largest_singular_value + (0...[rows, cols].min).each do |i| + s[i] = 1 / s[i] if s[i] > cutoff + s[i] = 0 if s[i] <= cutoff + end + multiplier = u_reduced.dot(NMatrix.diagonal(s.to_a)).transpose + vt_reduced.dot(multiplier) + end + alias :pseudo_inverse :pinv + alias :pseudoinverse :pinv + + # # call-seq: # getrf! -> Array diff --git a/spec/math_spec.rb b/spec/math_spec.rb index 127c7b13..290f2459 100644 --- a/spec/math_spec.rb +++ b/spec/math_spec.rb @@ -332,6 +332,63 @@ end end + NON_INTEGER_DTYPES.each do |dtype| + next if dtype == :object + context dtype do + err = Complex(1e-3, 1e-3) + it "should correctly invert a 2x2 matrix" do + if dtype == :complex64 || dtype == :complex128 + a = NMatrix.new([2, 2], [Complex(16, 81), Complex(91, 51), \ + Complex(13, 54), Complex(71, 24)], dtype: dtype) + b = NMatrix.identity(2, dtype: dtype) + + begin + expect(a.dot(a.pinv)).to be_within(err).of(b) + rescue NotImplementedError + pending "Suppressing a NotImplementedError when the atlas plugin is not available" + end + + else + a = NMatrix.new([2, 2], [141, 612, 9123, 654], dtype: dtype) + b = NMatrix.identity(2, dtype: dtype) + + begin + expect(a.dot(a.pinv)).to be_within(err).of(b) + rescue NotImplementedError + pending "Suppressing a NotImplementedError when the atlas plugin is not available" + end + end + end + + it "should verify a.dot(b.dot(a)) == a and b.dot(a.dot(b)) == b" do + if dtype == :complex64 || dtype == :complex128 + a = NMatrix.new([3, 2], [Complex(94, 11), Complex(87, 51), Complex(82, 39), \ + Complex(45, 16), Complex(25, 32), Complex(91, 43) ], dtype: dtype) + + begin + b = a.pinv # pseudo inverse + expect(a.dot(b.dot(a))).to be_within(err).of(a) + expect(b.dot(a.dot(b))).to be_within(err).of(b) + rescue NotImplementedError + pending "Suppressing a NotImplementedError when the atlas plugin is not available" + end + + else + a = NMatrix.new([3, 3], [9, 4, 52, 12, 52, 1, 3, 55, 6], dtype: dtype) + + begin + b = a.pinv # pseudo inverse + expect(a.dot(b.dot(a))).to be_within(err).of(a) + expect(b.dot(a.dot(b))).to be_within(err).of(b) + rescue NotImplementedError + pending "Suppressing a NotImplementedError when the atlas plugin is not available" + end + end + end + end + end + + # TODO: Get it working with ROBJ too [:byte,:int8,:int16,:int32,:int64,:float32,:float64].each do |left_dtype| [:byte,:int8,:int16,:int32,:int64,:float32,:float64].each do |right_dtype|