# Phase 1: Knowledge Graph Construction

In [9]:
# --- Cài đặt ---
# --- CÀI ĐẶT MÔI TRƯỜNG ---
# Cập nhật hệ thống và cài các công cụ cần thiết (chỉ chạy lần đầu)
!apt-get update -qq && apt-get install -y build-essential git > /dev/null 2>&1

# Cài đặt các thư viện Python cần thiết
!pip install -q -U \
    neo4j \
    tree-sitter \
    sentence-transformers \
    langchain \
    langchain-openai \
    langchain-community \
    tqdm \
    python-dotenv

# Cài đặt các gói ngữ pháp Tree-sitter đã biên dịch sẵn
!pip install -q \
    tree-sitter-python \
    tree-sitter-java \
    tree-sitter-javascript \
    tree-sitter-typescript \
    tree-sitter-cpp \
    tree-sitter-c

W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)


In [50]:
# --- Imports ---
import os
import glob
import json
import fnmatch
import requests 
import re
from tqdm.auto import tqdm
from typing import List, Dict, Tuple, Set, Optional
from langchain_openai import ChatOpenAI
from neo4j import GraphDatabase
from langchain_core.messages import HumanMessage

from sentence_transformers import SentenceTransformer
from tree_sitter import Language, Parser, Query, QueryCursor
from kaggle_secrets import UserSecretsClient 

# --- Configuration (Ví dụ) ---
class Config:
    REPO_ROOT_DIR = '/kaggle/input/chatbot/chatbotai_v1-main'
    NEO4J_URI_SECRET = "NEO4J_URI"
    NEO4J_USER_SECRET = "NEO4J_USER"
    NEO4J_PASSWORD_SECRET = "NEO4J_PASSWORD"
    OPENAI_API_KEY_SECRET = "OPENAI_API_KEY"
    OPENAI_MODEL = "gpt-4o"
    # OLLAMA_BASE_URL = "http://192.168.92.23:11434"
    # OLLAMA_MODEL = "qwen3-embedding:8b" 
    ENCODER_MODEL = 'all-MiniLM-L6-v2'
    BATCH_SIZE = 500 # Kích thước lô cho Neo4j ingest
    MAX_CODE_LENGTH = 2048 # Giới hạn độ dài code lưu trữ
    ENABLE_PROGRESS_BAR = True # Bật/tắt thanh tiến trình
    SKIP_FOLDERS = ['.*', '*venv*', 'node_modules', '__pycache__', 'build', 'dist', 'docs', 'tests', 'examples'] # Thư mục bỏ qua
    SKIP_FILES = ['*setup.py', '*config.js', '*.min.js', '*.css', '*.html', '*.md', '*.txt', '*.json', '*.yaml', '*.yml', 'LICENSE', '.*'] # File bỏ qua

In [22]:
try:
    from tree_sitter_python import language as python_lang_loader
    from tree_sitter_java import language as java_lang_loader
    from tree_sitter_javascript import language as js_lang_loader
    # import tree_sitter_typescript.typescript as ts
    # import tree_sitter_typescript.tsx as tsx

    print("Đang tải các đối tượng Ngôn ngữ...")

    PY_LANG = Language(python_lang_loader())
    JAVA_LANG = Language(java_lang_loader())
    JS_LANG = Language(js_lang_loader()) # Dùng cho cả .js và .jsx, .ts, .tsx
    # TS_LANG = Language(ts.language())
    # TSX_LANG = Language(tsx.language())

    # Dùng js cho cả ts và tsx
    # *** ĐÂY LÀ PHẦN SỬA LỖI QUAN TRỌNG ***
    LANGUAGE_MAP = {
    '.py': (PY_LANG, 'python'),
    '.java': (JAVA_LANG, 'java'),
    '.js': (JS_LANG, 'javascript'),
    '.jsx': (JS_LANG, 'javascript')
    # '.ts': (TS_LANG, 'typescript'),
    # '.tsx': (TSX_LANG, 'typescriptreact')
}

    print(f"✓ Đã tải {len(LANGUAGE_MAP)} ngôn ngữ/đuôi file.")

except Exception as e:
    print(f"⚠ Lỗi tải ngôn ngữ: {e}")
    raise e

Đang tải các đối tượng Ngôn ngữ...
✓ Đã tải 4 ngôn ngữ/đuôi file.


In [23]:
# --- Phần 2: Trình kết nối Neo4j (Database Connector) ---

class Neo4jGraphConstructor:
    """Constructor cho Neo4j Knowledge Graph"""

    def __init__(self, uri: str, user: str, password: str):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))
        print("✓ Connected to Neo4j AuraDB")

    def close(self):
        self.driver.close()
        print("✓ Disconnected from Neo4j")

    def run_cypher_query(self, query: str, params: Optional[Dict] = None) -> Optional[List]:
        with self.driver.session() as session:
            try:
                result = session.run(query, params)
                return [record for record in result]
            except Exception as e:
                print(f"⚠ Cypher error: {e}")
                return None

    def clear_database(self):
        """Xóa toàn bộ database (cẩn thận!)"""
        print("\n🗑️  Clearing database...")
        self.run_cypher_query("MATCH (n) DETACH DELETE n")
        print("✓ Database cleared")

    def create_constraints_and_indexes(self):
        """Tạo constraints và indexes theo schema cải thiện"""
        print("\n🔧 Creating constraints and indexes...")

        # Constraints (Bao gồm Placeholder, Definition, Attribute, Documentation)
        constraints = [
            "CREATE CONSTRAINT IF NOT EXISTS FOR (f:File) REQUIRE f.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (c:Class) REQUIRE c.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (fn:Function) REQUIRE fn.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (m:Method) REQUIRE m.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (p:Placeholder) REQUIRE p.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (d:GeneratedDescription) REQUIRE d.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (cd:ClassDefinition) REQUIRE cd.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (fd:FunctionDefinition) REQUIRE fd.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (md:MethodDefinition) REQUIRE md.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (a:Attribute) REQUIRE a.fqn IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (doc:Documentation) REQUIRE doc.fqn IS UNIQUE" # Thêm cho Documentation
        ]
        for constraint in constraints:
            self.run_cypher_query(constraint)

        # Full-text index (Bao gồm Placeholder)
        self.run_cypher_query("""
            CREATE FULLTEXT INDEX names IF NOT EXISTS
            FOR (n:Class|Function|Method|Attribute|Placeholder)
            ON EACH [n.name]
            OPTIONS { indexConfig: { `fulltext.analyzer`: 'standard' } }
        """)

        # Vector index 1 (cho GeneratedDescription)
        self.run_cypher_query("""
            CREATE VECTOR INDEX generated_descriptions IF NOT EXISTS
            FOR (n:GeneratedDescription)
            ON (n.embedding)
            OPTIONS { indexConfig: {
                `vector.dimensions`: 384,
                `vector.similarity_function`: 'cosine'
            }}
        """)

        # Vector index 2 (cho Documentation)
        self.run_cypher_query("""
            CREATE VECTOR INDEX documentation_embeddings IF NOT EXISTS
            FOR (n:Documentation)
            ON (n.embedding)
            OPTIONS { indexConfig: {
                `vector.dimensions`: 384,
                `vector.similarity_function`: 'cosine'
            }}
        """)

        print("✓ Constraints and indexes created")

    def ingest_data(self, nodes: List[Dict], relationships: List[Dict]):
        """Ingest nodes và relationships vào Neo4j theo lô"""
        print(f"\n📥 Ingesting data to Neo4j...")
        print(f"  - Nodes to ingest: {len(nodes)}")
        print(f"  - Relationships to ingest: {len(relationships)}")

        # Batch insert nodes
        self._batch_insert_nodes(nodes)
        # Batch insert relationships
        self._batch_insert_relationships(relationships)

        print("✓ Data ingestion complete")

    def _batch_insert_nodes(self, nodes: List[Dict]):
        """Batch insert nodes"""
        query = """
        UNWIND $nodes AS node_data
        MERGE (n {fqn: node_data.fqn})
        ON CREATE SET n += node_data, n:Node
        ON MATCH SET n += node_data, n:Node
        WITH n, node_data
        // Chỉ gọi addLabels nếu node_data.type tồn tại và hợp lệ
        WHERE node_data.type IS NOT NULL AND node_data.type <> ''
        CALL apoc.create.addLabels(n, [node_data.type]) YIELD node
        RETURN count(node)
        """
        total_inserted = 0
        iterable = range(0, len(nodes), Config.BATCH_SIZE)
        if Config.ENABLE_PROGRESS_BAR:
             iterable = tqdm(iterable, desc="  Ingesting Nodes", unit="batch")

        for i in iterable:
            batch = nodes[i:i + Config.BATCH_SIZE]
            result = self.run_cypher_query(query, params={'nodes': batch})
            if result:
                 total_inserted += result[0]['count(node)'] if result[0]['count(node)'] else 0
                 # Cập nhật progress bar nếu dùng tqdm
                 # if isinstance(iterable, tqdm): iterable.set_postfix({"inserted": total_inserted})


        print(f"  ✓ Processed {len(nodes)} nodes (actual inserts depend on MERGE)")


    def _batch_insert_relationships(self, relationships: List[Dict]):
        """Batch insert relationships"""
        if not relationships:
            print("  ✓ No relationships to insert.")
            return

        query = """
        UNWIND $relationships AS rel_data
        MATCH (source {fqn: rel_data.source_fqn})
        MATCH (target {fqn: rel_data.target_fqn})
        // Sử dụng MERGE thay vì CALL apoc.create.relationship để tránh trùng lặp
        // Cần đảm bảo rel_data.type là tên hợp lệ
        CALL apoc.merge.relationship(source, rel_data.type, {}, rel_data.properties, target) YIELD rel
        RETURN count(rel)
        """
        # Lưu ý: CALL apoc.merge.relationship yêu cầu APOC.
        # Nếu không có APOC, bạn cần tạo query riêng cho từng loại relationship:
        # MERGE (source)-[r:<REL_TYPE>]->(target) SET r += rel_data.properties

        total_inserted = 0
        iterable = range(0, len(relationships), Config.BATCH_SIZE)
        if Config.ENABLE_PROGRESS_BAR:
            iterable = tqdm(iterable, desc="  Ingesting Relationships", unit="batch")

        for i in iterable:
            batch = relationships[i:i + Config.BATCH_SIZE]
            result = self.run_cypher_query(query, params={'relationships': batch})
            if result:
                total_inserted += result[0]['count(rel)'] if result[0]['count(rel)'] else 0
                # if isinstance(iterable, tqdm): iterable.set_postfix({"inserted": total_inserted})


        print(f"  ✓ Processed {len(relationships)} relationships (actual inserts depend on MERGE)")


    def verify_graph(self):
        """Verify graph statistics"""
        print("\n📊 Graph Statistics:")
        # Node counts by type (lấy label đầu tiên ngoài :Node)
        node_query = """
        MATCH (n)
        WHERE size(labels(n)) > 1
        RETURN labels(n)[1] as type, count(n) as count
        ORDER BY count DESC
        """
        results = self.run_cypher_query(node_query)
        if results:
            print("\n  Node Types (excluding :Node):")
            for record in results:
                print(f"    {record['type']}: {record['count']}")
        # Relationship counts by type (Giữ nguyên)
        rel_query = "MATCH ()-[r]->() RETURN type(r) as type, count(r) as count ORDER BY count DESC"
        results = self.run_cypher_query(rel_query)
        if results:
            print("\n  Relationship Types:")
            for record in results:
                print(f"    {record['type']}: {record['count']}")
        # Total counts (Giữ nguyên)
        total_query = "MATCH (n) WITH count(n) as node_count MATCH ()-[r]->() RETURN node_count, count(r) as rel_count"
        results = self.run_cypher_query(total_query)
        if results:
            record = results[0]
            print(f"\n  Total Nodes: {record['node_count']}")
            print(f"  Total Relationships: {record['rel_count']}")

