In [4]:
from tree_sitter import Language, Parser
import os

# Load Java language grammar (this needs the tree-sitter-java repository compiled)
Language.build_library(
  "build/language-java.so",
  ["../tree-sitter-java/"]
)

JAVA_LANGUAGE = Language("build/language-java.so", "java")

def get_node_text(node, lines):
  start = node.start_point
  end = node.end_point
  if start[0] == end[0]:
    return lines[start[0]][start[1]:end[1]]
  else:
    text = [lines[start[0]][start[1]:]]
    for i in range(start[0] + 1, end[0]):
      text.append(lines[i])
    text.append(lines[end[0]][:end[1]])
    return '\n'.join(text)

def extract_java_fragment(filename, line):
  with open(filename, 'r') as file:
    code = file.read()

  parser = Parser()
  parser.set_language(JAVA_LANGUAGE)
  tree = parser.parse(bytes(code, 'utf8'))

  root_node = tree.root_node
  lines = code.split('\n')

  # Function to find the node for the specified line
  # def find_node(node):
  #   if node.type == 'class_declaration' or node.type == 'import_declaration':
  #     return node
  #   for child in node.children:
  #     if child.start_point[0] <= line-1 <= child.end_point[0]:
  #       if child.type == 'method_declaration':
  #         return child
  #       return find_node(child)

  # def node_for_line(node, line_number):
  #   if node.start_point[0] <= line_number <= node.end_point[0]:
  #     for child in node.children:
  #       result = node_for_line(child, line_number)
  #       if result is not None:
  #         return result
  #     return node
  #   return None

  def find_function_node(node, line_number):
    if node.start_point[0] <= line_number <= node.end_point[0]:
      if node.type == 'method_declaration' or node.type == 'constructor_declaration':
        return node
      for child in node.children:
        result = find_function_node(child, line_number)
        if result is not None:
          return result
    return None

  function_node = find_function_node(root_node, line)
  last_start_point = (0, 0)
  last_end_point = (0, 0)
  print(root_node.sexp())

  def process_node(node):
    print(node)
    print(last_start_point)
    print(last_end_point)
    
    result = ""

    # Skip comments
    if node.type == 'comment' or node.type == 'line_comment':
      result = ""

    # If it's a function and not the input_node, stub it out
    elif (node.type == 'method_declaration' or node.type == 'constructor_declaration') and node != function_node:
      start_index = node.start_byte
      end_index = node.child_by_field_name('body').start_byte
      result = (code[start_index:end_index] + '{ ... }')

    # If it's the input_node or any other type of node, print it
    elif node == function_node or not node.children:
      result = (code[node.start_byte:node.end_byte])
    else:
      for child in node.children:
        result += process_node(child)

    return result + whitespace(node)
      
  def whitespace(node):
    nonlocal last_start_point
    nonlocal last_end_point

    result = ""

    if node.start_point[0] > last_end_point[0]:
      result = "\n"
    elif node.start_point[1] > last_end_point[1]:
      result = " "

    last_start_point = node.start_point
    last_end_point = node.end_point

    return result

  return process_node(root_node)

In [56]:
result = extract_java_fragment('ComplexA.java', 1+8)
print(result)

