Skip to content

Commit

Permalink
feat: implement better refinement
Browse files Browse the repository at this point in the history
part of #389
  • Loading branch information
HerringtonDarkholme committed Oct 31, 2023
1 parent 98872b0 commit 1f84b77
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 32 deletions.
12 changes: 6 additions & 6 deletions crates/pyo3/ast_grep_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ class SgNode:
def text(self) -> str: ...

# Search Refinement
def matches(self, m: str) -> bool: ...
def inside(self, m: str) -> bool: ...
def has(self, m: str) -> bool: ...
def precedes(self, m: str) -> bool: ...
def follows(self, m: str) -> bool: ...
def matches(self, **rule: Unpack[Rule]) -> bool: ...
def inside(self, **rule: Unpack[Rule]) -> bool: ...
def has(self, **rule: Unpack[Rule]) -> bool: ...
def precedes(self, **rule: Unpack[Rule]) -> bool: ...
def follows(self, **rule: Unpack[Rule]) -> bool: ...
def get_match(self, meta_var: str) -> Optional[SgNode]: ...
def get_multiple_matches(self, meta_var: str) -> List[SgNode]: ...

# Tree Traversal
def get_root(self) -> SgRoot: ...
def find(self, config=None, **kwargs: Unpack[Rule]) -> SgNode: ...
def find_all(self, config=None, **kwargs: Unpack[Rule]) -> List[SgNode]: ...
def find_all(self, config=None, **kwargs: Unpack[Rule]) -> List[SgNode]: ...
58 changes: 37 additions & 21 deletions crates/pyo3/src/py_node.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::range::Range;
use crate::SgRoot;

use ast_grep_config::{SerializableRule, SerializableRuleCore};
use ast_grep_config::SerializableRuleCore;
use ast_grep_core::{NodeMatch, StrDoc};
use ast_grep_language::SupportLang;

Expand Down Expand Up @@ -52,24 +52,39 @@ impl SgNode {
}

/*---------- Search Refinement ----------*/
fn matches(&self, m: String) -> bool {
self.inner.matches(&*m)
#[pyo3(signature = (**kwargs))]
fn matches(&self, kwargs: Option<&PyDict>) -> bool {
let config = config_from_rule(self.inner.lang(), kwargs.unwrap());
let matcher = config.get_matcher(&Default::default()).unwrap();
self.inner.matches(matcher)
}

fn inside(&self, m: String) -> bool {
self.inner.inside(&*m)
#[pyo3(signature = (**kwargs))]
fn inside(&self, kwargs: Option<&PyDict>) -> bool {
let config = config_from_rule(self.inner.lang(), kwargs.unwrap());
let matcher = config.get_matcher(&Default::default()).unwrap();
self.inner.inside(matcher)
}

fn has(&self, m: String) -> bool {
self.inner.has(&*m)
#[pyo3(signature = (**kwargs))]
fn has(&self, kwargs: Option<&PyDict>) -> bool {
let config = config_from_rule(self.inner.lang(), kwargs.unwrap());
let matcher = config.get_matcher(&Default::default()).unwrap();
self.inner.has(matcher)
}

fn precedes(&self, m: String) -> bool {
self.inner.precedes(&*m)
#[pyo3(signature = (**kwargs))]
fn precedes(&self, kwargs: Option<&PyDict>) -> bool {
let config = config_from_rule(self.inner.lang(), kwargs.unwrap());
let matcher = config.get_matcher(&Default::default()).unwrap();
self.inner.precedes(matcher)
}

fn follows(&self, m: String) -> bool {
self.inner.follows(&*m)
#[pyo3(signature = (**kwargs))]
fn follows(&self, kwargs: Option<&PyDict>) -> bool {
let config = config_from_rule(self.inner.lang(), kwargs.unwrap());
let matcher = config.get_matcher(&Default::default()).unwrap();
self.inner.follows(matcher)
}

fn get_match(&self, meta_var: &str) -> Option<Self> {
Expand Down Expand Up @@ -237,14 +252,8 @@ impl SgNode {
if let Some(config) = config {
config_from_dict(lang, config)
} else {
let rule = rule_from_dict(kwargs.unwrap());
SerializableRuleCore {
language: *lang,
rule,
constraints: None,
utils: None,
transform: None,
}
// TODO: remove unwrap
config_from_rule(lang, kwargs.unwrap())
}
}
}
Expand All @@ -254,6 +263,13 @@ fn config_from_dict(lang: &SupportLang, dict: &PyDict) -> SerializableRuleCore<S
depythonize(dict).unwrap()
}

fn rule_from_dict(dict: &PyDict) -> SerializableRule {
depythonize(dict).unwrap()
fn config_from_rule(lang: &SupportLang, dict: &PyDict) -> SerializableRuleCore<SupportLang> {
let rule = depythonize(dict).unwrap();
SerializableRuleCore {
language: *lang,
rule,
constraints: None,
utils: None,
transform: None,
}
}
24 changes: 19 additions & 5 deletions crates/pyo3/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,33 @@ def test_text():
assert node.text() == "123"

def test_matches():
pass
node = root.find(pattern="let $A = $B")
assert node.matches(kind="lexical_declaration")
assert not node.matches(kind="number")
assert node.matches(pattern="let a = 123")
assert not node.matches(pattern="let b = 456")

def test_inside():
pass
node = root.find(pattern="let $A = $B")
assert node.inside(kind="function_declaration")
assert not node.inside(kind="function")

def test_has():
pass
node = root.find(pattern="let $A = $B")
assert node.has(pattern="123")
assert node.has(kind="number")
assert not node.has(kind="function")

def test_precedes():
pass
node = root.find(pattern="let $A = $B\n")
assert node.precedes(pattern="let b = 456\n")
assert node.precedes(pattern="let c = 789\n")
assert not node.precedes(pattern="notExist")

def test_follows():
pass
node = root.find(pattern="let b = 456\n")
assert node.follows(pattern="let a = 123\n")
assert not node.follows(pattern="let c = 789\n")

def test_get_match():
node = root.find(pattern="let $A = $B")
Expand Down

0 comments on commit 1f84b77

Please sign in to comment.