Skip to content

Commit

Permalink
feat: support pyo3 TypedDict
Browse files Browse the repository at this point in the history
fix #389
  • Loading branch information
HerringtonDarkholme committed Nov 6, 2023
1 parent 86b1e40 commit 4da1aef
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 8 deletions.
10 changes: 6 additions & 4 deletions crates/pyo3/ast_grep_pyo3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List, TypedDict, Literal, Dict, Union
from typing import List, TypedDict, Literal, Dict, Union, Mapping
from .ast_grep_pyo3 import SgNode, SgRoot, Pos, Range

class Pattern(TypedDict):
Expand Down Expand Up @@ -34,15 +34,17 @@ class Rule(RuleWithoutNot, TypedDict("Not", {"not": "Rule"}, total=False)):
# Relational Rule Related
StopBy = Union[Literal["neighbor"], Literal["end"], Rule]

class Relation(Rule, total=False):
# Relation do NOT inherit from Rule due to pyright bug
# see tests/test_rule.py
class Relation(RuleWithoutNot, TypedDict("Not", {"not": "Rule"}, total=False), total=False):
stopBy: StopBy
field: str

class Config(TypedDict, total=False):
rule: Rule
constraints: Dict[str, Dict]
constraints: Dict[str, Mapping]
utils: Dict[str, Rule]
transform: Dict[str, Dict]
transform: Dict[str, Mapping]

__all__ = [
"Rule",
Expand Down
1 change: 1 addition & 0 deletions crates/pyo3/ast_grep_pyo3/ast_grep_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SgNode:
def follows(self, **rule: Unpack[Rule]) -> bool: ...
def get_match(self, meta_var: str) -> Optional[SgNode]: ...
def get_multiple_matches(self, meta_var: str) -> List[SgNode]: ...
def get_transformed(self, meta_var: str) -> Optional[str]: ...
def __getitem__(self, meta_var: str) -> SgNode: ...

# Tree Traversal
Expand Down
8 changes: 8 additions & 0 deletions crates/pyo3/src/py_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ impl SgNode {
.collect()
}

fn get_transformed(&self, meta_var: &str) -> Option<String> {
self
.inner
.get_env()
.get_transformed(meta_var)
.map(|n| String::from_utf8_lossy(n).to_string())
}

/*---------- Tree Traversal ----------*/
fn get_root(&self) -> Py<SgRoot> {
self.root.clone()
Expand Down
104 changes: 100 additions & 4 deletions crates/pyo3/tests/test_rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ast_grep_pyo3 import SgRoot, Rule, Config, Relation
from ast_grep_pyo3 import SgRoot, Rule, Config, Relation, Pattern

source = """
function test() {
Expand Down Expand Up @@ -47,10 +47,106 @@ def test_not_rule():
node = root.find(**rule)
assert node

def test_relational_rule():
relation: Relation = Relation(kind="function_declaration", stopBy="end")
def test_relational_dict():
relation: Relation = {"kind": "function_declaration", "stopBy": "end"}
node = root.find(
pattern="let a = 123\n",
inside=relation,
)
assert node
assert node
node = root.find(
pattern="let a = 123\n",
inside={"kind": "function_declaration", "stopBy": "end"},
)
assert node

def test_relational_rule():
node = root.find(
pattern="let a = 123\n",
inside=Relation(kind="function_declaration", stopBy="end"),
)
assert node

def test_complex_config_dict():
node = root.find({
"rule": {
"pattern": "let $A = $B",
"regex": "123",
"not": {
"regex": "456"
},
},
"constraints": {
"A": {
"pattern": "a"
}
},
"transform": {
"C": {
"substring": {
"source": "$B",
"startChar": 1,
"endChar": -1,
}
}
}
})
assert node
assert node.get_transformed("C") == "2"

def test_complex_config_dict_not_found():
node = root.find({
"rule": {
"pattern": "let $A = $B",
"regex": "123",
"not": {
"regex": "456"
},
},
"constraints": {
"A": {
"pattern": "a"
},
"B": {
"regex": "222"
},
},
"transform": {
"C": {
"substring": {
"source": "$B",
"startChar": 1,
"endChar": -1,
}
}
}
})
assert not node

def test_complex_config():
node = root.find(Config(
rule=Rule(pattern="let $A = $B", regex="123"),
constraints=dict(A=Rule(pattern="a")),
transform=dict(C={
"substring": {
"source": "$B",
"startChar": 1,
}
})
))
assert node
assert node.text() == "let a = 123"
assert node.get_transformed("C") == "23"

def test_pattern():
node = root.find(pattern={
"context": "let a = 123",
"selector": "variable_declarator"
})
assert node
assert node.text() == "a = 123"
node2 = root.find(pattern=Pattern(
context="let a = 123",
selector="variable_declarator",
))
assert node == node2

0 comments on commit 4da1aef

Please sign in to comment.