In [43]:
# --- Phần 4: Repository Parser ---

class RepositoryParser:
    """Parse toàn bộ repository, lọc bỏ file/folder không cần thiết"""

    def __init__(self, repo_path: str, language_map: Dict):
        self.repo_path = os.path.abspath(repo_path) # Chuẩn hóa đường dẫn
        self.language_map = language_map
        self.all_nodes: List[Dict] = []
        self.all_relationships: List[Dict] = []

    def parse_repository(self) -> Tuple[List[Dict], List[Dict]]:
        """Parse tất cả files trong repo"""
        print(f"\n📂 Parsing repository: {self.repo_path}")

        files_to_parse = self._collect_files_with_os_walk() # Sử dụng os.walk
        print(f"Found {len(files_to_parse)} files to parse after filtering")

        iterable = files_to_parse
        if Config.ENABLE_PROGRESS_BAR:
            iterable = tqdm(files_to_parse, desc="Parsing files", unit="file")

        for file_path in iterable:
            self._parse_file(file_path)

        print(f"\n✓ Parsing complete!")
        print(f"  - Total nodes extracted: {len(self.all_nodes)}")
        print(f"  - Total relationships extracted: {len(self.all_relationships)}")

        return self.all_nodes, self.all_relationships

    def _collect_files_with_os_walk(self) -> List[str]:
        """Thu thập files dùng os.walk và lọc bỏ"""
        collected_files = []
        exclude_dir_patterns = Config.SKIP_FOLDERS
        exclude_file_patterns = Config.SKIP_FILES

        for root, dirs, files in os.walk(self.repo_path, topdown=True):
            # Lọc bỏ thư mục không mong muốn
            # ':' modifies dirs in-place
            dirs[:] = [d for d in dirs if not any(fnmatch.fnmatch(d, pattern) for pattern in exclude_dir_patterns)]

            for filename in files:
                # Lọc bỏ file không mong muốn
                if any(fnmatch.fnmatch(filename, pattern) for pattern in exclude_file_patterns):
                    continue

                # Chỉ lấy file có đuôi hỗ trợ
                file_ext = os.path.splitext(filename)[1]
                if file_ext in self.language_map:
                    collected_files.append(os.path.join(root, filename))

        return sorted(collected_files)


    def _parse_file(self, file_path: str):
        """Parse một file - SỬA LỖI: Đọc file dưới dạng bytes ('rb')"""
        ext = os.path.splitext(file_path)[1]
        lang_tuple = self.language_map.get(ext)
        if not lang_tuple: return

        try:
            # SỬA LỖI: Mở file ở chế độ read-bytes ('rb')
            # và không chỉ định encoding
            with open(file_path, 'rb') as f:
                code_content_bytes = f.read()

            # Giới hạn kích thước file
            if len(code_content_bytes) > 500 * 1024: # Giới hạn 500KB
                 print(f"  🟡 Skipping large file: {file_path} ({len(code_content_bytes)/1024:.1f} KB)")
                 return

            if not code_content_bytes.strip(): # Bỏ qua file trống
                 print(f"  🟡 Skipping empty file: {file_path}")
                 return

            # SỬA LỖI: Truyền 'code_content_bytes' (là bytes)
            # vào constructor của parser
            parser = EnhancedTreeSitterParser(file_path, code_content_bytes, lang_tuple)
            nodes, rels = parser.parse()

            self.all_nodes.extend(nodes)
            self.all_relationships.extend(rels)

        except UnicodeDecodeError:
             # Lỗi này bây giờ không nên xảy ra ở đây,
             # nhưng có thể xảy ra trong get_node_text
             print(f"  ⚠ UnicodeDecodeError parsing {file_path}. Skipping.")
        except Exception as e:
            print(f"  ⚠ Error parsing {file_path}: {e}")

In [52]:
# -*- coding: utf-8 -*-
import os, fnmatch, re
from typing import Dict, List, Tuple, Optional, Set
from tree_sitter import Parser, Query, QueryCursor

# Assumed available in your environment
# from your_config import Config

