# Coder


## Import


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import ast
from pathlib import Path
from typing import Any

from haystack import Document, component
from loguru import logger as lg
from rich import print as rprint

from coder.config.coder_config import CoderConfig

## Setup config


In [None]:
cc = CoderConfig()
rprint(cc.paths)

## Ingest data


In [4]:
class CodeFileDocumentExtractor:
    """
    Tool class that processes a single Python file and splits the code into several chunks.
    Chunks are based on:
      - Top-level functions,
      - Top-level classes,
      - Methods inside classes,
      - And code outside these definitions (global code).

    For each chunk, the following metadata is saved:
      - 'chunk_number': Sequential number of the chunk.
      - 'first_line': The first line of code in the chunk.
      - 'start_line' and 'end_line': The boundaries in the original file.
      - 'type': Type of the chunk (e.g. 'global', 'function', 'class', 'method').
    """

    def __init__(self, file_path: Path):
        self.file_path = file_path

    def _parse_code(self, code: str) -> ast.AST:
        """Parses the source code into an AST."""
        return ast.parse(code)

    def _get_top_level_defs(
        self, tree: ast.AST
    ) -> list[tuple[int, int, str, str, ast.AST]]:
        """
        Returns a list of top-level definitions as tuples:
        (start_line, end_line, type, name, node)
        """
        defs = []
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                defs.append(
                    (
                        node.lineno,
                        getattr(node, "end_lineno", node.lineno),
                        "function",
                        node.name,
                        node,
                    )
                )
            elif isinstance(node, ast.ClassDef):
                defs.append(
                    (
                        node.lineno,
                        getattr(node, "end_lineno", node.lineno),
                        "class",
                        node.name,
                        node,
                    )
                )
        return sorted(defs, key=lambda x: x[0])

    def _extract_global_intervals(
        self,
        code_lines: list[str],
        top_level_defs: list[tuple[int, int, str, str, ast.AST]],
    ) -> list[tuple[int, int, dict[str, Any]]]:
        """
        Determines intervals (line ranges) for global code segments that are not covered by any
        top-level definition.
        """
        intervals = []
        last_end = 1
        for start, end, typ, name, node in top_level_defs:
            if start > last_end:
                intervals.append((last_end, start - 1, {"type": "global"}))
            last_end = max(last_end, end + 1)
        if last_end <= len(code_lines):
            intervals.append((last_end, len(code_lines), {"type": "global"}))
        return intervals

    def _process_function(
        self, start: int, end: int, name: str
    ) -> list[tuple[int, int, dict[str, Any]]]:
        """Returns the interval for a function."""
        return [(start, end, {"type": "function", "name": name})]

    def _process_class(
        self, class_node: ast.ClassDef
    ) -> list[tuple[int, int, dict[str, Any]]]:
        """
        Processes a class node, splitting it into:
          - Header (code before the first method),
          - Method intervals,
          - Gaps between methods (middle), and
          - Footer (code after the last method).
        """
        intervals = []
        class_start = class_node.lineno
        class_end = getattr(class_node, "end_lineno", class_node.lineno)
        methods = []
        for child in class_node.body:
            if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
                methods.append(
                    (
                        child.lineno,
                        getattr(child, "end_lineno", child.lineno),
                        "method",
                        child.name,
                    )
                )
        methods.sort(key=lambda x: x[0])
        if methods:
            # Class header: before the first method.
            if class_start < methods[0][0]:
                intervals.append(
                    (
                        class_start,
                        methods[0][0] - 1,
                        {"type": "class", "name": class_node.name, "part": "header"},
                    )
                )
            prev_end = methods[0][1]
            intervals.append(
                (
                    methods[0][0],
                    methods[0][1],
                    {"type": "method", "class": class_node.name, "name": methods[0][3]},
                )
            )
            for i in range(1, len(methods)):
                # Gap between methods.
                if prev_end + 1 < methods[i][0]:
                    intervals.append(
                        (
                            prev_end + 1,
                            methods[i][0] - 1,
                            {
                                "type": "class",
                                "name": class_node.name,
                                "part": "middle",
                            },
                        )
                    )
                intervals.append(
                    (
                        methods[i][0],
                        methods[i][1],
                        {
                            "type": "method",
                            "class": class_node.name,
                            "name": methods[i][3],
                        },
                    )
                )
                prev_end = methods[i][1]
            # Class footer: after the last method.
            if prev_end < class_end:
                intervals.append(
                    (
                        prev_end + 1,
                        class_end,
                        {"type": "class", "name": class_node.name, "part": "footer"},
                    )
                )
        else:
            intervals.append(
                (class_start, class_end, {"type": "class", "name": class_node.name})
            )
        return intervals

    def _extract_defs_intervals(
        self, tree: ast.AST
    ) -> list[tuple[int, int, dict[str, Any]]]:
        """
        Processes all top-level definitions (functions and classes) and returns their intervals.
        """
        intervals = []
        top_level_defs = self._get_top_level_defs(tree)
        for start, end, typ, name, node in top_level_defs:
            if typ == "function":
                intervals.extend(self._process_function(start, end, name))
            elif typ == "class":
                intervals.extend(self._process_class(node))
        return intervals

    def _create_documents_from_intervals(
        self, code_lines: list[str], intervals: list[tuple[int, int, dict[str, Any]]]
    ) -> list[Document]:
        """
        Converts each interval (chunk) into a Haystack Document with rich metadata.
        """
        documents = []
        chunk_number = 1
        for start, end, meta in intervals:
            if start > end or start < 1 or end > len(code_lines):
                continue
            chunk_lines = code_lines[start - 1 : end]  # Convert to 0-indexed.
            content = "\n".join(chunk_lines)
            first_line = chunk_lines[0] if chunk_lines else ""
            meta_updated = meta.copy()
            meta_updated.update(
                {
                    "chunk_number": chunk_number,
                    "first_line": first_line,
                    "start_line": start,
                    "end_line": end,
                }
            )
            documents.append(Document(content=content, meta=meta_updated))
            chunk_number += 1
        return documents

    def extract_documents(self) -> list[Document]:
        """
        Orchestrates the extraction process:
          1. Reads the file and splits it into lines.
          2. Parses the code into an AST.
          3. Extracts intervals for definitions and global code.
          4. Converts intervals into Haystack Documents.
        """
        code = self.file_path.read_text(encoding="utf-8")
        code_lines = code.splitlines()
        try:
            tree = self._parse_code(code)
        except Exception as e:
            # If parsing fails, return the whole file as a single chunk with an error.
            return [
                Document(
                    content=code,
                    meta={
                        "error": str(e),
                        "chunk_number": 1,
                        "first_line": code_lines[0] if code_lines else "",
                        "start_line": 1,
                        "end_line": len(code_lines),
                    },
                )
            ]

        # Get intervals for definitions and global code.
        defs_intervals = self._extract_defs_intervals(tree)
        top_level_defs = self._get_top_level_defs(tree)
        global_intervals = self._extract_global_intervals(code_lines, top_level_defs)
        intervals = defs_intervals + global_intervals
        intervals.sort(key=lambda x: (x[0], x[1]))

        # Create documents from intervals.
        documents = self._create_documents_from_intervals(code_lines, intervals)
        return documents

