diff --git a/sqlmesh/core/linter/definition.py b/sqlmesh/core/linter/definition.py index a8ebdec0e8..14ae1dd2ef 100644 --- a/sqlmesh/core/linter/definition.py +++ b/sqlmesh/core/linter/definition.py @@ -8,7 +8,7 @@ from collections.abc import Iterator, Iterable, Set, Mapping, Callable from functools import reduce from sqlmesh.core.model import Model -from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.linter.rule import Rule, RuleViolation, Range from sqlmesh.core.console import LinterConsole, get_console if t.TYPE_CHECKING: @@ -74,6 +74,7 @@ def lint_model( violation_msg=violation.violation_msg, model=model, violation_type="error", + violation_range=violation.violation_range, ) for violation in error_violations ] + [ @@ -82,6 +83,7 @@ def lint_model( violation_msg=violation.violation_msg, model=model, violation_type="warning", + violation_range=violation.violation_range, ) for violation in warn_violations ] @@ -149,7 +151,8 @@ def __init__( violation_msg: str, model: Model, violation_type: t.Literal["error", "warning"], + violation_range: t.Optional[Range] = None, ) -> None: - super().__init__(rule, violation_msg) + super().__init__(rule, violation_msg, violation_range) self.model = model self.violation_type = violation_type diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py new file mode 100644 index 0000000000..6d0d796c97 --- /dev/null +++ b/sqlmesh/core/linter/helpers.py @@ -0,0 +1,115 @@ +from pathlib import Path + +from sqlmesh.core.linter.rule import Position, Range +from sqlmesh.utils.pydantic import PydanticModel +import typing as t + + +class TokenPositionDetails(PydanticModel): + """ + Details about a token's position in the source code in the structure provided by SQLGlot. + + Attributes: + line (int): The line that the token ends on. + col (int): The column that the token ends on. + start (int): The start index of the token. + end (int): The ending index of the token. + """ + + line: int + col: int + start: int + end: int + + @staticmethod + def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails": + return TokenPositionDetails( + line=meta["line"], + col=meta["col"], + start=meta["start"], + end=meta["end"], + ) + + def to_range(self, read_file: t.Optional[t.List[str]]) -> Range: + """ + Convert a TokenPositionDetails object to a Range object. + + In the circumstances where the token's start and end positions are the same, + there is no need for a read_file parameter, as the range can be derived from the token's + line and column. This is an optimization to avoid unnecessary file reads and should + only be used when the token represents a single character or position in the file. + + If the token's start and end positions are different, the read_file parameter is required. + + :param read_file: List of lines from the file. Optional + :return: A Range object representing the token's position + """ + if self.start == self.end: + # If the start and end positions are the same, we can create a range directly + return Range( + start=Position(line=self.line - 1, character=self.col - 1), + end=Position(line=self.line - 1, character=self.col), + ) + + if read_file is None: + raise ValueError("read_file must be provided when start and end positions differ.") + + # Convert from 1-indexed to 0-indexed for line only + end_line_0 = self.line - 1 + end_col_0 = self.col + + # Find the start line and column by counting backwards from the end position + start_pos = self.start + end_pos = self.end + + # Initialize with the end position + start_line_0 = end_line_0 + start_col_0 = end_col_0 - (end_pos - start_pos + 1) + + # If start_col_0 is negative, we need to go back to previous lines + while start_col_0 < 0 and start_line_0 > 0: + start_line_0 -= 1 + start_col_0 += len(read_file[start_line_0]) + # Account for newline character + if start_col_0 >= 0: + break + start_col_0 += 1 # For the newline character + + # Ensure we don't have negative values + start_col_0 = max(0, start_col_0) + return Range( + start=Position(line=start_line_0, character=start_col_0), + end=Position(line=end_line_0, character=end_col_0), + ) + + +def read_range_from_file(file: Path, text_range: Range) -> str: + """ + Read the file and return the content within the specified range. + + Args: + file: Path to the file to read + text_range: The range of text to extract + + Returns: + The content within the specified range + """ + with file.open("r", encoding="utf-8") as f: + lines = f.readlines() + + # Ensure the range is within bounds + start_line = max(0, text_range.start.line) + end_line = min(len(lines), text_range.end.line + 1) + + if start_line >= end_line: + return "" + + # Extract the relevant portions of each line + result = [] + for i in range(start_line, end_line): + line = lines[i] + start_char = text_range.start.character if i == text_range.start.line else 0 + end_char = text_range.end.character if i == text_range.end.line else len(line) + result.append(line[start_char:end_char]) + + return "".join(result) diff --git a/sqlmesh/core/linter/rule.py b/sqlmesh/core/linter/rule.py index 003f9b813a..84e1693bef 100644 --- a/sqlmesh/core/linter/rule.py +++ b/sqlmesh/core/linter/rule.py @@ -1,6 +1,7 @@ from __future__ import annotations import abc +from dataclasses import dataclass from sqlmesh.core.model import Model @@ -22,6 +23,22 @@ class RuleLocation(PydanticModel): start_line: t.Optional[int] = None +@dataclass(frozen=True) +class Position: + """The position of a rule violation in a file, the position follows the LSP standard.""" + + line: int + character: int + + +@dataclass(frozen=True) +class Range: + """The range of a rule violation in a file. The range follows the LSP standard.""" + + start: Position + end: Position + + class _Rule(abc.ABCMeta): def __new__(cls: Type[_Rule], clsname: str, bases: t.Tuple, attrs: t.Dict) -> _Rule: attrs["name"] = clsname.lower() @@ -45,9 +62,15 @@ def summary(self) -> str: """A summary of what this rule checks for.""" return self.__doc__ or "" - def violation(self, violation_msg: t.Optional[str] = None) -> RuleViolation: + def violation( + self, + violation_msg: t.Optional[str] = None, + violation_range: t.Optional[Range] = None, + ) -> RuleViolation: """Create a RuleViolation instance for this rule""" - return RuleViolation(rule=self, violation_msg=violation_msg or self.summary) + return RuleViolation( + rule=self, violation_msg=violation_msg or self.summary, violation_range=violation_range + ) def get_definition_location(self) -> RuleLocation: """Return the file path and position information for this rule. @@ -79,9 +102,12 @@ def __repr__(self) -> str: class RuleViolation: - def __init__(self, rule: Rule, violation_msg: str) -> None: + def __init__( + self, rule: Rule, violation_msg: str, violation_range: t.Optional[Range] = None + ) -> None: self.rule = rule self.violation_msg = violation_msg + self.violation_range = violation_range def __repr__(self) -> str: return f"{self.rule.name}: {self.violation_msg}" diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py index 9f93a24236..0480683f6d 100644 --- a/sqlmesh/core/linter/rules/builtin.py +++ b/sqlmesh/core/linter/rules/builtin.py @@ -4,9 +4,11 @@ import typing as t +from sqlglot.expressions import Star from sqlglot.helper import subclasses -from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.linter.helpers import TokenPositionDetails +from sqlmesh.core.linter.rule import Rule, RuleViolation, Range from sqlmesh.core.linter.definition import RuleSet from sqlmesh.core.model import Model, SqlModel @@ -15,10 +17,25 @@ class NoSelectStar(Rule): """Query should not contain SELECT * on its outer most projections, even if it can be expanded.""" def check_model(self, model: Model) -> t.Optional[RuleViolation]: + # Only applies to SQL models, as other model types do not have a query. if not isinstance(model, SqlModel): return None - - return self.violation() if model.query.is_star else None + if model.query.is_star: + violation_range = self._get_range(model) + return self.violation(violation_range=violation_range) + return None + + def _get_range(self, model: SqlModel) -> t.Optional[Range]: + """Get the range of the violation if available.""" + try: + if len(model.query.expressions) == 1 and isinstance(model.query.expressions[0], Star): + return TokenPositionDetails.from_meta(model.query.expressions[0].meta).to_range( + None + ) + except Exception: + pass + + return None class InvalidSelectStarExpansion(Rule): diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py index 9ecbb9a2b1..4ac55f1a22 100644 --- a/sqlmesh/lsp/context.py +++ b/sqlmesh/lsp/context.py @@ -5,7 +5,7 @@ from sqlmesh.core.model.definition import SqlModel from sqlmesh.core.linter.definition import AnnotatedRuleViolation -from sqlmesh.lsp.custom import RenderModelEntry, ModelForRendering +from sqlmesh.lsp.custom import ModelForRendering from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry from sqlmesh.lsp.uri import URI diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 942d644689..55a6280c30 100644 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -655,8 +655,24 @@ def _diagnostic_to_lsp_diagnostic( ) -> t.Optional[types.Diagnostic]: if diagnostic.model._path is None: return None - with open(diagnostic.model._path, "r", encoding="utf-8") as file: - lines = file.readlines() + if not diagnostic.violation_range: + with open(diagnostic.model._path, "r", encoding="utf-8") as file: + lines = file.readlines() + range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=len(lines) - 1, character=len(lines[-1])), + ) + else: + range = types.Range( + start=types.Position( + line=diagnostic.violation_range.start.line, + character=diagnostic.violation_range.start.character, + ), + end=types.Position( + line=diagnostic.violation_range.end.line, + character=diagnostic.violation_range.end.character, + ), + ) # Get rule definition location for diagnostics link rule_location = diagnostic.rule.get_definition_location() @@ -665,10 +681,7 @@ def _diagnostic_to_lsp_diagnostic( # Use URI format to create a link for "related information" return types.Diagnostic( - range=types.Range( - start=types.Position(line=0, character=0), - end=types.Position(line=len(lines), character=len(lines[-1])), - ), + range=range, message=diagnostic.violation_msg, severity=types.DiagnosticSeverity.Error if diagnostic.violation_type == "error" diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index b43cc56751..0ddef76cad 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -4,6 +4,11 @@ from sqlmesh.core.audit import StandaloneAudit from sqlmesh.core.dialect import normalize_model_name +from sqlmesh.core.linter.helpers import ( + TokenPositionDetails, + Range as SQLMeshRange, + Position as SQLMeshPosition, +) from sqlmesh.core.model.definition import SqlModel from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget from sqlglot import exp @@ -156,12 +161,17 @@ def get_model_definitions_for_a_path( if isinstance(alias, exp.TableAlias): identifier = alias.this if isinstance(identifier, exp.Identifier): - target_range = _range_from_token_position_details( - TokenPositionDetails.from_meta(identifier.meta), read_file - ) - table_range = _range_from_token_position_details( - TokenPositionDetails.from_meta(table.this.meta), read_file - ) + target_range_sqlmesh = TokenPositionDetails.from_meta( + identifier.meta + ).to_range(read_file) + table_range_sqlmesh = TokenPositionDetails.from_meta( + table.this.meta + ).to_range(read_file) + + # Convert SQLMesh Range to LSP Range + target_range = to_lsp_range(target_range_sqlmesh) + table_range = to_lsp_range(table_range_sqlmesh) + references.append( Reference( uri=document_uri.value, # Same file @@ -203,25 +213,26 @@ def get_model_definitions_for_a_path( # Extract metadata for positioning table_meta = TokenPositionDetails.from_meta(table.this.meta) - table_range = _range_from_token_position_details(table_meta, read_file) - start_pos = table_range.start - end_pos = table_range.end + table_range_sqlmesh = table_meta.to_range(read_file) + start_pos_sqlmesh = table_range_sqlmesh.start + end_pos_sqlmesh = table_range_sqlmesh.end # If there's a catalog or database qualifier, adjust the start position catalog_or_db = table.args.get("catalog") or table.args.get("db") if catalog_or_db is not None: catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) - catalog_or_db_range = _range_from_token_position_details( - catalog_or_db_meta, read_file - ) - start_pos = catalog_or_db_range.start + catalog_or_db_range_sqlmesh = catalog_or_db_meta.to_range(read_file) + start_pos_sqlmesh = catalog_or_db_range_sqlmesh.start description = generate_markdown_description(referenced_model) references.append( Reference( uri=referenced_model_uri.value, - range=Range(start=start_pos, end=end_pos), + range=Range( + start=to_lsp_position(start_pos_sqlmesh), + end=to_lsp_position(end_pos_sqlmesh), + ), markdown_description=description, ) ) @@ -229,71 +240,6 @@ def get_model_definitions_for_a_path( return references -class TokenPositionDetails(PydanticModel): - """ - Details about a token's position in the source code. - - Attributes: - line (int): The line that the token ends on. - col (int): The column that the token ends on. - start (int): The start index of the token. - end (int): The ending index of the token. - """ - - line: int - col: int - start: int - end: int - - @staticmethod - def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails": - return TokenPositionDetails( - line=meta["line"], - col=meta["col"], - start=meta["start"], - end=meta["end"], - ) - - -def _range_from_token_position_details( - token_position_details: TokenPositionDetails, read_file: t.List[str] -) -> Range: - """ - Convert a TokenPositionDetails object to a Range object. - - :param token_position_details: Details about a token's position - :param read_file: List of lines from the file - :return: A Range object representing the token's position - """ - # Convert from 1-indexed to 0-indexed for line only - end_line_0 = token_position_details.line - 1 - end_col_0 = token_position_details.col - - # Find the start line and column by counting backwards from the end position - start_pos = token_position_details.start - end_pos = token_position_details.end - - # Initialize with the end position - start_line_0 = end_line_0 - start_col_0 = end_col_0 - (end_pos - start_pos + 1) - - # If start_col_0 is negative, we need to go back to previous lines - while start_col_0 < 0 and start_line_0 > 0: - start_line_0 -= 1 - start_col_0 += len(read_file[start_line_0]) - # Account for newline character - if start_col_0 >= 0: - break - start_col_0 += 1 # For the newline character - - # Ensure we don't have negative values - start_col_0 = max(0, start_col_0) - return Range( - start=Position(line=start_line_0, character=start_col_0), - end=Position(line=end_line_0, character=end_col_0), - ) - - def get_macro_definitions_for_a_path( lsp_context: LSPContext, document_uri: URI ) -> t.List[Reference]: @@ -373,11 +319,10 @@ def get_macro_reference( try: # Get the position of the macro invocation in the source file first if hasattr(node, "meta") and node.meta: - token_details = TokenPositionDetails.from_meta(node.meta) - macro_range = _range_from_token_position_details(token_details, read_file) + macro_range = TokenPositionDetails.from_meta(node.meta).to_range(read_file) # Check if it's a built-in method - if builtin := get_built_in_macro_reference(macro_name, macro_range): + if builtin := get_built_in_macro_reference(macro_name, to_lsp_range(macro_range)): return builtin else: # Skip if we can't get the position @@ -429,7 +374,7 @@ def get_macro_reference( return Reference( uri=macro_uri.value, - range=macro_range, + range=to_lsp_range(macro_range), target_range=Range( start=Position(line=start_line - 1, character=0), end=Position(line=end_line - 1, character=get_length_of_end_line), @@ -544,3 +489,24 @@ def _position_within_range(position: Position, range: Range) -> bool: range.end.line > position.line or (range.end.line == position.line and range.end.character >= position.character) ) + + +def to_lsp_range( + range: SQLMeshRange, +) -> Range: + """ + Converts a SQLMesh Range to an LSP Range. + """ + return Range( + start=Position(line=range.start.line, character=range.start.character), + end=Position(line=range.end.line, character=range.end.character), + ) + + +def to_lsp_position( + position: SQLMeshPosition, +) -> Position: + """ + Converts a SQLMesh Position to an LSP Position. + """ + return Position(line=position.line, character=position.character) diff --git a/tests/lsp/test_diagnostics.py b/tests/lsp/test_diagnostics.py new file mode 100644 index 0000000000..96167d47e5 --- /dev/null +++ b/tests/lsp/test_diagnostics.py @@ -0,0 +1,55 @@ +from sqlmesh import Context +from sqlmesh.core.linter.helpers import read_range_from_file +from sqlmesh.lsp.context import LSPContext +from sqlmesh.lsp.uri import URI + + +def test_diagnostic_on_sushi(tmp_path, copy_to_temp_path) -> None: + # Copy sushi example to a temporary directory + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Override the active_customers.sql file to introduce a linter violation + active_customers_path = sushi_path / "models" / "active_customers.sql" + # Replace SELECT customer_id, zip with SELECT * to trigger a linter violation + with active_customers_path.open("r") as f: + lines = f.readlines() + lines = [ + line.replace("SELECT customer_id, zip", "SELECT *") + if "SELECT customer_id, zip" in line + else line + for line in lines + ] + with active_customers_path.open("w") as f: + f.writelines(lines) + + # Override the config and turn the linter on + config_path = sushi_path / "config.py" + with config_path.open("r") as f: + lines = f.readlines() + lines = [ + line.replace("enabled=False,", "enabled=True,") if "enabled=False," in line else line + for line in lines + ] + with config_path.open("w") as f: + f.writelines(lines) + + # Load the context with the temporary sushi path + context = Context(paths=[str(sushi_path)]) + lsp_context = LSPContext(context) + + # Diagnostics should be available + active_customers_uri = URI.from_path(active_customers_path) + lsp_diagnostics = lsp_context.lint_model(active_customers_uri) + + assert len(lsp_diagnostics) > 0 + + # Get the no select star diagnostic + select_star_diagnostic = [diag for diag in lsp_diagnostics if diag.rule.name == "noselectstar"] + assert len(select_star_diagnostic) == 1 + diagnostic = select_star_diagnostic[0] + + assert diagnostic.violation_range + + contents = read_range_from_file(active_customers_path, diagnostic.violation_range) + assert contents == "*"