In [2]:
from llama_index.core import SimpleDirectoryReader
from llama_index.core.text_splitter import CodeSplitter
from llama_index.packs.code_hierarchy import CodeHierarchyNodeParser

import os
from IPython.display import Markdown, display
from dotenv import load_dotenv
load_dotenv()


def print_python(python_text):
    """This function prints python text in ipynb nicely formatted."""
    display(Markdown("```python\n" + python_text + "```"))

In [16]:
class GlobalGraphInfo:
    visited_nodes: dict
    imports: dict
    import_aliases: dict
    autoloaded_modules: dict
    inheritances: dict
    entity_id: str
    aliases: dict

    def __init__(self, entity_id: str):
        self.visited_nodes = {}
        self.imports = {}
        self.import_aliases = {}
        self.auto_loaded_imports = {}
        self.inheritances = {}
        self.entity_id = entity_id
        self.aliases = {}

### Utility functions

In [18]:
from pathlib import Path
import re
import hashlib

def _skip_file(path: Path) -> bool:
    # skip lock files
    path = path.name
    if path.endswith("lock") or path == "package-lock.json" or path == "yarn.lock":
        return True
    # skip tests and legacy directories
    if path in ["legacy", "test"] and self.skip_tests:
        return True
    # skip hidden files
    if path.startswith("."):
        return True
    # skip images
    if path.endswith(".png") or path.endswith(".jpg"):
        return True
    return False

def _remove_non_ascii(text):
    # Define the regular expression pattern to match ascii characters
    pattern = re.compile(r"[^\x00-\x7F]+")
    # Replace ascii characters with an empty string
    cleaned_text = pattern.sub("", text)
    return cleaned_text

def _skip_directory(directory: Path) -> bool:
    # skip hidden directories
    if directory.name.startswith("."):
        return True
    return directory == "__pycache__" or directory == "node_modules"



## --- BASE PARSER
def generate_node_id(path: str, company_id: str):
    # Concatenate path and signature
    combined_string = f"{company_id}:{path}"
    hash_object = hashlib.md5()
    hash_object.update(combined_string.encode("utf-8"))
    # Get the hexadecimal representation of the hash
    node_id = hash_object.hexdigest()
    return node_id

def is_package(path: str) -> bool:
    return os.path.exists(os.path.join(path, "__init__.py"))

## Format node functions

In [52]:
import os
from dataclasses import dataclass
from typing import List, Optional
from enum import Enum

from llama_index.core.schema import BaseNode

# ---- Schema

class MyNodeType(Enum):
    CODE_BLOCK = "CODE_BLOCK"
    FUNCTION = "FUNCTION"
    CLASS = "CLASS"
    FILE = "FILE"
    PACKAGE = "PACKAGE"
    FOLDER = "FOLDER"


# ---- Attributes
@dataclass
class BaseAttributes:
    name: str
    text: Optional[str] = None
    function_calls: Optional[List[str]] = None
    file_node_id: Optional[str] = None

@dataclass
class CodeBlockAttributes(BaseAttributes):
    signature: str = None

@dataclass
class FunctionAttributes(BaseAttributes):
    signature: str = None

@dataclass
class ClassAttributes(BaseAttributes):
    signature: str = None
    inheritances: List[str] = None

@dataclass
class FileAttributes(BaseAttributes):
    path: Optional[str] = None

@dataclass
class DirectoryAttributes(BaseAttributes):
    path: str = None
    level: int = None


# ---- Node
@dataclass
class MyNode:
    type: MyNodeType
    attributes: BaseAttributes

@dataclass
class Relationship:
    source_id: str
    target_id: str
    type: str





