Skip to content

Commit

Permalink
test: add test for get_match and wrong usages
Browse files Browse the repository at this point in the history
part of #389
  • Loading branch information
HerringtonDarkholme committed Nov 4, 2023
1 parent 2ad3463 commit 89aec8d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
13 changes: 10 additions & 3 deletions crates/pyo3/ast_grep_pyo3.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, TypedDict, Unpack, Literal
from typing import List, Optional, TypedDict, Unpack, Literal, overload

class Pattern(TypedDict):
selector: str
Expand Down Expand Up @@ -66,11 +66,18 @@ class SgNode:
def follows(self, **rule: Unpack[Rule]) -> bool: ...
def get_match(self, meta_var: str) -> Optional[SgNode]: ...
def get_multiple_matches(self, meta_var: str) -> List[SgNode]: ...
def __getitem__(self, meta_var: str) -> SgNode: ...

# Tree Traversal
def get_root(self) -> SgRoot: ...
def find(self, config=None, **kwargs: Unpack[Rule]) -> SgNode: ...
def find_all(self, config=None, **kwargs: Unpack[Rule]) -> List[SgNode]: ...
@overload
def find(self, config=None) -> SgNode: ...
@overload
def find(self, **kwargs: Unpack[Rule]) -> SgNode: ...
@overload
def find_all(self, config=None) -> List[SgNode]: ...
@overload
def find_all(self, **kwargs: Unpack[Rule]) -> List[SgNode]: ...
def field(self, name: str) -> Optional[SgNode]: ...
def parent(self) -> Optional[SgNode]: ...
def child(self, nth: int) -> Optional[SgNode]: ...
Expand Down
9 changes: 9 additions & 0 deletions crates/pyo3/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ def test_get_match():
assert rng.start.line == 1
assert rng.start.column == 6

def test_must_get_match():
node = root.find(pattern="let $A = $B")
a = node["A"]
assert a is not None
assert a.text() == "a"
rng = a.range()
assert rng.start.line == 1
assert rng.start.column == 6


def test_get_multi_match():
node = root.find(pattern="function test() { $$$STMT }")
Expand Down
40 changes: 40 additions & 0 deletions crates/pyo3/tests/test_wrong_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from ast_grep_pyo3 import SgRoot
import pytest

source = """
function test() {
let a = 123
let b = 456
let c = 789
}
""".strip()
sg = SgRoot(source, "javascript")
root = sg.root()

def test_wrong_use_pattern_as_dict():
with pytest.raises(TypeError):
root.find("let $A = 123")

def test_get_unfound_match():
node = root.find(pattern="let $A = 123")
with pytest.raises(KeyError):
node["B"]

# TODO: remove BaseException
def test_wrong_rule_key():
with pytest.raises(BaseException):
root.find(pat="not") # type: ignore

def test_no_rule_key():
with pytest.raises(BaseException):
root.find()

def test_error_for_invalid_kind():
with pytest.raises(BaseException):
root.find(kind="nonsense")

def test_no_error_for_invalid_pattern():
with pytest.raises(BaseException):
root.find(pattern="$@!!--l3**+no//nsense")
# but not this
assert not root.find(pattern="@test")

0 comments on commit 89aec8d

Please sign in to comment.