In [6]:
import os
import ast
import jedi
import networkx as nx
from neo4j import GraphDatabase

In [10]:
import os
import ast
import jedi
import networkx as nx
from neo4j import GraphDatabase

# ---------- Helper functions ----------

def find_python_files(root):
    """Yield all .py files under root, skipping any in a directory named 'tests'."""
    for dirpath, dirnames, filenames in os.walk(root):
        # Remove any directories named 'tests' (case-insensitive) from traversal.
        dirnames[:] = [d for d in dirnames if d.lower() != "tests"]
        for filename in filenames:
            if filename.endswith('.py'):
                yield os.path.join(dirpath, filename)

def get_definitions(file_path):
    """
    Use Jedi to extract definitions from a file.
    Returns a dict mapping a unique id to a dict with properties:
      id, name, type (class, function, or method), file, and line.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return {}
    script = jedi.Script(source, path=file_path)
    defs = script.get_names(all_scopes=True, definitions=True, references=False)
    definitions = {}
    for d in defs:
        if d.type in ('function', 'class'):
            full_name = d.full_name or f"{os.path.basename(file_path)}:{d.name}@{d.line}"
            # For functions, if full_name contains a dot, assume it's a method.
            if d.type == 'function' and '.' in (d.full_name or ''):
                if 'initialize_lags' in d.full_name:
                    print(d.full_name)
                node_type = 'method'
            else:
                node_type = d.type  # 'function' or 'class'
            definitions[full_name] = {
                'id': full_name,
                'name': d.name,
                'type': node_type,
                'file': file_path,
                'line': d.line,
            }
    return definitions

def add_parent_pointers(node, parent=None):
    """Recursively add a 'parent' attribute to AST nodes."""
    for child in ast.iter_child_nodes(node):
        child.parent = node
        add_parent_pointers(child, node)

def get_call_nodes(file_path):
    """
    Parse the AST of a file and return all Call nodes (plus the AST tree).
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return [], None
    tree = ast.parse(source, filename=file_path)
    add_parent_pointers(tree)
    calls = []
    class CallVisitor(ast.NodeVisitor):
        def visit_Call(self, node):
            calls.append(node)
            self.generic_visit(node)
    CallVisitor().visit(tree)
    return calls, tree