class format_nodes:

    @staticmethod
    def format_code_block_node( node_text: str, scope: dict, function_calls: List[str], file_node_id: str ) -> MyNode:
        attributes = CodeBlockAttributes(
            name=scope["name"],
            signature=scope["signature"],
            text=node_text,
            function_calls=function_calls,
            file_node_id=file_node_id
        )
        return MyNode(type=NodeType.CODE_BLOCK, attributes=attributes)

    @staticmethod
    def format_function_node( node_text: str, scope: dict, function_calls: List[str], file_node_id: str ) -> MyNode:
        attributes = FunctionAttributes(
            name=scope["name"],
            signature=scope["signature"],
            text=node_text,
            function_calls=function_calls,
            file_node_id=file_node_id
        )
        return MyNode(type=NodeType.FUNCTION, attributes=attributes)

    @staticmethod
    def format_class_node( node_text: str, scope: dict, file_node_id: str, inheritances: List[str], function_calls: List[str] ) -> MyNode:
        attributes = ClassAttributes(
            name=scope["name"],
            signature=scope["signature"],
            text=node_text,
            file_node_id=file_node_id,
            inheritances=inheritances,
            function_calls=function_calls
        )
        return MyNode(type=NodeType.CLASS, attributes=attributes)

    @staticmethod
    def format_file_node(node_text: str, no_extension_path: str, function_calls: List[str] = None, file_node_id = None ) -> MyNode:
        attributes = FileAttributes(
            name=os.path.basename(no_extension_path),
            path=no_extension_path,
            text=node_text,
            file_node_id=file_node_id,
            function_calls=function_calls
        )
        return MyNode(type=NodeType.FILE, attributes=attributes)

    @staticmethod
    def format_directory_node( path: str, package: bool, level: int ) -> MyNode:
        attributes = DirectoryAttributes(
            name=os.path.basename(path),
            path=f"{path}/",
            level=level
        )
        return MyNode(
            type=NodeType.PACKAGE if package else NodeType.FOLDER,
            attributes=attributes
        )

### Parse file (BaseParser)

1. responsible for getting parser
2. calling SimpleDirectory & CodeHierarchyNodeParser from llama-index

