In [52]:
!pip install javalang graphviz esprima



In [53]:
import os
import ast
import javalang
from graphviz import Digraph

repo_path = "testFiles"
code_dict = {}

#Loades Py and Java Code and copies it in a dict
for filename in os.listdir(repo_path):
    filepath = os.path.join(repo_path, filename)
    if filename.endswith((".py", ".java")):
        with open(filepath, "r") as f:
            code_dict[filename] = f.read()

# Creates CallChian of Python Functions
class PythonCallVisitor(ast.NodeVisitor):
    def __init__(self, filename):
        self.filename = filename
        self.call_chain = {}

    def visit_FunctionDef(self, node):
        func_name = f"{self.filename}:{node.name}"
        calls = []
        for n in ast.walk(node):
            if isinstance(n, ast.Call):
                if isinstance(n.func, ast.Name):
                    calls.append(n.func.id)
                elif isinstance(n.func, ast.Attribute):
                    calls.append(n.func.attr)
        self.call_chain[func_name] = calls
        self.generic_visit(node)

#Creates CallChain of Java Functions by travesrsing the AST
class JavaCallVisitor:
    def __init__(self, filename):
        self.filename = filename
        self.call_chain = {}

    def visit(self, tree):
        for _, node in tree.filter(javalang.tree.MethodDeclaration):
            method_name = f"{self.filename}:{node.name}"
            calls = []
            if node.body:
                for path, subnode in node:
                    if isinstance(subnode, javalang.tree.MethodInvocation):
                        if subnode.qualifier:
                            calls.append(f"{subnode.qualifier}.{subnode.member}")
                        else:
                            calls.append(subnode.member)
            self.call_chain[method_name] = calls

#Creates DFD of Python File by traversing the AST
class PythonDFDVisitor(ast.NodeVisitor):
    def __init__(self, filename):
        self.filename = filename
        self.flows = []
        self.current_func = None

    def visit_FunctionDef(self, node):
        self.current_func = f"{self.filename}:{node.name}"
        for n in ast.walk(node):
            # Capture return values
            if isinstance(n, ast.Return):
                if isinstance(n.value, ast.Name):
                    data_repr = n.value.id
                elif isinstance(n.value, ast.Constant):
                    data_repr = repr(n.value.value)
                elif isinstance(n.value, ast.BinOp):
                    data_repr = ast.unparse(n.value)
                else:
                    data_repr = type(n.value).__name__
                self.flows.append((self.current_func, data_repr, "(returns)"))

            elif isinstance(n, ast.Call) and isinstance(n.func, ast.Name):
                callee = n.func.id
                for arg in n.args:
                    if isinstance(arg, ast.Name):
                        data_repr = arg.id
                    elif isinstance(arg, ast.Constant):
                        data_repr = repr(arg.value)
                    elif isinstance(arg, ast.BinOp):
                        data_repr = ast.unparse(arg)
                    else:
                        data_repr = type(arg).__name__
                    self.flows.append((self.current_func, data_repr, callee))

        self.generic_visit(node)


#Creates DFD of JAVA code by traversing AST
class JavaDFDExtractor:
    def __init__(self, filename):
        self.filename = filename
        self.flows = []

    def extract(self, tree):
        for type_decl in tree.types:
            if isinstance(type_decl, javalang.tree.ClassDeclaration):
                for method in type_decl.methods:
                    func_name = f"{self.filename}:{method.name}"
                    for path, node in method:
                        # Return statements
                        if isinstance(node, javalang.tree.ReturnStatement):
                            if node.expression:
                                if isinstance(node.expression, javalang.tree.MemberReference):
                                    data_repr = node.expression.member
                                elif isinstance(node.expression, javalang.tree.Literal):
                                    data_repr = node.expression.value
                                else:
                                    data_repr = type(node.expression).__name__
                                self.flows.append((func_name, data_repr, "(returns)"))

                        elif isinstance(node, javalang.tree.MethodInvocation):
                            callee = node.member
                            if node.arguments:
                                for arg in node.arguments:
                                    if isinstance(arg, javalang.tree.MemberReference):
                                        data_repr = arg.member
                                    elif isinstance(arg, javalang.tree.Literal):
                                        data_repr = arg.value
                                    else:
                                        data_repr = str(arg)
                                    self.flows.append((func_name, data_repr, callee))
                            else:
                                self.flows.append((func_name, "()", callee))

In [54]:
repo_call_chain = {}
repo_flows = []

