<a href="https://colab.research.google.com/github/MounikaNallamothu11/program_analysis/blob/AST_Static_branch/Static_Analysis_using_AST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import re

class ASTNode:
    def __init__(self, type, name, children=None):
        self.type = type  # e.g., 'class', 'method', 'call'
        self.name = name
        self.children = children if children else []

    def add_child(self, child):
        self.children.append(child)

    def __repr__(self, level=0):
        ret = "\t" * level + f"{self.type}: {self.name}\n"
        for child in self.children:
            ret += child.__repr__(level + 1)
        return ret


def parse_java_code(code):
    """
    Parses a limited structure of Java code to build a basic AST.
    Identifies classes, methods, and method calls.
    """
    lines = code.splitlines()
    root = ASTNode("root", "root")

    class_node = None
    method_node = None
    detected_calls = set()
    method_dependency_dict = {}
    for line in lines:
        line = line.strip()

        # Check for class declaration
        class_match = re.match(r'(?:\b(public|private|protected)\s+)?\bclass\s+(\w+)', line)
        if class_match:
            class_name = class_match.group(2)
            class_node = ASTNode("class", class_name)
            root.add_child(class_node)
            continue

        # Check for method declaration
        method_match = re.match(r'\b(public|private|protected|static|final|synchronized|native|abstract)?\s*(\w+)\s+(\w+)\s*\((.*?)\)\s*\{', line)
        if method_match:
            method_name = method_match.group(3)+'('+method_match.group(4)+')'
            method_node = ASTNode("method", method_name)
            detected_calls.clear()  # Reset for a new method
            if class_node:
                class_node.add_child(method_node)
                method_dependency_dict[method_node.name]=[]
                detected_calls.add(method_name)
            continue

        # Check for object method calls (e.g., account.getBalance())
        call_match = re.search(r'\b(\w+)\.(\w+)\s*\((.*?)\)', line)
        if call_match and method_node:
            method_name = call_match.group(2)  # Extract method name only
            if method_name not in detected_calls:  # Check for duplicates
                #print(f"Detected method call: {method_name}")
                call_node = ASTNode("call", method_name)
                method_node.add_child(call_node)
                method_dependency_dict[method_node.name].append(call_node.name)  # Mark as detected

        standalone_call_match = re.search(r'\b(\w+)\s*\((.*?)\)\s*;', line)
        if standalone_call_match and method_node:
            method_name = standalone_call_match.group(1)
            if method_name not in detected_calls:  # Check for duplicates
                #print(f"Detected standalone method call: {method_name}")
                call_node = ASTNode("call", method_name)
                method_node.add_child(call_node)
                detected_calls.add(method_name)
                method_dependency_dict[method_node.name].append(call_node.name)

    #print(method_dependency_dict)
    return root


"""def find_test_methods(test_ast, changed_methods, modified_ast):

    #Identifies unit tests that call the changed methods.

    relevant_tests = []
    #print(changed_methods)
    #print(test_ast)
    for class_node in test_ast.children:
        for method_node in class_node.children:
            if method_node.type == "method" and method_node.name.startswith("test"):
                if method_node.name.replace('test', '').lower() in [method.lower() for method in changed_methods]:
                  relevant_tests.append(method_node.name)
                '''for call_node in method_node.children:
                    #print(call_node.name)
                    if call_node.type == "call" and call_node.name in changed_methods:
                        print(method_node.name)
                        print(call_node.name)
                        relevant_tests.append(method_node.name)
                        break'''

    return relevant_tests"""


def extract_changed_methods(code_ast):
    """
    Extracts the names of all methods in the given AST.
    """
    changed_methods = []
    for class_node in code_ast.children:
        for method_node in class_node.children:
            if method_node.type == "method":
                changed_methods.append(method_node.name)
                for call_node in method_node.children:
                    if call_node.type == "call":
                        changed_methods.append(call_node.name)
    return changed_methods


