Skip to content

Commit 7e82824

Browse files
committed
fix: handle decorated identifiers
1 parent 582e575 commit 7e82824

File tree

4 files changed

+159
-126
lines changed

4 files changed

+159
-126
lines changed

Diff for: src/cedarscript_editor/tree_sitter_identifier_finder.py

+53-9
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def find_identifiers(
128128
match identifier_type:
129129
case 'method':
130130
identifier_type = 'function'
131-
candidate_nodes = self.language.query(self.query_info[identifier_type].format(name=name)).captures(self.tree.root_node)
131+
_query = self.query_info[identifier_type].format(name=name)
132+
candidate_nodes = self.language.query(_query).captures(self.tree.root_node)
132133
if not candidate_nodes:
133134
return []
134135
# Convert captures to boundaries and filter by parent
@@ -198,9 +199,9 @@ def parents(self) -> list[ParentInfo]:
198199
current = self.node.parent
199200

200201
while current:
201-
# Check if current node is a container type we care about
202-
if current.type.endswith('_definition'):
203-
# Try to find the name node - exact field depends on language
202+
# Check if current node is a container type we care about - TODO exact field depends on language
203+
if current.type.endswith('_definition') and current.type != 'decorated_definition':
204+
# Try to find the name node - TODO exact field depends on language
204205
name = None
205206
for child in current.children:
206207
if child.type == 'identifier' or child.type == 'name':
@@ -242,14 +243,13 @@ def associate_identifier_parts(captures: Iterable[CaptureInfo], lines: Sequence[
242243
raise ValueError(f'Parent node not found for [{capture.capture_type} - {capture.node_type}] ({capture.node.text.decode("utf-8").strip()})')
243244
match capture_type:
244245
case 'body':
245-
parent = parent._replace(body=range_spec)
246+
parent.body=range_spec
246247
case 'docstring':
247-
parent = parent._replace(docstring=range_spec)
248+
parent.docstring=range_spec
248249
case 'decorator':
249-
parent = parent.decorators.append(range_spec)
250+
parent.append_decorator(range_spec)
250251
case _ as invalid:
251252
raise ValueError(f'Invalid capture type: {invalid}')
252-
identifier_map[parent_key] = parent
253253

254254
return sorted(identifier_map.values(), key=lambda x: x.whole.start)
255255

@@ -260,6 +260,8 @@ def find_parent_definition(node):
260260
while node.parent:
261261
node = node.parent
262262
if node.type.endswith('_definition'):
263+
if node.type == 'decorated_definition':
264+
node = node.named_children[0].next_named_sibling
263265
return node
264266
return None
265267

@@ -278,4 +280,46 @@ def capture2identifier_boundaries(captures, lines: Sequence[str]) -> list[Identi
278280
unique_captures = {}
279281
for capture in captures:
280282
unique_captures[f'{capture.range[0]}:{capture.capture_type}'] = capture
281-
return associate_identifier_parts(unique_captures.values(), lines)
283+
# unique_captures={
284+
# '157:function.decorator': CaptureInfo(capture_type='function.decorator', node=<Node type=decorator, start_point=(157, 4), end_point=(157, 17)>),
285+
# '158:function.definition': CaptureInfo(capture_type='function.definition', node=<Node type=function_definition, start_point=(158, 4), end_point=(207, 19)>),
286+
# '159:function.body': CaptureInfo(capture_type='function.body', node=<Node type=block, start_point=(159, 8), end_point=(207, 19)>)
287+
# }
288+
return associate_identifier_parts(sort_captures(unique_captures), lines)
289+
290+
def parse_capture_key(key):
291+
"""
292+
Parses the dictionary key into line number and capture type.
293+
Args:
294+
key (str): The key in the format 'line_number:capture_type'.
295+
Returns:
296+
tuple: (line_number as int, capture_type as str)
297+
"""
298+
line_number, capture_type = key.split(':')
299+
return int(line_number), capture_type.split('.')[-1]
300+
301+
def get_sort_priority():
302+
"""
303+
Returns a dictionary mapping capture types to their sort priority.
304+
Returns:
305+
dict: Capture type priorities.
306+
"""
307+
return {'definition': 1, 'decorator': 2, 'body': 3, 'docstring': 4}
308+
309+
def sort_captures(captures):
310+
"""
311+
Sorts the values of the captures dictionary by capture type and line number.
312+
Args:
313+
captures (dict): The dictionary to sort.
314+
Returns:
315+
list: Sorted list of values.
316+
"""
317+
priority = get_sort_priority()
318+
sorted_items = sorted(
319+
captures.items(),
320+
key=lambda item: (
321+
priority[parse_capture_key(item[0])[1]], # Sort by capture type priority
322+
parse_capture_key(item[0])[0] # Then by line number
323+
)
324+
)
325+
return [value for _, value in sorted_items]

Diff for: src/cedarscript_editor/tree_sitter_identifier_queries.py

+94-110
Original file line numberDiff line numberDiff line change
@@ -30,126 +30,110 @@
3030
# except KeyError:
3131
# return
3232

33+
_common_template = """
34+
; Common pattern for body and docstring capture
35+
body: (block
36+
.
37+
(expression_statement
38+
(string) @{type}.docstring)?
39+
.
40+
) @{type}.body
41+
"""
42+
43+
_definition_base_template = """
44+
name: (identifier) @_{type}_name
45+
(#match? @_{type}_name "^{{name}}$")
46+
(#set! role name)
47+
"""
3348

3449
LANG_TO_TREE_SITTER_QUERY = {
3550
"python": {
3651
'function': """
37-
; Regular and async function definitions with optional docstring
38-
(function_definition
39-
name: (identifier) @_function_name
40-
(#match? @_function_name "^{name}$")
41-
body: (block) @function.body) @function.definition
42-
52+
; Function Definitions
53+
(function_definition
54+
{definition_base}
55+
{common_body}
56+
) @function.definition
57+
58+
; Decorated Function Definitions
59+
(decorated_definition
60+
(decorator)+ @function.decorator
4361
(function_definition
44-
name: (identifier) @_function_name
45-
(#match? @_function_name "^{name}$")
46-
body: (block
47-
.
48-
(expression_statement
49-
(string) @function.docstring)?
50-
.
51-
(_)*))
52-
53-
; Decorated function definitions (including async) with optional docstring
54-
(decorated_definition
55-
(decorator)+
56-
(function_definition
57-
name: (identifier) @_function_name
58-
(#match? @_function_name "^{name}$")
59-
body: (block) @function.body)) @function.definition
60-
61-
(decorated_definition
62-
(decorator)+
63-
(function_definition
64-
name: (identifier) @_function_name
65-
(#match? @_function_name "^{name}$")
66-
body: (block
67-
.
68-
(expression_statement
69-
(string) @function.docstring)?
70-
.
71-
(_)*)))
72-
73-
; Method definitions in classes (including async and decorated) with optional docstring
74-
(class_definition
75-
body: (block
62+
{definition_base}
63+
{common_body}
64+
) @function.definition
65+
)
66+
67+
; Methods in Classes
68+
(class_definition
69+
body: (block
7670
(function_definition
77-
name: (identifier) @_function_name
78-
(#match? @_function_name "^{name}$")
79-
body: (block) @function.body) @function.definition))
80-
81-
(class_definition
82-
body: (block
83-
(function_definition
84-
name: (identifier) @_function_name
85-
(#match? @_function_name "^{name}$")
86-
body: (block
87-
.
88-
(expression_statement
89-
(string) @function.docstring)?
90-
.
91-
(_)*))))
92-
""",
71+
{definition_base}
72+
{common_body}
73+
) @function.definition
74+
)
75+
)
76+
77+
; Decorated Methods in Classes
78+
(class_definition
79+
body: (block
80+
(decorated_definition
81+
(decorator)+ @function.decorator
82+
(function_definition
83+
{definition_base}
84+
{common_body}
85+
) @function.definition
86+
)
87+
)
88+
)
89+
""".format(
90+
definition_base=_definition_base_template.format(type="function"),
91+
common_body=_common_template.format(type="function")
92+
),
9393

9494
'class': """
95-
; Regular and decorated class definitions (including nested) with optional docstring
96-
(class_definition
97-
name: (identifier) @_class_name
98-
(#match? @_class_name "^{name}$")
99-
body: (block) @class.body) @class.definition
100-
95+
; Class Definitions
96+
(class_definition
97+
{definition_base}
98+
{common_body}
99+
) @class.definition
100+
101+
; Decorated Class Definitions
102+
(decorated_definition
103+
(decorator)+ @class.decorator
101104
(class_definition
102-
name: (identifier) @_class_name
103-
(#match? @_class_name "^{name}$")
104-
body: (block
105-
.
106-
(expression_statement
107-
(string) @class.docstring)?
108-
.
109-
(_)*))
110-
111-
; Decorated class definitions
112-
(decorated_definition
113-
(decorator)+
114-
(class_definition
115-
name: (identifier) @_class_name
116-
(#match? @_class_name "^{name}$")
117-
body: (block) @class.body)) @class.definition
118-
119-
(decorated_definition
120-
(decorator)+
121-
(class_definition
122-
name: (identifier) @_class_name
123-
(#match? @_class_name "^{name}$")
124-
body: (block
125-
.
126-
(expression_statement
127-
(string) @class.docstring)?
128-
.
129-
(_)*)))
130-
131-
; Nested class definitions within other classes
132-
(class_definition
133-
body: (block
134-
(class_definition
135-
name: (identifier) @_class_name
136-
(#match? @_class_name "^{name}$")
137-
body: (block) @class.body) @class.definition))
138-
139-
(class_definition
140-
body: (block
105+
{definition_base}
106+
{common_body}
107+
) @class.definition
108+
)
109+
110+
; Nested Classes
111+
(class_definition
112+
body: (block
141113
(class_definition
142-
name: (identifier) @_class_name
143-
(#match? @_class_name "^{name}$")
144-
body: (block
145-
.
146-
(expression_statement
147-
(string) @class.docstring)?
148-
.
149-
(_)*))))
150-
"""
151-
},
152-
"kotlin": {
114+
{definition_base}
115+
{common_body}
116+
) @class.definition
117+
)
118+
)
119+
120+
; Decorated Nested Classes
121+
(class_definition
122+
body: (block
123+
(decorated_definition
124+
(decorator)+ @class.decorator
125+
(class_definition
126+
{definition_base}
127+
{common_body}
128+
) @class.definition
129+
)
130+
)
131+
)
132+
""".format(
133+
definition_base=_definition_base_template.format(type="class"),
134+
common_body=_common_template.format(type="class")
135+
)
136+
}, "kotlin": {
153137
'function': """
154138
; Regular function definitions with optional annotations and KDoc
155139
(function_declaration

Diff for: src/text_manipulation/range_spec.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections.abc import Sequence
1414
from typing import NamedTuple, TypeAlias
1515
from functools import total_ordering
16+
from dataclasses import dataclass, field
1617

1718

1819
from cedarscript_ast_parser import Marker, RelativeMarker, RelativePositionType, MarkerType, BodyOrWhole
@@ -331,8 +332,8 @@ class ParentInfo(NamedTuple):
331332

332333
ParentRestriction: TypeAlias = RangeSpec | str | None
333334

334-
335-
class IdentifierBoundaries(NamedTuple):
335+
@dataclass
336+
class IdentifierBoundaries:
336337
"""
337338
Represents the boundaries of an identifier in code, including its whole range and body range.
338339
@@ -347,8 +348,12 @@ class IdentifierBoundaries(NamedTuple):
347348
whole: RangeSpec
348349
body: RangeSpec | None = None
349350
docstring: RangeSpec | None = None
350-
decorators: list[RangeSpec] = []
351-
parents: list[ParentInfo] = []
351+
decorators: list[RangeSpec] = field(default_factory=list)
352+
parents: list[ParentInfo] = field(default_factory=list)
353+
354+
def append_decorator(self, decorator: RangeSpec):
355+
self.decorators.append(decorator)
356+
self.whole = self.whole._replace(start = min(self.whole.start, decorator.start))
352357

353358
def __str__(self):
354359
return f'IdentifierBoundaries({self.whole} (BODY: {self.body}) )'

Diff for: src/text_manipulation/text_editor_kit.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from collections.abc import Sequence
1010
from typing import Protocol, runtime_checkable
11-
from os import PathLike
11+
from os import PathLike, path
1212

1313
from cedarscript_ast_parser import Marker, RelativeMarker, RelativePositionType, Segment, MarkerType, BodyOrWhole
1414
from .range_spec import IdentifierBoundaries, RangeSpec
@@ -24,7 +24,7 @@ def read_file(file_path: str | PathLike) -> str:
2424
Returns:
2525
str: The contents of the file as a string.
2626
"""
27-
with open(file_path, 'r') as file:
27+
with open(path.normpath(file_path), 'r') as file:
2828
return file.read()
2929

3030

@@ -36,7 +36,7 @@ def write_file(file_path: str | PathLike, lines: Sequence[str]):
3636
file_path (str | PathLike): The path to the file to be written.
3737
lines (Sequence[str]): The lines to be written to the file.
3838
"""
39-
with open(file_path, 'w') as file:
39+
with open(path.normpath(file_path), 'w') as file:
4040
file.writelines([line + '\n' for line in lines])
4141

4242

0 commit comments

Comments
 (0)