Skip to content
This repository was archived by the owner on May 12, 2021. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions ocw/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ def run(self, reference_dataset, target_dataset):
.. note::
Overrides BinaryMetric.run()

:param ref_dataset: The reference dataset to use in this metric run
:type ref_dataset: ocw.dataset.Dataset object
:param reference_dataset: The reference dataset to use in this metric
run
:type reference_dataset: ocw.dataset.Dataset object
:param target_dataset: The target dataset to evaluate against the
reference dataset in this metric run
:type target_dataset: ocw.dataset.Dataset object
Expand All @@ -201,22 +202,23 @@ class RMSError(BinaryMetric):
'''Calculate the Root Mean Square Difference (RMS Error), with the mean
calculated over time and space.'''

def run(self, eval_dataset, ref_dataset):
def run(self, reference_dataset, target_dataset):
'''Calculate the Root Mean Square Difference (RMS Error), with the mean
calculated over time and space.

.. note::
Overrides BinaryMetric.run()

:param eval_dataset: The dataset to evaluate against the reference
dataset
:type eval_dataset: ocw.dataset.Dataset object
:param ref_dataset: The reference dataset for the metric
:param reference_dataset: The reference dataset to use in this metric
run
:type reference_dataset: ocw.dataset.Dataset object
:param target_dataset: The target dataset to evaluate against the
reference dataset in this metric run
:type target_dataset: ocw.dataset.Dataset object

:returns: The RMS error, with the mean calculated over time and space
'''

sqdiff = (eval_dataset.values - ref_dataset.values) ** 2
sqdiff = (reference_dataset.values - target_dataset.values) ** 2
return numpy.sqrt(sqdiff.mean())

22 changes: 11 additions & 11 deletions ocw/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,27 @@ class TestRMSError(unittest.TestCase):
def setUp(self):
# Set metric.
self.metric = metrics.RMSError()
# Initialize evaluation dataset.
self.eval_lats = np.array([10, 20, 30, 40, 50])
self.eval_lons = np.array([5, 15, 25, 35, 45])
self.eval_times = np.array([dt.datetime(2000, x, 1)
for x in range(1, 13)])
self.eval_values = np.array([4] * 300).reshape(12, 5, 5)
self.eval_variable = "eval"
self.eval_dataset = Dataset(self.eval_lats, self.eval_lons,
self.eval_times, self.eval_values, self.eval_variable)
# Initialize reference dataset.
self.ref_lats = np.array([10, 20, 30, 40, 50])
self.ref_lons = np.array([5, 15, 25, 35, 45])
self.ref_times = np.array([dt.datetime(2000, x, 1)
for x in range(1, 13)])
self.ref_values = np.array([2] * 300).reshape(12, 5, 5)
self.ref_values = np.array([4] * 300).reshape(12, 5, 5)
self.ref_variable = "ref"
self.ref_dataset = Dataset(self.ref_lats, self.ref_lons,
self.ref_times, self.ref_values, self.ref_variable)
# Initialize target dataset.
self.tgt_lats = np.array([10, 20, 30, 40, 50])
self.tgt_lons = np.array([5, 15, 25, 35, 45])
self.tgt_times = np.array([dt.datetime(2000, x, 1)
for x in range(1, 13)])
self.tgt_values = np.array([2] * 300).reshape(12, 5, 5)
self.tgt_variable = "tgt"
self.tgt_dataset = Dataset(self.tgt_lats, self.tgt_lons,
self.tgt_times, self.tgt_values, self.tgt_variable)

def test_function_run(self):
result = self.metric.run(self.eval_dataset, self.ref_dataset)
result = self.metric.run(self.ref_dataset, self.tgt_dataset)
self.assertEqual(result, 2.0)


Expand Down