1. 安装tree-sitter

In [1]:
# %pip install tree-sitter==0.23.2
# %pip install tree-sitter-c==0.23.4 # 根据目标语言选择

2. 初始化

In [2]:
from tree_sitter import Language, Parser, Node
import tree_sitter_c
# 加载语言库
language = Language(tree_sitter_c.language())
parser = Parser(language)

3. 分析代码
    - 3.1 获得Tree_root_node
    - 3.2 Query

In [3]:
Code = """
#include <stdio.h>

// 函数声明
void print_hello();
int add(int a, int b);
int subtract(int a, int b);
void swap(int *a, int *b);

int main() {
    // 调用 print_hello 函数
    print_hello();

    // 调用 add 和 subtract 函数
    int num1 = 10;
    int num2 = 5;
    int sum = add(num1, num2);
    int difference = subtract(num1, num2);

    printf("Sum of %d and %d is: %d\n", num1, num2, sum);
    printf("Difference between %d and %d is: %d\n", num1, num2, difference);

    // 调用 swap 函数
    int x = 100;
    int y = 200;
    printf("Before swap: x = %d, y = %d\n", x, y);
    swap(&x, &y);
    printf("After swap: x = %d, y = %d\n", x, y);

    return 0;
}

// 函数定义
void print_hello() {
    printf("Hello, World!\n");
}

int add(int a, int b) {
    return a + b;
}

int subtract(int a, int b) {
    return a - b;
}

void swap(int *a, int *b) {
    int temp = *a;
    *a = *b;
    *b = temp;
}
"""
tree = parser.parse(bytes(Code,'utf8'))
root_node = tree.root_node

In [10]:
Query = language.query("""
    (function_definition
        (function_declarator
            (identifier)@func_id
        )
    )

    (declaration
        (function_declarator
            (identifier)@decl_id 
        )
    )
""") # 以搜索函数名举例
captures = Query.captures(root_node) # 选择想搜索的节点，这里以根节点为例
for capture_pattern, node_list in captures.items():
    print(f"---------------{capture_pattern}--------------")
    for node in node_list:
        func_id = node.text.decode()
        function = node.parent.parent # 网上找两个节点，可以得到函数整体的节点
        print(func_id)

---------------decl_id--------------
print_hello
subtract
swap
add
---------------func_id--------------
main
subtract
print_hello
add
swap


In [12]:
# 如果想要搜索特定id的节点
identifier = "add"
Query = language.query(f"""
    (function_definition
        (function_declarator
            (identifier)@func_id
            (#eq? @func_id "{identifier}")
        )
    )

    (declaration
        (function_declarator
            (identifier)@decl_id 
            (#eq? @decl_id "{identifier}")
        )
    )
""")
captures = Query.captures(root_node)
for capture_pattern, node_list in captures.items():
    print(f"---------------{capture_pattern}--------------")
    for node in node_list:
        print(node.parent.parent.text.decode())

---------------decl_id--------------
int add(int a, int b);
---------------func_id--------------
int add(int a, int b) {
    return a + b;
}
