Skip to content

Commit

Permalink
Improve error handling in ROLEQ and add unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayitzin committed May 6, 2022
1 parent 8953c71 commit 7690bab
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 12 deletions.
81 changes: 69 additions & 12 deletions ahrs/filters/roleq.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@

import numpy as np
from ..common.orientation import ecompass
from ..common.mathfuncs import cosd, sind
from ..common.mathfuncs import cosd
from ..common.mathfuncs import sind

def _assert_iterables(item, item_name: str = 'iterable'):
if not isinstance(item, (list, tuple, np.ndarray)):
raise TypeError(f"{item_name} must be given as an array. Got {type(item)}")

class ROLEQ:
"""
Expand Down Expand Up @@ -140,22 +145,24 @@ def __init__(self,
frame: str = 'NED',
**kwargs
):
self.gyr = gyr
self.acc = acc
self.mag = mag
self.a = weights if weights is not None else np.ones(2)
self.Q = None
self.frequency = kwargs.get('frequency', 100.0)
self.Dt = kwargs.get('Dt', 1.0/self.frequency)
self.q0 = kwargs.get('q0')
self.frame = frame
self.gyr: np.ndarray = gyr
self.acc: np.ndarray = acc
self.mag: np.ndarray = mag
self.a: np.ndarray = weights if weights is not None else np.ones(2)
self.frequency: float = kwargs.get('frequency', 100.0)
self.Dt: float = kwargs.get('Dt', (1.0/self.frequency) if self.frequency else 0.01)
self.q0: np.ndarray = kwargs.get('q0')
self.frame: str = frame
# Reference measurements
self._set_reference_frames(magnetic_ref, self.frame)
self._assert_validity_of_inputs()
# Estimate all quaternions if data is given
if self.acc is not None and self.gyr is not None and self.mag is not None:
self.Q = self._compute_all()

def _set_reference_frames(self, mref: float, frame: str = 'NED'):
if not isinstance(frame, str):
raise TypeError(f"'frame' must be a string. Got {type(frame)}.")
if frame.upper() not in ['NED', 'ENU']:
raise ValueError(f"Invalid frame '{frame}'. Try 'NED' or 'ENU'")
#### Magnetic Reference Vector ####
Expand All @@ -166,16 +173,64 @@ def _set_reference_frames(self, mref: float, frame: str = 'NED'):
wmm = WMM(latitude=MUNICH_LATITUDE, longitude=MUNICH_LONGITUDE, height=MUNICH_HEIGHT)
cd, sd = cosd(wmm.I), sind(wmm.I)
self.m_ref = np.array([sd, 0.0, cd]) if frame.upper() == 'NED' else np.array([0.0, cd, -sd])
elif isinstance(mref, bool):
raise TypeError(f"'mref' must be a float or numpy.ndarray. Got {type(mref)}.")
elif isinstance(mref, (int, float)):
# Use given magnetic dip angle (in degrees)
cd, sd = cosd(mref), sind(mref)
self.m_ref = np.array([sd, 0.0, cd]) if frame.upper() == 'NED' else np.array([0.0, cd, -sd])
else:
elif isinstance(mref, (list, tuple, np.ndarray)):
# Magnetic reference is given as a vector
self.m_ref = np.copy(mref)
else:
raise TypeError(f"Invalid magnetic reference type. Try float, int, list, tuple or numpy.ndarray")
if self.m_ref.shape != (3,):
raise ValueError(f"Magnetic reference vector must be of shape (3,). Got {self.m_ref.shape}.")
if np.linalg.norm(self.m_ref) == 0.0:
raise ValueError(f"Magnetic reference vector must not be zero.")
self.m_ref /= np.linalg.norm(self.m_ref)
#### Gravitational Reference Vector ####
#### Gravitational Reference Vector ####
self.a_ref = np.array([0.0, 0.0, -1.0]) if frame.upper() == 'NED' else np.array([0.0, 0.0, 1.0])

def _assert_validity_of_inputs(self):
"""Asserts the validity of the inputs."""
# Assert floats
for item in ["frequency", "Dt"]:
if isinstance(self.__getattribute__(item), bool):
raise TypeError(f"Parameter '{item}' must be numeric.")
if not isinstance(self.__getattribute__(item), (int, float)):
raise TypeError(f"Parameter '{item}' is not a non-zero number.")
if self.__getattribute__(item) <= 0.0:
raise ValueError(f"Parameter '{item}' must be a non-zero number.")
# Assert arrays
for item in ['gyr', 'acc', 'mag', 'a', 'm_ref', 'a_ref', 'q0']:
if self.__getattribute__(item) is not None:
if isinstance(self.__getattribute__(item), bool):
raise TypeError(f"Parameter '{item}' must be an array of numeric values.")
_assert_iterables(self.__getattribute__(item), item)
self.__setattr__(item, np.copy(self.__getattribute__(item)))
if self.acc is not None and self.mag is None:
raise ValueError("If 'acc' is given, 'mag' must also be given.")
if self.mag is not None and self.acc is None:
raise ValueError("If 'mag' is given, 'acc' must also be given.")
if self.q0 is not None:
if self.q0.ndim != 1:
raise ValueError(f"Parameter 'q0' must be a 1-dimensional array.")
if self.q0.shape != (4,):
raise ValueError(f"Parameter 'q0' must be an array of shape (4,). It is {self.q0.shape}.")
if not np.allclose(np.linalg.norm(self.q0), 1.0):
raise ValueError(f"Parameter 'q0' must be a versor (norm equal to 1.0). Its norm is equal to {np.linalg.norm(self.q0)}.")
# Assert weights
if self.a.shape != (2,):
raise ValueError(f"Dimension of 'weights' must be (2,). Got {self.a.shape}.")
for item in self.a:
if not isinstance(item, (int, float)):
raise TypeError(f"'weights' must be an array of numeric values. Got {type(item)}.")
if item < 0.0:
raise ValueError(f"'weights' must be non-negative. Got {item}.")
if not any(self.a > 0):
raise ValueError("'weights' must contain positive values.")

def _compute_all(self) -> np.ndarray:
"""
Estimate the quaternions given all data.
Expand Down Expand Up @@ -294,6 +349,8 @@ def oleq(self, acc: np.ndarray, mag: np.ndarray, q_omega: np.ndarray) -> np.ndar
Final quaternion.
"""
_assert_iterables(acc, 'Gravitational acceleration vector')
_assert_iterables(mag, 'Geomagnetic field vector')
a_norm = np.linalg.norm(acc)
m_norm = np.linalg.norm(mag)
if not a_norm > 0 or not m_norm > 0: # handle NaN
Expand Down
72 changes: 72 additions & 0 deletions tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,5 +987,77 @@ def test_estimation(self):
orientation = ahrs.filters.ROLEQ(gyr=self.gyros, acc=self.Rg, mag=self.Rm)
self.assertLess(np.nanmean(ahrs.utils.metrics.qad(self.Qts, orientation.Q)), self.noise_sigma*10.0)

def test_wrong_input_vectors(self):
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=1.0)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc="self.Rg")
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=True)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=1.0, mag=2.0)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=2.0)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=1.0, mag=self.Rm)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc="self.Rg", mag="self.Rm")
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg[0], mag=True)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=True, mag=[1.0, 2.0, 3.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=[1.0, 2.0, 3.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, mag=[2.0, 3.0, 4.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=[1.0, 2.0], mag=[2.0, 3.0, 4.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=[1.0, 2.0, 3.0, 4.0], mag=[2.0, 3.0, 4.0, 5.0])

def test_wrong_magnetic_reference(self):
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, magnetic_ref='34.5')
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, magnetic_ref=False)
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, magnetic_ref=['34.5'])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, magnetic_ref=('34.5',))
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, magnetic_ref=[1.0, 2.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, magnetic_ref=[0.0, 0.0, 0.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, magnetic_ref=[[1.0], [2.0], [3.0]])

def test_wrong_input_frame(self):
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frame=1)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frame=1.0)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frame=True)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frame=['NED'])
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frame=('NED',))
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frame='NWU')

def test_wrong_weights(self):
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=1)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=1.0)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=True)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights="[1.0, 1.0]")
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=['1.0', '1.0'])
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=['1.0', 1.0])
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=[1.0, '1.0'])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=[[1.0], [1.0]])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=[[1.0, 1.0]])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=[1.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=[0.5, -0.5])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=[0.0, 0.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, weights=np.zeros(4))

def test_wrong_input_frequency(self):
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frequency="100.0")
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frequency=[100.0])
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frequency=(100.0,))
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frequency=True)
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frequency=0.0)
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, frequency=-100.0)

def test_wrong_input_Dt(self):
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, Dt="0.01")
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, Dt=[0.01])
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, Dt=(0.01,))
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, Dt=True)
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, Dt=0.0)
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, Dt=-0.01)

def test_wrong_initial_quaternion(self):
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, q0=1)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, q0=1.0)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, q0=True)
self.assertRaises(TypeError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, q0="[1.0, 0.0, 0.0, 0.0]")
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, q0=[1.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, q0=[1.0, 0.0, 0.0])
self.assertRaises(ValueError, ahrs.filters.ROLEQ, gyr=self.gyros, acc=self.Rg, mag=self.Rm, q0=np.zeros(4))

if __name__ == '__main__':
unittest.main()

0 comments on commit 7690bab

Please sign in to comment.