From d58a744513d910e8664699537ba69c0fa5a6e658 Mon Sep 17 00:00:00 2001 From: Yuduo Date: Tue, 14 Jan 2020 14:35:08 -0800 Subject: [PATCH] Minor fix. --- coremltools/models/datatypes.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/coremltools/models/datatypes.py b/coremltools/models/datatypes.py index da2efa8c3..6652aaefc 100644 --- a/coremltools/models/datatypes.py +++ b/coremltools/models/datatypes.py @@ -116,13 +116,13 @@ def __init__(self, key_type = None): global _simple_type_remap if key_type in _simple_type_remap: key_type = _simple_type_remap[key_type] - + if not isinstance(key_type, (Int64, String)): raise TypeError("Key type for dictionary must be either string or integer.") self.key_type = key_type - _DatatypeBase.__init__(self, "Dictionary", "Dictionary({})".format(repr(self.key_type), None)) + _DatatypeBase.__init__(self, "Dictionary", "Dictionary({})".format(repr(self.key_type)), None) _simple_type_remap = {int: Int64(), @@ -140,19 +140,19 @@ def _is_valid_datatype(datatype_instance): """ Returns true if datatype_instance is a valid datatype object and false otherwise. """ - + # Remap so we can still use the python types for the simple cases global _simple_type_remap if datatype_instance in _simple_type_remap: return True - + # Now set the protobuf from this interface. if isinstance(datatype_instance, (Int64, Double, String, Array)): return True elif isinstance(datatype_instance, Dictionary): kt = datatype_instance.key_type - + if isinstance(kt, (Int64, String)): return True @@ -161,8 +161,8 @@ def _is_valid_datatype(datatype_instance): def _normalize_datatype(datatype_instance): """ - Translates a user specified datatype to an instance of the ones defined above. - + Translates a user specified datatype to an instance of the ones defined above. + Valid data types are passed through, and the following type specifications are translated to the proper instances: @@ -170,19 +170,19 @@ def _normalize_datatype(datatype_instance): int, "Int64" -> Int64() float, "Double" -> Double() - If a data type is not recognized, then an error is raised. + If a data type is not recognized, then an error is raised. """ global _simple_type_remap if datatype_instance in _simple_type_remap: return _simple_type_remap[datatype_instance] - + # Now set the protobuf from this interface. if isinstance(datatype_instance, (Int64, Double, String, Array)): return datatype_instance elif isinstance(datatype_instance, Dictionary): kt = datatype_instance.key_type - + if isinstance(kt, (Int64, String)): return datatype_instance @@ -217,7 +217,7 @@ def _set_datatype(proto_type_obj, datatype_instance): proto_type_obj.dictionaryType.MergeFromString(b'') kt = datatype_instance.key_type - + if isinstance(kt, Int64): proto_type_obj.dictionaryType.int64KeyType.MergeFromString(b'') elif isinstance(kt, String): @@ -229,4 +229,3 @@ def _set_datatype(proto_type_obj, datatype_instance): raise TypeError("Datatype parameter not recognized; must be an instance " "of datatypes.{Double, Int64, String, Dictionary, Array}, or " "python int, float, or str types.") -