class EnhancedTreeSitterParser:
    """
    Parser theo paper, đã FIX:
    - Đọc nội dung file dạng bytes để tránh lỗi encoding.
    - Docstring: query nới lỏng + xác thực đúng vị trí, làm sạch theo ngôn ngữ.
    - CALLS: truy vấn đúng node theo grammar từng ngôn ngữ (identifier/attribute/member_expression),
      chuẩn hóa tên hàm được gọi để tránh cắt chuỗi sai (ví dụ 'lude_router(ch').
    """

    def __init__(self, file_path: str, code_content_bytes: bytes, lang_tuple: Tuple):
        self.file_path = file_path
        self.code_content_bytes = code_content_bytes
        self.language_obj = lang_tuple[0]
        self.language_name = lang_tuple[1]

        self.parser = Parser()
        self.parser.language = self.language_obj

        self.nodes: List[Dict] = []
        self.relationships: List[Dict] = []
        self._processed_fqns: Set[str] = set()
        self.queries: Dict = self._init_queries()

    def _init_queries(self) -> Dict:
        """
        Khởi tạo query hợp lệ cho từng ngôn ngữ.
        Lưu ý: KHÔNG dùng 'function: @call_name' trống — phải nêu rõ node type.
        """
        return {
            'python': {
                'class': "(class_definition name: (identifier) @class_name)",
                'function': "(function_definition name: (identifier) @function_name)",
                'method': "(class_definition body: (block (function_definition name: (identifier) @method_name) @method_node) @class_node)",
                'attribute': "(class_definition body: (block (expression_statement (assignment left: (identifier) @attr_name)))) @class_node",

                'import_from': "(import_from_statement module_name: (dotted_name) @module)",
                'import_simple': "(import_statement name: (dotted_name) @module)",

                # CALLS: 2 dạng — gọi trực tiếp và qua thuộc tính
                'call_py_identifier': "(call function: (identifier) @call_name)",
                'call_py_attribute': "(call function: (attribute attribute: (identifier) @call_attr_name))",

                # Docstring: string đầu block
                'docstring': """
                    (function_definition body: (block (expression_statement (string) @docstring))) @func
                    (class_definition body: (block (expression_statement (string) @docstring))) @class
                """
            },

            'java': {
                'class': "(class_declaration name: (identifier) @class_name)",
                'function': "(method_declaration name: (identifier) @function_name)",
                'method': "(class_declaration body: (class_body (method_declaration name: (identifier) @method_name) @method_node) @class_node)",
                'attribute': "(class_declaration body: (class_body (field_declaration declarator: (variable_declarator name: (identifier) @attr_name)))) @class_node",

                'import_from': "(import_declaration (scoped_identifier) @module)",

                # CALLS: Java dùng method_invocation
                'call_java_name': "(method_invocation name: (identifier) @call_name)",
                'call_java_scoped': "(method_invocation name: (scoped_identifier) @call_scoped)",

                # Docstring/JavaDoc: comment ngay trước khai báo
                'docstring': """
                    (method_declaration (comment) @docstring) @func
                    (class_declaration (comment) @docstring) @class
                """
            },

            'javascript': {
                'class': "(class_declaration name: (identifier) @class_name)",
                'function': "(function_declaration name: (identifier) @function_name)",
                'method': "(class_declaration body: (class_body (method_definition name: (property_identifier) @method_name) @method_node) @class_node)",
                'attribute': "(class_declaration body: (class_body (public_field_definition name: (property_identifier) @attr_name))) @class_node",

                'import_from': "(import_statement source: (string) @module)",

                # CALLS: 2 dạng — gọi trực tiếp và qua member_expression
                'call_js_identifier': "(call_expression function: (identifier) @call_name)",
                'call_js_member': "(call_expression function: (member_expression) @call_member)",

                # JSDoc: comment ngay trước function/class/method
                'docstring': """
                    (function_declaration (comment) @docstring) @func
                    (class_declaration (comment) @docstring) @class
                    (method_definition (comment) @docstring) @func
                """
            },

            # Nếu cần TypeScript/TSX, bạn có thể copy block JS sang và đổi language_name tương ứng
        }

    def get_node_text(self, node) -> str:
        """Trích xuất text an toàn từ bytes."""
        snippet_bytes = self.code_content_bytes[node.start_byte:node.end_byte]
        return snippet_bytes.decode("utf-8", errors="ignore")

    def parse(self) -> Tuple[List[Dict], List[Dict]]:
        tree = self.parser.parse(self.code_content_bytes)
        root_node = tree.root_node

        self._create_file_node()

        # Phase 1: Definitions
        self._run_query('class', root_node)
        self._run_query('function', root_node)
        self._run_query('method', root_node)
        self._run_query('attribute', root_node)

        # Phase 2: Docstrings
        self._run_query('docstring', root_node)

        # Phase 3: Dependencies
        self._run_query('import_from', root_node)
        self._run_query('import_simple', root_node)

        # Phase 4: Calls — chạy tất cả query bắt đầu bằng 'call_'
        for qname in list(self.queries.get(self.language_name, {}).keys()):
            if qname.startswith('call_'):
                self._run_query(qname, root_node)

        return self.nodes, self.relationships

    def _create_file_node(self):
        if self.file_path not in self._processed_fqns:
            self.nodes.append({
                'fqn': self.file_path, 'type': 'File',
                'name': os.path.basename(self.file_path), 'path': self.file_path
            })
            self._processed_fqns.add(self.file_path)

    def _run_query(self, query_type: str, root_node):
        query_str = self.queries.get(self.language_name, {}).get(query_type)
        if not query_str:
            return
        try:
            query = Query(self.language_obj, query_str)
            cursor = QueryCursor(query)
            captures = cursor.captures(root_node)
            captures_by_name = captures
        except Exception as e:
            # In ra file + query để debug nhanh
            print(f"Lỗi tree-sitter query (File: {self.file_path}, Query: {query_type}): {e}")
            return

        handlers = {
            'class': self._handle_class,
            'function': self._handle_function,
            'method': self._handle_method,
            'attribute': self._handle_attribute,
            'import_from': self._handle_import,
            'import_simple': self._handle_import,

            # CALLS
            'call_py_identifier': self._handle_call,
            'call_py_attribute': self._handle_call,
            'call_java_name': self._handle_call,
            'call_java_scoped': self._handle_call,
            'call_js_identifier': self._handle_call,
            'call_js_member': self._handle_call,

            # Docstrings
            'docstring': self._handle_docstring,
        }
        handler = handlers.get(query_type)
        if handler:
            handler(captures_by_name)

    # ==========================
    # Definition handlers
    # ==========================
    def _handle_class(self, caps: Dict):
        class_name_nodes = caps.get('class_name', [])
        for node in class_name_nodes:
            class_name = self.get_node_text(node)
            class_fqn = f"{self.file_path}::{class_name}"
            if class_fqn in self._processed_fqns: continue

            self.nodes.append({'fqn': class_fqn, 'type': 'Class', 'name': class_name})
            self.relationships.append({'source_fqn': self.file_path, 'target_fqn': class_fqn, 'type': 'DEFINES_CLASS', 'properties': {}})

            class_def_fqn = f"DEF::{class_fqn}"
            class_node_def = node.parent
            self.nodes.append({'fqn': class_def_fqn, 'type': 'ClassDefinition', 'code': self.get_node_text(class_node_def)[:Config.MAX_CODE_LENGTH]})
            self.relationships.append({'source_fqn': class_fqn, 'target_fqn': class_def_fqn, 'type': 'HAS_DEFINITION', 'properties': {}})

            self._processed_fqns.add(class_fqn)
            self._processed_fqns.add(class_def_fqn)

    def _handle_function(self, caps: Dict):
        function_name_nodes = caps.get('function_name', [])
        for node in function_name_nodes:
            if self._is_inside_class(node): continue
            func_name = self.get_node_text(node)
            func_fqn = f"{self.file_path}::{func_name}"
            if func_fqn in self._processed_fqns: continue

            self.nodes.append({'fqn': func_fqn, 'type': 'Function', 'name': func_name})
            self.relationships.append({'source_fqn': self.file_path, 'target_fqn': func_fqn, 'type': 'DEFINES_FUNCTION', 'properties': {}})

            func_def_fqn = f"DEF::{func_fqn}"
            func_node_def = node.parent
            self.nodes.append({'fqn': func_def_fqn, 'type': 'FunctionDefinition', 'code': self.get_node_text(func_node_def)[:Config.MAX_CODE_LENGTH]})
            self.relationships.append({'source_fqn': func_fqn, 'target_fqn': func_def_fqn, 'type': 'HAS_DEFINITION', 'properties': {}})

            self._processed_fqns.add(func_fqn)
            self._processed_fqns.add(func_def_fqn)

    def _handle_method(self, caps: Dict):
        method_nodes = caps.get('method_name', [])
        class_nodes = caps.get('class_node', [])
        if not method_nodes or not class_nodes: return

        for method_node in method_nodes:
            containing_class = None
            for class_node in class_nodes:
                if class_node.start_byte <= method_node.start_byte <= class_node.end_byte:
                    containing_class = class_node
                    break
            if not containing_class: continue

            class_name_node = containing_class.child_by_field_name('name')
            if not class_name_node: continue

            class_name = self.get_node_text(class_name_node)
            class_fqn = f"{self.file_path}::{class_name}"

            method_name = self.get_node_text(method_node)
            method_fqn = f"{class_fqn}::{method_name}"
            if method_fqn in self._processed_fqns: continue

            self.nodes.append({'fqn': method_fqn, 'type': 'Method', 'name': method_name})
            self.relationships.append({'source_fqn': class_fqn, 'target_fqn': method_fqn, 'type': 'HAS_METHOD', 'properties': {}})

            method_def_fqn = f"DEF::{method_fqn}"
            method_node_def = method_node.parent
            self.nodes.append({'fqn': method_def_fqn, 'type': 'MethodDefinition', 'code': self.get_node_text(method_node_def)[:Config.MAX_CODE_LENGTH]})
            self.relationships.append({'source_fqn': method_fqn, 'target_fqn': method_def_fqn, 'type': 'HAS_DEFINITION', 'properties': {}})

            self._processed_fqns.add(method_fqn)
            self._processed_fqns.add(method_def_fqn)

    def _handle_attribute(self, caps: Dict):
        attr_nodes = caps.get('attr_name', [])
        class_nodes = caps.get('class_node', [])
        if not attr_nodes or not class_nodes: return

        for attr_node in attr_nodes:
            containing_class = None
            for class_node in class_nodes:
                if class_node.start_byte <= attr_node.start_byte <= class_node.end_byte:
                    containing_class = class_node
                    break
            if not containing_class: continue

            class_name_node = containing_class.child_by_field_name('name')
            if not class_name_node: continue

            class_name = self.get_node_text(class_name_node)
            class_fqn = f"{self.file_path}::{class_name}"

            attr_name = self.get_node_text(attr_node)
            attr_fqn = f"{class_fqn}::{attr_name}"
            if attr_fqn in self._processed_fqns: continue

            self.nodes.append({'fqn': attr_fqn, 'type': 'Attribute', 'name': attr_name})
            self.relationships.append({'source_fqn': class_fqn, 'target_fqn': attr_fqn, 'type': 'HAS_ATTRIBUTE', 'properties': {}})

            self._processed_fqns.add(attr_fqn)

    def _handle_import(self, caps: Dict):
        module_nodes = caps.get('module', [])
        for node in module_nodes:
            module_name = self.get_node_text(node).strip('\'"')
            module_fqn = f"MODULE::{module_name}"
            if module_fqn not in self._processed_fqns:
                self.nodes.append({'fqn': module_fqn, 'type': 'Placeholder', 'name': module_name, 'category': 'module'})
                self._processed_fqns.add(module_fqn)
            self.relationships.append({'source_fqn': self.file_path, 'target_fqn': module_fqn, 'type': 'IMPORTS', 'properties': {'module_name': module_name}})

    # ==========================
    # CALLS handler (multi-language)
    # ==========================
    def _handle_call(self, caps: Dict):
        """
        Chuẩn hóa tên hàm được gọi:
        - Python: identifier → lấy trực tiếp; attribute → lấy phần cuối sau dấu '.'
        - JS: identifier → trực tiếp; member_expression → lấy phần cuối (nếu có).
        - Java: method_invocation name → lấy identifier/scoped_identifier cuối.
        """
        # Gom tất cả capture key có thể chứa tên gọi
        nodes = []
        for key in ['call_name', 'call_attr_name', 'call_member', 'call_scoped']:
            nodes.extend(caps.get(key, []))

        for node in nodes:
            raw_text = self.get_node_text(node).strip()

            # Python
            if self.language_name == 'python':
                if node.type == 'identifier':
                    call_name = raw_text
                elif node.type == 'attribute':
                    # attribute node trả về toàn cụm "obj.method" → lấy phần cuối
                    call_name = raw_text.split('.')[-1]
                else:
                    call_name = raw_text

            # JavaScript
            elif self.language_name == 'javascript':
                if node.type == 'identifier':
                    call_name = raw_text
                elif node.type == 'member_expression':
                    # Có thể dạng "obj.method" hoặc phức tạp hơn
                    # Thường text chứa dấu '.' → lấy phần cuối
                    call_name = raw_text.split('.')[-1]
                else:
                    call_name = raw_text

            # Java
            elif self.language_name == 'java':
                # method_invocation name: identifier/scoped_identifier
                # Nếu scoped dạng "pkg.Class.method" → lấy phần cuối
                call_name = raw_text.split('.')[-1]

            else:
                call_name = raw_text

            caller_fqn = self._get_caller_context(node)
            callee_fqn = f"CALL::{call_name}"

            if callee_fqn not in self._processed_fqns:
                self.nodes.append({'fqn': callee_fqn, 'type': 'Placeholder', 'name': call_name, 'category': 'call'})
                self._processed_fqns.add(callee_fqn)

            self.relationships.append({'source_fqn': caller_fqn, 'target_fqn': callee_fqn, 'type': 'CALLS', 'properties': {'call_name': call_name}})

    # ==========================
    # Docstring handlers
    # ==========================
    def _clean_docstring(self, text: str) -> str:
        """Dọn docstring/comment theo ngôn ngữ bằng regex/strip phù hợp."""
        if self.language_name == 'python':
            # Bóc phần giữa triple quotes hoặc single quotes
            m = re.match(r'^[rRuU]*([\'"]{3})(.*?)([\'"]{3})$', text, re.S)
            if m:
                return m.group(2).strip()
            m = re.match(r'^[rRuU]*([\'"])(.*?)([\'"])$', text, re.S)
            if m:
                return m.group(2).strip()
            return text.strip()

        elif self.language_name == 'java':
            t = text.strip()
            if t.startswith("/**"):
                t = t[3:]
            if t.endswith("*/"):
                t = t[:-2]
            lines = [line.lstrip('*').strip() for line in t.splitlines()]
            return "\n".join(lines).strip()

        elif self.language_name in ['javascript', 'typescript', 'typescriptreact']:
            t = text.strip()
            if t.startswith("/**"):
                t = t[3:]
            if t.endswith("*/"):
                t = t[:-2]
            if t.startswith("//"):
                t = t[2:]
            lines = [line.lstrip('*').strip() for line in t.splitlines()]
            return "\n".join(lines).strip()

        return text.strip()

    def _process_docstring_node(self, parent_node, doc_node):
        name_node = parent_node.child_by_field_name('name')
        if not name_node: return

        parent_name = self.get_node_text(name_node)
        parent_fqn = self._build_fqn_from_node(parent_node, parent_name)
        if not parent_fqn: return

        def_fqn = f"DEF::{parent_fqn}"
        doc_fqn = f"DOC::{parent_fqn}"

        if def_fqn not in self._processed_fqns:
            return
        if doc_fqn in self._processed_fqns:
            return

        doc_text = self._clean_docstring(self.get_node_text(doc_node))
        if len(doc_text) < 5:
            return

        self.nodes.append({'fqn': doc_fqn, 'type': 'Documentation', 'text': doc_text})
        self._processed_fqns.add(doc_fqn)
        self.relationships.append({'source_fqn': def_fqn, 'target_fqn': doc_fqn, 'type': 'HAS_DOCUMENTATION', 'properties': {}})

    def _handle_docstring(self, caps: Dict):
        doc_nodes = caps.get('docstring', [])
        func_parents = caps.get('func', [])
        class_parents = caps.get('class', [])
        parent_nodes = func_parents + class_parents
        if not doc_nodes or not parent_nodes: return

        # Gán docstring đúng parent bằng khoảng byte
        docs_by_parent: Dict[int, List] = {}
        for dnode in doc_nodes:
            for p in parent_nodes:
                if p.start_byte <= dnode.start_byte <= p.end_byte:
                    docs_by_parent.setdefault(p.id, []).append(dnode)
                    break

        for parent in parent_nodes:
            candidates = docs_by_parent.get(parent.id, [])
            if not candidates: continue

            valid = None
            if self.language_name == 'python':
                body = parent.child_by_field_name('body')
                if not body or body.type != 'block': continue
                # statement đầu tiên không phải comment
                first_stmt = None
                for ch in body.children:
                    if ch.type not in ('comment', 'line_comment', 'block_comment'):
                        first_stmt = ch
                        break
                if not first_stmt: continue
                if first_stmt.type == 'expression_statement' and first_stmt.child_count > 0:
                    if first_stmt.children[0].type == 'string':
                        valid = first_stmt.children[0]

            elif self.language_name in ['java', 'javascript']:
                body = parent.child_by_field_name('body')
                body_start = body.start_byte if body else parent.end_byte
                for d in candidates:
                    if d.end_byte <= body_start:
                        valid = d
                        break

            if valid:
                self._process_docstring_node(parent, valid)

    # ==========================
    # Helpers
    # ==========================
    def _is_inside_class(self, node) -> bool:
        current = node.parent
        while current:
            if current.type in ('class_definition', 'class_declaration'):
                return True
            current = current.parent
        return False

    def _get_caller_context(self, node) -> str:
        current = node.parent
        while current:
            ntype = current.type
            name_node = current.child_by_field_name('name')
            if name_node and ntype in ('function_definition', 'method_declaration', 'function_declaration', 'method_definition'):
                func_name = self.get_node_text(name_node)
                class_parent = current.parent
                while class_parent:
                    if class_parent.type in ('class_definition', 'class_declaration'):
                        cls_name_node = class_parent.child_by_field_name('name')
                        if cls_name_node:
                            cls = self.get_node_text(cls_name_node)
                            return f"{self.file_path}::{cls}::{func_name}"
                        break
                    class_parent = class_parent.parent
                return f"{self.file_path}::{func_name}"
            if ntype in ('class_definition', 'class_declaration') and name_node:
                cls = self.get_node_text(name_node)
                return f"{self.file_path}::{cls}"
            current = current.parent
        return self.file_path

    def _build_fqn_from_node(self, node, name: str) -> Optional[str]:
        if node.type in ('function_definition', 'method_declaration', 'function_declaration', 'method_definition'):
            current = node.parent
            while current:
                if current.type in ('class_definition', 'class_declaration'):
                    cls_name_node = current.child_by_field_name('name')
                    if cls_name_node:
                        cls = self.get_node_text(cls_name_node)
                        return f"{self.file_path}::{cls}::{name}"
                    break
                current = current.parent
            return f"{self.file_path}::{name}"
        elif node.type in ('class_definition', 'class_declaration'):
            return f"{self.file_path}::{name}"
        return None


