From b486d299c60c848d3c5c7988283df32493402efb Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Thu, 23 May 2024 18:46:45 +0530 Subject: [PATCH] Replace `lex_starts_at` with `Tokens` in the formatter --- crates/ruff_benchmark/benches/formatter.rs | 9 +-- crates/ruff_python_formatter/src/cli.rs | 4 +- crates/ruff_python_formatter/src/context.rs | 14 ++++- crates/ruff_python_formatter/src/lib.rs | 16 +++--- crates/ruff_python_formatter/src/range.rs | 1 + .../src/statement/suite.rs | 7 ++- .../src/string/docstring.rs | 2 +- crates/ruff_python_formatter/src/verbatim.rs | 57 +++++++++---------- crates/ruff_python_parser/src/lib.rs | 5 ++ crates/ruff_wasm/src/lib.rs | 7 +-- 10 files changed, 64 insertions(+), 58 deletions(-) diff --git a/crates/ruff_benchmark/benches/formatter.rs b/crates/ruff_benchmark/benches/formatter.rs index d24b30fe300e3..9b3bff8b6ac35 100644 --- a/crates/ruff_benchmark/benches/formatter.rs +++ b/crates/ruff_benchmark/benches/formatter.rs @@ -57,13 +57,8 @@ fn benchmark_formatter(criterion: &mut Criterion) { b.iter(|| { let options = PyFormatOptions::from_extension(Path::new(case.name())) .with_preview(PreviewMode::Enabled); - let formatted = format_module_ast( - program.syntax(), - program.comment_ranges(), - case.code(), - options, - ) - .expect("Formatting to succeed"); + let formatted = format_module_ast(&program, case.code(), options) + .expect("Formatting to succeed"); formatted.print().expect("Printing to succeed") }); diff --git a/crates/ruff_python_formatter/src/cli.rs b/crates/ruff_python_formatter/src/cli.rs index 012c3024ddf8e..e7b1835824103 100644 --- a/crates/ruff_python_formatter/src/cli.rs +++ b/crates/ruff_python_formatter/src/cli.rs @@ -62,8 +62,8 @@ pub fn format_and_debug_print(source: &str, cli: &Cli, source_path: &Path) -> Re }); let source_code = SourceCode::new(source); - let formatted = format_module_ast(program.syntax(), program.comment_ranges(), source, options) - .context("Failed to format node")?; + let formatted = + format_module_ast(&program, source, options).context("Failed to format node")?; if cli.print_ir { println!("{}", formatted.document().display(source_code)); } diff --git a/crates/ruff_python_formatter/src/context.rs b/crates/ruff_python_formatter/src/context.rs index 3d5f23590adc3..61fc302ccbde9 100644 --- a/crates/ruff_python_formatter/src/context.rs +++ b/crates/ruff_python_formatter/src/context.rs @@ -3,6 +3,7 @@ use crate::other::f_string_element::FStringExpressionElementContext; use crate::PyFormatOptions; use ruff_formatter::{Buffer, FormatContext, GroupId, IndentWidth, SourceCode}; use ruff_python_ast::str::Quote; +use ruff_python_parser::Tokens; use ruff_source_file::Locator; use std::fmt::{Debug, Formatter}; use std::ops::{Deref, DerefMut}; @@ -12,6 +13,7 @@ pub struct PyFormatContext<'a> { options: PyFormatOptions, contents: &'a str, comments: Comments<'a>, + tokens: &Tokens, node_level: NodeLevel, indent_level: IndentLevel, /// Set to a non-None value when the formatter is running on a code @@ -28,11 +30,17 @@ pub struct PyFormatContext<'a> { } impl<'a> PyFormatContext<'a> { - pub(crate) fn new(options: PyFormatOptions, contents: &'a str, comments: Comments<'a>) -> Self { + pub(crate) fn new( + options: PyFormatOptions, + contents: &'a str, + comments: Comments<'a>, + tokens: &Tokens, + ) -> Self { Self { options, contents, comments, + tokens, node_level: NodeLevel::TopLevel(TopLevelStatementPosition::Other), indent_level: IndentLevel::new(0), docstring: None, @@ -69,6 +77,10 @@ impl<'a> PyFormatContext<'a> { &self.comments } + pub(crate) fn tokens(&self) -> &Tokens { + self.tokens + } + /// Returns a non-None value only if the formatter is running on a code /// snippet within a docstring. /// diff --git a/crates/ruff_python_formatter/src/lib.rs b/crates/ruff_python_formatter/src/lib.rs index c35c9caa0d29b..49927c0d7f369 100644 --- a/crates/ruff_python_formatter/src/lib.rs +++ b/crates/ruff_python_formatter/src/lib.rs @@ -6,7 +6,7 @@ use ruff_formatter::prelude::*; use ruff_formatter::{format, write, FormatError, Formatted, PrintError, Printed, SourceCode}; use ruff_python_ast::AstNode; use ruff_python_ast::Mod; -use ruff_python_parser::{parse, AsMode, ParseError, ParseErrorType}; +use ruff_python_parser::{parse, AsMode, ParseError, ParseErrorType, Program}; use ruff_python_trivia::CommentRanges; use ruff_source_file::Locator; @@ -114,23 +114,22 @@ pub fn format_module_source( ) -> Result { let source_type = options.source_type(); let program = parse(source, source_type.as_mode())?; - let formatted = format_module_ast(program.syntax(), program.comment_ranges(), source, options)?; + let formatted = format_module_ast(&program, source, options)?; Ok(formatted.print()?) } pub fn format_module_ast<'a>( - module: &'a Mod, - comment_ranges: &'a CommentRanges, + program: &'a Program, source: &'a str, options: PyFormatOptions, ) -> FormatResult>> { let source_code = SourceCode::new(source); - let comments = Comments::from_ast(module, source_code, comment_ranges); + let comments = Comments::from_ast(program.syntax(), source_code, program.comment_ranges()); let locator = Locator::new(source); let formatted = format!( - PyFormatContext::new(options, locator.contents(), comments), - [module.format()] + PyFormatContext::new(options, locator.contents(), comments, program.tokens()), + [program.syntax().format()] )?; formatted .context() @@ -201,8 +200,7 @@ def main() -> None: let source_path = "code_inline.py"; let program = parse(source, source_type.as_mode()).unwrap(); let options = PyFormatOptions::from_extension(Path::new(source_path)); - let formatted = - format_module_ast(program.syntax(), program.comment_ranges(), source, options).unwrap(); + let formatted = format_module_ast(&program, source, options).unwrap(); // Uncomment the `dbg` to print the IR. // Use `dbg_write!(f, []) instead of `write!(f, [])` in your formatting code to print some IR diff --git a/crates/ruff_python_formatter/src/range.rs b/crates/ruff_python_formatter/src/range.rs index 5c63b179b04ff..75479ba26e464 100644 --- a/crates/ruff_python_formatter/src/range.rs +++ b/crates/ruff_python_formatter/src/range.rs @@ -81,6 +81,7 @@ pub fn format_range( options.with_source_map_generation(SourceMapGeneration::Enabled), source, comments, + program.tokens(), ); let (enclosing_node, base_indent) = diff --git a/crates/ruff_python_formatter/src/statement/suite.rs b/crates/ruff_python_formatter/src/statement/suite.rs index 6cf752ea6707c..b28835caa540d 100644 --- a/crates/ruff_python_formatter/src/statement/suite.rs +++ b/crates/ruff_python_formatter/src/statement/suite.rs @@ -859,16 +859,17 @@ def trailing_func(): pass "; - let module = parse_module(source).unwrap(); + let program = parse_module(source).unwrap(); let context = PyFormatContext::new( PyFormatOptions::default(), source, - Comments::from_ranges(module.comment_ranges()), + Comments::from_ranges(program.comment_ranges()), + program.tokens(), ); let test_formatter = - format_with(|f: &mut PyFormatter| module.suite().format().with_options(level).fmt(f)); + format_with(|f: &mut PyFormatter| program.suite().format().with_options(level).fmt(f)); let formatted = format!(context, [test_formatter]).unwrap(); let printed = formatted.print().unwrap(); diff --git a/crates/ruff_python_formatter/src/string/docstring.rs b/crates/ruff_python_formatter/src/string/docstring.rs index 9ae80f7f0645f..e0bb8d7e26d28 100644 --- a/crates/ruff_python_formatter/src/string/docstring.rs +++ b/crates/ruff_python_formatter/src/string/docstring.rs @@ -1558,7 +1558,7 @@ fn docstring_format_source( crate::Comments::from_ast(program.syntax(), source_code, program.comment_ranges()); let locator = Locator::new(source); - let ctx = PyFormatContext::new(options, locator.contents(), comments) + let ctx = PyFormatContext::new(options, locator.contents(), comments, program.tokens()) .in_docstring(docstring_quote_style); let formatted = crate::format!(ctx, [program.syntax().format()])?; formatted diff --git a/crates/ruff_python_formatter/src/verbatim.rs b/crates/ruff_python_formatter/src/verbatim.rs index 94635802ef767..a6bb434fdfca4 100644 --- a/crates/ruff_python_formatter/src/verbatim.rs +++ b/crates/ruff_python_formatter/src/verbatim.rs @@ -1,13 +1,13 @@ use std::borrow::Cow; use std::iter::FusedIterator; +use std::slice::Iter; use unicode_width::UnicodeWidthStr; use ruff_formatter::{write, FormatError}; use ruff_python_ast::AnyNodeRef; use ruff_python_ast::Stmt; -use ruff_python_parser::lexer::{lex_starts_at, LexResult}; -use ruff_python_parser::{Mode, Tok}; +use ruff_python_parser::{self as parser, TokenKind}; use ruff_python_trivia::lines_before; use ruff_source_file::Locator; use ruff_text_size::{Ranged, TextRange, TextSize}; @@ -725,13 +725,13 @@ struct FormatVerbatimStatementRange { impl Format> for FormatVerbatimStatementRange { fn fmt(&self, f: &mut Formatter>) -> FormatResult<()> { - let lexer = lex_starts_at( - &f.context().source()[self.verbatim_range], - Mode::Module, - self.verbatim_range.start(), + let logical_lines = LogicalLinesIter::new( + f.context() + .tokens() + .tokens_in_range(self.verbatim_range) + .iter(), + self.verbatim_range, ); - - let logical_lines = LogicalLinesIter::new(lexer, self.verbatim_range); let mut first = true; for logical_line in logical_lines { @@ -784,43 +784,47 @@ impl Format> for FormatVerbatimStatementRange { } } -struct LogicalLinesIter { - lexer: I, +struct LogicalLinesIter<'a> { + tokens: Iter<'a, parser::Token>, // The end of the last logical line last_line_end: TextSize, // The position where the content to lex ends. content_end: TextSize, } -impl LogicalLinesIter { - fn new(lexer: I, verbatim_range: TextRange) -> Self { +impl LogicalLinesIter<'_> { + fn new(tokens: Iter<'_, parser::Token>, verbatim_range: TextRange) -> Self { Self { - lexer, + tokens, last_line_end: verbatim_range.start(), content_end: verbatim_range.end(), } } } -impl Iterator for LogicalLinesIter -where - I: Iterator, -{ +impl Iterator for LogicalLinesIter { type Item = FormatResult; fn next(&mut self) -> Option { let mut parens = 0u32; let (content_end, full_end) = loop { - match self.lexer.next() { - Some(Ok((token, range))) => match token { - Tok::Newline => break (range.start(), range.end()), + match self.tokens.next() { + Some(token) if token.kind() == TokenKind::Unknown => { + return Some(Err(FormatError::syntax_error( + "Unexpected token when lexing verbatim statement range.", + ))) + } + Some(token) => match token.kind() { + TokenKind::Newline => break (token.start(), token.end()), // Ignore if inside an expression - Tok::NonLogicalNewline if parens == 0 => break (range.start(), range.end()), - Tok::Lbrace | Tok::Lpar | Tok::Lsqb => { + TokenKind::NonLogicalNewline if parens == 0 => { + break (token.start(), token.end()) + } + TokenKind::Lbrace | TokenKind::Lpar | TokenKind::Lsqb => { parens = parens.saturating_add(1); } - Tok::Rbrace | Tok::Rpar | Tok::Rsqb => { + TokenKind::Rbrace | TokenKind::Rpar | TokenKind::Rsqb => { parens = parens.saturating_sub(1); } _ => {} @@ -839,11 +843,6 @@ where None }; } - Some(Err(_)) => { - return Some(Err(FormatError::syntax_error( - "Unexpected token when lexing verbatim statement range.", - ))) - } } }; @@ -857,7 +856,7 @@ where } } -impl FusedIterator for LogicalLinesIter where I: Iterator {} +impl FusedIterator for LogicalLinesIter where I: Iterator {} /// A logical line or a comment (or form feed only) line struct LogicalLine { diff --git a/crates/ruff_python_parser/src/lib.rs b/crates/ruff_python_parser/src/lib.rs index d65842c31bd9f..5dfe22048372a 100644 --- a/crates/ruff_python_parser/src/lib.rs +++ b/crates/ruff_python_parser/src/lib.rs @@ -408,6 +408,11 @@ impl Tokens { /// The range `4..10` would return a slice of `Name`, `Lpar`, `Rpar`, and `Colon` tokens. But, /// if either the start or end position of the given range doesn't match any of the tokens /// (like `5..10` or `4..12`), the returned slice will be empty. + /// + /// ## Note + /// + /// The returned slice can contain the [`TokenKind::Unknown`] token if there was a lexical + /// error encountered within the given range. pub fn tokens_in_range(&self, range: TextRange) -> &[Token] { let Ok(start) = self.binary_search_by_key(&range.start(), Ranged::start) else { return &[]; diff --git a/crates/ruff_wasm/src/lib.rs b/crates/ruff_wasm/src/lib.rs index 88ff6fc845a4b..57ca9eb50b1d4 100644 --- a/crates/ruff_wasm/src/lib.rs +++ b/crates/ruff_wasm/src/lib.rs @@ -293,11 +293,6 @@ impl<'a> ParsedModule<'a> { .to_format_options(PySourceType::default(), self.source_code) .with_source_map_generation(SourceMapGeneration::Enabled); - format_module_ast( - self.program.syntax(), - self.program.comment_ranges(), - self.source_code, - options, - ) + format_module_ast(&self.program, self.source_code, options) } }