In [5]:
@component
class CodeFolderIngestor:
    """
    Pipeline component that accepts a list of folder paths, uses the CodeFileDocumentExtractor
    for each Python file found in each folder, and aggregates the resulting Haystack Documents.
    """

    @component.output_types(documents=list[Document])
    def run(self, folder_paths: list[Path]) -> dict[str, list[Document]]:
        documents: list[Document] = []
        for folder_path in folder_paths:
            for file_path in folder_path.rglob("*.py"):
                extractor = CodeFileDocumentExtractor(file_path)
                document = extractor.extract_documents()
                documents.extend(document)
        return {"documents": documents}

In [None]:
folders = [cc.paths.src_fol]
ingestor = CodeFolderIngestor()
result = ingestor.run(folders)
documents = result["documents"]
print(f"Ingested {len(documents)} documents from the provided folder paths.")

In [None]:
documents[0].content[:100], documents[2].meta

# Dump


In [None]:
class CodeFileDocumentExtractor:
    """
    Tool class that processes a single Python file and splits the code into several chunks.
    The splitting is based on:
      - Top-level functions,
      - Top-level classes,
      - Methods inside classes,
      - And any code outside these definitions (global code).

    For each chunk, the following metadata is saved:
      - 'chunk_number': Sequential number of the chunk.
      - 'first_line': The first line of code in the chunk.
      - 'start_line' and 'end_line': The boundaries in the original file.
      - 'type': Type of the chunk (e.g. 'global', 'function', 'class', 'method').
      - Other details such as the name of the function/class or, for class chunks, which part (header, middle, footer).
    """

    def __init__(self, file_path: Path):
        self.file_path = file_path

    def extract_documents(self) -> list[Document]:
        code = self.file_path.read_text(encoding="utf-8")
        code_lines = code.splitlines()
        try:
            tree = ast.parse(code)
        except Exception as e:
            # If the file fails to parse, return the whole file as one chunk with an error message.
            return [
                Document(
                    content=code,
                    meta={
                        "error": str(e),
                        "chunk_number": 1,
                        "first_line": code_lines[0] if code_lines else "",
                    },
                )
            ]

        # We'll accumulate intervals (start_line, end_line, metadata) for each chunk.
        intervals = []

        def add_interval(start: int, end: int, meta: dict[str, Any]) -> None:
            if start <= end:
                intervals.append((start, end, meta))

        # === 1. Top-level definitions (functions & classes) ===
        top_level_defs = []  # Each item: (start, end, type, name)
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                top_level_defs.append(
                    (
                        node.lineno,
                        getattr(node, "end_lineno", node.lineno),
                        "function",
                        node.name,
                    )
                )
            elif isinstance(node, ast.ClassDef):
                top_level_defs.append(
                    (
                        node.lineno,
                        getattr(node, "end_lineno", node.lineno),
                        "class",
                        node.name,
                    )
                )
        top_level_defs.sort(key=lambda x: x[0])

        # === 2. Global code: segments outside of any top-level definition ===
        last_end = 1
        for start, end, typ, name in top_level_defs:
            if start > last_end:
                add_interval(last_end, start - 1, {"type": "global"})
            last_end = max(last_end, end + 1)
        if last_end <= len(code_lines):
            add_interval(last_end, len(code_lines), {"type": "global"})

        # === 3. Process top-level definitions ===
        for start, end, typ, name in top_level_defs:
            if typ == "function":
                add_interval(start, end, {"type": "function", "name": name})
            elif typ == "class":
                # Find the class node so we can inspect its body.
                class_node = None
                for node in tree.body:
                    if (
                        isinstance(node, ast.ClassDef)
                        and node.name == name
                        and node.lineno == start
                    ):
                        class_node = node
                        break
                if class_node is None:
                    add_interval(start, end, {"type": "class", "name": name})
                else:
                    # For a class, also consider methods (only direct children)
                    methods = []
                    for child in class_node.body:
                        if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
                            methods.append(
                                (
                                    child.lineno,
                                    getattr(child, "end_lineno", child.lineno),
                                    "method",
                                    child.name,
                                )
                            )
                    methods.sort(key=lambda x: x[0])
                    if methods:
                        # Class header: from the start of the class to just before the first method.
                        if class_node.lineno < methods[0][0]:
                            add_interval(
                                class_node.lineno,
                                methods[0][0] - 1,
                                {"type": "class", "name": name, "part": "header"},
                            )
                        # Process each method and gaps between methods.
                        prev_end = methods[0][1]
                        add_interval(
                            methods[0][0],
                            methods[0][1],
                            {"type": "method", "class": name, "name": methods[0][3]},
                        )
                        for i in range(1, len(methods)):
                            # Gap between previous method and current method, if any.
                            if prev_end + 1 < methods[i][0]:
                                add_interval(
                                    prev_end + 1,
                                    methods[i][0] - 1,
                                    {"type": "class", "name": name, "part": "middle"},
                                )
                            add_interval(
                                methods[i][0],
                                methods[i][1],
                                {
                                    "type": "method",
                                    "class": name,
                                    "name": methods[i][3],
                                },
                            )
                            prev_end = methods[i][1]
                        # Class footer: from end of last method to the end of the class.
                        if prev_end < class_node.end_lineno:
                            add_interval(
                                prev_end + 1,
                                class_node.end_lineno,
                                {"type": "class", "name": name, "part": "footer"},
                            )
                    else:
                        add_interval(
                            class_node.lineno,
                            class_node.end_lineno,
                            {"type": "class", "name": name},
                        )

        # === 4. Sort intervals by start line ===
        intervals.sort(key=lambda x: (x[0], x[1]))

        # === 5. Create documents for each interval ===
        documents = []
        chunk_number = 1
        for start, end, meta in intervals:
            if start > end or start < 1 or end > len(code_lines):
                continue
            chunk_lines = code_lines[start - 1 : end]  # converting to 0-indexed
            content = "\n".join(chunk_lines)
            first_line = chunk_lines[0] if chunk_lines else ""
            meta_updated = meta.copy()
            meta_updated.update(
                {
                    "chunk_number": chunk_number,
                    "first_line": first_line,
                    "start_line": start,
                    "end_line": end,
                }
            )
            documents.append(Document(content=content, meta=meta_updated))
            chunk_number += 1

        return documents


