In [15]:
test_code = """import test.de.ChainClass;

public class TestChainedClass {
    public void testMethod() {
        ChainClass obj 
        = new ChainClass();
        obj.getInner1().doSomething();
        obj.getInner2()
        .doSomething2();

        TestClass2 obj2 = new TestClass2();
        obj2.test_obj2();
        
        obj
        .test_obj
        ();
    }

    public StaticClass ohSoStatic(){
        int staticReturn = StaticClass.countToOne();
        return StaticClass.staticMethodForReturn();
    }
}
"""

In [16]:
from tree_sitter import Language, Parser
import tree_sitter_java as tsjava

parser = Parser()
JAVA_LANGUAGE = Language(tsjava.language())
parser.language = JAVA_LANGUAGE
tree = parser.parse(bytes(test_code, "utf8"))
root_node = tree.root_node

In [17]:
from typing import Optional
from tree_sitter import Node


object_creation_query = JAVA_LANGUAGE.query(
    """
    (
  (local_variable_declaration 
    (type_identifier) @type
    (variable_declarator
      (identifier) @variable_name
      (object_creation_expression
        (type_identifier) @declarator_value_type
      )
    )
  ) @declaration
)
    """
)


def get_text(
    capture: dict[str, Node | list[Node]], var_name: Optional[str] = None
) -> str:
    if var_name is not None:
        node = capture[var_name]
        if node is None:
            raise ValueError(f"Node {var_name} not found in capture")
    else:
        node = capture
    return node.text.decode("utf-8")


object_creation_captures = object_creation_query.matches(tree.root_node)

defined_variables = {}

# Loop through the input list
for _, capture in object_creation_captures:
    # pass each node which is a key of a dict via the text attribute

    defined_variables[get_text(capture, "variable_name")] = {
        "declarator_value_type": get_text(capture, "declarator_value_type"),
        "type": get_text(capture, "type"),
        "calls": [],
        "definition_line_start": capture["declaration"].start_point[0],
        "definition_line_end": capture["declaration"].end_point[0],
    }


static_method_invocation_query = JAVA_LANGUAGE.query(
    """
(
  (local_variable_declaration 
    (variable_declarator
      (identifier) @variable_name
      (method_invocation
        object: (identifier) @method_object
        name: (identifier) @method_name
        arguments: (argument_list)
      )
    )
  ) @declaration
)

    """
)
static_method_captures = static_method_invocation_query.matches(tree.root_node)
for id, capture in static_method_captures:
    defined_variables[get_text(capture, "variable_name")] = {
        "declarator_value_type": get_text(capture, "method_object"),
        # "type": capture['type'].text.decode("utf-8"),
        "type": "todo",
        "calls": [
            {
                "start_line": capture["declaration"].start_point[0],
                "end_line": capture["declaration"].end_point[0],
                "method_name": get_text(capture, "method_name"),
            }
        ],
        "definition_line_start": capture["declaration"].start_point[0],
        "definition_line_end": capture["declaration"].end_point[0],
    }

defined_variables

{'obj': {'declarator_value_type': 'ChainClass',
  'type': 'ChainClass',
  'calls': [],
  'definition_line_start': 4,
  'definition_line_end': 5},
 'obj2': {'declarator_value_type': 'TestClass2',
  'type': 'TestClass2',
  'calls': [],
  'definition_line_start': 10,
  'definition_line_end': 10},
 'staticReturn': {'declarator_value_type': 'StaticClass',
  'type': 'todo',
  'calls': [{'start_line': 19, 'end_line': 19, 'method_name': 'countToOne'}],
  'definition_line_start': 19,
  'definition_line_end': 19}}

In [18]:
query2 = JAVA_LANGUAGE.query(
    """
(expression_statement
  (method_invocation
    object: (identifier) @object
    name: (identifier) @method
    arguments: (argument_list)) @method_invocation
  (#is-not? @object ".*\\n")
)
"""
)


captures = query2.matches(tree.root_node)


# Loop through the input list
for id, nodes in captures:
    if nodes.get("object") is None:
        continue

    object_definition = defined_variables.get(get_text(nodes, "object"))
    if object_definition is not None:

        method_details = {
            "start_line": nodes["method_invocation"].start_point[0],
            "end_line": nodes["method_invocation"].end_point[0],
            "method_name": get_text(nodes, "method"),
        }
        defined_variables[get_text(nodes, "object")]["calls"].append(method_details)
        print("Appended", method_details)

Appended {'start_line': 11, 'end_line': 11, 'method_name': 'test_obj2'}
Appended {'start_line': 13, 'end_line': 15, 'method_name': 'test_obj'}


In [19]:
query_chained_method_invocation = JAVA_LANGUAGE.query(
    """
(
  (method_invocation
    object: (method_invocation
      object: (identifier) @initial_object_name
      name: (identifier) @first_method_name
      arguments: (argument_list)
    )
    name: (identifier) @second_method_name
    arguments: (argument_list)
  ) @method_invocation
)
"""
)