def static_analysis(modified_code_path, unit_test_path):
    """
    Main function to perform static analysis and find relevant unit tests.
    Reads input from the provided file paths.
    """
    # Read files
    with open(modified_code_path, "r") as modified_file:
        modified_code = modified_file.read()
        #print(modified_code)

    with open(unit_test_path, "r") as test_file:
        unit_test_code = test_file.read()

    # Parse the modified code and unit test code into ASTs
    modified_ast = parse_java_code(modified_code)
    print(modified_ast)
    unit_test_ast = parse_java_code(unit_test_code)
    #print(unit_test_ast)
    # Extract changed methods from the modified code
    changed_methods = extract_changed_methods(modified_ast)
    changed_methods = set(changed_methods)
    changed_methods.discard('println')
    print(changed_methods)
    # Find relevant test methods
    #relevant_tests = find_test_methods(unit_test_ast, changed_methods,modified_ast)

    return changed_methods


if __name__ == "__main__":
    # Input: File paths for the modified code and unit tests
    modified_code_path = "/content/new_code.java"  # Replace with your file path
    unit_test_path = "/content/tests.java"  # Replace with your file path

    # Perform static analysis
    relevant_tests = static_analysis(modified_code_path, unit_test_path)
    #print("Relevant Unit Tests:", relevant_tests)


root: root
	class: BankAccount
		method: BankAccount(long accountNr, double initialBalance)
		method: deposit(double amount)
			call: add
		method: add(double amount)
		method: transfer(double amount, BankAccount destinationAccount)
			call: withdraw
		method: calculateInterest(int years)
			call: ArithmeticException
			call: getBalance
			call: getBalance
			call: getBalance
			call: println
			call: println
		method: test()
		method: test2()
		method: getBalance()
		method: withdraw(double amount)
		method: getAccountNumber()

{'ArithmeticException', 'getBalance', 'test2()', 'getAccountNumber()', 'add(double amount)', 'transfer(double amount, BankAccount destinationAccount)', 'calculateInterest(int years)', 'getBalance()', 'withdraw(double amount)', 'withdraw', 'add', 'BankAccount(long accountNr, double initialBalance)', 'test()', 'deposit(double amount)'}


In [4]:
import os
import re

class ASTNode:
    def __init__(self, type, name, children=None):
        self.type = type  # e.g., 'class', 'method', 'call'
        self.name = name
        self.children = children if children else []

    def add_child(self, child):
        self.children.append(child)

    def __repr__(self, level=0):
        ret = "\t" * level + f"{self.type}: {self.name}\n"
        for child in self.children:
            ret += child.__repr__(level + 1)
        return ret


def parse_java_code(code, symbol_table):
    """
    Parses a limited structure of Java code to build a basic AST.
    Identifies classes, methods, and method calls.
    """
    lines = code.splitlines()
    root = ASTNode("root", "root")

    class_node = None
    method_node = None
    for line in lines:
        line = line.strip()

        # Check for class declaration
        class_match = re.match(r'(?:\b(public|private|protected)\s+)?\bclass\s+(\w+)', line)
        if class_match:
            class_name = class_match.group(2)
            class_node = ASTNode("class", class_name)
            root.add_child(class_node)
            symbol_table[class_name] = class_node  # Register class in symbol table
            continue

        # Check for method declaration
        method_match = re.match(r'\b(public|private|protected|static|final|synchronized|native|abstract)?\s*(\w+)\s+(\w+)\s*\((.*?)\)\s*\{', line)
        if method_match:
            method_name = method_match.group(3) + '(' + method_match.group(4) + ')'
            method_node = ASTNode("method", method_name)
            if class_node:
                class_node.add_child(method_node)
                symbol_table[f"{class_node.name}.{method_node.name}"] = method_node  # Register method
            continue

        # Check for object method calls (e.g., account.getBalance())
        call_match = re.search(r'\b(\w+)\.(\w+)\s*\((.*?)\)', line)
        if call_match and method_node:
            object_name = call_match.group(1)
            method_name = call_match.group(2)
            call_node = ASTNode("call", f"{object_name}.{method_name}")
            method_node.add_child(call_node)

        # Check for standalone method calls (e.g., print())
        standalone_call_match = re.search(r'\b(\w+)\s*\((.*?)\)\s*;', line)
        if standalone_call_match and method_node:
            method_name = standalone_call_match.group(1)
            call_node = ASTNode("call", method_name)
            method_node.add_child(call_node)

    return root