#Creates Object of above classes and traverse the ast
for filename, code in code_dict.items():
      if filename.endswith(".py"):
          tree = ast.parse(code)
          call_visitor = PythonCallVisitor(filename)
          call_visitor.visit(tree)
          repo_call_chain.update(call_visitor.call_chain)
          dfd_visitor = PythonDFDVisitor(filename)
          dfd_visitor.visit(tree)
          repo_flows.extend(dfd_visitor.flows)

      elif filename.endswith(".java"):
          tree = javalang.parse.parse(code)
          call_visitor = JavaCallVisitor(filename)
          call_visitor.visit(tree)
          repo_call_chain.update(call_visitor.call_chain)
          dfd_extractor = JavaDFDExtractor(filename)
          dfd_extractor.extract(tree)
          repo_flows.extend(dfd_extractor.flows)

In [55]:
print("Data Flow (DFD):\n")
for src, data, dst in repo_flows:
    print(f"{src} --[{data}]--> {dst}")

#Creates edges from the repo_call_chain and renders the call-chain graph
dot_calls = Digraph(format="png", engine="sfdp")
dot_calls.attr(size="100,100", dpi="200")
for func, calls in repo_call_chain.items():
    func_name = func.split(":")[1]
    for called_func in calls:
        dot_calls.edge(func_name, called_func)
call_path = "call_chain"
dot_calls.render(call_path, cleanup=True)

#Creates edges from repo_flows and renders the DFD
dot_dfd = Digraph(format="png", engine="sfdp")
dot_dfd.attr(size="100,100", dpi = "200")
for src, data, dst in repo_flows:
    src_name = src.split(":")[1]
    dst_name = dst.split(":")[1] if ":" in dst else dst
    label = f"{data}"
    dot_dfd.edge(src_name, dst_name, label=label)
dfd_path = "DFD"
dot_dfd.render(dfd_path, cleanup=True)

Data Flow (DFD):