def get_enclosing_definition(file_path, lineno, col_offset):
    """
    Use Jedi to get the innermost definition that encloses the given position.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
        script = jedi.Script(source, path=file_path)
        context = script.get_context(line=lineno, column=col_offset)
        if context and context.type in ('function', 'class'):
            return context
    except Exception:
        return None
    return None

def extract_call_target_name(call_node):
    """
    Extract the target function name from a Call node.
    If the function is called directly (e.g. foo()), return 'foo'.
    If it's an attribute call (e.g. obj.bar()), return 'bar'.
    """
    if isinstance(call_node.func, ast.Name):
        return call_node.func.id
    elif isinstance(call_node.func, ast.Attribute):
        return call_node.func.attr
    return None

# ---------- Build the Call Graph ----------

def build_call_graph(project_root):
    """
    Build a call graph for all Python files under project_root.
    Nodes represent definitions (functions, methods, classes) with properties.
    Edges:
      - "call": a call from one definition to a target (matched by name)
      - "nested": when a definition is declared inside another.
    Returns a NetworkX DiGraph.
    """
    G = nx.DiGraph()
    all_defs = {}

    # Pass 1: Gather all definitions.
    for file_path in find_python_files(project_root):
        defs = get_definitions(file_path)
        for def_id, info in defs.items():
            all_defs[def_id] = info
            G.add_node(def_id, **info)

    # Pass 2: Add nested definition edges.
    for file_path in find_python_files(project_root):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                source = f.read()
            tree = ast.parse(source, filename=file_path)
            add_parent_pointers(tree)
        except Exception:
            continue
        # Use AST walk to find nested definitions.
        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                try:
                    script = jedi.Script(source, path=file_path)
                    defs = script.get_names(all_scopes=True, definitions=True, references=False)
                except Exception:
                    continue
                current_id = None
                for d in defs:
                    if d.name == node.name and d.line == node.lineno:
                        current_id = d.full_name or f"{os.path.basename(file_path)}:{d.name}@{d.line}"
                        break
                parent = getattr(node, 'parent', None)
                parent_id = None
                while parent:
                    if isinstance(parent, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                        for pd in defs:
                            if pd.name == parent.name and pd.line == parent.lineno:
                                parent_id = pd.full_name or f"{os.path.basename(file_path)}:{pd.name}@{pd.line}"
                                break
                        if parent_id:
                            break
                    parent = getattr(parent, 'parent', None)
                if current_id and parent_id:
                    if parent_id in G.nodes and current_id in G.nodes:
                        G.add_edge(parent_id, current_id, relation='nested')

    # Pass 3: Add call edges.
    for file_path in find_python_files(project_root):
        calls, _ = get_call_nodes(file_path)
        for call in calls:
            target_name = extract_call_target_name(call)
            if not target_name:
                continue
            caller_context = get_enclosing_definition(file_path, call.lineno, call.col_offset)
            if not caller_context:
                continue
            caller_id = caller_context.full_name or f"{os.path.basename(file_path)}:{caller_context.name}@{caller_context.line}"
            candidate_ids = [did for did, info in all_defs.items() if info['name'] == target_name]
            for candidate_id in candidate_ids:
                if caller_id in G.nodes and candidate_id in G.nodes:
                    G.add_edge(caller_id, candidate_id, relation='call')
    return G

# ---------- Neo4j Integration ----------

def push_graph_to_neo4j(G, uri="bolt://localhost:7687", user="neo4j", password="mike_pass"):
    """
    Push the NetworkX graph G to a Neo4j database.
    Each node is created with a label according to its type:
      - Class, Function, or Method.
    Each edge uses the relationship type from its 'relation' property.
    """
    driver = GraphDatabase.driver(uri, auth=(user, password))
    with driver.session() as session:
        # Clear existing nodes (use with caution in production)
        session.run("MATCH (n) DETACH DELETE n")
        # Create nodes.
        for node_id, data in G.nodes(data=True):
            if data.get('type') == 'class':
                label = "Class"
            elif data.get('type') == 'method':
                label = "Method"
            else:
                label = "Function"
            session.run(
                f"""
                CREATE (n:{label} {{id: $id, name: $name, file: $file, line: $line}})
                """,
                id=node_id,
                name=data.get('name'),
                file=data.get('file'),
                line=data.get('line')
            )
        # Create relationships.
        for source, target, data in G.edges(data=True):
            rel_type = data.get('relation', 'call').upper()  # e.g. 'CALL' or 'NESTED'
            session.run(
                f"""
                MATCH (a {{id: $source}}), (b {{id: $target}})
                CREATE (a)-[r:{rel_type}]->(b)
                """,
                source=source,
                target=target
            )
    driver.close()
    print("Graph successfully pushed to Neo4j.")

# ---------- Notebook Main Execution ----------

# Set the project root directory (make sure to exclude tests directories)
project_root = "C:/Projects/codebase_rag/.venv/Lib/site-packages/skforecast"  # <-- Update this!
graph = build_call_graph(project_root)
print(f"Graph has {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges.")
push_graph_to_neo4j(graph)


.venv.Lib.site-packages.skforecast.model_selection._utils.initialize_lags_grid
.venv.Lib.site-packages.skforecast.utils.utils.initialize_lags
Graph has 375 nodes and 1906 edges.
Graph successfully pushed to Neo4j.


In [1]:
import os
import ast
import jedi
import networkx as nx
from neo4j import GraphDatabase

# ---------- Helper functions ----------

def find_python_files(root):
    """Yield all .py files under root, skipping any in a directory named 'tests'."""
    for dirpath, dirnames, filenames in os.walk(root):
        # Remove any directories named 'tests' (case-insensitive) from traversal.
        dirnames[:] = [d for d in dirnames if d.lower() != "tests"]
        for filename in filenames:
            if filename.endswith('.py'):
                yield os.path.join(dirpath, filename)

def add_parent_pointers(node, parent=None):
    """Recursively add a 'parent' attribute to AST nodes."""
    for child in ast.iter_child_nodes(node):
        child.parent = node
        add_parent_pointers(child, node)

def is_method_in_file(file_path, name, lineno):
    """
    Parse file_path and return True if there is a function definition
    with the given name and lineno that is nested inside a ClassDef.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
        tree = ast.parse(source, filename=file_path)
        add_parent_pointers(tree)
    except Exception as e:
        print(f"Error in is_method_in_file for {file_path}: {e}")
        return False
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == name and node.lineno == lineno:
            parent = getattr(node, 'parent', None)
            while parent:
                if isinstance(parent, ast.ClassDef):
                    return True
                parent = getattr(parent, 'parent', None)
    return False