def build_ast_from_folder(folder_path):
    """
    Parses all Java files in a folder to build a unified AST.
    Handles inter-class method calls using a symbol table.
    """
    symbol_table = {}  # Map class and method names to AST nodes
    folder_ast = ASTNode("root", folder_path.split("/")[-1])

    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".java"):
                file_path = os.path.join(root, file)
                with open(file_path, "r") as java_file:
                    code = java_file.read()
                    file_ast = parse_java_code(code, symbol_table)
                    folder_ast.add_child(file_ast)

    # Resolve inter-class method calls
    for class_node in folder_ast.children:
        for method_node in class_node.children:
            if method_node.type == "method":
                for call_node in method_node.children:
                    if call_node.type == "call" and "." in call_node.name:
                        object_name, method_name = call_node.name.split(".", 1)
                        for class_key in symbol_table.keys():
                            if class_key.startswith(object_name):
                                full_method_name = f"{class_key}.{method_name}"
                                if full_method_name in symbol_table:
                                    call_node.name = full_method_name  # Update with fully qualified name

    return folder_ast


if __name__ == "__main__":
    folder_path = "/content/outer"  # Replace with your folder path

    # Build the AST for the folder
    folder_ast = build_ast_from_folder(folder_path)
    print(folder_ast)


root: outer
	root: root
		class: Calculator
			method: add(int a, int b)
			method: subtract(int a, int b)
			method: multiply(int a, int b)
			method: divide(int a, int b)
				call: out.println
				call: println
			method: square(int a)
	root: root
		class: BankAccount
			method: BankAccount(long accountNr, double initialBalance)
			method: deposit(double amount)
				call: add
			method: add(double amount)
			method: transfer(double amount, BankAccount destinationAccount)
				call: withdraw
			method: calculateInterest(int years)
				call: ArithmeticException
				call: ArithmeticException
				call: account.getBalance
				call: account.getBalance
				call: getBalance
				call: out.println
				call: println
			method: test()
			method: test2()
			method: getBalance()
			method: withdraw(double amount)
			method: getAccountNumber()
	root: root
		class: MathOperations
			method: MathOperations()
				call: Calculator
			method: sumOfSquares(int a, int b)
				call: calculator.square
				cal

In [98]:
import re

