From 1c4feb9cb8dfbcf64635ccc150686911272d744b Mon Sep 17 00:00:00 2001 From: ishii-norimi Date: Sat, 22 Nov 2025 19:21:08 +0900 Subject: [PATCH] Refactor variance calculation in Matrix class for improved performance and update tests. --- lib/util/matrix.js | 17 +++++++++++++---- tests/lib/util/matrix.test.js | 6 +++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/lib/util/matrix.js b/lib/util/matrix.js index a7c5e594..e0ca37b2 100644 --- a/lib/util/matrix.js +++ b/lib/util/matrix.js @@ -1802,21 +1802,30 @@ export default class Matrix { * @returns {Matrix | number} Variance values */ variance(axis = -1, ddof = 0) { - const m = this.mean(axis) if (axis < 0) { - return this._value.reduce((acc, v) => acc + (v - m) ** 2, 0) / (this.length - ddof) + let v = 0 + let v2 = 0 + const den = this.length - ddof + for (let i = 0; i < this.length; i++) { + v += this._value[i] + v2 += this._value[i] ** 2 + } + return v2 / den - (v / den) ** 2 } let v_step = axis === 0 ? 1 : this.cols let s_step = axis === 0 ? this.cols : 1 const new_size = [].concat(this._size) new_size[axis] = 1 + const den = this._size[axis] - ddof const mat = Matrix.zeros(...new_size) for (let n = 0, nv = 0; n < mat.length; n++, nv += v_step) { let v = 0 + let v2 = 0 for (let i = 0; i < this._size[axis]; i++) { - v += (this._value[i * s_step + nv] - m._value[n]) ** 2 + v += this._value[i * s_step + nv] + v2 += this._value[i * s_step + nv] ** 2 } - mat._value[n] = v / (this._size[axis] - ddof) + mat._value[n] = v2 / den - (v / den) ** 2 } return mat } diff --git a/tests/lib/util/matrix.test.js b/tests/lib/util/matrix.test.js index 98e446b3..c6e40fe2 100644 --- a/tests/lib/util/matrix.test.js +++ b/tests/lib/util/matrix.test.js @@ -2663,7 +2663,7 @@ describe('Matrix', () => { [4, 5, 6], ] const org = new Matrix(2, 3, data) - expect(org.variance()).toBe(17.5 / 6) + expect(org.variance()).toBeCloseTo(17.5 / 6) }) test('axis 0', () => { @@ -2685,7 +2685,7 @@ describe('Matrix', () => { const org = new Matrix(2, 3, data) const prod = org.variance(1) expect(prod.sizes).toEqual([2, 1]) - expect(prod.value).toEqual([8 / 3, 8 / 3]) + expect(prod.value).toEqual([expect.closeTo(8 / 3), expect.closeTo(8 / 3)]) }) }) @@ -2718,7 +2718,7 @@ describe('Matrix', () => { const org = new Matrix(2, 3, data) const prod = org.std(1) expect(prod.sizes).toEqual([2, 1]) - expect(prod.value).toEqual([Math.sqrt(8 / 3), Math.sqrt(8 / 3)]) + expect(prod.value).toEqual([expect.closeTo(Math.sqrt(8 / 3)), expect.closeTo(Math.sqrt(8 / 3))]) }) })