diff --git a/pyupgrade/_plugins/legacy.py b/pyupgrade/_plugins/legacy.py index 4c7897ea..568eebe0 100644 --- a/pyupgrade/_plugins/legacy.py +++ b/pyupgrade/_plugins/legacy.py @@ -131,19 +131,42 @@ def visit_Call(self, node: ast.Call) -> None: isinstance(node.func, ast.Name) and node.func.id == 'super' and len(node.args) == 2 and - isinstance(node.args[0], ast.Name) and isinstance(node.args[1], ast.Name) and # there are at least two scopes len(self._scopes) >= 2 and - # the second to last scope is the class in arg1 - isinstance(self._scopes[-2].node, ast.ClassDef) and - node.args[0].id == self._scopes[-2].node.name and # the last scope is a function where the first arg is arg2 isinstance(self._scopes[-1].node, FUNC_TYPES) and self._scopes[-1].node.args.args and node.args[1].id == self._scopes[-1].node.args.args[0].arg ): - self.super_offsets.add(ast_to_offset(node)) + args = node.args[0] + scope = len(self._scopes) - 2 + current_scope = self._scopes[scope] + # if in nested classes, all names in arg1 must match the scopes + while ( + isinstance(args, ast.Attribute) and + scope > 0 and + isinstance(current_scope.node, ast.ClassDef) and + args.attr == current_scope.node.name + ): + args = args.value + scope -= 1 + current_scope = self._scopes[scope] + # now check if it is outer most class and its name match + if ( + isinstance(args, ast.Name) and + isinstance(current_scope.node, ast.ClassDef) and + args.id == current_scope.node.name and + # an enclosing scope cannot be a class + ( + scope == 0 or + not isinstance( + self._scopes[scope - 1].node, + ast.ClassDef, + ) + ) + ): + self.super_offsets.add(ast_to_offset(node)) self.generic_visit(node) diff --git a/tests/features/super_test.py b/tests/features/super_test.py index 67b28aaa..e795526c 100644 --- a/tests/features/super_test.py +++ b/tests/features/super_test.py @@ -20,6 +20,19 @@ 'class C(Base):\n' ' def f(self):\n' ' super(Base, self).f()\n', + 'class Outer:\n' # common nesting + ' class C(Base):\n' + ' def f(self):\n' + ' super(C, self).f()\n', + 'class Outer:\n' # higher levels of nesting + ' class Inner:\n' + ' class C(Base):\n' + ' def f(self):\n' + ' super(Inner.C, self).f()\n', + 'class Outer:\n' # super arg1 nested in unrelated name + ' class C(Base):\n' + ' def f(self):\n' + ' super(some_module.Outer.C, self).f()\n', # super outside of a class (technically legal!) 'def f(self):\n' @@ -87,12 +100,36 @@ def test_fix_super_noop(s): 'class Outer:\n' ' class C(Base):\n' ' def f(self):\n' - ' super (C, self).f()\n', + ' super (Outer.C, self).f()\n', 'class Outer:\n' ' class C(Base):\n' ' def f(self):\n' ' super().f()\n', ), + ( + 'def f():\n' + ' class Outer:\n' + ' class C(Base):\n' + ' def f(self):\n' + ' super(Outer.C, self).f()\n', + 'def f():\n' + ' class Outer:\n' + ' class C(Base):\n' + ' def f(self):\n' + ' super().f()\n', + ), + ( + 'class A:\n' + ' class B:\n' + ' class C:\n' + ' def f(self):\n' + ' super(A.B.C, self).f()\n', + 'class A:\n' + ' class B:\n' + ' class C:\n' + ' def f(self):\n' + ' super().f()\n', + ), ( 'class C(Base):\n' ' f = lambda self: super(C, self).f()\n',