diff --git a/autoflake.py b/autoflake.py index ccb6f09..d4a7dfc 100755 --- a/autoflake.py +++ b/autoflake.py @@ -346,7 +346,8 @@ def filter_code(source, additional_imports=None, duplicate_key_line_numbers(messages)) if ( marked_key_line_numbers and - any_complex_duplicate_key_cases(marked_key_line_numbers, source) + any_complex_duplicate_key_cases(messages, + source) ): marked_key_line_numbers = frozenset() else: @@ -380,19 +381,22 @@ def filter_code(source, additional_imports=None, previous_line = line -def any_complex_duplicate_key_cases(marked_line_numbers, source): +def any_complex_duplicate_key_cases(messages, source): """Return True if duplicate key lines contain complex code. We don't want to bother trying to parse this stuff and get it right. """ lines = source.split('\n') - for line_number in marked_line_numbers: - line = lines[line_number - 1] - - if line.rstrip().endswith((':', '\\')): - return True + for message in messages: + line = lines[message.lineno - 1] + key = message.message_args[0] - if ':' not in line or '#' in line: + if ( + line.rstrip().endswith((':', '\\')) or + not dict_entry_has_key(line, key) or + '#' in line or + not line.rstrip().endswith(',') + ): return True diff --git a/test_autoflake.py b/test_autoflake.py index f287a29..74b33ed 100755 --- a/test_autoflake.py +++ b/test_autoflake.py @@ -492,6 +492,40 @@ def test_filter_code_should_ignore_duplicate_key_with_comments(self): (0,1): 3, } print(a) +""" + + self.assertEqual( + code, + ''.join(autoflake.filter_code(code, + remove_duplicate_keys=True))) + + def test_filter_code_should_ignore_duplicate_key_with_multiline_key(self): + """We only handle simple cases.""" + code = """\ +a = { + (0,1 + ): 1, + (0, 1): 'two', + (0,1): 3, +} +print(a) +""" + + self.assertEqual( + code, + ''.join(autoflake.filter_code(code, + remove_duplicate_keys=True))) + + def test_filter_code_should_ignore_duplicate_key_with_no_comma(self): + """We don't want to delete the line and leave a lone comma.""" + code = """\ +a = { + (0,1) : 1 + , + (0, 1): 'two', + (0,1): 3, +} +print(a) """ self.assertEqual(