In [43]:
from pathlib import Path
class BaseParser:
    def __init__(
        self,
        language: str,
        wildcard: str,
        extension: str,
        import_path_separator: str = ".",
        global_graph_info: GlobalGraphInfo = {"alias": {}},
    ):
        self.language = language
        self.wildcard = wildcard
        self.extension = extension
        self.import_path_separator = import_path_separator
        self.global_graph_info = global_graph_info


    def parse(self, file_path: str, root_path: str, global_graph_info: GlobalGraphInfo, level: int):
        path = Path(file_path)
        if not path.exists():
            print(f"File {file_path} does not exist.")
            raise FileNotFoundError

        documents = SimpleDirectoryReader(
            input_files=[path],
            file_metadata=lambda x: {"filepath": x},
        ).load_data()
        # Bug related to llama-index it's safer to remove non-ascii characters. Could be removed in the future
        documents[0].text = _remove_non_ascii(documents[0].text)
        code = CodeHierarchyNodeParser(
            language=self.language,
            chunk_min_characters=3,
            signature_identifiers=self.signature_identifiers,
        )
        try:
            split_nodes = code.get_nodes_from_documents(documents)  # Each node has a "text" attribute which is the code block, & a "node id"
        except TimeoutError:
            print(f"Timeout error: {file_path}")
            return [], [], {}

        node_list = []
        edges_list = []
        assignment_dict = {}
        # -- Make file node & relation
        file_node, file_relations = self.__process_node__(
            split_nodes.pop(0), file_path, "", global_graph_info, assignment_dict, documents[0], level  # split_nodes.pop(0) gives code structure of the full file
        )
        node_list.append(file_node)
        edges_list.extend(file_relations)
        # -- Make all other nodes within file
        for node in split_nodes:
            processed_node, relationships = self.__process_node__(
                node,
                file_path,
                file_node["attributes"]["node_id"],
                global_graph_info,
                assignment_dict,
                documents[0],
                level,
            )
            node_list.append(processed_node)
            edges_list.extend(relationships)

        post_processed_node_list = []
        for node in node_list:
            node = self._post_process_node(node, global_graph_info)
            post_processed_node_list.append(node)

        imports = self._get_imports(str(path), node_list[0]["attributes"]["node_id"], root_path)

        return post_processed_node_list, edges_list, imports



    def __process_node__(
        self,
        node: BaseNode,
        file_path: str,
        file_node_id: str,
        global_graph_info: GlobalGraphInfo,
        assignment_dict: dict,
        document: Document,
        level: int,
    ):
        relationships = []
        inclusive_scopes = node.metadata["inclusive_scopes"]
        scope = inclusive_scopes[-1] if inclusive_scopes else None
        type_node = scope["type"] if scope else "file"
        parent_level = level
        leaf = False

        function_calls = self._get_function_calls(node, assignment_dict)
        if type_node in self.scopes_names["function"]:
            core_node = format_nodes.format_function_node(node, scope, function_calls, file_node_id)
        elif type_node in self.scopes_names["class"]:
            inheritances = self._get_inheritances(node)
            core_node = format_nodes.format_class_node(node, scope, file_node_id, inheritances, function_calls)
        else:
            core_node = format_nodes.format_file_node(node, file_path, function_calls)

        parent_level = self._get_parent_level(node, global_graph_info, level)

        node_path = self.get_node_path(node)
        parent_path = ".".join(node_path.split(".")[:-1])

        parent_id = generate_node_id(parent_path, global_graph_info.entity_id)
        node_id = generate_node_id(node_path, global_graph_info.entity_id)
        if type_node in self.scopes_names["class"]:
            global_graph_info.inheritances[node_id] = inheritances

        relation_type = scope["type"] if scope else ""
        if self.relation_types_map.get(relation_type) is not None:
            relationships.append(
                {
                    "sourceId": parent_id,
                    "targetId": node_id,
                    "type": self.relation_types_map.get(relation_type, "UNKNOWN"),
                }
            )

        start_line, end_line = self._get_lines_range(
            document.text, node.metadata["start_byte"], node.metadata["end_byte"]
        )

        horizontal_attributes = {
            "start_line": start_line,
            "end_line": end_line,
            "path": node_path,
            "file_path": file_path,
            "level": parent_level + 1,
            "leaf": leaf,
            "node_id": node_id,
        }

        processed_node = {
            **core_node,
            "attributes": {
                **core_node["attributes"],
                **horizontal_attributes,
            },
        }

        global_graph_info.imports[node_path] = {
            "id": processed_node["attributes"]["node_id"],
            "type": processed_node["type"],
            "node": processed_node,
        }

        global_graph_info.visited_nodes[node.node_id] = {"level": parent_level + 1, "generated_id": node_id}
        return processed_node, relationships


    # Takes out the node_id from llama_index's CodeNodeHierarchyParser
    def _post_process_node(self, node: dict, global_graph_info: GlobalGraphInfo):
        text = node["attributes"]["text"]
        # Extract the node_id using re.search
        matches = re.findall(r"Code replaced for brevity\. See node_id ([0-9a-fA-F-]+)", text)
        for match in matches:
            extracted_node_id = match
            # Get the mapped_generated_id using the extracted node_id
            mapped_generated_id = global_graph_info.visited_nodes.get(extracted_node_id, {}).get("generated_id")
            if mapped_generated_id is not None:
                # Replace the extracted node_id with the mapped_generated_id
                updated_text = re.sub(
                    rf"Code replaced for brevity\. See node_id {extracted_node_id}",
                    f"Code replaced for brevity. See node_id {mapped_generated_id}",
                    text,
                )
                text = updated_text
        node["attributes"]["text"] = text
        return node


    def _get_lines_range(self, file_contents, start_byte, end_byte):
        start_line = file_contents.count("\n", 0, start_byte) + 1
        end_line = file_contents.count("\n", 0, end_byte) + 1

        return (start_line, end_line)



    def get_node_path(self, node: BaseNode):
        file_path = node.metadata["filepath"]
        scopes = node.metadata["inclusive_scopes"]
        scopes_path = reduce(lambda x, y: x + "." + y["name"], scopes, "")
        no_extension_path = self.remove_extensions(file_path)
        node_path = no_extension_path.replace("/", ".")

        if len(scopes_path) > 0:
            return node_path + scopes_path
        return node_path

    def _get_parent_level(self, node: BaseNode, global_graph_info: GlobalGraphInfo, level: int):
        parent_level = level
        try:
            parent = node.parent_node
        except Exception:
            parent = None
        if parent:
            parent_level = global_graph_info.visited_nodes.get(parent.node_id, {}).get("level", level)
        return parent_level


