Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions coremltools/models/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def __init__(self, key_type = None):
global _simple_type_remap
if key_type in _simple_type_remap:
key_type = _simple_type_remap[key_type]

if not isinstance(key_type, (Int64, String)):
raise TypeError("Key type for dictionary must be either string or integer.")

self.key_type = key_type

_DatatypeBase.__init__(self, "Dictionary", "Dictionary({})".format(repr(self.key_type), None))
_DatatypeBase.__init__(self, "Dictionary", "Dictionary({})".format(repr(self.key_type)), None)


_simple_type_remap = {int: Int64(),
Expand All @@ -140,19 +140,19 @@ def _is_valid_datatype(datatype_instance):
"""
Returns true if datatype_instance is a valid datatype object and false otherwise.
"""

# Remap so we can still use the python types for the simple cases
global _simple_type_remap
if datatype_instance in _simple_type_remap:
return True

# Now set the protobuf from this interface.
if isinstance(datatype_instance, (Int64, Double, String, Array)):
return True

elif isinstance(datatype_instance, Dictionary):
kt = datatype_instance.key_type

if isinstance(kt, (Int64, String)):
return True

Expand All @@ -161,28 +161,28 @@ def _is_valid_datatype(datatype_instance):

def _normalize_datatype(datatype_instance):
"""
Translates a user specified datatype to an instance of the ones defined above.
Translates a user specified datatype to an instance of the ones defined above.

Valid data types are passed through, and the following type specifications
are translated to the proper instances:

str, "String" -> String()
int, "Int64" -> Int64()
float, "Double" -> Double()

If a data type is not recognized, then an error is raised.
If a data type is not recognized, then an error is raised.
"""
global _simple_type_remap
if datatype_instance in _simple_type_remap:
return _simple_type_remap[datatype_instance]

# Now set the protobuf from this interface.
if isinstance(datatype_instance, (Int64, Double, String, Array)):
return datatype_instance

elif isinstance(datatype_instance, Dictionary):
kt = datatype_instance.key_type

if isinstance(kt, (Int64, String)):
return datatype_instance

Expand Down Expand Up @@ -217,7 +217,7 @@ def _set_datatype(proto_type_obj, datatype_instance):
proto_type_obj.dictionaryType.MergeFromString(b'')

kt = datatype_instance.key_type

if isinstance(kt, Int64):
proto_type_obj.dictionaryType.int64KeyType.MergeFromString(b'')
elif isinstance(kt, String):
Expand All @@ -229,4 +229,3 @@ def _set_datatype(proto_type_obj, datatype_instance):
raise TypeError("Datatype parameter not recognized; must be an instance "
"of datatypes.{Double, Int64, String, Dictionary, Array}, or "
"python int, float, or str types.")