In [None]:
# Import built-in modules for file handling, code parsing, and AST manipulation.
import os         # For interacting with the operating system (e.g., file paths, directory traversal)
import ast        # For parsing Python source code into an Abstract Syntax Tree (AST)
import jedi       # For static code analysis and extracting definitions from source code
import networkx as nx  # For constructing and managing graphs (nodes and edges)
from neo4j import GraphDatabase  # For connecting and pushing data to a Neo4j graph database
from pathlib import Path
import logging

logger = logging.getLogger(__name__)

# ---------- Helper functions ----------

def find_python_files(root: str):
    """
    Yield all .py files under the given root directory, 
    skipping any directories named 'tests' (case-insensitive).
    """
    root_path = Path(root)
    for path in root_path.rglob("*.py"):
        # Use pathlib to inspect parts of the path:
        if any(part.lower() == "tests" for part in path.parts):
            continue
        if path.name == "__init__.py":
            continue
        yield path



def get_definitions_info(defs,tree, source, file_path):
    definitions = {}

    for d in defs:
        # Consider only functions, methods and classes
        if d.type in ('function', 'class') and d.full_name:

            # Get parent id for nested relationships
            try:
                parent = d.parent()
                if parent is not None and parent.type in ('function', 'class'):
                    parent_id = parent.full_name
                else:
                    parent_id = None
            except Exception as e:
                print(f"Error getting parent for {d.name}: {e}")
                parent_id = None

            # Determine if the function is actually a method
            node_type = ("method" if ((d.type == "function") and (parent.type == "class"))
                         else d.type)

            # Extract source code of definition
            code_segment = ""
            if tree is not None:
                for node in ast.walk(tree):
                    if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == d.name and node.lineno == d.line:
                        code_segment = ast.get_source_segment(source, node)
                        break

            # Save the extracted information in the dictionary
            definitions[d.full_name] = {
                'id': d.full_name,
                'name': d.name,
                'type': node_type,
                'file': file_path,
                'line': d.line,
                'code': code_segment,
                'definition': d,
                'parent_id': parent_id
            }

    return definitions


def find_call_nodes(tree):
    """Find all call nodes in the AST."""
    call_nodes = []
    for node in ast.walk(tree):

        if isinstance(node, ast.Call):
            call_nodes.append(node)

    return call_nodes

def get_class_context(script, lineno, col_offset):
    """Get the class context for a given position in the source code."""
    # Get the context at the given position
    context = script.get_context(lineno, col_offset)
    
    # Walk up all the context hierarchy to find the class
    while context:
        if context.type == 'class':
            return context
        context = context.parent()
    return None

def get_full_name_of_call(script, call_node):
    """Get the full name of a call node, handling `self`."""

    # Handle self method calls
    if (
        isinstance(call_node.func, ast.Attribute)
        and isinstance(call_node.func.value, ast.Name)
        and call_node.func.value.id == 'self'
    ):

        class_context = get_class_context(script, call_node.lineno, call_node.col_offset) 

        if class_context:
            # Construct the full name of the method
            class_name = class_context.full_name
            method_name = call_node.func.attr
            return f"{class_name}.{method_name}"
        else:
            return None
        
    # normal calls
    else:
        definitions = script.goto(call_node.lineno, call_node.col_offset, follow_imports=True)

        if definitions:
            return definitions[0].full_name

    return None


def get_enclosing_definition(script, lineno, col_offset):
    """
    Use Jedi to determine the innermost definition (function or class)
    that encloses the given line and column position in the file.
    This is used to figure out which definition a function call belongs to.
    """
    try:
        context = script.get_context(line=lineno, column=col_offset)
        # Return the context if it represents a function or class definition
        if context and context.type in ('function', 'class'):
            return context
    except Exception:
        return None
    return None

def get_call_pair_id(call_nodes_list, script):
    call_pairs_list = []

    for call_node in call_nodes_list:
        caller_context = get_enclosing_definition(script, call_node.lineno, call_node.col_offset) # Context calling the function
        if not caller_context:
            continue

        called_function_id = get_full_name_of_call(script, call_node)

        call_pairs_list.append({"caller_id":caller_context.full_name, "candidate_id":called_function_id})
    
    return call_pairs_list


def parse_source(source: str, filename: str):
    """
    Parse Python source code into an AST.
    """
    try:
        return ast.parse(source, filename=filename)
    except SyntaxError as e:
        logger.error(f"Syntax error in {filename}: {e}")
        raise