def get_definitions(file_path):
    """
    Use Jedi to extract definitions from a file.
    Returns a dict mapping a unique id to a dict with properties:
      id, name, type (class, function, or method), file, and line.
    """
    definitions = {}
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return definitions
    script = jedi.Script(source, path=file_path)
    defs = script.get_names(all_scopes=True, definitions=True, references=False)
    for d in defs:
        if d.type in ('function', 'class'):
            full_name = d.full_name or f"{os.path.basename(file_path)}:{d.name}@{d.line}"
            if d.type == 'function':
                # Check using AST whether this function is nested in a class.
                if is_method_in_file(file_path, d.name, d.line):
                    node_type = 'method'
                else:
                    node_type = 'function'
            else:
                node_type = 'class'
            definitions[full_name] = {
                'id': full_name,
                'name': d.name,
                'type': node_type,
                'file': file_path,
                'line': d.line,
            }
    return definitions

def get_call_nodes(file_path):
    """
    Parse the AST of a file and return all Call nodes (plus the AST tree).
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return [], None
    tree = ast.parse(source, filename=file_path)
    add_parent_pointers(tree)
    calls = []
    class CallVisitor(ast.NodeVisitor):
        def visit_Call(self, node):
            calls.append(node)
            self.generic_visit(node)
    CallVisitor().visit(tree)
    return calls, tree

def get_enclosing_definition(file_path, lineno, col_offset):
    """
    Use Jedi to get the innermost definition that encloses the given position.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
        script = jedi.Script(source, path=file_path)
        context = script.get_context(line=lineno, column=col_offset)
        if context and context.type in ('function', 'class'):
            return context
    except Exception:
        return None
    return None

def extract_call_target_name(call_node):
    """
    Extract the target function name from a Call node.
    If the function is called directly (e.g. foo()), return 'foo'.
    If it's an attribute call (e.g. obj.bar()), return 'bar'.
    """
    if isinstance(call_node.func, ast.Name):
        return call_node.func.id
    elif isinstance(call_node.func, ast.Attribute):
        return call_node.func.attr
    return None

# ---------- Build the Call Graph ----------

