Skip to content

Commit

Permalink
feat: add find_all support
Browse files Browse the repository at this point in the history
part of #389
  • Loading branch information
HerringtonDarkholme committed Oct 23, 2023
1 parent f1cf76e commit 097a209
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 15 deletions.
50 changes: 36 additions & 14 deletions crates/pyo3/src/py_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,28 @@ impl SgNode {

#[pyo3(signature = (config=None, **kwargs))]
fn find(&self, config: Option<&PyDict>, kwargs: Option<&PyDict>) -> Option<Self> {
let lang = self.inner.lang();
let config = if let Some(config) = config {
config_from_dict(lang, config)
} else {
let rule = rule_from_dict(kwargs?);
SerializableRuleCore {
language: *lang,
rule,
constraints: None,
utils: None,
transform: None,
}
};
let config = self.get_config(config, kwargs);
let matcher = config.get_matcher(&Default::default()).unwrap();
let inner = self.inner.find(matcher)?;
Some(Self {
inner,
root: self.root.clone(),
})
}
#[pyo3(signature = (config=None, **kwargs))]
fn find_all(&self, config: Option<&PyDict>, kwargs: Option<&PyDict>) -> Vec<Self> {
let config = self.get_config(config, kwargs);
let matcher = config.get_matcher(&Default::default()).unwrap();
self
.inner
.find_all(matcher)
.map(|n| Self {
inner: n,
root: self.root.clone(),
})
.collect()
}

// TODO find_all
// TODO field
// TODO parent
// TODO child
Expand All @@ -132,6 +132,28 @@ impl SgNode {
// TODO prev_all
}

impl SgNode {
fn get_config(
&self,
config: Option<&PyDict>,
kwargs: Option<&PyDict>,
) -> SerializableRuleCore<SupportLang> {
let lang = self.inner.lang();
if let Some(config) = config {
config_from_dict(lang, config)
} else {
let rule = rule_from_dict(kwargs.unwrap());
SerializableRuleCore {
language: *lang,
rule,
constraints: None,
utils: None,
transform: None,
}
}
}
}

fn config_from_dict(lang: &SupportLang, dict: &PyDict) -> SerializableRuleCore<SupportLang> {
dict.set_item("language", lang.to_string()).unwrap();
depythonize(dict).unwrap()
Expand Down
11 changes: 10 additions & 1 deletion crates/pyo3/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
source = '''
function test() {
let a = 123
let b = 456
let c = 789
}
'''.strip()
sg = SgRoot(source, 'javascript')
Expand Down Expand Up @@ -34,4 +36,11 @@ def test_get_root():
assert node is not None
root2 = node.get_root()
assert root2.filename() == 'anonymous'
# assert root2 == root
# assert root2 == root

def test_get_all():
nodes = root.find_all(pattern = 'let $N = $V')
assert len(nodes) == 3
assert nodes[0].get_match('N').text() == 'a'
assert nodes[1].get_match('N').text() == 'b'
assert nodes[2].get_match('N').text() == 'c'

0 comments on commit 097a209

Please sign in to comment.