diff --git a/lang/py/.gitignore b/lang/py/.gitignore index a689956457b..1f829ef8d86 100644 --- a/lang/py/.gitignore +++ b/lang/py/.gitignore @@ -1,14 +1,15 @@ *.egg-info/ +.coverage* .eggs/ -build/ -lib/ -userlogs/ +.tox avro/HandshakeRequest.avsc avro/HandshakeResponse.avsc avro/VERSION.txt avro/interop.avsc +avro/test/interop/ avro/tether/InputProtocol.avpr avro/tether/OutputProtocol.avpr -avro/test/interop/ -.tox +build/ dist +lib/ +userlogs/ diff --git a/lang/py/avro/schema.py b/lang/py/avro/schema.py index 60b3a28685c..f41852dab1e 100644 --- a/lang/py/avro/schema.py +++ b/lang/py/avro/schema.py @@ -62,6 +62,11 @@ # Constants # +# The name portion of a fullname, record field names, and enum symbols must: +# start with [A-Za-z_] +# subsequently contain only [A-Za-z0-9_] +_BASE_NAME_PATTERN = re.compile(r'(?:^|\.)[A-Za-z_][A-Za-z0-9_]*$') + PRIMITIVE_TYPES = ( 'null', 'boolean', @@ -133,6 +138,14 @@ class AvroWarning(UserWarning): class IgnoredLogicalType(AvroWarning): """Warnings for unknown or invalid logical types.""" + +def validate_basename(basename): + """Raise InvalidName if the given basename is not a valid name.""" + if not _BASE_NAME_PATTERN.search(basename): + raise InvalidName("{!s} is not a valid Avro name because it " + "does not match the pattern {!s}".format( + basename, _BASE_NAME_PATTERN.pattern)) + # # Base Classes # @@ -184,14 +197,11 @@ def to_json(self, names): """ raise Exception("Must be implemented by subclasses.") + + class Name(object): """Class to describe Avro name.""" - # The name portion of a fullname, record field names, and enum symbols must: - # start with [A-Za-z_] - # subsequently contain only [A-Za-z0-9_] - _base_name_pattern = re.compile(r'(?:^|\.)[A-Za-z_][A-Za-z0-9_]*$') - _full = None def __init__(self, name_attr, space_attr, default_space): @@ -222,10 +232,7 @@ def __init__(self, name_attr, space_attr, default_space): def _validate_fullname(self, fullname): for name in fullname.split('.'): - if not self._base_name_pattern.search(name): - raise InvalidName("{!s} is not a valid Avro name because it " - "does not match the pattern {!s}".format( - name, self._base_name_pattern.pattern)) + validate_basename(name) def __eq__(self, other): """Equality of names is defined on the fullname and is case-sensitive.""" @@ -540,15 +547,18 @@ def __eq__(self, that): class EnumSchema(NamedSchema): - def __init__(self, name, namespace, symbols, names=None, doc=None, other_props=None): - # Ensure valid ctor args - if not isinstance(symbols, list): - fail_msg = 'Enum Schema requires a JSON array for the symbols property.' - raise AvroException(fail_msg) - elif False in [isinstance(s, basestring) for s in symbols]: - fail_msg = 'Enum Schema requires all symbols to be JSON strings.' - raise AvroException(fail_msg) - elif len(set(symbols)) < len(symbols): + def __init__(self, name, namespace, symbols, names=None, doc=None, other_props=None, validate_enum_symbols=True): + """ + @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names. + """ + if validate_enum_symbols: + for symbol in symbols: + try: + validate_basename(symbol) + except InvalidName: + raise InvalidName("An enum symbol must be a valid schema name.") + + if len(set(symbols)) < len(symbols): fail_msg = 'Duplicate symbol: %s' % symbols raise AvroException(fail_msg) @@ -557,7 +567,8 @@ def __init__(self, name, namespace, symbols, names=None, doc=None, other_props=N # Add class members self.set_prop('symbols', symbols) - if doc is not None: self.set_prop('doc', doc) + if doc is not None: + self.set_prop('doc', doc) # read-only properties symbols = property(lambda self: self.get_prop('symbols')) @@ -919,11 +930,12 @@ def make_logical_schema(logical_type, type_, other_props): warnings.warn(warning) return None -def make_avsc_object(json_data, names=None): +def make_avsc_object(json_data, names=None, validate_enum_symbols=True): """ Build Avro Schema from data parsed out of JSON string. @arg names: A Names object (tracks seen names and default space) + @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names. """ if names is None: names = Names() @@ -953,7 +965,7 @@ def make_avsc_object(json_data, names=None): elif type == 'enum': symbols = json_data.get('symbols') doc = json_data.get('doc') - return EnumSchema(name, namespace, symbols, names, doc, other_props) + return EnumSchema(name, namespace, symbols, names, doc, other_props, validate_enum_symbols) elif type in ['record', 'error']: fields = json_data.get('fields') doc = json_data.get('doc') @@ -990,8 +1002,11 @@ def make_avsc_object(json_data, names=None): raise SchemaParseException(fail_msg) # TODO(hammer): make method for reading from a file? -def parse(json_string): - """Constructs the Schema from the JSON text.""" +def parse(json_string, validate_enum_symbols=True): + """Constructs the Schema from the JSON text. + + @arg validate_enum_symbols: If False, will allow enum symbols that are not valid Avro names. + """ # parse the JSON try: json_data = json.loads(json_string) @@ -1007,4 +1022,4 @@ def parse(json_string): names = Names() # construct the Avro Schema object - return make_avsc_object(json_data, names) + return make_avsc_object(json_data, names, validate_enum_symbols) diff --git a/lang/py/avro/test/test_schema.py b/lang/py/avro/test/test_schema.py index 8f082032ed0..9fcc0b65f45 100644 --- a/lang/py/avro/test/test_schema.py +++ b/lang/py/avro/test/test_schema.py @@ -85,11 +85,13 @@ class InvalidTestSchema(ValidTestSchema): ENUM_EXAMPLES = [ ValidTestSchema({"type": "enum", "name": "Test", "symbols": ["A", "B"]}), + ValidTestSchema({"type": "enum", "name": "AVRO2174", "symbols": ["nowhitespace"]}), InvalidTestSchema({"type": "enum", "name": "Status", "symbols": "Normal Caution Critical"}), InvalidTestSchema({"type": "enum", "name": [0, 1, 1, 2, 3, 5, 8], "symbols": ["Golden", "Mean"]}), InvalidTestSchema({"type": "enum", "symbols" : ["I", "will", "fail", "no", "name"]}), InvalidTestSchema({"type": "enum", "name": "Test", "symbols": ["AA", "AA"]}), + InvalidTestSchema({"type": "enum", "name": "AVRO2174", "symbols": ["white space"]}), ] ARRAY_EXAMPLES = [ @@ -441,6 +443,26 @@ def test_fixed_decimal_invalid_max_precision(self): self.assertIsInstance(fixed_decimal, schema.FixedSchema) self.assertNotIsInstance(fixed_decimal, schema.DecimalLogicalSchema) + def test_parse_invalid_symbol(self): + """Disabling enumschema symbol validation should allow invalid symbols to pass.""" + test_schema_string = json.dumps({ + "type": "enum", "name": "AVRO2174", "symbols": ["white space"]}) + + try: + case = schema.parse(test_schema_string, validate_enum_symbols=True) + except schema.InvalidName: + pass + else: + self.fail("When enum symbol validation is enabled, " + "an invalid symbol should raise InvalidName.") + + try: + case = schema.parse(test_schema_string, validate_enum_symbols=False) + except schema.InvalidName: + self.fail("When enum symbol validation is disabled, " + "an invalid symbol should not raise InvalidName.") + + class SchemaParseTestCase(unittest.TestCase): """Enable generating parse test cases over all the valid and invalid example schema.""" @@ -479,6 +501,7 @@ def parse_invalid(self): else: self.fail("Invalid schema should not have parsed: {!s}".format(self.test_schema)) + class RoundTripParseTestCase(unittest.TestCase): """Enable generating round-trip parse test cases over all the valid test schema."""