nested_matches = query_chained_method_invocation.matches(tree.root_node)


for id, nodes in nested_matches:
    start_line = nodes["method_invocation"].start_point[0]
    end_line = nodes["method_invocation"].end_point[0]

    defined_variables[get_text(nodes, "initial_object_name")]["calls"].append(
        {
            "start_line": start_line,
            "end_line": end_line,
            "method_name": get_text(nodes, "first_method_name"),
        }
    )

    defined_variables[get_text(nodes, "initial_object_name")]["calls"].append(
        {
            "start_line": start_line,
            "end_line": end_line,
            "method_name": get_text(nodes, "second_method_name"),
        }
    )


#   for key in nodes:
#     print(key, nodes[key].text)

In [20]:
defined_variables

{'obj': {'declarator_value_type': 'ChainClass',
  'type': 'ChainClass',
  'calls': [{'start_line': 13, 'end_line': 15, 'method_name': 'test_obj'},
   {'start_line': 6, 'end_line': 6, 'method_name': 'getInner1'},
   {'start_line': 6, 'end_line': 6, 'method_name': 'doSomething'},
   {'start_line': 7, 'end_line': 8, 'method_name': 'getInner2'},
   {'start_line': 7, 'end_line': 8, 'method_name': 'doSomething2'}],
  'definition_line_start': 4,
  'definition_line_end': 5},
 'obj2': {'declarator_value_type': 'TestClass2',
  'type': 'TestClass2',
  'calls': [{'start_line': 11, 'end_line': 11, 'method_name': 'test_obj2'}],
  'definition_line_start': 10,
  'definition_line_end': 10},
 'staticReturn': {'declarator_value_type': 'StaticClass',
  'type': 'todo',
  'calls': [{'start_line': 19, 'end_line': 19, 'method_name': 'countToOne'}],
  'definition_line_start': 19,
  'definition_line_end': 19}}

In [1]:
method_invocation_query = JAVA_LANGUAGE.query(
    """
    (
      (method_invocation
        object: (_) @object
        name: (identifier) @method
        arguments: (argument_list) @arguments
      ) @method_invocation

    )

    """
)


def get_coords(node, key=None):
    if key is not None:
        node = node[key]
    return node.start_point, node.end_point


def process_method_invocations(node):
    stack = [node]
    method_invocations = []
    visited_nodes = set()

    while stack:
        current_node = stack.pop()

        # Skip nodes that have been visited
        if get_coords(current_node) in visited_nodes:
            continue

        visited_nodes.add(get_coords(current_node))

        captures = method_invocation_query.matches(current_node)

        for _, capture in captures:
            object_text = get_text(capture, "object")
            method_name = get_text(capture, "method")
            start_line = capture["method_invocation"].start_point[0]
            end_line = capture["method_invocation"].end_point[0]

            method_invocation = {
                "object": object_text,
                "method_name": method_name,
                "start_line": start_line,
                "end_line": end_line,
            }

            method_invocations.append(method_invocation)

            # Add the nested method invocation node to the stack to process nested invocations

            print(visited_nodes)
            print(get_coords(capture, "method_invocation"))
            if get_coords(capture, "method_invocation") not in visited_nodes:
                stack.append(capture["method_invocation"])

    return method_invocations


# Process the entire tree to find method invocations
all_method_invocations = process_method_invocations(tree.root_node)

print(all_method_invocations)

# Organize the results by defined variables
defined_variables_in_tree = {}

for invocation in all_method_invocations:
    object_name = invocation["object"]
    if object_name not in defined_variables_in_tree:
        defined_variables_in_tree[object_name] = {"calls": []}
    defined_variables_in_tree[object_name]["calls"].append(
        {
            "method_name": invocation["method_name"],
            "start_line": invocation["start_line"],
            "end_line": invocation["end_line"],
        }
    )

# Print the defined variables with their method invocations
for var_name, details in defined_variables_in_tree.items():
    print(f"Variable: {var_name}")
    for call in details["calls"]:
        print(
            f"  Method: {call['method_name']} (Lines: {call['start_line']}-{call['end_line']})"
        )

# %%
print(defined_variables_in_tree)

NameError: name 'JAVA_LANGUAGE' is not defined

In [None]:
# Online Query: https://github.com/paul-gauthier/aider/blob/main/aider/queries/tree-sitter-java-tags.scm https://github.com/paul-gauthier/aider/blob/main/LICENSE.txt
(class_declaration
  name: (identifier) @name.definition.class) @definition.class

(method_declaration
  name: (identifier) @name.definition.method) @definition.method

(method_invocation
  name: (identifier) @name.reference.call
  arguments: (argument_list) @reference.call)

(interface_declaration
  name: (identifier) @name.definition.interface) @definition.interface

(type_list
  (type_identifier) @name.reference.implementation) @reference.implementation

(object_creation_expression
  type: (type_identifier) @name.reference.class) @reference.class

(superclass (type_identifier) @name.reference.class) @reference.class