(program (import_declaration (identifier)) (class_declaration (modifiers) name: (identifier) body: (class_body (line_comment) (method_declaration (modifiers) type: (array_type element: (floating_point_type) dimensions: (dimensions)) name: (identifier) parameters: (formal_parameters (formal_parameter type: (array_type element: (floating_point_type) dimensions: (dimensions)) name: (identifier)) (formal_parameter type: (array_type element: (floating_point_type) dimensions: (dimensions)) name: (identifier))) body: (block (return_statement (array_creation_expression type: (floating_point_type) dimensions: (dimensions) value: (array_initializer (binary_expression left: (array_access array: (identifier) index: (decimal_integer_literal)) right: (array_access array: (identifier) index: (decimal_integer_literal))) (binary_expression left: (array_access array: (identifier) index: (decimal_integer_literal)) right: (array_access array: (identifier) index: (decimal_integer_literal)))))))) (line_comm

In [40]:
import tree_sitter
from tree_sitter import Language, Parser

def extract_code_fragment(filename, line_number):
    # Initialize parser
    parser = Parser()
    parser.set_language(JAVA_LANGUAGE)

    # Read the Java file
    with open(filename, 'r') as file:
        code = file.read()

    # Parse the code
    tree = parser.parse(bytes(code, 'utf8'))
    root_node = tree.root_node

    # Function to check if a node includes the specified line
    def includes_line(node, line):
        start_line = node.start_point[0] + 1
        end_line = node.end_point[0] + 1
        return start_line <= line <= end_line

    # Function to create a stub for methods
    def method_stub(node):
        start_byte = node.start_byte
        end_byte = node.child_by_field_name('body').start_byte
        return code[start_byte:end_byte] + "{ ... }"

    # Extract relevant parts
    imports, members, methods, target_method = [], [], [], None
    for node in root_node.children:
        if node.type == 'import_declaration':
            imports.append(code[node.start_byte:node.end_byte].strip())

        elif node.type == 'class_declaration':
            for child in node.children:
                if child.type != 'class_body':
                    continue

                for child in child.children:
                    if child.type == 'field_declaration':
                        members.append(code[child.start_byte:child.end_byte].strip())
                    elif child.type == 'method_declaration':
                        if includes_line(child, line_number):
                            target_method = code[child.start_byte:child.end_byte].strip()
                        else:
                            methods.append(method_stub(child))

    # Assemble the output
    class_declaration = next(node for node in root_node.children if node.type == 'class_declaration')
    class_header = code[class_declaration.start_byte:class_declaration.child_by_field_name('body').start_byte].strip() + " {"

    output = '\n'.join(imports) + "\n\n" + class_header + "\n"
    # output += '    ' + '\n    '.join(members + ([target_method] if target_method else methods)) + "\n}"
    output += '    ' + '\n    '.join(members + ([target_method, *methods] if target_method else methods)) + "\n}"
    return output

def extract_stubbed_class(filename):
    # Initialize parser
    parser = Parser()
    parser.set_language(JAVA_LANGUAGE)

    # Read the Java file
    with open(filename, 'r') as file:
        code = file.read()

    # Parse the code
    tree = parser.parse(bytes(code, 'utf8'))
    root_node = tree.root_node

    # Function to create a stub for methods
    def method_stub(node):
        start_byte = node.start_byte
        end_byte = node.child_by_field_name('body').start_byte
        return code[start_byte:end_byte] + "{ ... }"

    # Extract relevant parts
    imports, members, methods = [], [], []
    for node in root_node.children:
        if node.type == 'import_declaration':
            imports.append(code[node.start_byte:node.end_byte].strip())

        elif node.type == 'class_declaration':
            for child in node.children:
                if child.type != 'class_body':
                    continue

                for child in child.children:
                    if child.type == 'field_declaration':
                        members.append(code[child.start_byte:child.end_byte].strip())
                    elif child.type == 'method_declaration':
                        methods.append(method_stub(child))

    # Assemble the output
    class_declaration = next(node for node in root_node.children if node.type == 'class_declaration')
    class_header = code[class_declaration.start_byte:class_declaration.child_by_field_name('body').start_byte].strip() + " {"

    output = '\n'.join(imports) + "\n\n" + class_header + "\n"
    # output += '    ' + '\n    '.join(members + ([target_method] if target_method else methods)) + "\n}"
    output += '    ' + '\n    '.join(members + methods) + "\n}"
    return output

def deindent_code_block(code_block):
    lines = code_block.split('\n')
    if not lines:
        return code_block

    # Find the number of leading spaces or tabs in the first line
    first_line = lines[0]
    leading_whitespace = len(first_line) - len(first_line.lstrip())

    # Remove that amount of whitespace from each line
    deindented_lines = [line[leading_whitespace:] if len(line) >= leading_whitespace else line for line in lines]
    
    return '\n'.join(deindented_lines)

def extract_method_from_line_number(filename, line_number):
    # Initialize parser
    parser = Parser()
    parser.set_language(JAVA_LANGUAGE)

    # Read the Java file
    with open(filename, 'r') as file:
        code = file.read()

    # Parse the code
    tree = parser.parse(bytes(code, 'utf8'))
    root_node = tree.root_node

    # Function to check if a node includes the specified line
    def includes_line(node, line):
        start_line = node.start_point[0] + 1
        end_line = node.end_point[0] + 1
        return start_line <= line <= end_line

    # Function to create a stub for methods
    def method_stub(node):
        start_byte = node.start_byte
        end_byte = node.child_by_field_name('body').start_byte
        return code[start_byte:end_byte] + "{ ... }"

    # Extract relevant parts
    imports, members, methods, target_method = [], [], [], None
    for node in root_node.children:
        if node.type == 'class_declaration':
            for child in node.children:
                if child.type != 'class_body':
                    continue

                for child in child.children:
                    if child.type == 'field_declaration':
                        members.append(code[child.start_byte:child.end_byte].strip())
                    elif child.type == 'method_declaration':
                        if includes_line(child, line_number):
                            # print(child)
                            target_method = code[child.start_byte-child.start_point[1]:child.end_byte]
                            # target_method = code[child.start_byte:child.end_byte].strip()
                        else:
                            methods.append(method_stub(child))




    return deindent_code_block(target_method)


In [41]:
print(extract_method_from_line_number('ComplexA.java', 9))
print(extract_stubbed_class('ComplexA.java'))


public static double[] add(double[] a, double[] b) {
    return new double[]{a[0] + b[0], a[1] + b[1]};
}
import x;

public class Complex {
    private int member_variable;
    public static double[] add(double[] a, double[] b) { ... }
    public static double[] sub(double[] a, double[] b) { ... }
    public static double[] mul(double[] a, double[] b) { ... }
    public static double[] div(double[] a, double[] b) { ... }
}