class DependencyTracker:

    class ASTNode:
        def __init__(self, type, name, children=None):
            self.type = type  # e.g., 'class', 'method', 'call'
            self.name = name
            self.children = children if children else []

        def add_child(self, child):
            self.children.append(child)

        def __repr__(self, level=0):
            ret = "\t" * level + f"{self.type}: {self.name}\n"
            for child in self.children:
                ret += child.__repr__(level + 1)
            return ret


    def __init__(self, modified_dir) -> None:
        """
        Intializes the dependency tracker with the paths to the modified Java code and test files
        """
        self.user_defined_methods = set()
        self.folder_ast=self.build_modified_dir_ast(modified_dir)
        #self.modified_code = self.read_java_file(modified_java_file_path)
        #self.modified_tests = self.read_java_file(modified_tests_file_path)



    def read_java_file(self, file_path: str) -> str:
        with open(file_path, "r") as file:
            return file.read()


    def build_modified_dir_ast(self,folder_path):
      # Parse the modified code into an AST
        symbol_table = {}  # Map class and method names to AST nodes
        self.folder_ast = ASTNode("root", folder_path.split("/")[-1])

        for root, _, files in os.walk(folder_path):
            for file in files:
                if file.endswith(".java"):
                    file_path = os.path.join(root, file)
                    with open(file_path, "r") as java_file:
                        code = java_file.read()
                        file_ast = self.parse_java_code(code, symbol_table)
                        self.folder_ast.add_child(file_ast)

        return self.folder_ast

    def provide_all_caller_methods(self, methods: list[str]):

        #modified_ast = folder_ast #self.parse_java_code(self.modified_code)
        #print(self.folder_ast)
        # Filter out non-user-defined methods
        #print('before',self.folder_ast)
        self.folder_ast = self.filter_non_user_defined_methods(self.folder_ast)
        #print('after',self.folder_ast)
        #print(methods)
        # Extract all caller methods
        all_caller_methods = self.extract_callers(methods, self.folder_ast)

        return all_caller_methods


    def filter_non_user_defined_methods(self, root: ASTNode):
        """
        Filters out non-user-defined methods from the AST.
        """
        def dfs(node):
            if node.type == "call" and node.name not in self.user_defined_methods:
                return None
            else:
                node.children = [child for child in node.children if dfs(child)]
                return node

        return dfs(root)


    def parse_java_code(self, code, symbol_table):
      """
      Parses a limited structure of Java code to build a basic AST.
      Identifies classes, methods, and method calls.
      """
      symbol_table={}
      lines = code.splitlines()
      root = ASTNode("root", "root")

      class_node = None
      method_node = None
      for line in lines:
          line = line.strip()

          # Check for class declaration
          class_match = re.match(r'(?:\b(public|private|protected)\s+)?\bclass\s+(\w+)', line)
          if class_match:
              class_name = class_match.group(2)
              class_node = ASTNode("class", class_name)
              root.add_child(class_node)
              symbol_table[class_name] = class_node  # Register class in symbol table
              continue

          # Check for method declaration
          method_match = re.match(r'\b(public|private|protected|static|final|synchronized|native|abstract)?\s*(\w+)\s+(\w+)\s*\((.*?)\)\s*\{', line)
          if method_match:
              method_name = method_match.group(3) #+ '(' + method_match.group(4) + ')'
              method_node = ASTNode("method", method_name)
              if class_node:
                  class_node.add_child(method_node)
                  self.user_defined_methods.add(method_name)
                  symbol_table[f"{class_node.name}.{method_node.name}"] = method_node  # Register method
              continue

          # Check for object method calls (e.g., account.getBalance())
          call_match = re.search(r'\b(\w+)\.(\w+)\s*\((.*?)\)', line)
          if call_match and method_node:
              object_name = call_match.group(1)
              method_name = call_match.group(2)
              call_node = ASTNode("call", f"{object_name}.{method_name}")
              method_node.add_child(call_node)

          # Check for standalone method calls (e.g., print())
          standalone_call_match = re.search(r'\b(\w+)\s*\((.*?)\)\s*;', line)
          if standalone_call_match and method_node:
              method_name = standalone_call_match.group(1)
              call_node = ASTNode("call", method_name)
              method_node.add_child(call_node)

      return root



    def extract_callers(self, target_methods: set[str], code_ast) -> set[str]:
        """
        Extract all methods that call the given target methods.
        Input and output are sets to ensure uniqueness and efficiency.
        Returns callers in the format 'ClassName.methodName'.
        """
        callers = set()

        def find_callers(target_method, target_class):
          for folder_node in code_ast.children:
            if folder_node.type == "root":
                  for class_node in folder_node.children:
                      if class_node.type == "class":
                          current_class = class_node.name
                          for method_node in class_node.children:
                              if method_node.type == "method":
                                  full_method_name = f"{current_class}.{method_node.name}"
                                  for child in method_node.children:
                                      if child.type == "call" and child.name == target_method:
                                          if target_class == current_class or child.full_name.startswith(target_class):
                                              callers.add(full_method_name)

        for method in target_methods:
            class_name, method_name = method.split('.')
            find_callers(method_name, class_name)

        return callers



if __name__ == "__main__":
    # Input: File paths for the modified code and unit tests
    modified_code_path = "BankAccount.java"  # Replace with your file path
    unit_test_path = "BankAccountTest.java"  # Replace with your file path

    # Perform static analysis
    #relevant_tests = static_analysis(modified_code_path, unit_test_path)
    #print("Relevant Unit Tests:", relevant_tests)

In [99]:
    # DependencyTracker is a class that creates an AST of a Java file to track all caller methods of a given list of methods
    dependencyTracker = DependencyTracker('/content/outer/java')
    # Get all caller methods for the directly affected methods
    indirectly_affected_methods = dependencyTracker.provide_all_caller_methods({'MathOperations.multiply'})
    print(indirectly_affected_methods)

{'MathOperations.differenceOfProducts'}
