Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sqlmesh/core/linter/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Iterator, Iterable, Set, Mapping, Callable
from functools import reduce
from sqlmesh.core.model import Model
from sqlmesh.core.linter.rule import Rule, RuleViolation
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range
from sqlmesh.core.console import LinterConsole, get_console

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -74,6 +74,7 @@ def lint_model(
violation_msg=violation.violation_msg,
model=model,
violation_type="error",
violation_range=violation.violation_range,
)
for violation in error_violations
] + [
Expand All @@ -82,6 +83,7 @@ def lint_model(
violation_msg=violation.violation_msg,
model=model,
violation_type="warning",
violation_range=violation.violation_range,
)
for violation in warn_violations
]
Expand Down Expand Up @@ -149,7 +151,8 @@ def __init__(
violation_msg: str,
model: Model,
violation_type: t.Literal["error", "warning"],
violation_range: t.Optional[Range] = None,
) -> None:
super().__init__(rule, violation_msg)
super().__init__(rule, violation_msg, violation_range)
self.model = model
self.violation_type = violation_type
115 changes: 115 additions & 0 deletions sqlmesh/core/linter/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from pathlib import Path

from sqlmesh.core.linter.rule import Position, Range
from sqlmesh.utils.pydantic import PydanticModel
import typing as t


class TokenPositionDetails(PydanticModel):
"""
Details about a token's position in the source code in the structure provided by SQLGlot.

Attributes:
line (int): The line that the token ends on.
col (int): The column that the token ends on.
start (int): The start index of the token.
end (int): The ending index of the token.
"""

line: int
col: int
start: int
end: int

@staticmethod
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
return TokenPositionDetails(
line=meta["line"],
col=meta["col"],
start=meta["start"],
end=meta["end"],
)

def to_range(self, read_file: t.Optional[t.List[str]]) -> Range:
"""
Convert a TokenPositionDetails object to a Range object.

In the circumstances where the token's start and end positions are the same,
there is no need for a read_file parameter, as the range can be derived from the token's
line and column. This is an optimization to avoid unnecessary file reads and should
only be used when the token represents a single character or position in the file.

If the token's start and end positions are different, the read_file parameter is required.

:param read_file: List of lines from the file. Optional
:return: A Range object representing the token's position
"""
if self.start == self.end:
# If the start and end positions are the same, we can create a range directly
return Range(
start=Position(line=self.line - 1, character=self.col - 1),
end=Position(line=self.line - 1, character=self.col),
)

if read_file is None:
raise ValueError("read_file must be provided when start and end positions differ.")

# Convert from 1-indexed to 0-indexed for line only
end_line_0 = self.line - 1
end_col_0 = self.col

# Find the start line and column by counting backwards from the end position
start_pos = self.start
end_pos = self.end

# Initialize with the end position
start_line_0 = end_line_0
start_col_0 = end_col_0 - (end_pos - start_pos + 1)

# If start_col_0 is negative, we need to go back to previous lines
while start_col_0 < 0 and start_line_0 > 0:
start_line_0 -= 1
start_col_0 += len(read_file[start_line_0])
# Account for newline character
if start_col_0 >= 0:
break
start_col_0 += 1 # For the newline character

# Ensure we don't have negative values
start_col_0 = max(0, start_col_0)
return Range(
start=Position(line=start_line_0, character=start_col_0),
end=Position(line=end_line_0, character=end_col_0),
)


def read_range_from_file(file: Path, text_range: Range) -> str:
"""
Read the file and return the content within the specified range.

Args:
file: Path to the file to read
text_range: The range of text to extract

Returns:
The content within the specified range
"""
with file.open("r", encoding="utf-8") as f:
lines = f.readlines()

# Ensure the range is within bounds
start_line = max(0, text_range.start.line)
end_line = min(len(lines), text_range.end.line + 1)

if start_line >= end_line:
return ""

# Extract the relevant portions of each line
result = []
for i in range(start_line, end_line):
line = lines[i]
start_char = text_range.start.character if i == text_range.start.line else 0
end_char = text_range.end.character if i == text_range.end.line else len(line)
result.append(line[start_char:end_char])

return "".join(result)
32 changes: 29 additions & 3 deletions sqlmesh/core/linter/rule.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import abc
from dataclasses import dataclass

from sqlmesh.core.model import Model

Expand All @@ -22,6 +23,22 @@ class RuleLocation(PydanticModel):
start_line: t.Optional[int] = None


@dataclass(frozen=True)
class Position:
"""The position of a rule violation in a file, the position follows the LSP standard."""

line: int
character: int
Comment on lines +30 to +31
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does "character" mean column offset in the line? Why not use column or something similar so that the naming is consistent?

I also suggest we start documenting some details, like whether it's 1-based indexing and stuff, so reading the docstring makes it obvious to understand what's going on.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idem, just want to specify LSP standard



