Skip to content

Commit

Permalink
Merge pull request #707 from hamdanal/nested-super
Browse files Browse the repository at this point in the history
Rewrite 2-arg super call in nested class
  • Loading branch information
asottile authored Sep 26, 2022
2 parents d9461f5 + 00aa729 commit a41a733
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
33 changes: 28 additions & 5 deletions pyupgrade/_plugins/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,42 @@ def visit_Call(self, node: ast.Call) -> None:
isinstance(node.func, ast.Name) and
node.func.id == 'super' and
len(node.args) == 2 and
isinstance(node.args[0], ast.Name) and
isinstance(node.args[1], ast.Name) and
# there are at least two scopes
len(self._scopes) >= 2 and
# the second to last scope is the class in arg1
isinstance(self._scopes[-2].node, ast.ClassDef) and
node.args[0].id == self._scopes[-2].node.name and
# the last scope is a function where the first arg is arg2
isinstance(self._scopes[-1].node, FUNC_TYPES) and
self._scopes[-1].node.args.args and
node.args[1].id == self._scopes[-1].node.args.args[0].arg
):
self.super_offsets.add(ast_to_offset(node))
args = node.args[0]
scope = len(self._scopes) - 2
current_scope = self._scopes[scope]
# if in nested classes, all names in arg1 must match the scopes
while (
isinstance(args, ast.Attribute) and
scope > 0 and
isinstance(current_scope.node, ast.ClassDef) and
args.attr == current_scope.node.name
):
args = args.value
scope -= 1
current_scope = self._scopes[scope]
# now check if it is outer most class and its name match
if (
isinstance(args, ast.Name) and
isinstance(current_scope.node, ast.ClassDef) and
args.id == current_scope.node.name and
# an enclosing scope cannot be a class
(
scope == 0 or
not isinstance(
self._scopes[scope - 1].node,
ast.ClassDef,
)
)
):
self.super_offsets.add(ast_to_offset(node))

self.generic_visit(node)

Expand Down
39 changes: 38 additions & 1 deletion tests/features/super_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@
'class C(Base):\n'
' def f(self):\n'
' super(Base, self).f()\n',
'class Outer:\n' # common nesting
' class C(Base):\n'
' def f(self):\n'
' super(C, self).f()\n',
'class Outer:\n' # higher levels of nesting
' class Inner:\n'
' class C(Base):\n'
' def f(self):\n'
' super(Inner.C, self).f()\n',
'class Outer:\n' # super arg1 nested in unrelated name
' class C(Base):\n'
' def f(self):\n'
' super(some_module.Outer.C, self).f()\n',
# super outside of a class (technically legal!)
'def f(self):\n'
Expand Down Expand Up @@ -87,12 +100,36 @@ def test_fix_super_noop(s):
'class Outer:\n'
' class C(Base):\n'
' def f(self):\n'
' super (C, self).f()\n',
' super (Outer.C, self).f()\n',
'class Outer:\n'
' class C(Base):\n'
' def f(self):\n'
' super().f()\n',
),
(
'def f():\n'
' class Outer:\n'
' class C(Base):\n'
' def f(self):\n'
' super(Outer.C, self).f()\n',
'def f():\n'
' class Outer:\n'
' class C(Base):\n'
' def f(self):\n'
' super().f()\n',
),
(
'class A:\n'
' class B:\n'
' class C:\n'
' def f(self):\n'
' super(A.B.C, self).f()\n',
'class A:\n'
' class B:\n'
' class C:\n'
' def f(self):\n'
' super().f()\n',
),
(
'class C(Base):\n'
' f = lambda self: super(C, self).f()\n',
Expand Down

0 comments on commit a41a733

Please sign in to comment.