# Example usage:
if __name__ == "__main__":
    file_path = Path("/home/pmn/repos/coder/src/coder/config/coder_paths.py")
    # Replace with your file
    extractor = CodeFileDocumentExtractor(file_path)
    chunks = extractor.extract_documents()
    for doc in chunks:
        print(
            f"Chunk {doc.meta['chunk_number']} (lines {doc.meta['start_line']}-{doc.meta['end_line']}):"
        )
        print(f"First line: {doc.meta['first_line']}")
        print(f"Content:\n{doc.content}")
        print("----")

In [None]:
class CodeFileDocumentExtractor:
    """
    Tool class that processes a single Python file and splits the code into several chunks.
    Chunks are based on:
      - Top-level functions,
      - Top-level classes,
      - Methods inside classes,
      - And code outside these definitions (global code).

    For each chunk, the following metadata is saved:
      - 'chunk_number': Sequential number of the chunk.
      - 'first_line': The first line of code in the chunk.
      - 'start_line' and 'end_line': The boundaries in the original file.
      - 'type': Type of the chunk (e.g. 'global', 'function', 'class', 'method').
    """

    def __init__(self, file_path: Path):
        self.file_path = file_path

    def _parse_code(self, code: str) -> ast.AST:
        """Parses the source code into an AST."""
        return ast.parse(code)

    def _get_top_level_defs(
        self, tree: ast.AST
    ) -> list[tuple[int, int, str, str, ast.AST]]:
        """
        Returns a list of top-level definitions as tuples:
        (start_line, end_line, type, name, node)
        """
        defs = []
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
                defs.append(
                    (
                        node.lineno,
                        getattr(node, "end_lineno", node.lineno),
                        "function",
                        node.name,
                        node,
                    )
                )
            elif isinstance(node, ast.ClassDef):
                defs.append(
                    (
                        node.lineno,
                        getattr(node, "end_lineno", node.lineno),
                        "class",
                        node.name,
                        node,
                    )
                )
        return sorted(defs, key=lambda x: x[0])

    def _extract_global_intervals(
        self,
        code_lines: list[str],
        top_level_defs: list[tuple[int, int, str, str, ast.AST]],
    ) -> list[tuple[int, int, dict[str, Any]]]:
        """
        Determines intervals (line ranges) for global code segments that are not covered by any
        top-level definition.
        """
        intervals = []
        last_end = 1
        for start, end, typ, name, node in top_level_defs:
            if start > last_end:
                intervals.append((last_end, start - 1, {"type": "global"}))
            last_end = max(last_end, end + 1)
        if last_end <= len(code_lines):
            intervals.append((last_end, len(code_lines), {"type": "global"}))
        return intervals

    def _process_function(
        self, start: int, end: int, name: str
    ) -> list[tuple[int, int, dict[str, Any]]]:
        """Returns the interval for a function."""
        return [(start, end, {"type": "function", "name": name})]

    def _process_class(
        self, class_node: ast.ClassDef
    ) -> list[tuple[int, int, dict[str, Any]]]:
        """
        Processes a class node, splitting it into:
          - Header (code before the first method),
          - Method intervals,
          - Gaps between methods (middle), and
          - Footer (code after the last method).
        """
        intervals = []
        class_start = class_node.lineno
        class_end = getattr(class_node, "end_lineno", class_node.lineno)
        methods = []
        for child in class_node.body:
            if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
                methods.append(
                    (
                        child.lineno,
                        getattr(child, "end_lineno", child.lineno),
                        "method",
                        child.name,
                    )
                )
        methods.sort(key=lambda x: x[0])
        if methods:
            # Class header: before the first method.
            if class_start < methods[0][0]:
                intervals.append(
                    (
                        class_start,
                        methods[0][0] - 1,
                        {"type": "class", "name": class_node.name, "part": "header"},
                    )
                )
            prev_end = methods[0][1]
            intervals.append(
                (
                    methods[0][0],
                    methods[0][1],
                    {"type": "method", "class": class_node.name, "name": methods[0][3]},
                )
            )
            for i in range(1, len(methods)):
                # Gap between methods.
                if prev_end + 1 < methods[i][0]:
                    intervals.append(
                        (
                            prev_end + 1,
                            methods[i][0] - 1,
                            {
                                "type": "class",
                                "name": class_node.name,
                                "part": "middle",
                            },
                        )
                    )
                intervals.append(
                    (
                        methods[i][0],
                        methods[i][1],
                        {
                            "type": "method",
                            "class": class_node.name,
                            "name": methods[i][3],
                        },
                    )
                )
                prev_end = methods[i][1]
            # Class footer: after the last method.
            if prev_end < class_end:
                intervals.append(
                    (
                        prev_end + 1,
                        class_end,
                        {"type": "class", "name": class_node.name, "part": "footer"},
                    )
                )
        else:
            intervals.append(
                (class_start, class_end, {"type": "class", "name": class_node.name})
            )
        return intervals

    def _extract_defs_intervals(
        self, tree: ast.AST
    ) -> list[tuple[int, int, dict[str, Any]]]:
        """
        Processes all top-level definitions (functions and classes) and returns their intervals.
        """
        intervals = []
        top_level_defs = self._get_top_level_defs(tree)
        for start, end, typ, name, node in top_level_defs:
            if typ == "function":
                intervals.extend(self._process_function(start, end, name))
            elif typ == "class":
                intervals.extend(self._process_class(node))
        return intervals

    def _create_documents_from_intervals(
        self, code_lines: list[str], intervals: list[tuple[int, int, dict[str, Any]]]
    ) -> list[Document]:
        """
        Converts each interval (chunk) into a Haystack Document with rich metadata.
        """
        documents = []
        chunk_number = 1
        for start, end, meta in intervals:
            if start > end or start < 1 or end > len(code_lines):
                continue
            chunk_lines = code_lines[start - 1 : end]  # Convert to 0-indexed.
            content = "\n".join(chunk_lines)
            first_line = chunk_lines[0] if chunk_lines else ""
            meta_updated = meta.copy()
            meta_updated.update(
                {
                    "chunk_number": chunk_number,
                    "first_line": first_line,
                    "start_line": start,
                    "end_line": end,
                }
            )
            documents.append(Document(content=content, meta=meta_updated))
            chunk_number += 1
        return documents

    def extract_documents(self) -> list[Document]:
        """
        Orchestrates the extraction process:
          1. Reads the file and splits it into lines.
          2. Parses the code into an AST.
          3. Extracts intervals for definitions and global code.
          4. Converts intervals into Haystack Documents.
        """
        code = self.file_path.read_text(encoding="utf-8")
        code_lines = code.splitlines()
        try:
            tree = self._parse_code(code)
        except Exception as e:
            # If parsing fails, return the whole file as a single chunk with an error.
            return [
                Document(
                    content=code,
                    meta={
                        "error": str(e),
                        "chunk_number": 1,
                        "first_line": code_lines[0] if code_lines else "",
                        "start_line": 1,
                        "end_line": len(code_lines),
                    },
                )
            ]

        # Get intervals for definitions and global code.
        defs_intervals = self._extract_defs_intervals(tree)
        top_level_defs = self._get_top_level_defs(tree)
        global_intervals = self._extract_global_intervals(code_lines, top_level_defs)
        intervals = defs_intervals + global_intervals
        intervals.sort(key=lambda x: (x[0], x[1]))

        # Create documents from intervals.
        documents = self._create_documents_from_intervals(code_lines, intervals)
        return documents


# Example usage:
if __name__ == "__main__":
    file_path = Path("/home/pmn/repos/coder/src/coder/config/coder_paths.py")
    extractor = CodeFileDocumentExtractor(file_path)
    documents = extractor.extract_documents()
    for doc in documents:
        print(
            f"Chunk {doc.meta['chunk_number']} (lines {doc.meta['start_line']}-{doc.meta['end_line']}):"
        )
        print(f"First line: {doc.meta['first_line']}")
        print(f"Content:\n{doc.content}")
        print("----")