diff --git a/ocw/metrics.py b/ocw/metrics.py index fb7bce34..ee5c1f9e 100644 --- a/ocw/metrics.py +++ b/ocw/metrics.py @@ -184,8 +184,9 @@ def run(self, reference_dataset, target_dataset): .. note:: Overrides BinaryMetric.run() - :param ref_dataset: The reference dataset to use in this metric run - :type ref_dataset: ocw.dataset.Dataset object + :param reference_dataset: The reference dataset to use in this metric + run + :type reference_dataset: ocw.dataset.Dataset object :param target_dataset: The target dataset to evaluate against the reference dataset in this metric run :type target_dataset: ocw.dataset.Dataset object @@ -201,22 +202,23 @@ class RMSError(BinaryMetric): '''Calculate the Root Mean Square Difference (RMS Error), with the mean calculated over time and space.''' - def run(self, eval_dataset, ref_dataset): + def run(self, reference_dataset, target_dataset): '''Calculate the Root Mean Square Difference (RMS Error), with the mean calculated over time and space. .. note:: Overrides BinaryMetric.run() - :param eval_dataset: The dataset to evaluate against the reference - dataset - :type eval_dataset: ocw.dataset.Dataset object - :param ref_dataset: The reference dataset for the metric + :param reference_dataset: The reference dataset to use in this metric + run + :type reference_dataset: ocw.dataset.Dataset object + :param target_dataset: The target dataset to evaluate against the + reference dataset in this metric run :type target_dataset: ocw.dataset.Dataset object :returns: The RMS error, with the mean calculated over time and space ''' - sqdiff = (eval_dataset.values - ref_dataset.values) ** 2 + sqdiff = (reference_dataset.values - target_dataset.values) ** 2 return numpy.sqrt(sqdiff.mean()) diff --git a/ocw/tests/test_metrics.py b/ocw/tests/test_metrics.py index a2ca025c..affb9376 100644 --- a/ocw/tests/test_metrics.py +++ b/ocw/tests/test_metrics.py @@ -202,27 +202,27 @@ class TestRMSError(unittest.TestCase): def setUp(self): # Set metric. self.metric = metrics.RMSError() - # Initialize evaluation dataset. - self.eval_lats = np.array([10, 20, 30, 40, 50]) - self.eval_lons = np.array([5, 15, 25, 35, 45]) - self.eval_times = np.array([dt.datetime(2000, x, 1) - for x in range(1, 13)]) - self.eval_values = np.array([4] * 300).reshape(12, 5, 5) - self.eval_variable = "eval" - self.eval_dataset = Dataset(self.eval_lats, self.eval_lons, - self.eval_times, self.eval_values, self.eval_variable) # Initialize reference dataset. self.ref_lats = np.array([10, 20, 30, 40, 50]) self.ref_lons = np.array([5, 15, 25, 35, 45]) self.ref_times = np.array([dt.datetime(2000, x, 1) for x in range(1, 13)]) - self.ref_values = np.array([2] * 300).reshape(12, 5, 5) + self.ref_values = np.array([4] * 300).reshape(12, 5, 5) self.ref_variable = "ref" self.ref_dataset = Dataset(self.ref_lats, self.ref_lons, self.ref_times, self.ref_values, self.ref_variable) + # Initialize target dataset. + self.tgt_lats = np.array([10, 20, 30, 40, 50]) + self.tgt_lons = np.array([5, 15, 25, 35, 45]) + self.tgt_times = np.array([dt.datetime(2000, x, 1) + for x in range(1, 13)]) + self.tgt_values = np.array([2] * 300).reshape(12, 5, 5) + self.tgt_variable = "tgt" + self.tgt_dataset = Dataset(self.tgt_lats, self.tgt_lons, + self.tgt_times, self.tgt_values, self.tgt_variable) def test_function_run(self): - result = self.metric.run(self.eval_dataset, self.ref_dataset) + result = self.metric.run(self.ref_dataset, self.tgt_dataset) self.assertEqual(result, 2.0)