def build_call_graph(project_root):
    """
    Build a call graph for all Python files under project_root.
    Nodes represent definitions (functions, methods, classes) with properties.
    Edges:
      - "call": a call from one definition to a target (matched by name)
      - "nested": when a definition is declared inside another.
    Returns a NetworkX DiGraph.
    """
    G = nx.DiGraph()
    all_defs = {}

    # Pass 1: Gather all definitions.
    for file_path in find_python_files(project_root):
        defs = get_definitions(file_path)
        for def_id, info in defs.items():
            all_defs[def_id] = info
            G.add_node(def_id, **info)

    # Pass 2: Add nested definition edges.
    for file_path in find_python_files(project_root):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                source = f.read()
            tree = ast.parse(source, filename=file_path)
            add_parent_pointers(tree)
        except Exception:
            continue
        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                try:
                    script = jedi.Script(source, path=file_path)
                    defs = script.get_names(all_scopes=True, definitions=True, references=False)
                except Exception:
                    continue
                current_id = None
                for d in defs:
                    if d.name == node.name and d.line == node.lineno:
                        current_id = d.full_name or f"{os.path.basename(file_path)}:{d.name}@{d.line}"
                        break
                parent = getattr(node, 'parent', None)
                parent_id = None
                while parent:
                    if isinstance(parent, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                        for pd in defs:
                            if pd.name == parent.name and pd.line == parent.lineno:
                                parent_id = pd.full_name or f"{os.path.basename(file_path)}:{pd.name}@{pd.line}"
                                break
                        if parent_id:
                            break
                    parent = getattr(parent, 'parent', None)
                if current_id and parent_id:
                    if parent_id in G.nodes and current_id in G.nodes:
                        G.add_edge(parent_id, current_id, relation='nested')

    # Pass 3: Add call edges.
    for file_path in find_python_files(project_root):
        calls, _ = get_call_nodes(file_path)
        for call in calls:
            target_name = extract_call_target_name(call)
            if not target_name:
                continue
            caller_context = get_enclosing_definition(file_path, call.lineno, call.col_offset)
            if not caller_context:
                continue
            caller_id = caller_context.full_name or f"{os.path.basename(file_path)}:{caller_context.name}@{caller_context.line}"
            candidate_ids = [did for did, info in all_defs.items() if info['name'] == target_name]
            for candidate_id in candidate_ids:
                if caller_id in G.nodes and candidate_id in G.nodes:
                    G.add_edge(caller_id, candidate_id, relation='call')
    return G

# ---------- Neo4j Integration ----------

def push_graph_to_neo4j(G, uri="bolt://localhost:7687", user="neo4j", password="Mike_pass"):
    """
    Push the NetworkX graph G to a Neo4j database.
    Each node is created with a label according to its type:
      - Class, Function, or Method.
    Each edge uses the relationship type from its 'relation' property.
    """
    driver = GraphDatabase.driver(uri, auth=(user, password))
    with driver.session() as session:
        # Clear existing nodes (use with caution in production)
        session.run("MATCH (n) DETACH DELETE n")
        # Create nodes.
        for node_id, data in G.nodes(data=True):
            if data.get('type') == 'class':
                label = "Class"
            elif data.get('type') == 'method':
                label = "Method"
            else:
                label = "Function"
            session.run(
                f"""
                CREATE (n:{label} {{id: $id, name: $name, file: $file, line: $line}})
                """,
                id=node_id,
                name=data.get('name'),
                file=data.get('file'),
                line=data.get('line')
            )
        # Create relationships.
        for source, target, data in G.edges(data=True):
            rel_type = data.get('relation', 'call').upper()  # e.g. 'CALL' or 'NESTED'
            session.run(
                f"""
                MATCH (a {{id: $source}}), (b {{id: $target}})
                CREATE (a)-[r:{rel_type}]->(b)
                """,
                source=source,
                target=target
            )
    driver.close()
    print("Graph successfully pushed to Neo4j.")

# ---------- Notebook Main Execution ----------

# Set the project root directory (make sure to exclude tests directories)
project_root = "C:/Projects/codebase_rag/.venv/Lib/site-packages/skforecast"  # <-- Update this!
graph = build_call_graph(project_root)
print(f"Graph has {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges.")
push_graph_to_neo4j(graph)


Graph has 375 nodes and 1906 edges.
Graph successfully pushed to Neo4j.


In [71]:
import os
import ast
import jedi
import networkx as nx
from neo4j import GraphDatabase

# ---------- Helper functions ----------

def find_python_files(root):
    """Yield all .py files under root, skipping any in a directory named 'tests'."""
    for dirpath, dirnames, filenames in os.walk(root):
        # Remove any directories named 'tests' (case-insensitive) from traversal.
        dirnames[:] = [d for d in dirnames if d.lower() != "tests"]
        for filename in filenames:
            if filename.endswith('.py'):
                yield os.path.join(dirpath, filename)

def add_parent_pointers(node, parent=None):
    """Recursively add a 'parent' attribute to AST nodes."""
    for child in ast.iter_child_nodes(node):
        child.parent = node
        add_parent_pointers(child, node)

def is_method_in_file(file_path, name, lineno):
    """
    Parse file_path and return True if there is a function definition
    with the given name and lineno that is nested inside a ClassDef.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
        tree = ast.parse(source, filename=file_path)
        add_parent_pointers(tree)
    except Exception as e:
        print(f"Error in is_method_in_file for {file_path}: {e}")
        return False
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == name and node.lineno == lineno:
            parent = getattr(node, 'parent', None)
            while parent:
                if isinstance(parent, ast.ClassDef):
                    return True
                parent = getattr(parent, 'parent', None)
    return False

def get_definitions(file_path):
    """
    Use Jedi to extract definitions from a file.
    Returns a dict mapping a unique id to a dict with properties:
      id, name, type (class, function, or method), file, line, and code.
    """
    definitions = {}
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return definitions
    try:
        tree = ast.parse(source, filename=file_path)
    except Exception as e:
        tree = None
    script = jedi.Script(source, path=file_path)
    defs = script.get_names(all_scopes=True, definitions=True, references=False)
    for d in defs:
        if d.type in ('function', 'class'):
            full_name = d.full_name or f"{os.path.basename(file_path)}:{d.name}@{d.line}"
            if d.type == 'function':
                # Check using AST whether this function is nested in a class.
                if is_method_in_file(file_path, d.name, d.line):
                    node_type = 'method'
                else:
                    node_type = 'function'
            else:
                node_type = 'class'
            code_segment = ""
            if tree is not None:
                for node in ast.walk(tree):
                    if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and node.name == d.name and node.lineno == d.line:
                        code_segment = ast.get_source_segment(source, node)
                        break
            definitions[full_name] = {
                'id': full_name,
                'name': d.name,
                'type': node_type,
                'file': file_path,
                'line': d.line,
                'code': code_segment,
            }
    return definitions

def get_call_nodes(file_path):
    """
    Parse the AST of a file and return all Call nodes (plus the AST tree).
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
    except Exception as e:
        print(f"Error reading {file_path}: {e}")
        return [], None
    tree = ast.parse(source, filename=file_path)
    add_parent_pointers(tree)
    calls = []
    class CallVisitor(ast.NodeVisitor):
        def visit_Call(self, node):
            calls.append(node)
            self.generic_visit(node)
    CallVisitor().visit(tree)
    return calls, tree

def get_enclosing_definition(file_path, lineno, col_offset):
    """
    Use Jedi to get the innermost definition that encloses the given position.
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            source = f.read()
        script = jedi.Script(source, path=file_path)
        context = script.get_context(line=lineno, column=col_offset)
        if context and context.type in ('function', 'class'):
            return context
    except Exception:
        return None
    return None

def extract_call_target_name(call_node):
    """
    Extract the target function name from a Call node.
    If the function is called directly (e.g. foo()), return 'foo'.
    If it's an attribute call (e.g. obj.bar()), return 'bar'.
    """
    if isinstance(call_node.func, ast.Name):
        return call_node.func.id
    elif isinstance(call_node.func, ast.Attribute):
        return call_node.func.attr
    return None

# ---------- Build the Call Graph ----------

def build_call_graph(project_root):
    """
    Build a call graph for all Python files under project_root.
    Nodes represent definitions (functions, methods, classes) with properties.
    Edges:
      - "call": a call from one definition to a target (matched by name)
      - "nested": when a definition is declared inside another.
    Returns a NetworkX DiGraph.
    """
    G = nx.DiGraph()
    all_defs = {}

    # Pass 1: Gather all definitions.
    for file_path in find_python_files(project_root):
        defs = get_definitions(file_path)
        for def_id, info in defs.items():
            all_defs[def_id] = info
            G.add_node(def_id, **info)

    # Pass 2: Add nested definition edges.
    for file_path in find_python_files(project_root):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                source = f.read()
            tree = ast.parse(source, filename=file_path)
            add_parent_pointers(tree)
        except Exception:
            continue
        for node in ast.walk(tree):
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                try:
                    script = jedi.Script(source, path=file_path)
                    defs = script.get_names(all_scopes=True, definitions=True, references=False)
                except Exception:
                    continue
                current_id = None
                for d in defs:
                    if d.name == node.name and d.line == node.lineno:
                        current_id = d.full_name or f"{os.path.basename(file_path)}:{d.name}@{d.line}"
                        break
                parent = getattr(node, 'parent', None)
                parent_id = None
                while parent:
                    if isinstance(parent, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                        for pd in defs:
                            if pd.name == parent.name and pd.line == parent.lineno:
                                parent_id = pd.full_name or f"{os.path.basename(file_path)}:{pd.name}@{pd.line}"
                                break
                        if parent_id:
                            break
                    parent = getattr(parent, 'parent', None)
                if current_id and parent_id:
                    if parent_id in G.nodes and current_id in G.nodes:
                        G.add_edge(parent_id, current_id, relation='nested')

    # Pass 3: Add call edges.
    for file_path in find_python_files(project_root):
        calls, _ = get_call_nodes(file_path)
        for call in calls:
            target_name = extract_call_target_name(call)
            if not target_name:
                continue
            caller_context = get_enclosing_definition(file_path, call.lineno, call.col_offset)
            if not caller_context:
                continue
            caller_id = caller_context.full_name or f"{os.path.basename(file_path)}:{caller_context.name}@{caller_context.line}"
            candidate_ids = [did for did, info in all_defs.items() if info['name'] == target_name]
            for candidate_id in candidate_ids:
                if caller_id in G.nodes and candidate_id in G.nodes:
                    G.add_edge(caller_id, candidate_id, relation='call')
    return G

# ---------- Neo4j Integration ----------

def push_graph_to_neo4j(G, uri="bolt://localhost:7687", user="neo4j", password="Mike_pass"):
    """
    Push the NetworkX graph G to a Neo4j database.
    Each node is created with a label according to its type:
      - Class, Function, or Method.
    Each edge uses the relationship type from its 'relation' property.
    """
    driver = GraphDatabase.driver(uri, auth=(user, password))
    with driver.session() as session:
        # Clear existing nodes (use with caution in production)
        session.run("MATCH (n) DETACH DELETE n")
        # Create nodes.
        for node_id, data in G.nodes(data=True):
            if data.get('type') == 'class':
                label = "Class"
            elif data.get('type') == 'method':
                label = "Method"
            else:
                label = "Function"
            session.run(
                f"""
                CREATE (n:{label} {{id: $id, name: $name, file: $file, line: $line, code: $code}})
                """,
                id=node_id,
                name=data.get('name'),
                file=data.get('file'),
                line=data.get('line'),
                code=data.get('code')
            )
        # Create relationships.
        for source, target, data in G.edges(data=True):
            rel_type = data.get('relation', 'call').upper()  # e.g. 'CALL' or 'NESTED'
            session.run(
                f"""
                MATCH (a {{id: $source}}), (b {{id: $target}})
                CREATE (a)-[r:{rel_type}]->(b)
                """,
                source=source,
                target=target
            )
    driver.close()
    print("Graph successfully pushed to Neo4j.")

# ---------- Notebook Main Execution ----------

# Set the project root directory (make sure to exclude tests directories)
project_root = "C:/Projects/codebase_rag/.venv/Lib/site-packages/skforecast"  # <-- Update this!
graph = build_call_graph(project_root)
print(f"Graph has {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges.")
push_graph_to_neo4j(graph)


Graph has 375 nodes and 1906 edges.
Graph successfully pushed to Neo4j.


In [135]:
from langchain_neo4j import Neo4jGraph, GraphCypherQAChain

kg = Neo4jGraph(
    url="bolt://localhost:7687", username="neo4j", password="Mike_pass", database="neo4j"
)

In [136]:
import textwrap
print(textwrap.fill(kg.schema, 60))

Node properties: Class {file: STRING, id: STRING, line:
INTEGER, name: STRING, code: STRING} Function {file: STRING,
id: STRING, line: INTEGER, name: STRING, code: STRING}
Method {file: STRING, id: STRING, line: INTEGER, name:
STRING, code: STRING} Relationship properties:  The
relationships: (:Class)-[:NESTED]->(:Method)
(:Class)-[:CALL]->(:Function) (:Function)-[:CALL]->(:Method)
(:Function)-[:CALL]->(:Function)
(:Function)-[:CALL]->(:Class)
(:Function)-[:NESTED]->(:Function)
(:Method)-[:CALL]->(:Method) (:Method)-[:CALL]->(:Class)
(:Method)-[:CALL]->(:Function)
(:Method)-[:NESTED]->(:Method)


Prompting the LLM

In [138]:
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Always return the question too.
Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.
Examples: Here are a few examples of generated Cypher statements for particular questions:

# What are the nodes directly related to example_node?
MATCH (node)-[]-(other_nodes)
WHERE node.name = 'example.node'
RETURN other_nodes

# What are the function or method nodes that have a relationship towards example_node?
MATCH (node)<-[]-(other_nodes:Function|Method)
WHERE node.name = 'example.node'
RETURN other_nodes

# What are the function or method nodes that call example_node?
MATCH (node)<-[:CALL]-(other_nodes:Function|Method)
WHERE node.name = 'example.node'
RETURN other_nodes

# What is the file that stores example?
MATCH (node)
WHERE node.name = 'example'
RETURN node.file

# Inside what function or method is example_function defined?
MATCH (node)<-[r:NESTED_IN]-(other_node)
WHERE node.name = 'example_function'
RETURN other_node

The question is:
{question}"""

QA_GENERATION_TEMPLATE = """
You are an assistant specialized in retrieving and interpreting code snippets from a graph database.
Based on the user's question and the provided context, identify relevant pieces of code within the node properties and present them clearly.

User's Question:
{question}

Context from Database:
{context}

Extracted Code Snippets:
"""

from langchain.prompts.prompt import PromptTemplate
CYPHER_GENERATION_PROMPT = PromptTemplate(
    input_variables=["schema", "question"], 
    template=CYPHER_GENERATION_TEMPLATE
)

QA_GENERATION_PROMPT = PromptTemplate(
    input_variables=["question", "context"],
    template=QA_GENERATION_TEMPLATE
)

In [139]:
from langchain_openai import ChatOpenAI
import os
from dotenv import load_dotenv
load_dotenv()
from langchain.globals import set_verbose

set_verbose(True)

openai_api_key = os.getenv("OPENAI_API_KEY")

In [140]:

os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGSMITH_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGSMITH_API_KEY"] = "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"
os.environ["LANGSMITH_PROJECT"] = "codebase_rag"

llm = ChatOpenAI()

In [141]:

print("LANGSMITH_TRACING:", os.getenv("LANGSMITH_TRACING"))
print("LANGSMITH_ENDPOINT:", os.getenv("LANGSMITH_ENDPOINT"))
print("LANGSMITH_API_KEY:", os.getenv("LANGSMITH_API_KEY"))
print("LANGSMITH_PROJECT:", os.getenv("LANGSMITH_PROJECT"))
print("OPENAI_API_KEY:", os.getenv("OPENAI_API_KEY"))


LANGSMITH_TRACING: true
LANGSMITH_ENDPOINT: https://api.smith.langchain.com
LANGSMITH_API_KEY: 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
LANGSMITH_PROJECT: codebase_rag
OPENAI_API_KEY: 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


In [142]:
from langsmith import utils
utils.tracing_is_enabled()

True

In [143]:
# llm.invoke("Hello, world!")

In [196]:
cypherChain = GraphCypherQAChain.from_llm(
    llm = ChatOpenAI(model="gpt-4o", temperature=0),
    graph=kg,
    verbose=True,
    allow_dangerous_requests=True,
    cypher_prompt=CYPHER_GENERATION_PROMPT, 
    # qa_prompt=QA_GENERATION_PROMPT
)

In [203]:
question = """I want to modify the ForecasterRecursive class to receive an argument that allows to specify 
a string argument to select different 'differentiator's for the differentiation it implements.
Can you suggest what changes need to be made? Include code snippets
"""


In [204]:
from IPython.display import Markdown, display

def print_markdown(text):
    display(Markdown(text))

def prettyCypherChain(question: str):
    response = cypherChain.run(question)
    print_markdown(response)

prettyCypherChain(question)



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (class:Class)
WHERE class.name = 'ForecasterRecursive'
RETURN class.code[0m
Full Context:

[1m> Finished chain.[0m


To modify the `ForecasterRecursive` class to receive a string argument for selecting different 'differentiator's, you can follow these steps:

1. Add a new parameter `differentiator_type` to the `__init__` method to specify the type of differentiator.
2. Modify the initialization of the `differentiator` attribute based on the `differentiator_type` argument.
3. Update the relevant methods to use the selected differentiator.

Here is a code snippet illustrating these changes:

```python
class ForecasterRecursive(ForecasterBase):
    def __init__(
        self,
        regressor: object,
        lags: Optional[Union[int, list, np.ndarray, range]] = None,
        window_features: Optional[Union[object, list]] = None,
        transformer_y: Optional[object] = None,
        transformer_exog: Optional[object] = None,
        weight_func: Optional[Callable] = None,
        differentiation: Optional[int] = None,
        differentiator_type: str = 'default',  # New parameter
        fit_kwargs: Optional[dict] = None,
        binner_kwargs: Optional[dict] = None,
        forecaster_id: Optional[Union[str, int]] = None
    ) -> None:
        # Existing initialization code...

        # Initialize differentiator based on differentiator_type
        if self.differentiation is not None:
            if not isinstance(differentiation, int) or differentiation < 1:
                raise ValueError(
                    f"Argument `differentiation` must be an integer equal to or "
                    f"greater than 1. Got {differentiation}."
                )
            self.window_size += self.differentiation

            if differentiator_type == 'default':
                self.differentiator = TimeSeriesDifferentiator(
                    order=self.differentiation, window_size=self.window_size
                )
            elif differentiator_type == 'alternative':
                self.differentiator = AlternativeDifferentiator(
                    order=self.differentiation, window_size=self.window_size
                )
            else:
                raise ValueError(f"Unknown differentiator_type: {differentiator_type}")

    # Update other methods if necessary to use self.differentiator
```

In this example, the `differentiator_type` parameter allows you to choose between different differentiator implementations, such as `TimeSeriesDifferentiator` and `AlternativeDifferentiator`. You can add more differentiator types as needed.