Skip to content

Commit

Permalink
Reset model state when exiting deferred visitors (#6208)
Browse files Browse the repository at this point in the history
## Summary

Very subtle bug related to the AST traversal. Given:

```python
from __future__ import annotations

from logging import getLogger

__all__ = ("getLogger",)


def foo() -> None:
    pass
```

We end up visiting the `-> None` annotation, then reusing the state
snapshot when we go to visit the `__all__` exports, so when we visit
`"getLogger"`, we think we're inside of a deferred type annotation.

This PR changes all the deferred visitors to snapshot and restore the
state, which is a lot safer -- that way, the visitors avoid modifying
the current visitor state. (Previously, they implicitly left the visitor
state set to the state of the _last_ thing they visited.)

Closes #6207.
  • Loading branch information
charliermarsh committed Jul 31, 2023
1 parent 0fddb31 commit 6ee5cb3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Regression test: ensure that we don't treat the export entry as a typing-only reference."""
from __future__ import annotations

from logging import getLogger

__all__ = ("getLogger",)


def foo() -> None:
pass
14 changes: 14 additions & 0 deletions crates/ruff/src/checkers/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1731,6 +1731,7 @@ impl<'a> Checker<'a> {
}

fn visit_deferred_future_type_definitions(&mut self) {
let snapshot = self.semantic.snapshot();
while !self.deferred.future_type_definitions.is_empty() {
let type_definitions = std::mem::take(&mut self.deferred.future_type_definitions);
for (expr, snapshot) in type_definitions {
Expand All @@ -1741,9 +1742,11 @@ impl<'a> Checker<'a> {
self.visit_expr(expr);
}
}
self.semantic.restore(snapshot);
}

fn visit_deferred_type_param_definitions(&mut self) {
let snapshot = self.semantic.snapshot();
while !self.deferred.type_param_definitions.is_empty() {
let type_params = std::mem::take(&mut self.deferred.type_param_definitions);
for (type_param, snapshot) in type_params {
Expand All @@ -1757,9 +1760,11 @@ impl<'a> Checker<'a> {
}
}
}
self.semantic.restore(snapshot);
}

fn visit_deferred_string_type_definitions(&mut self, allocator: &'a typed_arena::Arena<Expr>) {
let snapshot = self.semantic.snapshot();
while !self.deferred.string_type_definitions.is_empty() {
let type_definitions = std::mem::take(&mut self.deferred.string_type_definitions);
for (range, value, snapshot) in type_definitions {
Expand Down Expand Up @@ -1803,9 +1808,11 @@ impl<'a> Checker<'a> {
}
}
}
self.semantic.restore(snapshot);
}

fn visit_deferred_functions(&mut self) {
let snapshot = self.semantic.snapshot();
while !self.deferred.functions.is_empty() {
let deferred_functions = std::mem::take(&mut self.deferred.functions);
for snapshot in deferred_functions {
Expand All @@ -1823,9 +1830,11 @@ impl<'a> Checker<'a> {
}
}
}
self.semantic.restore(snapshot);
}

fn visit_deferred_lambdas(&mut self) {
let snapshot = self.semantic.snapshot();
while !self.deferred.lambdas.is_empty() {
let lambdas = std::mem::take(&mut self.deferred.lambdas);
for (expr, snapshot) in lambdas {
Expand All @@ -1844,10 +1853,13 @@ impl<'a> Checker<'a> {
}
}
}
self.semantic.restore(snapshot);
}

/// Run any lint rules that operate over the module exports (i.e., members of `__all__`).
fn visit_exports(&mut self) {
let snapshot = self.semantic.snapshot();

let exports: Vec<(&str, TextRange)> = self
.semantic
.global_scope()
Expand Down Expand Up @@ -1890,6 +1902,8 @@ impl<'a> Checker<'a> {
}
}
}

self.semantic.restore(snapshot);
}
}

Expand Down
19 changes: 10 additions & 9 deletions crates/ruff/src/rules/flake8_type_checking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ mod tests {
use crate::test::{test_path, test_snippet};
use crate::{assert_messages, settings};

#[test_case(Rule::TypingOnlyFirstPartyImport, Path::new("TCH001.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("TCH002.py"))]
#[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("TCH003.py"))]
#[test_case(Rule::EmptyTypeCheckingBlock, Path::new("TCH005.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_1.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_10.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_11.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_12.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_13.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_14.pyi"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_2.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_3.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_4.py"))]
Expand All @@ -27,12 +30,10 @@ mod tests {
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_7.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_8.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_9.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_10.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_11.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_12.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_13.py"))]
#[test_case(Rule::RuntimeImportInTypeCheckingBlock, Path::new("TCH004_14.pyi"))]
#[test_case(Rule::EmptyTypeCheckingBlock, Path::new("TCH005.py"))]
#[test_case(Rule::TypingOnlyFirstPartyImport, Path::new("TCH001.py"))]
#[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("TCH003.py"))]
#[test_case(Rule::TypingOnlyStandardLibraryImport, Path::new("snapshot.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("TCH002.py"))]
#[test_case(Rule::TypingOnlyThirdPartyImport, Path::new("strict.py"))]
fn rules(rule_code: Rule, path: &Path) -> Result<()> {
let snapshot = format!("{}_{}", rule_code.as_ref(), path.to_string_lossy());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
source: crates/ruff/src/rules/flake8_type_checking/mod.rs
---

0 comments on commit 6ee5cb3

Please sign in to comment.