From d3cdb475d92d418fece6bdc3c5363ea70479b1e7 Mon Sep 17 00:00:00 2001 From: Jon Parise Date: Wed, 4 Aug 2021 16:10:03 -0700 Subject: [PATCH] Avoid stack exhaustion in superclass_names() Because this function is recursive, it can exhaust the stack if the class hierarchy contains duplicate names. Avoid that using simple memoization. Also, these two supporting functions can be classmethods because they don't use any instance-level state. --- src/pep8ext_naming.py | 33 +++++++++++++++++---------------- testsuite/N818.py | 4 ++++ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/pep8ext_naming.py b/src/pep8ext_naming.py index b227495..7893147 100644 --- a/src/pep8ext_naming.py +++ b/src/pep8ext_naming.py @@ -296,23 +296,24 @@ class ClassNameCheck(BaseASTCheck): N801 = "class name '{name}' should use CapWords convention" N818 = "exception name '{name}' should be named with an Error suffix" - def get_class_def(self, name, parents): + @classmethod + def get_classdef(cls, name, parents): for parent in parents: - for definition in parent.body: - is_class_definition = isinstance(definition, ast.ClassDef) - if is_class_definition and definition.name == name: - return definition - - def superclass_names(self, name, parents): - class_ids = set() - class_def = self.get_class_def(name, parents) - if not class_def: - return class_ids - for base in class_def.bases: - if hasattr(base, "id"): - class_ids.add(base.id) - class_ids.update(self.superclass_names(base.id, parents)) - return class_ids + for node in parent.body: + if isinstance(node, ast.ClassDef) and node.name == name: + return node + + @classmethod + def superclass_names(cls, name, parents, _names=None): + names = _names or set() + classdef = cls.get_classdef(name, parents) + if not classdef: + return names + for base in classdef.bases: + if isinstance(base, ast.Name) and base.id not in names: + names.add(base.id) + names.update(cls.superclass_names(base.id, parents, names)) + return names def visit_classdef(self, node, parents, ignore=None): name = node.name diff --git a/testsuite/N818.py b/testsuite/N818.py index 1d0ae10..755bf2d 100644 --- a/testsuite/N818.py +++ b/testsuite/N818.py @@ -28,3 +28,7 @@ class Mixin: pass class MixinActionClass(Mixin, MixinError): pass +#: Okay +from decimal import Decimal +class Decimal(Decimal): + pass