In [47]:
# --- Phần 5: Code Enricher (LLM + Embeddings) ---
# [ĐÃ CẬP NHẬT: Embed cả Docstring]


class CodeEnricher:
    """Làm giàu nodes bằng LLM descriptions và embeddings cho Docstrings"""

    def __init__(self, llm_client, encoder: SentenceTransformer):
        self.llm = llm_client
        self.encoder = encoder
        print("✓ CodeEnricher initialized")

    def enrich_nodes(self, nodes: List[Dict], relationships: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
        """
        Pipeline làm giàu: Embed Docstrings, Generate LLM Descriptions + Embeddings
        """
        print("\n🔧 Enriching nodes...")

        # Tạo map để truy cập nhanh node data bằng FQN
        node_map = {n['fqn']: n for n in nodes}
        # Tạo map để tìm Definition từ Documentation FQN
        doc_fqn_to_def_fqn = {r['target_fqn']: r['source_fqn']
                              for r in relationships if r['type'] == 'HAS_DOCUMENTATION'}

        doc_nodes_to_embed = []
        def_nodes_for_llm = []

        # Phân loại các node cần xử lý
        for node in nodes:
             if node['type'] == 'Documentation' and 'embedding' not in node and node.get('text'):
                  doc_nodes_to_embed.append(node)
             elif node['type'].endswith('Definition'):
                  # Chỉ enrich nếu chưa có description tương ứng
                  desc_fqn = f"DESC::{node['fqn']}"
                  if desc_fqn not in node_map:
                       def_nodes_for_llm.append(node)

        print(f"Found {len(doc_nodes_to_embed)} documentation nodes to embed.")
        print(f"Found {len(def_nodes_for_llm)} definition nodes for LLM description.")

        # 1. Embed Docstrings
        doc_embed_count = self._embed_documentation(doc_nodes_to_embed)

        # 2. Generate LLM Descriptions + Embeddings
        llm_desc_count = self._generate_llm_descriptions(def_nodes_for_llm, node_map, doc_fqn_to_def_fqn, nodes, relationships)

        print(f"✓ Enrichment complete! Embedded {doc_embed_count} docstrings, added {llm_desc_count} LLM descriptions.")
        return nodes, relationships

    def _embed_documentation(self, doc_nodes: List[Dict]) -> int:
        """Tạo embedding cho các node Documentation"""
        count = 0
        iterable = doc_nodes
        if Config.ENABLE_PROGRESS_BAR:
            iterable = tqdm(doc_nodes, desc="  Embedding Docstrings", unit="doc")

        texts_to_embed = [node['text'] for node in iterable if node.get('text')]
        if not texts_to_embed: return 0

        try:
             # Embed hàng loạt
             embeddings = self.encoder.encode(texts_to_embed, batch_size=32, show_progress_bar=False) # Tắt progress bar của encoder

             # Gán embedding lại cho các node
             idx = 0
             for node in iterable:
                  if node.get('text'):
                       if idx < len(embeddings):
                            node['embedding'] = embeddings[idx].tolist()
                            count += 1
                            idx += 1
        except Exception as e:
             print(f"⚠ Error embedding documentation: {e}")

        return count

    def _generate_llm_descriptions(self, def_nodes: List[Dict], node_map: Dict, doc_fqn_to_def_fqn: Dict, nodes_list: List[Dict], relationships_list: List[Dict]) -> int:
        """Tạo mô tả LLM và embedding cho các node Definition"""
        count = 0
        iterable = def_nodes
        if Config.ENABLE_PROGRESS_BAR:
            iterable = tqdm(def_nodes, desc="  Generating LLM Descriptions", unit="def")

        for def_node in iterable:
            def_fqn = def_node['fqn']
            code = def_node.get('code', '')
            if not code: continue

            # Tìm docstring từ node Documentation liên kết
            docstring_text = ""
            for doc_fqn, linked_def_fqn in doc_fqn_to_def_fqn.items():
                if linked_def_fqn == def_fqn:
                    doc_node = node_map.get(doc_fqn)
                    if doc_node and 'text' in doc_node:
                        docstring_text = doc_node['text']
                        break

            original_fqn = def_fqn.replace("DEF::", "")
            original_node_type = node_map.get(original_fqn, {}).get('type', "Code Snippet")

            # Generate description
            description = self._generate_description(code, docstring_text, original_node_type)
            if not description: continue

            # Generate embedding
            embedding = self._generate_embedding(description)

            # Tạo Node GeneratedDescription
            desc_fqn = f"DESC::{def_fqn}"
            if not any(n['fqn'] == desc_fqn for n in nodes_list):
                 desc_node = {'fqn': desc_fqn, 'type': 'GeneratedDescription',
                              'description': description, 'embedding': embedding}
                 nodes_list.append(desc_node)

                 # Tạo Relationship HAS_DESCRIPTION
                 if not any(r['source_fqn'] == def_fqn and r['type'] == 'HAS_DESCRIPTION' for r in relationships_list):
                     relationships_list.append({'source_fqn': def_fqn, 'target_fqn': desc_fqn,
                                                'type': 'HAS_DESCRIPTION', 'properties': {}})
                     count += 1
        return count


    def _generate_description(self, code: str, docstring: str, node_type: str) -> str:
        """Generate description với LLM"""
        prompt = f"""Analyze the following {node_type}. Provide a concise (1-2 sentences) description explaining its primary function or purpose. Focus on *what* it does.

DOCSTRING:
{docstring if docstring else "N/A"}

CODE:
{code[:Config.MAX_CODE_LENGTH]}

CONCISE DESCRIPTION:"""
        response = self.llm.invoke([HumanMessage(content=prompt)]) # Giảm max_tokens
        description = response.content
        return description.strip().replace("\n", " ") # Xóa xuống dòng

    def _generate_embedding(self, text: str) -> List[float]:
        """Generate embedding với Sentence Transformer"""
        try:
            embedding = self.encoder.encode(text, convert_to_numpy=True)
            return embedding.tolist()
        except Exception as e:
             print(f"⚠ Error generating embedding for text '{text[:50]}...': {e}")
             return []

In [53]:
# --- Phần 6: Chạy Pipeline Giai đoạn 1 (Đã Cải Thiện) ---

def main():
    """Hàm chính để chạy toàn bộ pipeline Giai đoạn 1"""
    start_time = time.time() # Bắt đầu bấm giờ

    try:
        # 1. Khởi tạo
        print("--- Initializing ---")
        secrets = UserSecretsClient()
        neo4j_uri = secrets.get_secret(Config.NEO4J_URI_SECRET)
        neo4j_user = secrets.get_secret(Config.NEO4J_USER_SECRET)
        neo4j_password = secrets.get_secret(Config.NEO4J_PASSWORD_SECRET)
        openai_api_key = secrets.get_secret(Config.OPENAI_API_KEY_SECRET)
        graph_constructor = Neo4jGraphConstructor(neo4j_uri, neo4j_user, neo4j_password)
        llm = ChatOpenAI(openai_api_key=openai_api_key, model=Config.OPENAI_MODEL)
        encoder = SentenceTransformer(Config.ENCODER_MODEL)
        enricher = CodeEnricher(llm, encoder)
        repo_parser = RepositoryParser(Config.REPO_ROOT_DIR, LANGUAGE_MAP) # Truyền LANGUAGE_MAP

        # 2. Chuẩn bị DB
        graph_constructor.clear_database()
        graph_constructor.create_constraints_and_indexes()

        # 3. Parse Repository
        all_nodes, all_relationships = repo_parser.parse_repository()

        # 4. Enrich Nodes
        all_nodes, all_relationships = enricher.enrich_nodes(all_nodes, all_relationships)

        # 5. Ingest Data
        graph_constructor.ingest_data(all_nodes, all_relationships)

        # 6. Verify Graph (Optional but recommended)
        graph_constructor.verify_graph()

        # 7. Close Connection
        graph_constructor.close()

        end_time = time.time() # Kết thúc bấm giờ
        print(f"\n--- PHASE 1 COMPLETE (Improved) ---")
        print(f"Total time: {end_time - start_time:.2f} seconds")
        print("Knowledge Graph is built, enriched, and indexed.")

    except Exception as e:
        print(f"\n--- 💥 PIPELINE FAILED ---")
        print(f"Error: {e}")
        # Đảm bảo đóng kết nối nếu có lỗi
        if 'graph_constructor' in locals() and graph_constructor.driver:
            try:
                graph_constructor.close()
            except: pass # Ignore errors during closing on failure

# --- Chạy hàm main ---
import time


if __name__ == "__main__":
    main()

--- Initializing ---
✓ Connected to Neo4j AuraDB
✓ CodeEnricher initialized

🗑️  Clearing database...
✓ Database cleared

🔧 Creating constraints and indexes...
✓ Constraints and indexes created

📂 Parsing repository: /kaggle/input/chatbot/chatbotai_v1-main
Found 25 files to parse after filtering


Parsing files:   0%|          | 0/25 [00:00<?, ?file/s]

  🟡 Skipping empty file: /kaggle/input/chatbot/chatbotai_v1-main/backend/__init__.py
  🟡 Skipping empty file: /kaggle/input/chatbot/chatbotai_v1-main/backend/database/__init__.py
  🟡 Skipping empty file: /kaggle/input/chatbot/chatbotai_v1-main/backend/database/schemas/__init__.py
  🟡 Skipping empty file: /kaggle/input/chatbot/chatbotai_v1-main/backend/routers/___init__.py
  🟡 Skipping empty file: /kaggle/input/chatbot/chatbotai_v1-main/backend/utils/__init__.py

✓ Parsing complete!
  - Total nodes extracted: 1021
  - Total relationships extracted: 2069

🔧 Enriching nodes...
Found 50 documentation nodes to embed.
Found 139 definition nodes for LLM description.


  Embedding Docstrings:   0%|          | 0/50 [00:00<?, ?doc/s]

  Generating LLM Descriptions:   0%|          | 0/139 [00:00<?, ?def/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

✓ Enrichment complete! Embedded 50 docstrings, added 139 LLM descriptions.

📥 Ingesting data to Neo4j...
  - Nodes to ingest: 1160
  - Relationships to ingest: 2208


  Ingesting Nodes:   0%|          | 0/3 [00:00<?, ?batch/s]

  ✓ Processed 1160 nodes (actual inserts depend on MERGE)


  Ingesting Relationships:   0%|          | 0/5 [00:00<?, ?batch/s]

  ✓ Processed 2208 relationships (actual inserts depend on MERGE)
✓ Data ingestion complete

📊 Graph Statistics:

  Node Types (excluding :Node):
    Placeholder: 227
    Node: 159
    GeneratedDescription: 139
    Attribute: 127
    FunctionDefinition: 121
    Documentation: 50
    ClassDefinition: 18

  Relationship Types:
    CALLS: 793
    IMPORTS: 139
    HAS_DEFINITION: 139
    HAS_DESCRIPTION: 139
    HAS_ATTRIBUTE: 127
    DEFINES_FUNCTION: 121
    HAS_DOCUMENTATION: 50
    DEFINES_CLASS: 18

  Total Nodes: 841
  Total Relationships: 1526
✓ Disconnected from Neo4j

--- PHASE 1 COMPLETE (Improved) ---
Total time: 241.90 seconds
Knowledge Graph is built, enriched, and indexed.


# Phase 2: Hybrid Code Retrieval

In [None]:
import json
class CodeRetriever:
    def __init__(self, graph_constructor, llm, encoder):
        """
        Initializes the retriever with necessary components.
        Args:
            graph_constructor: Instance of Neo4jGraphConstructor (or similar) to run Cypher.
            llm: Language model instance (e.g., ChatOpenAI) for entity extraction.
            encoder: SentenceTransformer instance for creating embeddings.
        """
        self.graph = graph_constructor
        self.llm = llm
        self.encoder = encoder
        print("Code Retriever initialized.")

    # === Step 2.1: Query Processing ===
    def _get_query_vector(self, query: str) -> list[float]:
        """Encodes the user query into a vector using the sentence transformer."""
        print(f"\n1a. Encoding query: '{query}'")
        try:
            return self.encoder.encode(query, convert_to_numpy=True).tolist()
        except Exception as e:
            print(f"  Error encoding query: {e}")
            return []

    def _identify_entities(self, query: str) -> list[str]:
        """Uses LLM to identify potential code entities (names) mentioned in the query."""
        print(f"\n1b. Identifying entities in query: '{query}'")
        # Prompt asking for specific code artifacts like function/class names
        prompt = f"""
        Analyze the following user query about a software project. Identify any specific function names,
        class names, method names, variable names, or module names mentioned.
        Extract only these names. Do not include generic terms like 'function', 'class', 'file', etc.
        Return the result ONLY as a JSON list of strings. If no specific names are found, return an empty list [].

        User Query: "{query}"

        JSON List of Names:
        """
        try:
            response = self.llm.invoke(prompt)
            # Attempt to parse the JSON list from the LLM response
            content = response.content.strip()
            # Handle potential markdown code block ```json ... ```
            if content.startswith("```json"):
                content = content[7:]
            if content.endswith("```"):
                content = content[:-3]
            content = content.strip()

            entities = json.loads(content)
            # Validate the result is a list of strings
            if isinstance(entities, list) and all(isinstance(e, str) for e in entities):
                 print(f"  Identified entities: {entities}")
                 return entities
            else:
                print(f"  Warning: LLM did not return a valid JSON list of strings. Response: {content}")
                return []
        except json.JSONDecodeError:
             print(f"  Warning: LLM response was not valid JSON. Response: {response.content}")
             return []
        except Exception as e:
            print(f"  Error identifying entities: {e}")
            return []

    # === Step 2.2: Initial Retrieval (Hybrid Search) ===
    def _initial_retrieval(self, query_vector: list[float], entities: list[str], top_k: int = 5) -> set[str]:
        """
        Performs hybrid search (Vector + Full-text) against the Neo4j graph.
        Returns a set of unique FQNs (Fully Qualified Names) of the initially retrieved code nodes.
        """
        initial_fqns = set()

        # --- Semantic Search (Vector Index) ---
        if query_vector:
            print(f"\n2a. Performing Semantic Search (top {top_k})...")
            # Uses the 'descriptions' vector index created in Phase 1
            cypher_vector_search = """
                CALL db.index.vector.queryNodes('descriptions', $top_k, $query_vector) YIELD node AS descNode, score
                // Navigate from description -> definition -> actual code node
                MATCH (codeNode)-[:HAS_DEFINITION]->(defNode)-[:HAS_DESCRIPTION]->(descNode)
                WHERE codeNode:Function OR codeNode:Method OR codeNode:Class // Focus on core code elements
                RETURN codeNode.fqn AS fqn, score
                ORDER BY score DESC
            """
            try:
                results = self.graph.run_cypher_query(cypher_vector_search,
                                                     params={'top_k': top_k, 'query_vector': query_vector})
                if results:
                    found_fqns = {record['fqn'] for record in results}
                    initial_fqns.update(found_fqns)
                    print(f"  Semantic Search added {len(found_fqns)} FQNs.")
            except Exception as e:
                print(f"  Error during Semantic Search: {e}")

        # --- Full-Text Search (Keyword Index) ---
        if entities:
            print(f"\n2b. Performing Full-Text Search for entities: {entities} (top {top_k})...")
            # Build a query string suitable for Neo4j full-text (simple OR logic)
            # Escape special characters if necessary, or use parameterized approach if supported
            ft_query_string = " OR ".join(f'"{entity}"' for entity in entities) # Wrap entities in quotes if needed

            # Uses the 'names' full-text index created in Phase 1
            cypher_ft_search = """
                CALL db.index.fulltext.queryNodes('names', $query_string, {limit: $top_k}) YIELD node AS codeNode, score
                WHERE codeNode:Function OR codeNode:Method OR codeNode:Class OR codeNode:Attribute // Match indexed types
                RETURN codeNode.fqn AS fqn, score
                ORDER BY score DESC
            """
            try:
                results = self.graph.run_cypher_query(cypher_ft_search,
                                                     params={'query_string': ft_query_string, 'top_k': top_k})
                if results:
                    found_fqns = {record['fqn'] for record in results}
                    initial_fqns.update(found_fqns)
                    print(f"  Full-Text Search added {len(found_fqns)} unique FQNs.")
            except Exception as e:
                print(f"  Error during Full-Text Search: {e}")

        print(f"\n=> Initial retrieval combined set size: {len(initial_fqns)}")
        return initial_fqns

    # === Step 2.3: Graph Traversal ===
    def _retrieve_n_hop_subgraph(self, start_fqns: set[str], num_hops: int = 2) -> tuple[list[dict], list[dict]]:
        """
        Retrieves the n-hop subgraph around the starting FQNs using APOC.
        Returns lists of node dictionaries and relationship dictionaries.
        """
        nodes_list = []
        relationships_list = []
        if not start_fqns:
            print("\n3. Graph Traversal skipped: No start nodes.")
            return nodes_list, relationships_list

        print(f"\n3. Retrieving {num_hops}-hop subgraph from {len(start_fqns)} start nodes...")

        # --- SỬA LỖI CÚ PHÁP CYPHER ---
        # Chạy APOC cho mỗi start node, sau đó UNWIND và COLLECT DISTINCT
        cypher_n_hop = """
            MATCH (startNode) WHERE startNode.fqn IN $start_fqns
            CALL apoc.path.subgraphAll(startNode, {
                maxLevel: $num_hops,
                relationshipFilter: "DEFINES_CLASS>|DEFINES_FUNCTION>|<HAS_METHOD|<HAS_DEFINITION>|<HAS_ATTRIBUTE|<IMPORTS>|<CALLS>|<HAS_DESCRIPTION"
            }) YIELD nodes, relationships
            // Unwind lists returned for each start node
            UNWIND nodes AS n
            UNWIND relationships AS r
            // Collect distinct nodes and relationships across all subgraphs
            RETURN collect(DISTINCT n) AS distinctNodes,
                   collect(DISTINCT r) AS distinctRelationships
        """
        # ---------------------------------

        try:
            results = self.graph.run_cypher_query(cypher_n_hop,
                                                 params={'start_fqns': list(start_fqns), 'num_hops': num_hops})

            # Kiểm tra kết quả cẩn thận hơn
            if results and results[0] and results[0].get('distinctNodes') is not None and results[0].get('distinctRelationships') is not None:
                 all_nodes_raw = results[0]['distinctNodes']
                 all_rels_raw = results[0]['distinctRelationships']

                 # Xử lý kết quả (giữ nguyên logic chuyển đổi)
                 nodes_map = {n.id: {"id": n.id, "labels": list(n.labels), **dict(n.items())} for n in all_nodes_raw}
                 nodes_list = list(nodes_map.values())

                 for rel in all_rels_raw:
                     start_node_data = nodes_map.get(rel.start_node.id)
                     end_node_data = nodes_map.get(rel.end_node.id)
                     if start_node_data and end_node_data:
                         relationships_list.append({
                             "id": rel.id,
                             "type": rel.type,
                             "startNodeId": rel.start_node.id,
                             "endNodeId": rel.end_node.id,
                             "startNodeFqn": start_node_data.get('fqn'),
                             "endNodeFqn": end_node_data.get('fqn'),
                             "properties": dict(rel.items())
                         })

                 print(f"  Retrieved subgraph with {len(nodes_list)} nodes and {len(relationships_list)} relationships.")
            else:
                 # In rõ hơn nếu không có kết quả hoặc kết quả rỗng
                 print(f"  Graph traversal query executed but returned no distinct nodes or relationships. Result: {results}")


        except Exception as e:
            print(f"  Error during N-hop retrieval: {e}")

        return nodes_list, relationships_list


    # === Step 2.4: Filtering Sub-graph ===
    def _filter_subgraph_by_similarity(self, nodes: list[dict], relationships: list[dict], query_vector: list[float], top_k_filter: int = 20) -> tuple[list[dict], list[dict]]:
        """
        Filters the subgraph based on semantic similarity to the query.
        Prioritizes GeneratedDescription nodes and keeps directly connected structural nodes.
        """
        filtered_nodes = []
        filtered_relationships = []
        if not nodes or not query_vector or top_k_filter <= 0:
            print("\n4. Filtering skipped: No nodes, query vector, or filter size.")
            return nodes, relationships # Return original if no filtering needed

        print(f"\n4. Filtering subgraph by similarity (aiming for ~{top_k_filter} most relevant nodes)...")

        try:
            # Calculate similarity for nodes with embeddings
            node_similarities = []
            for node in nodes:
                if node.get('embedding') and node.get('labels') and 'GeneratedDescription' in node.get('labels'):
                    node_vector = node['embedding']
                    # Cosine Similarity using numpy
                    vec1 = np.array(query_vector)
                    vec2 = np.array(node_vector)
                    # Handle potential zero vectors
                    norm1 = np.linalg.norm(vec1)
                    norm2 = np.linalg.norm(vec2)
                    if norm1 == 0 or norm2 == 0:
                        similarity = 0.0
                    else:
                        similarity = np.dot(vec1, vec2) / (norm1 * norm2)
                    node_similarities.append((node['id'], similarity, node.get('fqn'))) # Store node ID and FQN

            # Sort by similarity (highest first)
            node_similarities.sort(key=lambda item: item[1], reverse=True)

            # Get IDs of top-k description nodes
            top_desc_node_ids = {node_id for node_id, score, fqn in node_similarities[:top_k_filter]}
            print(f"  Identified {len(top_desc_node_ids)} top description nodes.")

            # Find structural nodes (Code, Definition, File) connected to these top descriptions
            connected_structural_node_ids = set()
            node_id_to_data = {node['id']: node for node in nodes} # Map ID to full node data

            for rel in relationships:
                is_connected_to_top_desc = False
                other_node_id = None
                connected_node_type = None

                if rel['endNodeId'] in top_desc_node_ids and rel['type'] == 'HAS_DESCRIPTION':
                    is_connected_to_top_desc = True
                    other_node_id = rel['startNodeId'] # This should be a Definition node
                    connected_node_data = node_id_to_data.get(other_node_id)
                    if connected_node_data:
                         connected_node_type = 'Definition' # Mark type for clarity

                # Add logic here if descriptions can link from code nodes directly (based on schema)

                if is_connected_to_top_desc and other_node_id:
                     connected_structural_node_ids.add(other_node_id)
                     # Now, find the actual code node linked to this definition node
                     for r_inner in relationships:
                         if r_inner['endNodeId'] == other_node_id and r_inner['type'] == 'HAS_DEFINITION':
                              code_node_id = r_inner['startNodeId']
                              connected_structural_node_ids.add(code_node_id)
                              # Also add the file containing the code node
                              code_node_data = node_id_to_data.get(code_node_id)
                              if code_node_data:
                                   for r_file in relationships:
                                       if r_file['endNodeId'] == code_node_id and r_file['type'] in ('DEFINES_CLASS', 'DEFINES_FUNCTION'):
                                            connected_structural_node_ids.add(r_file['startNodeId'])
                                       elif r_file['startNodeId'] == code_node_id and r_file['type'] == 'HAS_METHOD': # Method links from class
                                            class_node_id = r_file['startNodeId']
                                            for r_file_class in relationships:
                                                 if r_file_class['endNodeId'] == class_node_id and r_file_class['type'] == 'DEFINES_CLASS':
                                                      connected_structural_node_ids.add(r_file_class['startNodeId'])


            print(f"  Found {len(connected_structural_node_ids)} structural nodes connected to top descriptions.")

            # Combine top description nodes and connected structural nodes
            final_node_ids = top_desc_node_ids.union(connected_structural_node_ids)

            # Filter nodes and relationships based on the final set of IDs
            filtered_nodes = [node for node in nodes if node['id'] in final_node_ids]
            filtered_relationships = [rel for rel in relationships if rel['startNodeId'] in final_node_ids and rel['endNodeId'] in final_node_ids]

            print(f"=> Filtered subgraph to {len(filtered_nodes)} nodes and {len(filtered_relationships)} relationships.")

        except ImportError:
            print("  Numpy not installed. Skipping similarity filtering.")
            return nodes, relationships # Return unfiltered if numpy unavailable
        except Exception as e:
            print(f"  Error during filtering: {e}")
            return nodes, relationships # Return unfiltered on error

        return filtered_nodes, filtered_relationships

    # === Main Retrieval Function ===
    def retrieve(self, query: str, initial_top_k: int = 5, num_hops: int = 2, filter_top_k: int = 20) -> tuple[list[dict], list[dict]]:
        """
        Executes the full retrieval pipeline: Process Query -> Initial Retrieval -> Graph Traversal -> Filter.
        Returns the filtered subgraph (list of node dicts, list of relationship dicts).
        """
        print(f"\n--- Starting Retrieval Pipeline for Query: '{query}' ---")
        # 1. Process Query
        query_vector = self._get_query_vector(query)
        entities = self._identify_entities(query)

        # 2. Initial Retrieval (Hybrid)
        initial_fqns = self._initial_retrieval(query_vector, entities, initial_top_k)
        if not initial_fqns:
            print("Pipeline stopped: Initial retrieval failed to find any relevant nodes.")
            return [], []

        # 3. Graph Traversal (N-Hop Expansion)
        subgraph_nodes, subgraph_rels = self._retrieve_n_hop_subgraph(initial_fqns, num_hops)
        if not subgraph_nodes:
            print("Pipeline stopped: Graph traversal failed to expand nodes.")
            return [], []

        # 4. Filter Subgraph (Similarity Reranking)
        filtered_nodes, filtered_rels = self._filter_subgraph_by_similarity(subgraph_nodes, subgraph_rels, query_vector, filter_top_k)

        print(f"--- Retrieval Pipeline Complete ---")
        return filtered_nodes, filtered_rels

In [None]:
# Assuming graph_constructor, llm, encoder are initialized correctly

# 1. Create the retriever instance
retriever = CodeRetriever(graph_constructor, llm, encoder)

# 2. Define your query
user_query = "How do I implement JWT authentication for the login route?"

# 3. Run the retrieval process
# Adjust parameters as needed
retrieved_nodes, retrieved_relationships = retriever.retrieve(
    user_query,
    initial_top_k=5,  # How many initial results from hybrid search
    num_hops=2,       # How far to explore the graph (context depth)
    filter_top_k=20   # How many description nodes to keep after filtering
)

# 4. Prepare context for Phase 3 (Code Generation)
def format_subgraph_for_llm(nodes: list[dict], rels: list[dict]) -> str:
    """Formats the retrieved subgraph into a string suitable for LLM context."""
    context_str = "### Retrieved Code Context:\n\n"
    node_map = {node['id']: node for node in nodes} # Map ID to node data

    # Prioritize Code elements and their definitions/descriptions
    processed_node_ids = set()

    for node in nodes:
        if node['id'] in processed_node_ids: continue

        node_labels = node.get('labels', [])
        node_type = node.get('type', node_labels[0] if node_labels else 'Unknown') # Use 'type' if available
        node_fqn = node.get('fqn', f"NodeID_{node['id']}")
        node_name = node.get('name', node_fqn)

        # Display Code Nodes (Class, Function, Method)
        if node_type in ['Class', 'Function', 'Method']:
            processed_node_ids.add(node['id'])
            context_str += f"--- {node_type}: {node_name} ---\n"
            # Find its definition and description
            def_node = None
            desc_node = None
            for rel in rels:
                 if rel['startNodeId'] == node['id'] and rel['type'] == 'HAS_DEFINITION':
                      def_node = node_map.get(rel['endNodeId'])
                      if def_node:
                           processed_node_ids.add(def_node['id'])
                           for rel2 in rels:
                               if rel2['startNodeId'] == def_node['id'] and rel2['type'] == 'HAS_DESCRIPTION':
                                    desc_node = node_map.get(rel2['endNodeId'])
                                    if desc_node: processed_node_ids.add(desc_node['id'])
                                    break
                      break # Assume only one definition

            if def_node and def_node.get('code'):
                context_str += f"```\n{def_node.get('code')}\n```\n"
            if def_node and def_node.get('docstring'):
                context_str += f"Docstring: {def_node.get('docstring')}\n"
            if desc_node and desc_node.get('description'):
                context_str += f"LLM Description: {desc_node.get('description')}\n"
            context_str += "\n"

    # Optionally add relationship information
    context_str += "### Key Relationships:\n"
    rels_to_show = ['CALLS', 'IMPORTS', 'HAS_METHOD', 'HAS_ATTRIBUTE'] # Focus on dependencies
    for rel in rels:
        if rel['type'] in rels_to_show:
             start_node = node_map.get(rel['startNodeId'])
             end_node = node_map.get(rel['endNodeId'])
             if start_node and end_node:
                  start_name = start_node.get('name', start_node.get('fqn', f"ID_{start_node['id']}"))
                  end_name = end_node.get('name', end_node.get('fqn', f"ID_{end_node['id']}"))
                  context_str += f"- {start_name} --[{rel['type']}]--> {end_name}\n"

    return context_str

# Generate the context string
context_string_for_llm = format_subgraph_for_llm(retrieved_nodes, retrieved_relationships)

print("\n=== Formatted Context for LLM (Phase 3) ===")
print(context_string_for_llm[:2000] + "\n..." if len(context_string_for_llm) > 2000 else context_string_for_llm)

# Now, you would combine this context_string_for_llm with the user_query
# and send it to your code generation LLM (Phase 3).

In [None]:
query = "How do I implement JWT authentication for the login route?"
prompt = f"""
        Analyze the following user query about a software project. Identify any specific function names,
        class names, method names, variable names, or module names mentioned.
        Extract only these names. Do not include generic terms like 'function', 'class', 'file', etc.
        Return the result ONLY as a JSON list of strings. If no specific names are found, return an empty list [].

        User Query: "{query}"

        JSON List of Names:
        """
response = llm.invoke(prompt)
print(response)