From ca4d1c46efb02e1fa630e1de704e3c7c459086ea Mon Sep 17 00:00:00 2001 From: tcl Date: Thu, 15 Oct 2015 23:46:31 +0200 Subject: [PATCH] Added a custom equality check and finished up the ocean_level changes so that tests are passed. --- tests/serialization_test.py | 10 +++++----- worldengine/cli/main.py | 2 +- worldengine/common.py | 31 +++++++++++++++++++++++++++++++ worldengine/world.py | 13 ++----------- 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/tests/serialization_test.py b/tests/serialization_test.py index 2fcce3c1..61584882 100644 --- a/tests/serialization_test.py +++ b/tests/serialization_test.py @@ -1,9 +1,9 @@ import unittest from worldengine.plates import Step, world_gen from worldengine.world import World +from worldengine.common import _equal import tempfile import os -import numpy def _sort(l): @@ -28,7 +28,7 @@ def test_pickle_serialize_unserialize(self): self.assertEqual(w.ocean, unserialized.ocean) self.assertEqual(w.biome, unserialized.biome) self.assertEqual(w.humidity, unserialized.humidity) - self.assertTrue(numpy.array_equiv(w.irrigation, unserialized.irrigation)) + self.assertTrue(_equal(w.irrigation, unserialized.irrigation)) self.assertEqual(w.permeability, unserialized.permeability) self.assertEqual(w.watermap, unserialized.watermap) self.assertEqual(w.precipitation, unserialized.precipitation) @@ -36,7 +36,7 @@ def test_pickle_serialize_unserialize(self): self.assertEqual(w.sea_depth, unserialized.sea_depth) self.assertEquals(w.seed, unserialized.seed) self.assertEquals(w.n_plates, unserialized.n_plates) - self.assertAlmostEqual(w.ocean_level, unserialized.ocean_level) + self.assertTrue(_equal(w.ocean_level, unserialized.ocean_level)) self.assertEquals(w.lake_map, unserialized.lake_map) self.assertEquals(w.river_map, unserialized.river_map) self.assertEquals(w.step, unserialized.step) @@ -52,7 +52,7 @@ def test_protobuf_serialize_unserialize(self): self.assertEqual(w.ocean, unserialized.ocean) self.assertEqual(w.biome, unserialized.biome) self.assertEqual(w.humidity, unserialized.humidity) - self.assertTrue(numpy.array_equiv(w.irrigation, unserialized.irrigation)) + self.assertTrue(_equal(w.irrigation, unserialized.irrigation)) self.assertEqual(w.permeability, unserialized.permeability) self.assertEqual(w.watermap, unserialized.watermap) self.assertEqual(w.precipitation, unserialized.precipitation) @@ -60,7 +60,7 @@ def test_protobuf_serialize_unserialize(self): self.assertEqual(w.sea_depth, unserialized.sea_depth) self.assertEquals(w.seed, unserialized.seed) self.assertEquals(w.n_plates, unserialized.n_plates) - self.assertAlmostEqual(w.ocean_level, unserialized.ocean_level) + self.assertTrue(_equal(w.ocean_level, unserialized.ocean_level)) self.assertEquals(w.lake_map, unserialized.lake_map) self.assertEquals(w.river_map, unserialized.river_map) self.assertEquals(w.step, unserialized.step) diff --git a/worldengine/cli/main.py b/worldengine/cli/main.py index feeb483c..8c54abdf 100644 --- a/worldengine/cli/main.py +++ b/worldengine/cli/main.py @@ -103,7 +103,7 @@ def generate_plates(seed, world_name, output_dir, width, height, # Generate images filename = '%s/plates_%s.png' % (output_dir, world_name) sea_level = find_threshold_f(world.elevation['data'], 1.0 - ocean_level, - ocean=None, max=1.0, mindist=0.00001) + ocean=None, max=1.0, mindist=0.000005) draw_simple_elevation_on_file(world.elevation['data'], filename, width, height, sea_level) diff --git a/worldengine/common.py b/worldengine/common.py index 0decf046..3fd7c200 100644 --- a/worldengine/common.py +++ b/worldengine/common.py @@ -1,5 +1,6 @@ import sys import copy +import numpy #for the _equal method only # ---------------- # Global variables @@ -120,3 +121,33 @@ def array_to_matrix(array, width, height): for x in range(width): matrix[y].append(array[y * width + x]) return matrix + +def _equal(a, b): + #recursion on subclasses of types: tuple, list, dict + #specifically checks : float, ndarray + if type(a) is float and type(b) is float:#float + return(numpy.allclose(a, b)) + elif type(a) is numpy.ndarray and type(b) is numpy.ndarray:#ndarray + return(numpy.array_equiv(a, b))#alternative for float-arrays: numpy.allclose(a, b[, rtol, atol]) + elif isinstance(a, dict) and isinstance(b, dict):#dict + if len(a) != len(b): + return(False) + t = True + for key, val in a.items(): + if key not in b: + return(False) + t = _equal(val, b[key]) + if not t: + return(False) + return(t) + elif (isinstance(a, list) and isinstance(b, list)) or (isinstance(a, tuple) and isinstance(b, tuple)):#list, tuples + if len(a) != len(b): + return(False) + t = True + for vala, valb in zip(a, b): + t = _equal(vala, valb) + if not t: + return(False) + return(t) + else:#fallback + return(a == b) diff --git a/worldengine/world.py b/worldengine/world.py index 3fb80339..90918fb4 100644 --- a/worldengine/world.py +++ b/worldengine/world.py @@ -15,6 +15,7 @@ from worldengine.basic_map_operations import random_point import worldengine.protobuf.World_pb2 as Protobuf from worldengine.step import Step +from worldengine.common import _equal from worldengine.version import __version__ class World(object): @@ -37,17 +38,7 @@ def __init__(self, name, width, height, seed, num_plates, ocean_level, # def __eq__(self, other): - test = True - sd = self.__dict__ - od = other.__dict__ - for key, val in sd.items(): - if type(val) is numpy.ndarray: - test = numpy.array_equiv(val, od[key])#interesting alternative: numpy.allclose(a, b[, rtol, atol]) - else: - test = (val == od[key]) - if not test: - break - return test + return _equal(self.__dict__, other.__dict__) # # Serialization/Unserialization