<a href="https://colab.research.google.com/github/MounikaNallamothu11/program_analysis/blob/AST_Static_branch/Static_Analysis_using_AST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
import re

class ASTNode:
    def __init__(self, type, name, children=None):
        self.type = type  # e.g., 'class', 'method', 'call'
        self.name = name
        self.children = children if children else []

    def add_child(self, child):
        self.children.append(child)

    def __repr__(self, level=0):
        ret = "\t" * level + f"{self.type}: {self.name}\n"
        for child in self.children:
            ret += child.__repr__(level + 1)
        return ret


def parse_java_code(code):
    """
    Parses a limited structure of Java code to build a basic AST.
    Identifies classes, methods, and method calls.
    """
    lines = code.splitlines()
    root = ASTNode("root", "root")

    class_node = None
    method_node = None
    detected_calls = set()
    method_dependency_dict = {}
    for line in lines:
        line = line.strip()

        # Check for class declaration
        class_match = re.match(r'\bclass\s+(\w+)', line)
        if class_match:
            class_name = class_match.group(1)
            class_node = ASTNode("class", class_name)
            root.add_child(class_node)
            continue

        # Check for method declaration
        method_match = re.match(r'\b(public|private|protected|static|final|synchronized|native|abstract)?\s*(\w+)\s+(\w+)\s*\((.*?)\)\s*\{', line)
        if method_match:
            method_name = method_match.group(3)
            method_node = ASTNode("method", method_name)
            detected_calls.clear()  # Reset for a new method
            if class_node:
                class_node.add_child(method_node)
                method_dependency_dict[method_node.name]=[]
                detected_calls.add(method_name)
            continue

        # Check for object method calls (e.g., account.getBalance())
        call_match = re.search(r'\b(\w+)\.(\w+)\s*\((.*?)\)', line)
        if call_match and method_node:
            method_name = call_match.group(2)  # Extract method name only
            if method_name not in detected_calls:  # Check for duplicates
                #print(f"Detected method call: {method_name}")
                call_node = ASTNode("call", method_name)
                method_node.add_child(call_node)
                method_dependency_dict[method_node.name].append(call_node.name)  # Mark as detected

        standalone_call_match = re.search(r'\b(\w+)\s*\((.*?)\)\s*;', line)
        if standalone_call_match and method_node:
            method_name = standalone_call_match.group(1)
            if method_name not in detected_calls:  # Check for duplicates
                #print(f"Detected standalone method call: {method_name}")
                call_node = ASTNode("call", method_name)
                method_node.add_child(call_node)
                detected_calls.add(method_name)
                method_dependency_dict[method_node.name].append(call_node.name)

    #print(method_dependency_dict)
    return root


def find_test_methods(test_ast, changed_methods, modified_ast):
    """
    Identifies unit tests that call the changed methods.
    """
    relevant_tests = []
    #print(changed_methods)
    #print(test_ast)
    for class_node in test_ast.children:
        for method_node in class_node.children:
            if method_node.type == "method" and method_node.name.startswith("test"):
                if method_node.name.replace('test', '').lower() in [method.lower() for method in changed_methods]:
                  relevant_tests.append(method_node.name)
                '''for call_node in method_node.children:
                    #print(call_node.name)
                    if call_node.type == "call" and call_node.name in changed_methods:
                        print(method_node.name)
                        print(call_node.name)
                        relevant_tests.append(method_node.name)
                        break'''

    return relevant_tests


def extract_changed_methods(code_ast):
    """
    Extracts the names of all methods in the given AST.
    """
    changed_methods = []
    for class_node in code_ast.children:
        for method_node in class_node.children:
            if method_node.type == "method":
                changed_methods.append(method_node.name)
                for call_node in method_node.children:
                    if call_node.type == "call":
                        changed_methods.append(call_node.name)
    return changed_methods


def static_analysis(modified_code_path, unit_test_path):
    """
    Main function to perform static analysis and find relevant unit tests.
    Reads input from the provided file paths.
    """
    # Read files
    with open(modified_code_path, "r") as modified_file:
        modified_code = modified_file.read()
        #print(modified_code)

    with open(unit_test_path, "r") as test_file:
        unit_test_code = test_file.read()

    # Parse the modified code and unit test code into ASTs
    modified_ast = parse_java_code(modified_code)
    print(modified_ast)
    unit_test_ast = parse_java_code(unit_test_code)
    #print(unit_test_ast)
    # Extract changed methods from the modified code
    changed_methods = extract_changed_methods(modified_ast)
    changed_methods = set(changed_methods)
    changed_methods.discard('println')
    print(changed_methods)
    # Find relevant test methods
    relevant_tests = find_test_methods(unit_test_ast, changed_methods,modified_ast)

    return relevant_tests


if __name__ == "__main__":
    # Input: File paths for the modified code and unit tests
    modified_code_path = "/content/code.java"  # Replace with your file path
    unit_test_path = "/content/tests.java"  # Replace with your file path

    # Perform static analysis
    relevant_tests = static_analysis(modified_code_path, unit_test_path)
    print("Relevant Unit Tests:", relevant_tests)


root: root
	class: BankAccount
		method: sumPositiveBalances
			call: getBalance
			call: getBalance
			call: getBalance
			call: println
			call: println
		method: deposit
			call: add
		method: add

{'deposit', 'add', 'getBalance', 'sumPositiveBalances'}
Relevant Unit Tests: ['testDeposit', 'testAdd', 'testGetBalance', 'testSumPositiveBalances']