def get_definitions_calls(source, tree, file_path, project):
    """
    Extract definitions (functions, classes, and methods) from a file using Jedi.
    Returns a dictionary mapping a unique ID to another dictionary containing:
      - id: Unique identifier for the definition
      - name: The name of the function, class, or method
      - type: Type ('class', 'function', or 'method')
      - file: File path where the definition is found
      - line: Line number of the definition
      - code: Source code segment for the definition
    """
    definitions = {}

    # Use Jedi to extract all definitions in the source code
    script = jedi.Script(source, path=file_path, project=project)
    defs = script.get_names(all_scopes=True, definitions=True, references=False)
    
    definitions = get_definitions_info(defs=defs, tree=tree, source=source, file_path=file_path)

    call_ast_nodes = find_call_nodes(tree)

    call_pairs_list = get_call_pair_id(call_ast_nodes, script)   

    return definitions, call_pairs_list



# ---------- Build the Call Graph ----------
def build_call_graph(project_root):
    """
    Build a call graph for all Python files under the project_root directory.
    The graph is represented as a NetworkX directed graph (DiGraph):
      - Nodes represent definitions (functions, methods, classes) with their properties.
      - Edges represent relationships:
          "nested": a definition declared inside another (e.g., method inside a class).
          "call": a function/method call from one definition to another.
    Returns the constructed graph.
    """
    # Initialize an empty directed graph
    G = nx.DiGraph()
    all_defs = {}
    all_calls = []
    project = jedi.Project(project_root)

    # Create nodes: Gather all definitions from all Python files
    for file_path in find_python_files(project_root):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                source = f.read()
        except Exception as e:
            logging.getLogger(__name__).error(f"Error reading {file_path}: {e}")
            continue
        
        tree = parse_source(source, str(file_path))
        defs, calls_list = get_definitions_calls(source=source, tree=tree, file_path=str(file_path), project=project)

        for def_id, info in defs.items():
            all_defs[def_id] = info
            # Add the definition as a node with its attributes
            G.add_node(def_id, **info)

        # Store call relationships until all nodes are created
        all_calls.extend(calls_list)

    # Add nested definition edges
    for def_id, info in all_defs.items():
        parent_id = info.get('parent_id')
        if parent_id:
            G.add_edge(def_id, parent_id, relation='nested_in')

    # Add call relationship edges
    for call in all_calls:
        caller_id = call.get("caller_id")
        candidate_id = call.get("candidate_id")
        if caller_id in G.nodes and candidate_id in G.nodes:
            G.add_edge(caller_id, candidate_id, relation='call')

    return G

# ---------- Neo4j Integration ----------

def push_graph_to_neo4j(G, uri="bolt://localhost:7687", user="neo4j", password="Mike_pass"):
    """
    Push the NetworkX graph G to a Neo4j database.
    - Each node is created with a label according to its type:
        "Class" for classes, "Method" for methods, and "Function" for functions.
    - Each edge is created with a relationship type based on its 'relation' property (e.g., 'CALL' or 'NESTED').
    """

    driver = GraphDatabase.driver(uri, auth=(user, password))
    with driver.session() as session:
        session.run("MATCH (n) DETACH DELETE n")

        for node_id, data in G.nodes(data=True):
            label = ("Class" if data.get('type') == 'class'
                     else "Method" if data.get('type') == 'method'
                     else "Function")

            session.run(
                f"""
                CREATE (n:{label} {{id: $id, name: $name, file: $file, line: $line, code: $code}})
                """,
                id=node_id,
                name=data.get('name'),
                file=data.get('file'),
                line=data.get('line'),
                code=data.get('code')
            )

        for source, target, data in G.edges(data=True):
            rel_type = data.get('relation', 'call').upper()
            session.run(
                f"""
                MATCH (a {{id: $source}}), (b {{id: $target}})
                CREATE (a)-[r:{rel_type}]->(b)
                """,
                source=source,
                target=target
            )

    driver.close()

# ---------- Notebook Main Execution ----------
project_root = "C:/Projects/codebase_rag/.venv/Lib/site-packages/skforecast"  

# Build the call graph for all Python files under the project root.
graph = build_call_graph(project_root)

# # Print a summary of the graph: number of nodes and edges.
print(f"Graph has {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges.")

# # # Push the constructed graph to the Neo4j database.
push_graph_to_neo4j(graph)


Graph has 375 nodes and 1906 edges.