### Python Parser

In [29]:
class Parsers:
    def __init__(self, global_graph_info: GlobalGraphInfo, root_path: str):
        self.python = PythonParser(global_graph_info)
        pass

    def get_parser(self, path: str):
        extension = path[path.rfind(".") :]
        if extension == ".py":
            return self.python
        return None

class PythonParser(BaseParser):
    def __init__(self, global_graph_info: GlobalGraphInfo):
        super().__init__("python", "*", ".py", ".", global_graph_info)

### Graph Constructor

In [40]:
from llama_index.core import SimpleDirectoryReader
from typing import Optional, Set, Tuple, List, Dict
from pathlib import Path



class GraphConstructor:
    global_graph_info: GlobalGraphInfo
    root: str
    skip_tests: bool
    parsers: Parsers
    max_workers: int = 50


    def __init__(self, entity_id: str, root: str, max_workers: Optional[int] = None):
        self.global_graph_info = GlobalGraphInfo(entity_id=entity_id)
        self.parsers = Parsers(self.global_graph_info, root)
        self.root = root
        self.skip_tests = True
        if max_workers is not None:
            self.max_workers = max_workers
            

    def build_graph(self):
        # process every node to create the graph structure
        print("Building graph...")
        start_time = time.time()

        nodes, relationships, imports = self.__scan_directory(self.root)

        # relate imports between file nodes
        relationships.extend(self._relate_imports(imports))
        # relate functions calls
        relationships.extend(self._relate_constructor_calls(nodes, imports))
        end_time = time.time()
        execution_time = end_time - start_time
        print(f"Execution time: {execution_time} seconds")
        return nodes, relationships




    
    """
        helpers
    """

    def __scan_directory(self, path: str, parent_id: Optional[str] = None, level: int = 0, visited: Optional[Set[str]] = None,
    ) -> Tuple[List[Dict], List[Dict], Dict]:

        if visited is None:
            visited = set()

        nodes, relationships, imports = [],[],[]

        if not os.path.exists(path):
            raise FileNotFoundError(f"Directory {path} not found")
        if path.endswith("tests") or path.endswith("test"):
            return nodes, relationships, imports
        if path in visited:
            return nodes, relationships, imports

        visited.add(path)
        # Check if the directory is a package, logic for python
        package = is_package(path)
        core_directory_node = format_nodes.format_directory_node(path, package, level)
        directory_node_id = generate_node_id(path, self.global_graph_info.entity_id)
        directory_node = {
            **core_directory_node,
            "attributes": {**core_directory_node["attributes"], "node_id": directory_node_id},
        }
        print (directory_node)
        nodes.append(directory_node)
        if parent_id is not None:
            # relationship only exists when we are recursing (DFS)
            relationships.append(
                {
                    "sourceId": parent_id,
                    "targetId": directory_node_id,
                    "type": "CONTAINS",
                }
            )
        try:
            entries = list(os.scandir(path))
        except PermissionError:
            print(f"Permission denied: {path}")
            return nodes, relationships, imports
        # Submit all entries to the executor
        with ThreadPoolExecutor(max_workers=min(self.max_workers, os.cpu_count() or 1)) as executor:
            future_to_entry = {executor.submit(self.__process_entry, entry, directory_node_id): entry for entry in entries}
            for future in as_completed(future_to_entry):
                try:
                    entry_nodes, entry_relationships, entry_imports, entry_visited = future.result()
                    nodes.extend(entry_nodes)
                    relationships.extend(entry_relationships)
                    imports.update(entry_imports)
                    visited.update(entry_visited)
                except Exception as exc:
                    entry = future_to_entry[future]
                    print(f"Generated an exception: {entry.path} -> {exc}")
                    traceback.print_exc()
        return nodes, relationships, imports

    
    def __process_entry(entry: Path, directory_node_id: str) -> Tuple[List[Dict], List[Dict], Dict, Set[str]]:
        local_nodes: List[Dict] = []
        local_relationships: List[Dict] = []
        local_imports: Dict = {}
        local_visited: Set[str] = set()
        if _skip_file(entry.name):
            return local_nodes, local_relationships, local_imports, local_visited

        if entry.is_file():
            return self.__process_entry_file(entry)
        elif entry.is_dir():
            return self.__process_entry_dir(entry)

        return local_nodes, local_relationships, local_imports, local_visited



    def __process_entry_file(self, entry: Path, directory_node_id: str):
        local_nodes: List[Dict] = []
        local_relationships: List[Dict] = []
        local_imports: Dict = {}
        local_visited: Set[str] = set()

        parser = self.parsers.get_parser(entry.name)
        if parser:
            entry_name = entry.name.split(parser.extension)[0]
            try:
                processed_nodes, relations, file_imports = parser.parse_file(entry.path, self.root, global_graph_info=self.global_graph_info, level=level)
            except Exception:
                print(f"Error parsing file {entry.path}")
                print(traceback.format_exc())
                return local_nodes, local_relationships, local_imports, local_visited
            if processed_nodes:
                file_root_node_id = processed_nodes[0]["attributes"]["node_id"]
                local_nodes.extend(processed_nodes)
                local_relationships.extend(relations)
                local_relationships.append(
                    {
                        "sourceId": directory_node_id,
                        "targetId": file_root_node_id,
                        "type": "CONTAINS",
                    }
                )
                local_imports.update(file_imports)

                global_import_key = (directory_path + entry_name).replace("/", ".")
                self.global_graph_info.imports[global_import_key] = {
                    "id": file_root_node_id,
                    "type": "FILE",
                    "node": processed_nodes[0],
                }
            else:
                self.global_graph_info.import_aliases.update(file_imports)
        else:
            # else make a file node
            try:
                with open(entry.path, "r", encoding="utf-8") as file:
                    text = file.read()
            except UnicodeDecodeError:
                print(f"Error reading file {entry.path}")
                return local_nodes, local_relationships, local_imports, local_visited

            path = str(entry.path).replace("/", ".")
            file_node = format_nodes.format_file_node(text, path, generate_node_id(path, self.global_graph_info.entity_id)){
                "type": "FILE",
                "attributes": {
                    "node_id": generate_node_id(path, self.global_graph_info.entity_id),
                    "text": text,
                },
            }
            local_nodes.append(file_node)
            local_relationships.append(
                {
                    "sourceId": directory_node_id,
                    "targetId": file_node["attributes"]["node_id"],
                    "type": "CONTAINS",
                }
            )
        return local_nodes, local_relationships, local_imports, local_visited



    def __process_entry_dir(self, entry: Path) -> Tuple[List[Dict], List[Dict], Dict, Set[str]]:
        local_nodes: List[Dict] = []
        local_relationships: List[Dict] = []
        local_imports: Dict = {}
        local_visited: Set[str] = set()

        if _skip_directory(entry.name):
            return local_nodes, local_relationships, local_imports, local_visited

        sub_nodes, sub_relationships, sub_imports = self.__scan_directory(
            entry.path, directory_node_id, level + 1, visited
        )
        local_nodes.extend(sub_nodes)
        local_relationships.extend(sub_relationships)
        local_imports.update(sub_imports)
        return local_nodes, local_relationships, local_imports, local_visited