SQLInjectionActivity.java:onCreate --["sqli"]--> openOrCreateDatabase
SQLInjectionActivity.java:onCreate --[MODE_PRIVATE]--> openOrCreateDatabase
SQLInjectionActivity.java:onCreate --[null]--> openOrCreateDatabase
SQLInjectionActivity.java:onCreate --["DROP TABLE IF EXISTS sqliuser;"]--> execSQL
SQLInjectionActivity.java:onCreate --["CREATE TABLE IF NOT EXISTS sqliuser(user VARCHAR, password VARCHAR, credit_card VARCHAR);"]--> execSQL
SQLInjectionActivity.java:onCreate --["INSERT INTO sqliuser VALUES ('admin', 'passwd123', '1234567812345678');"]--> execSQL
SQLInjectionActivity.java:onCreate --["INSERT INTO sqliuser VALUES ('diva', 'p@ssword', '1111222233334444');"]--> execSQL
SQLInjectionActivity.java:onCreate --["INSERT INTO sqliuser VALUES ('john', 'password123', '5555666677778888');"]--> execSQL
SQLInjectionActivity.java:onCreate --["Diva-sqli"]--> d
SQLInjectionActivity.java:onCreate --[BinaryOperation(operandl=Literal(postfix_operators=[], prefix_operators=[], qu

'DFD.png'

In [56]:
import re

# Extract method/function code based on keywords
def extract_function_code(code, func_name, language):
    if language == "python":
        pattern = rf"def\s+{re.escape(func_name)}\s*\([^)]*\)\s*:[\s\S]*?(?=\ndef\s|\nclass\s|$)"
    elif language == "java":
        pattern = rf"\b{re.escape(func_name)}\s*\([^)]*\)\s*\{{[\s\S]*?\}}"

    match = re.search(pattern, code)
    return match.group(0).strip() if match else ""

In [57]:
#Determines the level of abstraction of each function in call chain based on the incoming and outgoing nodes
def analyze_call_layers(repo_call_chain):
    incoming_count = {func: 0 for func in repo_call_chain}
    outgoing_count = {func: len(calls) for func, calls in repo_call_chain.items()}

    for func, calls in repo_call_chain.items():
        for called in calls:
            for target_func in repo_call_chain:
                if target_func.endswith(f":{called}"):
                    incoming_count[target_func] += 1

    layer_map = {}
    for func in repo_call_chain:
        incoming = incoming_count[func]
        outgoing = outgoing_count[func]

        if incoming == 0 and outgoing > 0:
            layer_map[func] = "High-level (Application)"
        elif incoming > 0 and outgoing > 0:
            layer_map[func] = "Mid-level (Logic)"
        elif incoming > 0 and outgoing == 0:
            layer_map[func] = "Low-level (Utility)"
        else:
            layer_map[func] = "Isolated/Helper"

    return layer_map


In [58]:
#Extracts arguments passed from function1 to function2 in py using AST
def arguments_from_python(code: str, caller_name: str, callee_name: str):

    tree = ast.parse(code)
    class CallVisitor(ast.NodeVisitor):
        def __init__(self):
            self.args_found = []

        def visit_FunctionDef(self, node):
            if node.name == caller_name:
                for n in ast.walk(node):
                    if isinstance(n, ast.Call) and isinstance(n.func, ast.Name) and n.func.id == callee_name:
                        arg_values = []
                        for arg in n.args:
                            if isinstance(arg, ast.Constant):
                                arg_values.append(repr(arg.value))
                            elif isinstance(arg, ast.Name):
                                arg_values.append(arg.id)
                            else:
                                arg_values.append(ast.unparse(arg))

                        self.args_found.append(arg_values)

    visitor = CallVisitor()
    visitor.visit(tree)
    return visitor.args_found

In [59]:
import javalang

#Extracts arguments passed from method1 to method2 in java from ast
def arguments_from_java(code: str, caller_name: str, callee_name: str):

    tree = javalang.parse.parse(code)


    args_found = []

    for _, node in tree.filter(javalang.tree.MethodDeclaration):
        if node.name == caller_name:
            for path, call in node.filter(javalang.tree.MethodInvocation):
                if call.member == callee_name:
                    arg_values = []
                    for arg in call.arguments:
                        if isinstance(arg, javalang.tree.MemberReference):
                            arg_values.append(arg.member)
                        elif isinstance(arg, javalang.tree.Literal):
                            arg_values.append(arg.value)
                        else:
                            arg_values.append(str(arg))
                    args_found.append(arg_values)
    return args_found


In [60]:
#Finds the type of mapping using keywords
def find_mapping_type(lang1, lang2, code1, code2):
    code1 = code1.lower()
    code2 = code2.lower()

    if lang1 == lang2:
        return "Direct"

    jni_indicators = [
        "native", "system.loadlibrary", "jnienv", "jni_onload",
        "jniexport", "jnicall", "findclass", "getmethodid"
    ]
    if ("java" in [lang1, lang2]) and (
        any(indicator in code1 for indicator in jni_indicators)
        or any(indicator in code2 for indicator in jni_indicators)
    ):
        return "JNI Mapping"

In [61]:
import pandas as pd
layer_map = analyze_call_layers(repo_call_chain)
records = []

#Uses edges in the call chain to generate the
for func1_full, called_funcs in repo_call_chain.items():
    file1, func1 = func1_full.split(":")
    code1 = code_dict[file1]
    lang1 = "python" if file1.endswith(".py") else "java"
    func1_code = extract_function_code(code1, func1, lang1)

    for func2 in called_funcs:
        file2 = next((fname.split(":")[0] for fname in repo_call_chain if fname.endswith(":" + func2)), "unknown")
        code2 = code_dict.get(file2, "")
        lang2 = "python" if file2.endswith(".py") else "java"
        func2_code = extract_function_code(code2, func2, lang2)

        if lang1 == "python":
            args_list = arguments_from_python(code1, func1, func2)
        else:
            args_list = arguments_from_java(code1, func1, func2)

        records.append({
            "function1_name": func1,
            "function2_name": func2,
            "file_name_function1": file1,
            "file_name_function2": file2,
            "parameter_passed": args_list,
            "level_of_abstraction_function1": layer_map.get(func1_full, "Unknown"),
            "level_of_abstraction_function2": layer_map.get(
                next((fname for fname in repo_call_chain if fname.endswith(":" + func2)), "unknown"),
                "Unknown"
            ),
            "mapping_type" : find_mapping_type(lang1, lang2, func1_code, func2_code),
            "code_function1": func1_code,
            "code_function2": func2_code
        })

df = pd.DataFrame(records)
csv_path = "call_chain_links.csv"
df.to_csv(csv_path, index=False, sep=",")

display(df.head())

Unnamed: 0,function1_name,function2_name,file_name_function1,file_name_function2,parameter_passed,level_of_abstraction_function1,level_of_abstraction_function2,mapping_type,code_function1,code_function2
0,onCreate,openOrCreateDatabase,SQLInjectionActivity.java,unknown,"[[""sqli"", MODE_PRIVATE, null]]",Mid-level (Logic),Unknown,Direct,onCreate(Bundle savedInstanceState) {\n ...,
1,onCreate,mDB.execSQL,SQLInjectionActivity.java,unknown,[],Mid-level (Logic),Unknown,Direct,onCreate(Bundle savedInstanceState) {\n ...,
2,onCreate,mDB.execSQL,SQLInjectionActivity.java,unknown,[],Mid-level (Logic),Unknown,Direct,onCreate(Bundle savedInstanceState) {\n ...,
3,onCreate,mDB.execSQL,SQLInjectionActivity.java,unknown,[],Mid-level (Logic),Unknown,Direct,onCreate(Bundle savedInstanceState) {\n ...,
4,onCreate,mDB.execSQL,SQLInjectionActivity.java,unknown,[],Mid-level (Logic),Unknown,Direct,onCreate(Bundle savedInstanceState) {\n ...,