@dataclass(frozen=True)
class Range:
"""The range of a rule violation in a file. The range follows the LSP standard."""

start: Position
end: Position


class _Rule(abc.ABCMeta):
def __new__(cls: Type[_Rule], clsname: str, bases: t.Tuple, attrs: t.Dict) -> _Rule:
attrs["name"] = clsname.lower()
Expand All @@ -45,9 +62,15 @@ def summary(self) -> str:
"""A summary of what this rule checks for."""
return self.__doc__ or ""

def violation(self, violation_msg: t.Optional[str] = None) -> RuleViolation:
def violation(
self,
violation_msg: t.Optional[str] = None,
violation_range: t.Optional[Range] = None,
) -> RuleViolation:
"""Create a RuleViolation instance for this rule"""
return RuleViolation(rule=self, violation_msg=violation_msg or self.summary)
return RuleViolation(
rule=self, violation_msg=violation_msg or self.summary, violation_range=violation_range
)

def get_definition_location(self) -> RuleLocation:
"""Return the file path and position information for this rule.
Expand Down Expand Up @@ -79,9 +102,12 @@ def __repr__(self) -> str:


class RuleViolation:
def __init__(self, rule: Rule, violation_msg: str) -> None:
def __init__(
self, rule: Rule, violation_msg: str, violation_range: t.Optional[Range] = None
) -> None:
self.rule = rule
self.violation_msg = violation_msg
self.violation_range = violation_range

def __repr__(self) -> str:
return f"{self.rule.name}: {self.violation_msg}"
23 changes: 20 additions & 3 deletions sqlmesh/core/linter/rules/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import typing as t

from sqlglot.expressions import Star
from sqlglot.helper import subclasses

from sqlmesh.core.linter.rule import Rule, RuleViolation
from sqlmesh.core.linter.helpers import TokenPositionDetails
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range
from sqlmesh.core.linter.definition import RuleSet
from sqlmesh.core.model import Model, SqlModel

Expand All @@ -15,10 +17,25 @@ class NoSelectStar(Rule):
"""Query should not contain SELECT * on its outer most projections, even if it can be expanded."""

def check_model(self, model: Model) -> t.Optional[RuleViolation]:
# Only applies to SQL models, as other model types do not have a query.
if not isinstance(model, SqlModel):
return None

return self.violation() if model.query.is_star else None
if model.query.is_star:
violation_range = self._get_range(model)
return self.violation(violation_range=violation_range)
return None

def _get_range(self, model: SqlModel) -> t.Optional[Range]:
"""Get the range of the violation if available."""
try:
if len(model.query.expressions) == 1 and isinstance(model.query.expressions[0], Star):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this instead call find(Star) on the model's query? Note that SQL like t.* won't be picked up by this check because the projection is a Column that contains a Star child in this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be improved, but I'm hesitant to do so right away because it adds complexity. For me, I'm just in the simplest case of returning a range, which I think we can expand on over time.

There's also a lot of other changes so rather do this later.

return TokenPositionDetails.from_meta(model.query.expressions[0].meta).to_range(
None
)
except Exception:
pass

return None


class InvalidSelectStarExpansion(Rule):
Expand Down
2 changes: 1 addition & 1 deletion sqlmesh/lsp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sqlmesh.core.model.definition import SqlModel
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
from sqlmesh.lsp.custom import RenderModelEntry, ModelForRendering
from sqlmesh.lsp.custom import ModelForRendering
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
from sqlmesh.lsp.uri import URI

Expand Down
25 changes: 19 additions & 6 deletions sqlmesh/lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,24 @@ def _diagnostic_to_lsp_diagnostic(
) -> t.Optional[types.Diagnostic]:
if diagnostic.model._path is None:
return None
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
lines = file.readlines()
if not diagnostic.violation_range:
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
lines = file.readlines()
range = types.Range(
start=types.Position(line=0, character=0),
end=types.Position(line=len(lines) - 1, character=len(lines[-1])),
)
else:
range = types.Range(
start=types.Position(
line=diagnostic.violation_range.start.line,
character=diagnostic.violation_range.start.character,
),
end=types.Position(
line=diagnostic.violation_range.end.line,
character=diagnostic.violation_range.end.character,
),
)

# Get rule definition location for diagnostics link
rule_location = diagnostic.rule.get_definition_location()
Expand All @@ -665,10 +681,7 @@ def _diagnostic_to_lsp_diagnostic(

# Use URI format to create a link for "related information"
return types.Diagnostic(
range=types.Range(
start=types.Position(line=0, character=0),
end=types.Position(line=len(lines), character=len(lines[-1])),
),
range=range,
message=diagnostic.violation_msg,
severity=types.DiagnosticSeverity.Error
if diagnostic.violation_type == "error"
Expand Down
Loading