Skip to content

Commit

Permalink
Add function rmse_matrices() to submodule metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayitzin committed Aug 14, 2023
1 parent 1560da0 commit a87eea3
Showing 1 changed file with 129 additions and 0 deletions.
129 changes: 129 additions & 0 deletions ahrs/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,132 @@ def rmse(x: np.ndarray, y: np.ndarray):
if x.ndim > 1:
return np.sqrt(np.nanmean((x-y)**2, axis=1))
return np.sqrt(np.nanmean((x-y)**2))

def rmse_matrices(A: np.ndarray, B: np.ndarray, element_wise: bool = False) -> np.ndarray:
"""
Root Mean Square Error between two arrays (matrices)
Parameters
----------
A : np.ndarray
First M-by-N matrix or array of k M-by-N matrices.
B : np.ndarray
Second M-by-N matrix or array of k M-by-N matrices.
element_wise : bool, default: False
If True, calculate RMSE element-wise, and return an M-by-N array of
RMSEs.
Returns
-------
rmse : float or np.ndarray
Root Mean Square Error between the two matrices, or array of k RMSEs
between the two arrays of matrices.
Raises
------
ValueError
If the comparing arrays do not have the same shape.
Notes
-----
If the input arrays are 2-dimensional matrices, the RMSE is calculated as:
.. math::
RMSE = \\sqrt{\\frac{1}{MN}\\sum_{i=1}^{M}\\sum_{j=1}^{N}(A_{ij} - B_{ij})^2}
If the input arrays are arrays of 2-dimensional matrices (3-dimensional
array), the RMSE is calculated as:
.. math::
RMSE = \\sqrt{\\frac{1}{k}\\sum_{l=1}^{k}\\frac{1}{MN}\\sum_{i=1}^{M}\\sum_{j=1}^{N}(A_{ij}^{(l)} - B_{ij}^{(l)})^2}
where :math:`k` is the number of matrices in the arrays.
If the option ``element_wise`` is set to ``True``, the RMSE is calculated
element-wise, and an M-by-N array of RMSEs is returned. The following calls
are equivalent:
.. code-block:: python
rmse = rmse_matrices(A, B, element_wise=True)
rmse = np.sqrt(np.nanmean((A-B)**2, axis=0))
If the inputs are arrays of matrices (3-dimensional arrays), its call is
also equivalent to:
.. code-block:: python
rmse = np.zeros_like(A[0])
for i in range(A.shape[1]):
for j in range(A.shape[2]):
rmse[i, j] = np.sqrt(np.nanmean((A[:, i, j]-B[:, i, j])**2))
If the inputs are 2-dimensional matrices, the following calls would return
the same result:
.. code-block:: python
rmse_matrices(A, B)
rmse_matrices(A, B, element_wise=False)
rmse_matrices(A, B, element_wise=True)
Examples
--------
.. code-block:: python
>>> C = np.random.random((4, 3, 2)) # Array of four 3-by-2 matrices
>>> C.view()
array([[[0.2816407 , 0.30850589],
[0.44618209, 0.33081522],
[0.7994625 , 0.07377569]],
[[0.35549399, 0.47050713],
[0.94168683, 0.50388058],
[0.70023837, 0.77216167]],
[[0.79897129, 0.28555452],
[0.892488 , 0.71476669],
[0.19071524, 0.4123666 ]],
[[0.86301978, 0.14686002],
[0.98784823, 0.26129908],
[0.46982206, 0.88037599]]])
>>> D = np.random.random((4, 3, 2)) # Array of four 3-by-2 matrices
>>> D.view()
array([[[0.71560918, 0.34100321],
[0.92518341, 0.50741267],
[0.30730944, 0.19173378]],
[[0.31846657, 0.08578454],
[0.62643489, 0.84014104],
[0.7111152 , 0.95428613]],
[[0.8101591 , 0.9584096 ],
[0.91118705, 0.71203119],
[0.58217189, 0.45598271]],
[[0.79837603, 0.09954558],
[0.26532781, 0.55711476],
[0.03909648, 0.10787888]]])
>>> rmse_matrices(C[0], D[0]) # RMSE between first matrices
0.3430603410873006
>>> rmse_matrices(C, D) # RMSE between each of the four matrices
array([0.34306034, 0.25662067, 0.31842239, 0.48274156])
>>> rmse_matrices(C, D, element_wise=True) # RMSE element-wise along first dimension
array([[0.22022923, 0.3886001 ],
[0.46130561, 0.2407136 ],
[0.38114819, 0.40178899]])
>>> rmse_matrices(C[0], D[0], element_wise=True)
0.3430603410873006
>>> rmse_matrices(C[0], D[0])
0.3430603410873006
"""
A = np.copy(A)
B = np.copy(B)
if A.shape != B.shape:
raise ValueError("Both arrays must have the same shape.")
if A.ndim == 2:
return np.sqrt(np.mean((A - B)**2))
if element_wise:
return np.sqrt(np.mean((A - B)**2, axis=0))
return np.sqrt(np.mean(np.mean((A - B)**2, axis=2), axis=1))

0 comments on commit a87eea3

Please sign in to comment.