# 📊 Fund Review Automation

This notebook performs the full pipeline:
- Connect to Neo4j: Interface with a graph database to store and query legal content as a knowledge graph.

- Construct a Legal Knowledge Graph: Transform unstructured legal documents (e.g., constitutions, prospectuses, shareholder agreements) into a structured graph using custom document splitters, entity/relation extraction, and semantic linking.

- Run Louvain Community Detection: Identify hierarchical legal taxonomies by discovering clusters of semantically or referentially related legal content.

- Label and Analyze Communities: Use generative AI to assign interpretable labels to communities of legal texts, forming a multidimensional taxonomy for fund review reports.

- GraphRAG : Use a first version of GraphRAG to query the knowledge graph and retrieve legal texts.

## Constitution

In [7]:
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
import re
import asyncio

@dataclass
class HierarchyLevel:
    """Defines a hierarchy level with its regular expression extraction and metadata key."""
    name: str
    pattern: str
    metadata_key: str
    title_key: Optional[str] = None
    strip_header_pattern: Optional[str] = None

class ConstitutionSplitter(TextSplitter):
    """Splitter for constitution documents."""

    def __init__(self, llm, graph_id: str, overlap_percentage: float = 0.2) -> None:
        self.llm = llm
        self.graph_id = graph_id
        #self.overlap_percentage = max(0.0, min(1.0, overlap_percentage))
        
        # Define the hierarchy levels in order : section -> article
        self.hierarchy_levels = [
            HierarchyLevel(
                name="section",
                pattern=r'^Section\s+(\d+)\s+[–\-]\s+([^\n]+)',
                metadata_key="section_num",
                title_key="section_title",
                strip_header_pattern=r'^Section\s+{num}\s+[–\-]\s+{title}\n?'
            ),
            HierarchyLevel(
                name="article", 
                pattern=r'^(?:[ \t]*)(\d+)\s+[–\-]\s+([^\n]+)',
                metadata_key="article_num",
                title_key="article_title",
                strip_header_pattern=r'^{num}\s+[–\-]\s+{title}\n?'
            )
        ]

        model_name = self.llm.model_name
        try:
            self.tokenizer = tiktoken.encoding_for_model(model_name)
        except KeyError:
            self.tokenizer = tiktoken.get_encoding("cl100k_base")

    def split_text(self, text: str) -> TextChunks:
        """Synchronous wrapper for the main async processing method."""
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            pass
        return asyncio.run(self.run(text.strip()))

    async def run(self, text: str) -> TextChunks:
        """Main processing method using recursive hierarchy processing."""
        
        cleaned_text = re.sub(r'\\n?', '', text).strip()
        
        #header_info = self._extract_header(cleaned_text)
        header_info = {
            "document_id": self.graph_id,
            "source": "legal_document",
            "document_type": "legal_constitution",
            "document_title": f"Legal Document {self.graph_id}"
        }
        base_metadata = {**header_info, "graph_id": self.graph_id}
        
        chunks = []
        chunk_index = [0]  # Use list to allow modification in nested calls
        
        
        # Process hierarchical content recursively
        self._process_hierarchy_level(
            content=cleaned_text,
            level_index=0,
            metadata=base_metadata,
            chunks=chunks,
            chunk_index=chunk_index
        )
        
        return TextChunks(chunks=chunks)


    def _process_hierarchy_level(self, content: str, level_index: int, metadata: dict, 
                                chunks: List, chunk_index: List[int]):
        """
        Process recursively a hierarchy level. 
        Args:
            content: Text content to process at this level
            level_index: Current hierarchy level (0=section, 1=article)
            metadata: Accumulated metadata from parent levels
            chunks: List of chunks
            chunk_index: Current chunk index
        """
        
        # Base case: no more hierarchy levels to process
        if level_index >= len(self.hierarchy_levels):
            if content.strip():
                chunks.append(self._create_chunk(chunk_index[0], content, metadata))
                chunk_index[0] += 1
            return
        
        current_level = self.hierarchy_levels[level_index]
        items = self._extract_items_at_level(content, current_level)
        
        # If no items found at this level, try the next level or create chunk
        if not items:
            self._process_hierarchy_level(content, level_index + 1, metadata, chunks, chunk_index)
            return
        
        # Process each item at this level
        for item_id, item_title, item_content in items:
            # Build metadata for this level, {**metadata} to do a shallow copy
            current_metadata = {**metadata}
            current_metadata[current_level.metadata_key] = item_id
            if current_level.title_key and item_title:
                current_metadata[current_level.title_key] = item_title
            current_metadata["chunk_type"] = current_level.name
            
            # Strip header if needed
            processed_content = self._strip_level_header(
                item_content, current_level, item_id, item_title
            )
            
            # Recursively process the next level
            self._process_hierarchy_level(
                content=processed_content,
                level_index=level_index + 1,
                metadata=current_metadata,
                chunks=chunks,
                chunk_index=chunk_index
            )

    def _extract_items_at_level(self, content: str, level: HierarchyLevel) -> List[Tuple[str, Optional[str], str]]:
        """
        Extract items at a specific hierarchy level. An item can be a section or an article.
        Args:
            content: Text content to search for items
            level: HierarchyLevel object defining the pattern and metadata keys
        
        Returns:
            List of (item_id, item_title, item_content) tuples
        """
        
        matches = list(re.finditer(level.pattern, content, re.MULTILINE))
        if not matches:
            return []
        
        items = []
        for i, match in enumerate(matches):
            # Extract ID and title based on pattern groups
            if level.title_key:
                item_id, item_title = match.group(1), match.group(2).strip() # group(1) is ID, group(2) is title
            else:  # Only has ID (paragraphs)
                if match.lastindex >= 1:
                    item_id = match.group(1)
                else:
                    match.group('id')
                item_title = None
            
            # Extract content block
            start_pos = match.start()
            if i < len(matches) - 1:
                # Get content until the next match
                end_pos = matches[i + 1].start()
            else:
                end_pos = len(content)
                
            item_content = content[start_pos:end_pos].strip()
            
            if item_content:
                items.append((item_id, item_title, item_content))
        
        return items

    def _strip_level_header(self, content: str, level: HierarchyLevel, 
                           item_id: str, item_title: Optional[str]) -> str:
        """Strip the header from content at a specific level."""
        if not level.strip_header_pattern:
            # For enumerated items, strip the leading pattern
            pattern = level.pattern
            return re.sub(pattern, '', content, count=1, flags=re.MULTILINE).strip()
        
        # For sections/articles with titles
        if item_title:
            pattern = level.strip_header_pattern.format(
                num=re.escape(item_id),
                title=re.escape(item_title)
            )
        else:
            pattern = level.strip_header_pattern.format(num=re.escape(item_id))
        
        return re.sub(pattern, '', content, count=1, flags=re.MULTILINE).strip()

    def _extract_header(self, text: str) -> Dict[str, str]:
        """Extract document header information with required fields for lexical graph."""
        header = {
            "document_id": self.graph_id,
            "source": "legal_document",
            "document_type": "legal_constitution"
        }
        
        return header

    def _create_chunk(self, index: int, text: str, metadata: dict) -> TextChunk:
        """Creates a TextChunk with cleaned text and hierarchical metadata."""
        
        # Clean the text
        clean_text = re.sub(r'\s+', ' ', text).strip()
        
        # Determine level
        chunk_type = metadata.get("chunk_type", "content")
        level_mapping = {
            "document_root": "document",
            "section": "section",
            "article": "article", 
        }
        level = level_mapping.get(chunk_type, chunk_type)
        
        # Build hierarchical path
        path_parts = self._build_hierarchical_path(metadata, level)
        
        final_metadata = {
            **metadata,
            "level": level,
            "hierarchical_path": " → ".join(path_parts) if path_parts else "Document Content",
        }
        
        return TextChunk(index=index, text=clean_text, metadata=final_metadata)

    def _build_hierarchical_path(self, metadata: dict, level: str) -> List[str]:
        """Build the hierarchical path for a chunk."""
        path_parts = []
        
        if level == "document" and "document_title" in metadata:
            path_parts.append(metadata["document_title"])
        else:
            if "section_title" in metadata:
                path_parts.append(f"Section {metadata.get('section_num','?')} – {metadata.get('section_title','Untitled Section')}")
            if "article_title" in metadata:
                path_parts.append(f"Article {metadata.get('article_num','?')} – {metadata.get('article_title','Untitled Article')}")
        
        return path_parts

# Prospectus

In [2]:
import re
import asyncio
from typing import List, Dict
import tiktoken
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
from neo4j_graphrag.experimental.components.types import TextChunk, TextChunks

class ProspectusSplitter(TextSplitter):
    """Splitter for prospectus documents."""

    def __init__(self, llm, graph_id: str, overlap_percentage: float = 0.2) -> None:
        self.llm = llm
        self.graph_id = graph_id
        self.overlap_percentage = overlap_percentage
        if not 0.0 <= overlap_percentage <= 1.0:
            raise ValueError("overlap_percentage must be between 0.0 and 1.0")

        # Get the encoding for the model being used
        model_name = self.llm.model_name
        try:
            self.tokenizer = tiktoken.encoding_for_model(model_name)
        except KeyError:
            # Fallback to cl100k_base encoding (used by many OpenAI models)
            self.tokenizer = tiktoken.get_encoding("cl100k_base")

        # Document structure regular expressions
        self.patterns = {
            # Major sections with all-caps titles (e.g., "UNDERSTANDING INVESTMENT POWERS")
            'major_section': r'(?:\n|\A)([A-Z][A-Z\s]+[A-Z])(?:\n|$)(.*?)(?=(?:\n|\A)[A-Z][A-Z\s]+[A-Z](?:\n|$)|\Z)',

            # Sub-sections with Title Case (e.g., "What to Know About Investment Policy")
            'sub_section': r'(?:\n|\A)([A-Z][a-zA-Z\s]+[a-zA-Z])[:\.\n](.*?)(?=(?:\n|\A)[A-Z][a-zA-Z\s]+[a-zA-Z][:\.\n]|\Z)',

            # Tables that often begin with headers and have structured data
            'table': r'(?:\n|\A)([A-Za-z\s]+)\n((?:[^\n]+\n)+?(?=\n|\Z))',

            # Definitions (e.g., "Class: Limited partner interests...")
            'definition': r'(?:\n|\A)([A-Za-z\s]+)\.\s+(.*?)(?=(?:\n|\A)[A-Za-z\s]+\.\s+|\Z)',

            # Bullet points
            'bullet_points': r'(?:\n|\A)(?:•|\*|\-)\s+(.*?)(?=(?:\n|\A)(?:•|\*|\-)|\Z)',

            # For numbered items
            'numbered_item': r'(?:\n|\A)(\d+\.)\s+(.*?)(?=(?:\n|\A)\d+\.|\Z)',
        }

    async def run(self, text: str) -> TextChunks:
        # Identify the document type and extract metadata
        document_metadata = self._extract_document_metadata(text)

        # Extract major sections from the document
        major_sections = self._extract_major_sections(text)

        # If no clear major sections found, fall back to sub-sections
        if not major_sections:
            major_sections = self._extract_sub_sections(text)

        # If still no clear structure, fallback to paragraphs
        if not major_sections:
            major_sections = self._extract_paragraphs(text)

        
        all_chunks = []
        # Process each section into chunks
        for section_idx, (section_title, section_content) in enumerate(major_sections):
            # Add metadata about the section
            metadata = {
                **document_metadata,
                "graph_id": self.graph_id,
                "section_title": section_title,
                "section_index": section_idx,
            }

            # Tokenize the section
            tokens = self.tokenizer.encode(section_content)
            max_tokens = min(
                4096,  # Default context window
                self.llm.model_params.get("max_tokens", 3000) * 3,
            )

            # If the section fits in one chunk
            if len(tokens) <= max_tokens:
                all_chunks.append(
                    TextChunk(
                        index=len(all_chunks),
                        text=section_content,
                        metadata=metadata
                    )
                )
            else:
                # Need to split the section, try by sub-sections first
                sub_sections = self._extract_sub_sections(section_content)

                if sub_sections:
                    # Process each sub-section
                    current_chunk_text = ""
                    current_chunk_tokens = 0
                    subsection_metadata = metadata.copy()
                    subsection_metadata["subsections"] = []

                    for subsec_title, subsec_content in sub_sections:
                        subsec_tokens = self.tokenizer.encode(subsec_content)

                        # If adding this subsection would exceed max tokens, create a new chunk
                        if current_chunk_tokens + len(subsec_tokens) > max_tokens and current_chunk_text:
                            all_chunks.append(
                                TextChunk(
                                    index=len(all_chunks),
                                    text=current_chunk_text,
                                    metadata=subsection_metadata
                                )
                            )
                            current_chunk_text = f"[Continued from {section_title}]\n\n" # Indicate continuation from prev chunk
                            current_chunk_tokens = len(self.tokenizer.encode(current_chunk_text))
                            subsection_metadata = metadata.copy()
                            subsection_metadata["continued"] = True
                            subsection_metadata["subsections"] = []

                        # Add the subsection to the current chunk
                        current_chunk_text += subsec_content + "\n\n"
                        current_chunk_tokens += len(subsec_tokens) + len(self.tokenizer.encode("\n\n"))

                        # Track subsection information
                        subsection_metadata["subsections"].append(subsec_title)

                    # Add the final chunk if there is a content
                    if current_chunk_text:
                        all_chunks.append(
                            TextChunk(
                                index=len(all_chunks),
                                text=current_chunk_text,
                                metadata=subsection_metadata
                            )
                        )
                else:
                    # No subsection found : try bullet points or fall back to overlapping chunks
                    bullet_points = self._extract_bullet_points(section_content)

                    if bullet_points and len(bullet_points) > 1:
                        # Process bullet points as chunks
                        current_chunk_text = ""
                        current_chunk_tokens = 0
                        bullet_metadata = metadata.copy()
                        bullet_metadata["bullet_points"] = True

                        for bullet_title, bullet_content in bullet_points:
                            bullet_tokens = self.tokenizer.encode(bullet_content)

                            # If adding this bullet exceed max tokens, create a new chunk
                            if current_chunk_tokens + len(bullet_tokens) > max_tokens and current_chunk_text:
                                all_chunks.append(
                                    TextChunk(
                                        index=len(all_chunks),
                                        text=current_chunk_text,
                                        metadata=bullet_metadata
                                    )
                                )
                                current_chunk_text = f"[Continued from {section_title} - Bullet Points]\n\n"
                                current_chunk_tokens = len(self.tokenizer.encode(current_chunk_text))
                                bullet_metadata = metadata.copy()
                                bullet_metadata["continued"] = True
                                bullet_metadata["bullet_points"] = True

                            # Add the bullet to the current chunk
                            current_chunk_text += bullet_content + "\n\n"
                            current_chunk_tokens += len(bullet_tokens) + len(self.tokenizer.encode("\n\n"))

                        # Add the final chunk if there's content
                        if current_chunk_text:
                            all_chunks.append(
                                TextChunk(
                                    index=len(all_chunks),
                                    text=current_chunk_text,
                                    metadata=bullet_metadata
                                )
                            )
                    else:
                        # Fall back to overlapping chunks
                        chunk_size = int(max_tokens * (1 - self.overlap_percentage))
                        overlap_size = int(max_tokens * self.overlap_percentage)

                        for i in range(0, len(tokens), chunk_size):
                            start_index = max(0, i)
                            end_index = min(i + chunk_size + overlap_size, len(tokens))
                            current_chunk_tokens = tokens[start_index:end_index]
                            current_chunk_text = self.tokenizer.decode(current_chunk_tokens)

                            chunk_metadata = metadata.copy()
                            if i > 0:
                                chunk_metadata["continued"] = True
                                current_chunk_text = f"[Continued from {section_title}]\n\n{current_chunk_text}"

                            all_chunks.append(
                                TextChunk(
                                    index=len(all_chunks),
                                    text=current_chunk_text,
                                    metadata=chunk_metadata
                                )
                            )

        # Identify and mark special chunks like tables, definitions.
        all_chunks = self._mark_special_chunks(all_chunks)

        return TextChunks(chunks=all_chunks)

    def _extract_document_metadata(self, text: str) -> Dict:
        """Extract key metadata from the document."""
        
        metadata = {
            "document_type": "Offering Document",
            "document_title": ""
        }

        # Try to extract the document title
        title_pattern = r'(?:\A|\n)([^\n]{5,150}(?:Fund|Offering|Prospectus|Document)[^\n]{0,50})'
        title_match = re.search(title_pattern, text[:1000], re.IGNORECASE)
        if title_match:
            metadata["document_title"] = title_match.group(1).strip()

        # Try to extract the fund name
        fund_pattern = r'(?:\A|\n)([A-Za-z0-9\s]+(?:Fund|Partnership|Trust|Company))'
        fund_match = re.search(fund_pattern, text[:1000], re.IGNORECASE)
        if fund_match:
            metadata["fund_name"] = fund_match.group(1).strip()

        return metadata

    def _extract_major_sections(self, text: str) -> List[tuple]:
        """Extract major sections from the document."""
        
        matches = re.finditer(self.patterns['major_section'], text, re.DOTALL)
        sections = []

        prev_end = 0
        for match in matches:
            # If there is a content before the first section, add it
            if prev_end == 0 and match.start() > 0:
                intro_content = text[:match.start()].strip()
                if intro_content:
                    sections.append(("Introduction", intro_content))

            section_title = match.group(1).strip()
            section_content = match.group(0).strip()

            sections.append((section_title, section_content))
            prev_end = match.end()

        # If there's a content after the last section, add it
        if prev_end < len(text) and text[prev_end:].strip():
            sections.append(("Additional information", text[prev_end:].strip()))

        return sections

    def _extract_sub_sections(self, text: str) -> List[tuple]:
        """Extract sub-sections from a section of text."""
        
        matches = re.finditer(self.patterns['sub_section'], text, re.DOTALL)
        sub_sections = []

        for match in matches:
            subsec_title = match.group(1).strip()
            subsec_content = match.group(0).strip()

            sub_sections.append((subsec_title, subsec_content))

        return sub_sections
    
    
    def _extract_paragraphs(self, text: str) -> List[tuple]:
        """Extract paragraphs when no clear section structure is found."""
        
        paragraphs = re.split(r'\n\s*\n', text)
        result = []

        for i, para in enumerate(paragraphs):
            if para.strip():
                # Try to a title from the first line
                lines = para.strip().split('\n')
                if lines and len(lines[0]) < 100:
                    title = lines[0].strip()
                else:
                    words = para.strip().split()[:5]
                    
                    if words:
                        title = " ".join(words) + "..."
                    else:
                        f"Paragraph {i+1}"

                result.append((title, para.strip()))

        return result

    def _extract_bullet_points(self, text: str) -> List[tuple]:
        """Extract bullet points from a text section."""
        
        matches = re.finditer(self.patterns['bullet_points'], text, re.DOTALL)
        bullets = []

        for i, match in enumerate(matches):
            bullet_content = match.group(0).strip()
            bullet_text = match.group(1).strip()

            words = bullet_text.split()
            title = " ".join(words[:min(5, len(words))]) + "..."

            bullets.append((f"Bullet {i+1}: {title}", bullet_content))

        return bullets


    def _mark_special_chunks(self, chunks: List[TextChunk]) -> List[TextChunk]:
        """Identify chunks that contain tables, definitions, or numbered items."""
        
        for chunk in chunks:
            # Check for tables
            table_matches = re.findall(r'(?:\n|\A)([A-Za-z\s]+)\n((?:[^\n]+\n)+?(?=\n|\Z))', chunk.text, re.DOTALL)
            if table_matches:
                if chunk.metadata is None:
                    chunk.metadata = {}
                chunk.metadata["contains_tables"] = True
                chunk.metadata["table_titles"] = [title.strip() for title, _ in table_matches]

            # Check for definitions
            definitions = []
            for match in re.finditer(self.patterns['definition'], chunk.text, re.DOTALL):
                term = match.group(1)
                definition = match.group(2)
                if term and definition:
                    definitions.append((term.strip(), definition.strip()))

            if definitions:
                if chunk.metadata is None:
                    chunk.metadata = {}
                chunk.metadata["contains_definitions"] = True
                chunk.metadata["definitions"] = [term for term, _ in definitions]

            # Check for numbered items
            numbered_items = re.findall(self.patterns['numbered_item'], chunk.text, re.DOTALL)
            if numbered_items:
                if chunk.metadata is None:
                    chunk.metadata = {}
                chunk.metadata["contains_numbered_items"] = True
                chunk.metadata["item_count"] = len(numbered_items)

        return chunks

    def split_text(self, text: str) -> TextChunks:
        """Synchronously calls the async run method."""
        return asyncio.run(self.run(text))

# Shareholders Agreement

In [9]:
from dataclasses import dataclass
from typing import Optional, List, Tuple, Dict, Any
import re
import asyncio
import tiktoken

@dataclass
class HierarchyLevel:
    """Defines a hierarchy level with its regular expression extraction and metadata key."""
    name: str
    pattern: str
    metadata_key: str
    title_key: Optional[str] = None
    strip_header_pattern: Optional[str] = None
    level_number: int = 0  # Added this missing field


class ShareholdersAgreementSplitter(TextSplitter):
    """Improved splitter for shareholders documents using recursive hierarchy processing."""

    def __init__(self, llm, graph_id: str, overlap_percentage: float = 0.2) -> None:
        self.llm = llm
        self.graph_id = graph_id
        self.overlap_percentage = overlap_percentage
        if not 0.0 <= overlap_percentage <= 1.0:
            raise ValueError("overlap_percentage must be between 0.0 and 1.0")

        # Define the hierarchy levels in order of nesting
        self.hierarchy_levels = [
            # Level 1: Main numbered clauses (e.g., "1. Definitions and Interpretation")
            HierarchyLevel(
                name="main_clause",
                pattern=r'^(\d+)\.\s+([A-Z][A-Za-z0-9\s\-\(\),;&\']+?)(?:\n|$)',
                metadata_key="main_clause_num",
                title_key="main_clause_title",
                strip_header_pattern=r'^{num}\.\s+{title}\n?',
                level_number=1
            ),
            # Level 2: First-level sub-clauses (e.g., "1.1 Definitions")
            HierarchyLevel(
                name="sub_clause",
                pattern=r'^(\d+\.\d+)\s+([A-Z][A-Za-z0-9\s\-\(\),;&\']+?)(?:\n|$)',
                metadata_key="sub_clause_num",
                title_key="sub_clause_title",
                strip_header_pattern=r'^{num}\s+{title}\n?',
                level_number=2
            )
        ]

        # Regular expressions for recitals, schedules and annexures
        self.special_patterns = {
            'recital': r'^([A-Z])\.\s+(.*)',
            'schedule': r'^Schedule\s+(\d+)\s*([A-Za-z0-9\s\-\(\)]*)',
            'annexure': r'^(?:Annexure|ANNEXURE)\s+([A-Z])[:\s]*([^\n]*)',
            'schedule_part': r'^PART\s+([A-Z]+)'
        }

        # TOC patterns for filtering
        self.toc_patterns = [
            r'^(\d+(?:\.\d+)*)\.\s+([^\n]+?)\.{2,}\s*\d+\s*$',
            r'^(\d+(?:\.\d+)*)\.\s+([A-Za-z][A-Za-z0-9\s\-\(\),;&\']+?)\s+\d+\s*$',
            r'^(\d+(?:\.\d+)*)\.\s+([^\t\n]+?)\t+\d+\s*$',
            r'^(\d+(?:\.\d+)*)\.\s+([A-Za-z][^\n]+?)\s{2,}\d+\s*$',
            r'^([ivxlcdm]+)\.\s+([A-Za-z][^\n]+?)\s+\d+\s*$',
            r'^([A-Z])\.\s+([A-Za-z][^\n]+?)\s+\d+\s*$'
        ]

        model_name = self.llm.model_name
        try:
            self.tokenizer = tiktoken.encoding_for_model(model_name)
        except KeyError:
            self.tokenizer = tiktoken.get_encoding("cl100k_base")

    def split_text(self, text: str) -> TextChunks:
        """Synchronous wrapper for the main async processing method."""
        return asyncio.run(self.run(text))

    async def run(self, text: str) -> TextChunks:
        """Main processing method using recursive hierarchy processing."""
        
        # Normalize line endings and remove TOC
        cleaned_text = text.replace('\r\n', '\n').replace('\r', '\n')
        filtered_text = self._extract_non_toc_text(cleaned_text)
        
        # Extract header information
        header_info = self._extract_document_header(filtered_text)
        base_metadata = {**header_info, "graph_id": self.graph_id}
        
        chunks = []
        chunk_index = [0]
        
        # First, process recitals separately
        recitals_text, main_content = self._separate_recitals(filtered_text)
        if recitals_text:
            self._process_recitals(recitals_text, base_metadata, chunks, chunk_index)
        
        # Then process main hierarchical content
        if main_content:
            self._process_hierarchy_level(
                content=main_content,
                level_index=0,
                metadata=base_metadata.copy(),
                chunks=chunks,
                chunk_index=chunk_index
            )
        
        # Finally, process schedules and annexures
        schedules_text = self._extract_schedules_content(filtered_text)
        if schedules_text:
            self._process_schedules(schedules_text, base_metadata, chunks, chunk_index)
        
        return TextChunks(chunks=chunks)

    def _process_hierarchy_level(self, content: str, level_index: int, metadata: dict, 
                                chunks: List, chunk_index: List[int]):
        """
        Recursively process a hierarchy level.
        Args:
            content: Text content to process at this level
            level_index: Current hierarchy level index
            metadata: Accumulated metadata from parent levels
            chunks: List of chunks to append to
            chunk_index: Current chunk index counter
        """
        
        # Base case: no more hierarchy levels to process
        if level_index >= len(self.hierarchy_levels):
            if content.strip():
                chunk = self._create_chunk(chunk_index[0], content, metadata)
                if self._should_split_chunk(chunk):
                    sub_chunks = self._split_large_content(content, metadata, chunk_index[0])
                    chunks.extend(sub_chunks)
                    chunk_index[0] += len(sub_chunks)
                else:
                    chunks.append(chunk)
                    chunk_index[0] += 1
            return
        
        current_level = self.hierarchy_levels[level_index]
        items = self._extract_items_at_level(content, current_level)
        
        # If no items found at this level, try the next level or create chunk
        if not items:
            self._process_hierarchy_level(content, level_index + 1, metadata, chunks, chunk_index)
            return
        
        # Process each item at this level
        for item_id, item_title, item_content in items:
            # Build metadata for this level
            current_metadata = {**metadata}
            current_metadata[current_level.metadata_key] = item_id
            if current_level.title_key and item_title:
                current_metadata[current_level.title_key] = item_title
            current_metadata["chunk_type"] = current_level.name
            current_metadata["level"] = current_level.level_number
            
            # Add hierarchical path
            current_metadata["hierarchical_path"] = self._build_hierarchical_path(current_metadata)
            
            # Strip header if needed
            processed_content = self._strip_level_header(
                item_content, current_level, item_id, item_title
            )
            
            # Recursively process the next level
            self._process_hierarchy_level(
                content=processed_content,
                level_index=level_index + 1,
                metadata=current_metadata,
                chunks=chunks,
                chunk_index=chunk_index
            )

    def _extract_items_at_level(self, content: str, level: HierarchyLevel) -> List[Tuple[str, Optional[str], str]]:
        """Extract items at a specific hierarchy level."""
        
        matches = list(re.finditer(level.pattern, content, re.MULTILINE))
        if not matches:
            return []
        
        items = []
        for i, match in enumerate(matches):
            # Skip if this looks like a TOC entry
            if self._is_toc_entry(match.group(0)):
                continue
                
            # Extract ID and title based on pattern groups
            if level.title_key:
                item_id = match.group(1)
                item_title = match.group(2).strip() if match.group(2) else None
            else:
                item_id = match.group(1)
                item_title = None
            
            # Extract content block
            start_pos = match.start()
            if i < len(matches) - 1:
                # Find next non-TOC match
                next_match = None
                for j in range(i + 1, len(matches)):
                    if not self._is_toc_entry(matches[j].group(0)):
                        next_match = matches[j]
                        break
                end_pos = next_match.start() if next_match else len(content)
            else:
                end_pos = len(content)
                
            item_content = content[start_pos:end_pos].strip()
            
            if item_content:
                items.append((item_id, item_title, item_content))
        
        return items

    def _strip_level_header(self, content: str, level: HierarchyLevel, 
                           item_id: str, item_title: Optional[str]) -> str:
        """Strip the header from content at a specific level."""
        if not level.strip_header_pattern:
            return content
        
        if item_title and '{title}' in level.strip_header_pattern:
            pattern = level.strip_header_pattern.format(
                num=re.escape(item_id),
                title=re.escape(item_title),
                id=re.escape(item_id)
            )
        else:
            pattern = level.strip_header_pattern.format(
                num=re.escape(item_id),
                id=re.escape(item_id)
            )
        
        return re.sub(pattern, '', content, count=1, flags=re.MULTILINE).strip()

    def _process_recitals(self, recitals_text: str, base_metadata: dict, 
                         chunks: List, chunk_index: List[int]):
        """Process recitals section."""
        recital_matches = list(re.finditer(self.special_patterns['recital'], recitals_text, re.MULTILINE))
        
        for match in recital_matches:
            recital_id = match.group(1)
            recital_content = match.group(2).strip()
            
            metadata = {
                **base_metadata,
                "chunk_type": "recital",
                "level": 1,
                "recital_id": recital_id,
                "hierarchical_path": f"Recital {recital_id}"
            }
            
            chunk = self._create_chunk(chunk_index[0], f"Recital {recital_id}: {recital_content}", metadata)
            chunks.append(chunk)
            chunk_index[0] += 1

    def _process_schedules(self, schedules_text: str, base_metadata: dict, 
                          chunks: List, chunk_index: List[int]):
        """Process schedules and annexures."""
        schedule_sections = self._extract_schedule_sections(schedules_text)
        
        for section in schedule_sections:
            metadata = {
                **base_metadata,
                "chunk_type": section["type"],
                "level": 1,
                f"{section['type']}_id": section["id"],
                f"{section['type']}_title": section.get("title", ""),
                "hierarchical_path": f"{section['type'].title()} {section['id']}"
            }
            
            if self._should_split_chunk_content(section["content"]):
                sub_chunks = self._split_large_content(section["content"], metadata, chunk_index[0])
                chunks.extend(sub_chunks)
                chunk_index[0] += len(sub_chunks)
            else:
                chunk = self._create_chunk(chunk_index[0], section["content"], metadata)
                chunks.append(chunk)
                chunk_index[0] += 1

    def _separate_recitals(self, text: str) -> Tuple[str, str]:
        """Separate recitals from main content."""
        lines = text.split('\n')
        recitals_lines = []
        main_content_start = 0
        
        for i, line in enumerate(lines):
            if re.match(self.special_patterns['recital'], line.strip()):
                recitals_lines.append(line)
            elif re.match(r'^\d+\.\s+', line.strip()) and recitals_lines:
                main_content_start = i
                break
        
        if recitals_lines:
            recitals_text = '\n'.join(recitals_lines)
        else: 
            recitals_text = ""
        main_content = '\n'.join(lines[main_content_start:]) if main_content_start > 0 else text
        
        return recitals_text, main_content

    def _extract_schedules_content(self, text: str) -> str:
        """Extract schedules and annexures content."""
        lines = text.split('\n')
        schedule_start = None
        
        for i, line in enumerate(lines):
            if re.match(self.special_patterns['schedule'], line.strip()) or \
               re.match(self.special_patterns['annexure'], line.strip()):
                schedule_start = i
                break
        
        return '\n'.join(lines[schedule_start:]) if schedule_start is not None else ""

    def _extract_schedule_sections(self, text: str) -> List[Dict[str, Any]]:
        """Extract individual schedule/annexure sections."""
        sections = []
        lines = text.split('\n')
        current_section = None
        
        for line in lines:
            line_stripped = line.strip()
            
            # Check for schedule
            schedule_match = re.match(self.special_patterns['schedule'], line_stripped)
            if schedule_match:
                if current_section:
                    sections.append(current_section)
                current_section = {
                    'type': 'schedule',
                    'id': schedule_match.group(1),
                    'title': schedule_match.group(2).strip() if schedule_match.group(2) else "",
                    'content': line
                }
                continue
            
            # Check for annexure
            annexure_match = re.match(self.special_patterns['annexure'], line_stripped)
            if annexure_match:
                if current_section:
                    sections.append(current_section)
                current_section = {
                    'type': 'annexure',
                    'id': annexure_match.group(1),
                    'title': annexure_match.group(2).strip() if annexure_match.group(2) else "",
                    'content': line
                }
                continue
            
            # Add content to current section
            if current_section and line_stripped:
                current_section['content'] += '\n' + line
        
        if current_section:
            sections.append(current_section)
        
        return sections

    def _extract_non_toc_text(self, text: str) -> str:
        """Extract text excluding the Table of Contents section."""
        toc_start, toc_end = self._identify_toc_boundaries(text)
        
        if toc_start is None:
            return text
        
        lines = text.split('\n')
        
        if toc_end is not None:
            before_toc = lines[:toc_start] if toc_start > 0 else []
            after_toc = lines[toc_end + 1:] if toc_end + 1 < len(lines) else []
            filtered_lines = before_toc + after_toc
        else:
            before_toc = lines[:toc_start] if toc_start > 0 else []
            first_clause = self._find_first_real_clause_after_toc(lines, toc_start)
            if first_clause is not None:
                after_toc = lines[first_clause:]
                filtered_lines = before_toc + after_toc
            else:
                filtered_lines = before_toc
        
        return '\n'.join(filtered_lines)

    def _identify_toc_boundaries(self, text: str) -> Tuple[Optional[int], Optional[int]]:
        """Identify the start and end boundaries of the Table of Contents."""
        lines = text.split('\n')
        toc_start = None
        toc_end = None
        
        # Look for TOC start
        for i, line in enumerate(lines):
            if re.search(r'(?i)table\s+of\s+contents?|^contents?$', line.strip()):
                toc_start = i
                break
        
        if toc_start is not None:
            consecutive_toc_lines = 0
            for i in range(toc_start + 1, min(toc_start + 100, len(lines))):
                line_stripped = lines[i].strip()
                if not line_stripped:
                    continue
                
                if self._is_toc_entry(line_stripped):
                    consecutive_toc_lines += 1
                    toc_end = i
                else:
                    if consecutive_toc_lines > 0 and self._is_clause(line_stripped):
                        break
                    consecutive_toc_lines = 0
        
        return toc_start, toc_end

    def _is_toc_entry(self, line: str) -> bool:
        """Check if a line appears to be a Table of Contents entry."""
        return any(re.match(pattern, line) for pattern in self.toc_patterns)

    def _is_clause(self, line: str) -> bool:
        """Check if a line is a clause."""
        clause_patterns = [
            r'^(\d+)\.\s+([A-Z][A-Za-z0-9\s\-\(\),;&\']+?)$',
            r'^(\d+\.\d+)\s+([A-Z][A-Za-z0-9\s\-\(\),;&\']+?)$',
            r'^([A-Z])\.\s+([A-Za-z][^\d]*?)$'
        ]
        
        for pattern in clause_patterns:
            match = re.match(pattern, line)
            if match and not re.search(r'\s+\d+$', match.group(2) if len(match.groups()) > 1 else match.group(1)):
                return True
        return False

    def _find_first_real_clause_after_toc(self, lines: List[str], toc_start: int) -> Optional[int]:
        """Find the first real clause after TOC start."""
        for i in range(toc_start + 1, len(lines)):
            if self._is_clause(lines[i].strip()):
                return i
        return None

    def _extract_document_header(self, text: str) -> Dict[str, str]:
        """Extract document header information."""
        return {
            "document_type": "ShareholdersAgreement",
            "document_title": f"Shareholders Agreement {self.graph_id}",
            "source": "legal_document"
        }

    def _build_hierarchical_path(self, metadata: dict) -> str:
        """Build the hierarchical path for a chunk."""
        path_parts = []
        
        if "main_clause_num" in metadata and "main_clause_title" in metadata:
            path_parts.append(f"Clause {metadata['main_clause_num']} – {metadata['main_clause_title']}")
        
        if "sub_clause_num" in metadata and "sub_clause_title" in metadata:
            path_parts.append(f"Sub-clause {metadata['sub_clause_num']} – {metadata['sub_clause_title']}")
        
        return " → ".join(path_parts) if path_parts else "Document Content"

    def _create_chunk(self, index: int, text: str, metadata: dict) -> TextChunk:
        """Create a TextChunk with cleaned text and metadata."""
        clean_text = re.sub(r'\s+', ' ', text).strip()
        
        final_metadata = {
            **metadata,
            "component_type": self._determine_component_type(metadata)
        }
        
        return TextChunk(index=index, text=clean_text, metadata=final_metadata)

    def _determine_component_type(self, metadata: dict) -> str:
        """Determine the component type based on metadata."""
        chunk_type = metadata.get("chunk_type", "content")
        
        component_mapping = {
            "main_clause": "MainClause",
            "sub_clause": "SubClause", 
            "recital": "Recital",
            "schedule": "Schedule",
            "annexure": "Annexure"
        }
        
        return component_mapping.get(chunk_type, "Content")

    def _should_split_chunk(self, chunk: TextChunk) -> bool:
        """Check if a chunk should be split due to size."""
        tokens = self.tokenizer.encode(chunk.text)
        max_tokens = min(4096, self.llm.model_params.get("max_tokens", 3000) * 3)
        return len(tokens) > max_tokens

    def _should_split_chunk_content(self, content: str) -> bool:
        """Check if content should be split due to size."""
        tokens = self.tokenizer.encode(content)
        max_tokens = min(4096, self.llm.model_params.get("max_tokens", 3000) * 3)
        return len(tokens) > max_tokens

    def _split_large_content(self, content: str, metadata: dict, start_index: int) -> List[TextChunk]:
        """Split large content into smaller chunks with overlap."""
        tokens = self.tokenizer.encode(content)
        max_tokens = min(4096, self.llm.model_params.get("max_tokens", 3000) * 3)
        
        chunk_size = int(max_tokens * (1 - self.overlap_percentage))
        overlap_size = int(max_tokens * self.overlap_percentage)
        
        chunks = []
        for i in range(0, len(tokens), chunk_size):
            start_idx = max(0, i)
            end_idx = min(i + chunk_size + overlap_size, len(tokens))
            chunk_tokens = tokens[start_idx:end_idx]
            chunk_text = self.tokenizer.decode(chunk_tokens)
            
            chunk_metadata = metadata.copy()
            if i > 0:
                chunk_metadata["continued"] = True
                chunk_text = f"[Continued from {metadata.get('hierarchical_path', 'previous section')}]\n\n{chunk_text}"
            
            chunks.append(self._create_chunk(start_index + len(chunks), chunk_text, chunk_metadata))
        
        return chunks

# Ontology

In [4]:
# Source: https://github.com/jbarrasa/goingmeta/blob/main/session31/python/utils.py

from neo4j_graphrag.experimental.components.schema import (
    SchemaBuilder,
    SchemaConfig,
    SchemaEntity,
    SchemaProperty,
    SchemaRelation,
)
from rdflib.namespace import OWL, RDF, RDFS
from rdflib import Graph

def getLocalPart(uri):
    pos = uri.rfind("#")
    if pos < 0:
        pos = uri.rfind("/")
    if pos < 0:
        pos = uri.rindex(":")
    return uri[pos + 1 :]

def getPropertiesForClass(g, cat):
    props = []
    for dtp in g.subjects(RDFS.domain, cat):
        if (dtp, RDF.type, OWL.DatatypeProperty) in g:
            propName = getLocalPart(dtp)
            propDesc = next(g.objects(dtp, RDFS.comment), "")
            props.append(SchemaProperty(name=propName, type="STRING", description=propDesc))
    return props

def getSchemaFromOnto(g) -> SchemaConfig:
    schema_builder = SchemaBuilder()
    classes = {}
    entities = []
    rels = []
    triples = []

    for cat in g.subjects(RDF.type, OWL.Class):
        classes[cat] = None
        label = getLocalPart(cat)
        props = getPropertiesForClass(g, cat)
        entities.append(SchemaEntity(label=label, description=next(g.objects(cat, RDFS.comment), ""), properties=props))

    for cat in g.objects(None, RDFS.domain):
        if cat not in classes.keys():
            classes[cat] = None
            label = getLocalPart(cat)
            props = getPropertiesForClass(g, cat)
            entities.append(SchemaEntity(label=label, description=next(g.objects(cat, RDFS.comment), ""), properties=props))

    for cat in g.objects(None, RDFS.range):
        if not (cat.startswith("http://www.w3.org/2001/XMLSchema#") or cat in classes.keys()):
            classes[cat] = None
            label = getLocalPart(cat)
            props = getPropertiesForClass(g, cat)
            entities.append(SchemaEntity(label=label, description=next(g.objects(cat, RDFS.comment), ""), properties=props))

    for op in g.subjects(RDF.type, OWL.ObjectProperty):
        relname = getLocalPart(op)
        rels.append(SchemaRelation(label=relname, properties=[], description=next(g.objects(op, RDFS.comment), "")))

    for op in g.subjects(RDF.type, OWL.ObjectProperty):
        relname = getLocalPart(op)
        doms = [getLocalPart(dom) for dom in g.objects(op, RDFS.domain) if dom in classes.keys()]
        rans = [getLocalPart(ran) for ran in g.objects(op, RDFS.range) if ran in classes.keys()]
        for d in doms:
            for r in rans:
                triples.append((d, relname, r))

    return schema_builder.create_schema_model(entities=entities, relations=rels, potential_schema=triples)

# Knowledge Graph

In [None]:
import asyncio
import nest_asyncio
import neo4j
from typing import Dict
from pathlib import Path
import numpy as np
import faiss
from rdflib import Graph

from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.llm import OpenAILLM


URI = "neo4j://localhost:7687"
USERNAME = "neo4j"
PASSWORD = ""
AUTH = (USERNAME, PASSWORD)

def clear_database(driver):
    with driver.session() as session:
        session.run("MATCH (n) DETACH DELETE n")
        print("Database cleaned.")

def initialize_vector_index(driver):
    try:
        with driver.session() as session:
            session.run("""
            CALL db.index.vector.createNodeIndex(
                'legal_text_embeddings',
                'Paragraph',
                'embedding',
                1536,
                'cosine'
            )
            """)
            print("Vectorial index created.")
    except Exception as e:
        print(f"Unable to create index: {e}")

def fetch_chunks(tx, graph_id):
    """
    Fetch chunks with embeddings from a specific knowledge graph.

    Args:
        tx: Neo4j transaction
        graph_id: Identifier for the knowledge graph

    Returns:
        List of tuples containing (node_id, embedding)
    """
    query = """
    MATCH (c:Chunk)
    WHERE c.graph_id = $graph_id AND c.embedding IS NOT NULL
    RETURN elementId(c) AS id, c.embedding AS embedding, c.text AS text
    """
    result = tx.run(query, graph_id=graph_id)

    # Since embeddings are already lists in Neo4j, we can use them directly
    return [(r["id"], r["embedding"], r["text"]) for r in result]

def create_similarity_links(tx, matches):
    """
    Create SIMILAR relationships between chunks based on similarity scores.

    Args:
        tx: Neo4j transaction
        matches: Dictionary mapping source node IDs to list of (target_id, similarity_score) tuples
    """
    
    print("matches: ", matches.items())
    batch = []
    for src_id, targets in matches.items():
        for tgt_id, score in targets:
            if score >= 0.5:  # Similarity threshold
                batch.append({"src": src_id, "tgt": tgt_id, "score": float(score)})

    if batch:
        tx.run("""
            UNWIND $batch AS row
            MATCH (a:Chunk), (b:Chunk)
            WHERE elementId(a) = row.src AND elementId(b) = row.tgt
            MERGE (a)-[r:SIMILAR]->(b)
            SET r.score = row.score
        """, batch=batch)

    return len(batch)

# Run the pipeline for processing legal documents
async def define_and_run_pipeline(
    neo4j_driver: neo4j.Driver,
    llm: LLMInterface,
    embedder,
    neo4j_schema
) -> Dict:

    legal_path = Path("legal_corpus/")
    nest_asyncio.apply()

    structure_splitter = None
    for file_path in legal_path.glob("*.txt"):
        with open(file_path, "r", encoding="utf-8") as file:
            print("Processing file:", file_path.name)
            text = file.read()
            if file_path.name == "1Constitution.txt" or file_path.name == "1Constitution_small.txt":
                structure_splitter = ConstitutionSplitter(llm=llm, graph_id="Constitution")
            if file_path.name == "2Prospectus.txt" or file_path.name == "2Prospectus_small.txt":
                structure_splitter = ProspectusSplitter(llm=llm, graph_id="Prospectus")
            if file_path.name == "3Agreement.txt" or file_path.name == "3Agreement_small.txt":
                structure_splitter = ShareholdersAgreementSplitter(llm=llm, graph_id="Agreement")

            structure_pipeline = SimpleKGPipeline(
                llm=llm,
                driver=neo4j_driver,
                text_splitter=structure_splitter, 
                #text_splitter=FixedSizeSplitter(chunk_size=2500, chunk_overlap=10),
                #embedder=SentenceTransformerEmbeddings(model='all-MiniLM-L6-v2'),
                embedder=embedder,
                entities=list(neo4j_schema.entities.values()),
                relations=list(neo4j_schema.relations.values()),
                potential_schema=neo4j_schema.potential_schema,
                on_error="IGNORE",
                from_pdf=False
            )

            asyncio.run(structure_pipeline.run_async(text=text))

def merge_graphs(driver, graph_id_a, graph_id_b):
    stats = {"relationships_created": 0}

    with driver.session() as session:
        # Fetch chunks from both graphs
        print(f"Fetching chunks from graph {graph_id_a}")
        chunks_a = session.execute_read(fetch_chunks, graph_id_a)
        print(f"Found {len(chunks_a)} chunks with embeddings")

        print(f"Fetching chunks from graph {graph_id_b}")
        chunks_b = session.execute_read(fetch_chunks, graph_id_b)
        print(f"Found {len(chunks_b)} chunks with embeddings")

        if not chunks_a or not chunks_b:
            print("Warning: One or both graphs have no chunks with embeddings")
            #return stats

        # Extract IDs and embeddings
        ids_a, emb_a, texts_a = zip(*chunks_a) if chunks_a else ([], [], [])
        ids_b, emb_b, texts_b = zip(*chunks_b) if chunks_b else ([], [], [])

        # Convert to numpy arrays
        emb_a = np.array(emb_a).astype('float32')
        emb_b = np.array(emb_b).astype('float32')

        # Handle empty arrays
        if emb_a.size == 0 or emb_b.size == 0:
            print("Warning: No embeddings found in one or both graphs")
            #return stats

        # Normalize for cosine similarity
        faiss.normalize_L2(emb_a)
        faiss.normalize_L2(emb_b)

        # Build similarity index
        d = emb_a.shape[1]  # Embedding dimension
        print(f"Building FAISS index with {len(ids_a)} vectors of dimension {d}...")
        index = faiss.IndexFlatIP(d)  # Inner product = cosine similarity for normalized vectors
        index.add(emb_a)

        # Perform similarity search
        print(f"Finding top {3} similar chunks for each chunk in graph graph_id_b")
        k = min(3, len(ids_a))  # Make sure k isn't larger than available chunks
        scores, indices = index.search(emb_b, k)
        print(f"Found {len(indices)} similar chunks for each chunk in graph graph_id_b")
        print(f"Scores: {scores}")
        print(f"Indices: {indices}")

        # Prepare matches
        matches = {}
        for i in range(len(ids_b)):
            matches[ids_b[i]] = [(ids_a[indices[i][j]], scores[i][j])
                                for j in range(k)
                                if scores[i][j] >= 0.5]

        print(f"Prepared {len(matches)} matches")

        # Store links in Neo4j
        print("Creating similarity links in Neo4j...")
        stats["relationships_created"] = session.execute_write(
            create_similarity_links, matches
        )

        # Also create links in the reverse direction (B → A)
        # Build reverse index
        index_reverse = faiss.IndexFlatIP(d)
        index_reverse.add(emb_b)

        # Search A chunks in B
        scores_reverse, indices_reverse = index_reverse.search(emb_a, k)
        print(f"Found {len(indices_reverse)} similar chunks for each chunk in graph graph_id_a")
        print(f"Scores: {scores_reverse}")
        print(f"Indices: {indices_reverse}")

        # Prepare reverse matches
        matches_reverse = {}
        for i in range(len(ids_a)):
            matches_reverse[ids_a[i]] = [(ids_b[indices_reverse[i][j]], scores_reverse[i][j])
                                      for j in range(k)
                                      if scores_reverse[i][j] >= 0.5]

        print(f"Prepared {len(matches_reverse)} reverse matches")

        # Store reverse links
        print("Creating reverse similarity links...")
        stats["relationships_created"] += session.execute_write(
            create_similarity_links, matches_reverse
        )

        # Verify links in database
        count_query = """
        MATCH ()-[r:SIMILAR]->()
        RETURN count(r) as count
        """
        try:
            result = session.run(count_query).single()
            stats["total_similar_links"] = result["count"] if result else 0
        except Exception as e:
            print(f"Warning: Couldn't count total relationships: {e}")
            stats["total_similar_links"] = 0

        print(f"✅ Created {stats['relationships_created']} cross-graph similarity links")


async def main():
    g = Graph()
    
    # Get the schema from the ontology file
    neo4j_schema = getSchemaFromOnto(g.parse("ontology.ttl"))
    # Set up the llm
    llm = OpenAILLM(
        api_key="hide",
        model_name="gpt-4o-mini",
        #model_name="gpt-4.1",
        model_params={
            "max_tokens": 5000,
            "response_format": {"type": "json_object"},
            "temperature": 0,
        },
    )
    
    embedder = OpenAIEmbeddings(model="text-embedding-3-small", api_key="hide")
    
    with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
        print("Clearing database...")
        clear_database(driver)

        print("Initializing vector index...")
        initialize_vector_index(driver)

        print("Building knowledge graph...")
        res = await define_and_run_pipeline(driver, llm, embedder, neo4j_schema)

        # Merge graphs A => B => C
        print("Merging graphs...")
        merge_graphs(driver, "Constitution", "Prospectus")
        merge_graphs(driver, "Prospectus", "Agreement")

res = await main()

Clearing database...
Database cleaned.
Initializing vector index...
Unable to create index: {code: Neo.ClientError.Procedure.ProcedureCallFailed} {message: Failed to invoke procedure `db.index.vector.createNodeIndex`: Caused by: org.neo4j.kernel.api.exceptions.schema.EquivalentSchemaRuleAlreadyExistsException: An equivalent index already exists, 'Index( id=3, name='legal_text_embeddings', type='VECTOR', schema=(:Paragraph {embedding}), indexProvider='vector-2.0' )'.}
Building knowledge graph...
Processing file: 1Constitution.txt
Processing file: 2Prospectus.txt


LLM response has improper format for chunk_index=9
LLM response has improper format for chunk_index=78


Processing file: 3Agreement.txt
Merging graphs...
Fetching chunks from graph Constitution
Found 135 chunks with embeddings
Fetching chunks from graph Prospectus
Found 96 chunks with embeddings
Building FAISS index with 135 vectors of dimension 1536...
Finding top 3 similar chunks for each chunk in graph graph_id_b
Found 96 similar chunks for each chunk in graph graph_id_b
Scores: [[0.5367637  0.5284603  0.52012867]
 [0.6259637  0.5880318  0.5496872 ]
 [0.5779646  0.5323864  0.53067183]
 [0.5355469  0.5157231  0.51360005]
 [0.45598543 0.39387617 0.38689408]
 [0.6267467  0.5671647  0.55572885]
 [0.30426    0.27630615 0.25364658]
 [0.4731607  0.2943825  0.26604623]
 [0.26658404 0.2323254  0.22931841]
 [0.4000994  0.39882344 0.38583723]
 [0.54203016 0.5396124  0.5233353 ]
 [0.46757698 0.45210445 0.4381077 ]
 [0.44245687 0.44150355 0.4395835 ]
 [0.66028553 0.65227264 0.59230876]
 [0.45640653 0.4534918  0.44557342]
 [0.528533   0.46090922 0.44478583]
 [0.60506886 0.5140003  0.51134825]
 [0.4

# Taxonomy

In [None]:
import pandas as pd
from neo4j import GraphDatabase
import re
import logging
import random
import json
from datetime import datetime
import openai
from typing import Dict, List
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class LegalTaxonomyExtractor:
    def __init__(self, uri, username, password):
        self.driver = GraphDatabase.driver(uri, auth=(username, password))
        
    def close(self):
        self.driver.close()
        
    def run_query(self, query, params=None):
        with self.driver.session() as session:
            result = session.run(query, params or {})
            return list(result)
            
    def optimize_similarity_relationships(self):
        """Ensure similarity relationships have weight properties."""
        query = """
        MATCH (a)-[r:SIMILAR]->(b)
        WHERE r.score IS NOT NULL AND r.weight IS NULL
        SET r.weight = r.score
        RETURN count(r) as updated
        """
        result = self.run_query(query)
        logger.info(f"Optimized {result[0]['updated']} SIMILAR relationships")
        
        # For relationships that might not have a score
        query = """
        MATCH (a)-[r:SIMILAR]->(b)
        WHERE r.score IS NULL AND r.weight IS NULL
        SET r.weight = 0.5, r.score = 0.5
        RETURN count(r) as updated
        """
        result = self.run_query(query)
        logger.info(f"Set default weights for {result[0]['updated']} SIMILAR relationships")
        
    def create_graph_projection(self, graph_name="legal-taxonomy-graph"):
        """Create an in-memory graph projection for community detection."""
        if graph_name == "legal-taxonomy-graph":
            timestamp = int(datetime.now().timestamp() * 1000)
            graph_name = f"legal-taxonomy-graph-{timestamp}"
        
        # First, check if a graph with this name already exists
        check_query = """
        CALL gds.graph.exists($graph_name) YIELD exists
        RETURN exists
        """
        result = self.run_query(check_query, {"graph_name": graph_name})
        
        if result and result[0]["exists"]:
            # If graph exists, drop it first
            drop_query = """
            CALL gds.graph.drop($graph_name)
            YIELD graphName
            RETURN graphName
            """
            self.run_query(drop_query, {"graph_name": graph_name})
            logger.info(f"Dropped existing graph projection: {graph_name}")
        
        # Create new graph projection
        query = """
        CALL gds.graph.project(
            $graph_name,
            '*',
            {
                relType: {
                    type: '*',
                    orientation: 'UNDIRECTED',
                    properties: {}
                }
            }
        )
        YIELD graphName, nodeCount, relationshipCount
        RETURN graphName, nodeCount, relationshipCount
        """
        result = self.run_query(query, {"graph_name": graph_name})
        
        if result:
            logger.info(f"Created graph projection: {result[0]['graphName']} with {result[0]['nodeCount']} nodes and {result[0]['relationshipCount']} relationships")
            return graph_name
        else:
            logger.error("Failed to create graph projection")
            return None
        
    def run_community_detection(self, graph_name="legal-taxonomy-graph", limit=42, community_node_limit=10):
        """Run Louvain community detection algorithm on the projected graph."""
        query = """
        CALL gds.louvain.stream($graph_name, {
            relationshipWeightProperty: null,
            includeIntermediateCommunities: true,
            seedProperty: ''
        })
        YIELD nodeId, communityId AS community, intermediateCommunityIds AS communities
        WITH gds.util.asNode(nodeId) AS node, community, communities
        WITH community, communities, collect(node) AS nodes
        WITH community, communities, nodes, size(nodes) AS size
        ORDER BY size DESC
        LIMIT toInteger($limit)
        UNWIND nodes[0..$community_node_limit] AS node
        RETURN 
            node.id AS nodeId,
            node.text AS text,
            community,
            communities AS hierarchicalCommunities,
            size
        """
        
        result = self.run_query(query, {
            "graph_name": graph_name,
            "limit": limit,
            "community_node_limit": community_node_limit
        })
        
        logger.info(f"Detected communities with {len(result)} nodes")
        
        community_data = []
        for record in result:
            community_data.append({
                'nodeId': record['nodeId'],
                'text': record['text'] if record['text'] else '',
                'community': record['community'],
                'hierarchicalCommunities': record['hierarchicalCommunities'],
                'size': record['size']
            })
        
        return pd.DataFrame(community_data)
    
    def generate_category_name(self, client, texts, community_id):
        # Use a sample of texts to avoid token limits
        sample_size = min(5, len(texts))
        sample_texts = random.sample(texts, sample_size) if len(texts) > sample_size else texts
                        
        # Prepare prompt for the LLM
        prompt = f"""Below are {sample_size} legal text snippets that belong to the same category. 
        Please generate a concise, specific legal category name (2-5 words) that accurately captures 
        what these texts have in common:

        {'-' * 40}
        {'\n'.join([f"{i+1}. {text[:300]}..." if len(text) > 300 else f"{i+1}. {text}" for i, text in enumerate(sample_texts)])}
        {'-' * 40}
        
        Respond with ONLY the category name, nothing else."""
        
        # Call the LLM API
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "You are a legal taxonomy expert. Generate concise, specific category names for groups of legal texts."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.3,
            max_tokens=20
        )
        
        # Extract and clean the category name
        category_name = response.choices[0].message.content.strip()
        # Remove quotes if present
        category_name = category_name.strip('"\'')
            
        logger.info(f"Generated category name for community {community_id}: '{category_name}'")
        return category_name
    
    
    def create_taxonomy_categories(self, community_df, min_community_size=3):
        """Create taxonomy category nodes from communities using LLM for labeling."""
        
        communities = community_df.groupby('community')
        categories = []
        try:
            openai_api_key = "hide"
            client = openai.OpenAI(api_key=openai_api_key)
            
            # Process each community
            for community_id, group in communities:
                if len(group) < min_community_size:
                    continue
                    
                community_texts = group['text'].tolist()
                community_texts = [t for t in community_texts if t and isinstance(t, str)]
                
                if not community_texts:
                    continue
                
                try:
                    # Generate category name using LLM
                    category_name = self.generate_category_name(client, community_texts, community_id)
                except Exception as e:
                    logger.error(f"Error generating category name with LLM: {e}")
                    # Fallback to keyword extraction
                    #keywords = self.extract_keywords(community_texts)
                    category_name = "not found"
                
                categories.append({
                    'id': f'category-{community_id}',
                    'name': category_name,
                    'size': len(group)
                })
            
        except Exception as e:
            logger.error(f"Error setting up LLM: {e}")
        
        # Create category nodes in neo4j
        if categories:
            params = {"categories": categories}
            query = """
            UNWIND $categories AS category
            MERGE (t:TaxonomyCategory {id: category.id})
            SET t.name = category.name, 
                t.size = category.size,
                t.keywords = split(category.name, ' ')
            RETURN count(t) as created
            """
            result = self.run_query(query, params)
            logger.info(f"Created {result[0]['created']} taxonomy category nodes with LLM-generated names")
        
        return categories
    
    def link_nodes_to_categories(self, community_df):
        """Link nodes to their taxonomy categories."""
        
        # Prepare relationships data
        relationships = []
        for _, row in community_df.iterrows():
            relationships.append({
                'nodeId': row['nodeId'],
                'categoryId': f"category-{row['community']}"
            })
        
        # Create relationships : link nodes to categories
        if relationships:
            params = {"relationships": relationships}
            query = """
            UNWIND $relationships AS rel
            MATCH (d {id: rel.nodeId})
            MATCH (t:TaxonomyCategory {id: rel.categoryId})
            MERGE (d)-[:BELONGS_TO]->(t)
            RETURN count(*) as created
            """
            result = self.run_query(query, params)
            logger.info(f"Created {result[0]['created']} BELONGS_TO relationships")
    
    def create_constraints_and_indexes(self):
        """Create constraints and indexes."""
        # Create unique constraint on TaxonomyCategory.id if it doesn't exist
        try:
            query = """
            CREATE CONSTRAINT taxonomy_category_id IF NOT EXISTS
            FOR (t:TaxonomyCategory) REQUIRE t.id IS UNIQUE
            """
            self.run_query(query)
            logger.info("Created constraint on TaxonomyCategory.id")
        except Exception as e:
            logger.warning(f"Could not create constraint: {e}")
        
        # Create index on TaxonomyCategory.name for text search
        try:
            query = """
            CREATE INDEX taxonomy_category_name IF NOT EXISTS
            FOR (t:TaxonomyCategory) ON (t.name)
            """
            self.run_query(query)
            logger.info("Created index on TaxonomyCategory.name")
        except Exception as e:
            logger.warning(f"Could not create index: {e}")
    
    def extract_taxonomy(self):
        logger.info("Starting taxonomy extraction process")
        
        self.optimize_similarity_relationships()
        
        graph_name = self.create_graph_projection()
        if not graph_name:
            return False
        
        community_df = self.run_community_detection(graph_name)
        
        categories = self.create_taxonomy_categories(community_df)
        
        self.link_nodes_to_categories(community_df)
        
        self.create_constraints_and_indexes()
        
        logger.info("Taxonomy extraction completed")
        return True
            
    def print_hierarchical_taxonomy(self):
        """Print all taxonomy categories in a hierarchical bullet point format,
        ensuring all categories are displayed."""
        
        # Get all taxonomy categories
        query = """
            MATCH (t:TaxonomyCategory)
            RETURN t.id as id, t.name as name, t.size as size
            ORDER BY t.size DESC
        """
        all_categories = self.run_query(query)
        
        # Track categories already displayed
        displayed_categories = set()
        
        print("Generated Hierarchical Taxonomy :")
        
        # First, find all root categories (those that have no parents)
        query = """
            MATCH (t:TaxonomyCategory)
            WHERE NOT (t)-[:SUBCATEGORY_OF]->()
            RETURN t.id as id, t.name as name, t.size as size
            ORDER BY t.size DESC
        """
        root_categories = self.run_query(query)
        
        # Process each root category and its hierarchy
        for r in root_categories:
            category_id = r['id']
            category_name = r['name']
            size = r['size']
            
            # Print the root category
            print(f"* {category_name} ({size} nodes)")
            displayed_categories.add(category_id)
            
            # Recursively print subcategories
            self._print_subcategories(category_id, 1, displayed_categories)

    def _print_subcategories(self, category_id, depth, visited_nodes):
        """Recursively print subcategories of a given category."""
        # Get all categories related to this one
        query = """
            MATCH (child:TaxonomyCategory)-[:SUBCATEGORY_OF]->(parent:TaxonomyCategory {id: $category_id})
            RETURN child.id as id, child.name as name, child.size as size
            ORDER BY child.size DESC
        """
        result = self.run_query(query, {"category_id": category_id})
        
        if not result:
            # No subcategories found for this category
            return
        
        # Print each related category with proper indentation
        for r in result:
            sub_id = r['id']
            sub_name = r['name']
            size = r['size']
            
            # Create indentation
            indent = "  " * depth
            
            # Check for cycles
            if sub_id in visited_nodes:
                #print(f"{indent}* {sub_name} ({size} nodes) [CYCLE DETECTED]")
                continue
            
            # Print subcategory
            print(f"{indent}* {sub_name} ({size} nodes)")
            visited_nodes.add(sub_id)
            
            # Recursively print its subcategories
            self._print_subcategories(sub_id, depth + 1, visited_nodes)
        
    def generate_hierarchical_taxonomy_with_llm(self, client=None):
        """Generate a proper hierarchical taxonomy using LLM to organize categories."""
        
        if client is None:
            try:
                openai_api_key = "hide"
                client = openai.OpenAI(api_key=openai_api_key)
            except Exception as e:
                logger.error(f"Error setting up OpenAI client: {e}")
                return False
        
        # Get all existing categories
        query = """
        MATCH (t:TaxonomyCategory)
        RETURN t.id as id, t.name as name, t.size as size
        ORDER BY t.size DESC
        """
        result = self.run_query(query)
        
        if not result:
            logger.error("No taxonomy categories found")
            return False
        
        # Prepare category names for LLM
        category_names = [r['name'] for r in result]
        category_info = {r['name']: {'id': r['id'], 'size': r['size']} for r in result}
        
        # Generate hierarchy structure using LLM
        hierarchy_structure = self._generate_taxonomy_structure(client, category_names)
        
        if not hierarchy_structure:
            logger.error("Failed to generate hierarchy structure")
            return False
        
        # Create the hierarchy in Neo4j
        success = self._create_llm_generated_hierarchy(hierarchy_structure, category_info)
        
        return success
    
    def _generate_taxonomy_structure(self, client, category_names: List[str]) -> Dict:
        """Use LLM to generate a hierarchical taxonomy structure."""
        
        categories_text = '\n'.join([f"* {name}" for name in category_names])
        
        prompt = f"""Below are legal category names that should be grouped under a tree-like structure (taxonomy):

        {categories_text}

        Please generate a hierarchical taxonomy with several levels. Create logical parent categories that group related subcategories together. 

        Requirements:
        1. Create a root category that encompasses all legal categories
        2. Create 4-6 main parent categories under the root
        3. Group the existing categories under appropriate parent categories
        4. You can create additional intermediate levels if needed
        5. Each existing category should appear exactly once in the hierarchy
        6. Use clear, professional legal terminology for parent category names

        Return the result as a JSON structure with this format:
        {{
        "root": {{
            "name": "Root Category Name",
            "children": {{
            "Parent Category 1": {{
                "name": "Parent Category 1 Full Name",
                "children": {{
                "Subcategory 1": {{"name": "Existing Category Name"}},
                "Subcategory 2": {{"name": "Another Existing Category Name"}}
                }}
            }},
            "Parent Category 2": {{
                "name": "Parent Category 2 Full Name", 
                "children": {{
                "Subcategory 3": {{"name": "Yet Another Existing Category Name"}}
                }}
            }}
            }}
        }}
        }}

        Make sure every existing category from the input list appears exactly once in the JSON structure."""

        try:
            response = client.chat.completions.create(
                model="gpt-4",
                messages=[
                    {"role": "system", "content": "You are a legal taxonomy expert. Create well-structured hierarchical taxonomies for legal document categorization. Always return valid JSON."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.3,
                max_tokens=2000
            )
            
            response_text = response.choices[0].message.content.strip()
            
            json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
            if json_match:
                json_text = json_match.group()
                hierarchy_structure = json.loads(json_text)
                logger.info("Successfully generated hierarchy structure with LLM")
                return hierarchy_structure
            else:
                logger.error("No JSON structure found in LLM response")
                return None
                
        except Exception as e:
            logger.error(f"Error generating taxonomy structure with LLM: {e}")
            return None
    
    def _create_llm_generated_hierarchy(self, hierarchy_structure: Dict, category_info: Dict) -> bool:
        """Create the LLM-generated hierarchy in Neo4j."""
        
        try:
            new_categories = []
            category_mappings = []
            
            def process_hierarchy_level(node_data, parent_id=None, level=0):
                """Recursively process hierarchy levels."""
                
                if isinstance(node_data, dict) and 'name' in node_data:
                    category_name = node_data['name']
                    
                    # Generate ID for this category
                    if category_name in category_info:
                        # Existing leaf category
                        category_id = category_info[category_name]['id']
                        size = category_info[category_name]['size']
                    else:
                        # New parent category
                        category_id = f"category-parent-{len(new_categories)}"
                        size = 0 
                        new_categories.append({
                            'id': category_id,
                            'name': category_name,
                            'size': size,
                            'is_parent': True
                        })
                    
                    # Record parent-child relationship
                    if parent_id:
                        category_mappings.append({
                            'child_id': category_id,
                            'parent_id': parent_id
                        })
                    
                    # Process children if they exist
                    if 'children' in node_data and node_data['children']:
                        for child_key, child_data in node_data['children'].items():
                            process_hierarchy_level(child_data, category_id, level + 1)
                    
                    return category_id
                
                return None
            
            # Process the hierarchy starting from root
            if 'root' in hierarchy_structure:
                root_data = hierarchy_structure['root']
                process_hierarchy_level(root_data)
            else:
                logger.error("No root found in hierarchy structure")
                return False
            
            # Create new parent categories in Neo4j
            if new_categories:
                params = {"categories": new_categories}
                query = """
                UNWIND $categories AS category
                MERGE (t:TaxonomyCategory {id: category.id})
                SET t.name = category.name, 
                    t.size = category.size,
                    t.keywords = split(category.name, ' '),
                    t.is_parent = category.is_parent
                RETURN count(t) as created
                """
                result = self.run_query(query, params)
                logger.info(f"Created {result[0]['created']} new parent taxonomy categories")
            
            # Create hierarchy relationships in Neo4j
            if category_mappings:
                params = {"mappings": category_mappings}
                query = """
                UNWIND $mappings AS mapping
                MATCH (child:TaxonomyCategory {id: mapping.child_id})
                MATCH (parent:TaxonomyCategory {id: mapping.parent_id})
                MERGE (child)-[:SUBCATEGORY_OF]->(parent)
                RETURN count(*) as created
                """
                result = self.run_query(query, params)
                logger.info(f"Created {result[0]['created']} SUBCATEGORY_OF relationships")
            
            # Update parent category sizes based on their children
            self._update_parent_category_sizes()
            
            logger.info("Successfully created LLM-generated hierarchical taxonomy")
            return True
            
        except Exception as e:
            logger.error(f"Error creating LLM-generated hierarchy: {e}")
            return False
    
    def _update_parent_category_sizes(self):
        """Update parent category sizes based on the sum of their children's sizes."""
        query = """
        MATCH (parent:TaxonomyCategory)<-[:SUBCATEGORY_OF*]-(leaf:TaxonomyCategory)
        WHERE NOT ()-[:SUBCATEGORY_OF]->(leaf)
        WITH parent, sum(leaf.size) as total_size
        SET parent.size = total_size
        RETURN parent.name as name, total_size
        """
        result = self.run_query(query)
        logger.info(f"Updated sizes for {len(result)} parent categories")
    
    
    def build_hierarchical_taxonomy(self):
        """Complete process to rebuild taxonomy with LLM-generated hierarchy."""
        logger.info("Starting LLM-based taxonomy hierarchy rebuild")
        
        success = self.generate_hierarchical_taxonomy_with_llm()
        
        if success:
            logger.info("Successfully rebuilt taxonomy with LLM-generated hierarchy")
            self.print_hierarchical_taxonomy()
        else:
            logger.error("Failed to rebuild taxonomy with LLM hierarchy")
        
        return success
        

def main():
    NEO4J_URI = "neo4j://localhost:7687"
    NEO4J_USER = "neo4j"
    NEO4J_PASSWORD = ""
    
    try:
        extractor = LegalTaxonomyExtractor(NEO4J_URI, NEO4J_USER, NEO4J_PASSWORD)
        
        extractor.extract_taxonomy()

        extractor.print_hierarchical_taxonomy()
        
        extractor.build_hierarchical_taxonomy()
        
        
    except Exception as e:
        logger.error(f"Error: {e}")
    finally:
        if 'extractor' in locals():
            extractor.close()


if __name__ == "__main__":
    main()

2025-06-10 22:29:41,507 - INFO - Starting taxonomy extraction process
2025-06-10 22:29:45,708 - INFO - Optimized 498 SIMILAR relationships
2025-06-10 22:29:45,826 - INFO - Set default weights for 0 SIMILAR relationships
2025-06-10 22:29:46,223 - INFO - Created graph projection: legal-taxonomy-graph-1749587385827 with 2301 nodes and 8196 relationships
2025-06-10 22:29:47,474 - INFO - Detected communities with 380 nodes
2025-06-10 22:29:49,093 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-06-10 22:29:49,101 - INFO - Generated category name for community 325: 'Investor Share Redemption Laws'
2025-06-10 22:29:50,254 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-06-10 22:29:50,256 - INFO - Generated category name for community 415: 'Share Conversion Regulations'
2025-06-10 22:29:51,015 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-06-10 22:29:51,018

Generated Hierarchical Taxonomy :
* Investment Prospectus Documentation (30 nodes)


2025-06-10 22:30:13,524 - INFO - Starting LLM-based taxonomy hierarchy rebuild


* Investment Risk Disclosures (30 nodes)
* Investor Share Redemption Laws (28 nodes)
* Investment Subscription Procedures (26 nodes)
* Investment and Shareholding Regulations (24 nodes)
* Equity Issuance Legal Texts (20 nodes)
* Investment Fund Legal Documentation (19 nodes)
* Company Asset Management Laws (18 nodes)
* Corporate Acquisition Agreements (18 nodes)
* Equity Securities Regulations (17 nodes)
* Corporate Finance and Taxation Law (16 nodes)
* Corporate Meeting Regulations (16 nodes)
* Bank Resolution Risks Law (10 nodes)
* ESG Integration Policies (10 nodes)
* Corporate Structure Legislation (10 nodes)
* Contractual Terms and Conditions (10 nodes)
* Default Management Procedures (10 nodes)
* Share Conversion Regulations (9 nodes)
* Shareholder Rights and Transactions (9 nodes)
* Investment Partnership Agreements (8 nodes)
* Corporate Insurance Policies (8 nodes)
* Securities Conversion Regulations (8 nodes)
* Financial Law Documentation (8 nodes)
* Share Transfer Agreements 

2025-06-10 22:30:40,109 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-06-10 22:30:40,113 - INFO - Successfully generated hierarchy structure with LLM
2025-06-10 22:30:40,219 - INFO - Created 7 new parent taxonomy categories
2025-06-10 22:30:40,335 - INFO - Created 29 SUBCATEGORY_OF relationships
2025-06-10 22:30:40,561 - INFO - Updated sizes for 7 parent categories
2025-06-10 22:30:40,561 - INFO - Successfully created LLM-generated hierarchical taxonomy
2025-06-10 22:30:40,562 - INFO - Successfully rebuilt taxonomy with LLM-generated hierarchy


Generated Hierarchical Taxonomy :
* Legal Document Categories (342 nodes)
  * Investment Law and Regulations (137 nodes)
    * Investment Prospectus Documentation (30 nodes)
    * Investment Risk Disclosures (30 nodes)
    * Investment Subscription Procedures (26 nodes)
    * Investment and Shareholding Regulations (24 nodes)
    * Investment Fund Legal Documentation (19 nodes)
    * Investment Partnership Agreements (8 nodes)
  * Corporate Law and Governance (86 nodes)
    * Company Asset Management Laws (18 nodes)
    * Corporate Acquisition Agreements (18 nodes)
    * Corporate Finance and Taxation Law (16 nodes)
    * Corporate Meeting Regulations (16 nodes)
    * Corporate Structure Legislation (10 nodes)
    * Corporate Insurance Policies (8 nodes)
  * Securities and Equity Law (71 nodes)
    * Equity Issuance Legal Texts (20 nodes)
    * Equity Securities Regulations (17 nodes)
    * Share Conversion Regulations (9 nodes)
    * Shareholder Rights and Transactions (9 nodes)
    *

# GraphRAG

In [None]:
from neo4j_graphrag.indexes import create_vector_index, upsert_vector
import neo4j

VECTOR_STORE_NAME = "legal_corpus"
DIMENSION = 1536

def drop_vector_index(driver, index_name):
    """
    Drop a vector index from Neo4j database using various methods
    depending on Neo4j version.
    
    Args:
        driver: Neo4j driver instance
        index_name: Name of the vector index to drop
    """
    try:
        with driver.session() as session:
            try:
                session.run(f"DROP INDEX {index_name}")
                print(f"Vector index '{index_name}' dropped successfully using standard DROP INDEX.")
                return True
            except Exception as e1:
                print(f"Standard DROP INDEX failed: {e1}")
            
    except Exception as e:
        print(f"Error dropping vector index '{index_name}': {e}")
        return False

def get_chunk_embeddings(driver) -> list:
    """
    Retrieve all Chunk nodes along with their embedding properties.
    Returns a list of tuples (node_id, embedding).
    """
    with driver.session() as session:
        result = session.run(
            """
            MATCH (c:Chunk)
            WHERE c.embedding IS NOT NULL
            RETURN id(c) AS node_id, c.embedding AS embedding
            """
        )

        embeddings_data = [
            (record["node_id"], record["embedding"]) for record in result
        ]

        return embeddings_data

def main():
    URI = "neo4j://localhost:7687"
    USERNAME = "neo4j"
    PASSWORD = "hide"
    AUTH = (USERNAME, PASSWORD)

    driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)

    drop_vector_index(driver, VECTOR_STORE_NAME)

    create_vector_index(
        driver,
        name=VECTOR_STORE_NAME,
        label="Legal_corpus",
        embedding_property="vectorProperty",
        dimensions=DIMENSION,
        similarity_fn="cosine",
    )

    try:
        # Fetch all chunk nodes along with their embeddings from the knowledge graph.
        embeddings_data = get_chunk_embeddings(driver)

        if not embeddings_data:
            print("No Chunk nodes with embeddings found.")
            return

        print(f"Adding {len(embeddings_data)} embeddings to vector store '{VECTOR_STORE_NAME}'...")

        successful_ops_counter = 0

        # Process and upsert each node's embedding into the vector store.
        # This transfers embeddings data from standard node properties into a dedicated Neo4j vector index
        for node_id, embedding in embeddings_data:
            upsert_vector(
                driver,
                node_id=node_id,
                embedding_property="embedding",
                vector=embedding,
            )
            successful_ops_counter += 1

        print(f"Successfully added {successful_ops_counter} embeddings to the vector store.")

    except Exception as e:
        print(f"Error: {e}")
        
    finally:
        driver.close()

main()

2025-06-10 22:31:35,971 - INFO - Creating vector index named 'legal_corpus'


Standard DROP INDEX failed: {code: Neo.DatabaseError.Schema.IndexDropFailed} {message: Unable to drop index called `legal_corpus`. There is no such index.}


  upsert_vector(


Adding 338 embeddings to vector store 'legal_corpus'...
Successfully added 338 embeddings to the vector store.


In [None]:
from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.generation import GraphRAG
from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.retrievers import VectorRetriever

URI = "neo4j://localhost:7687"
USERNAME = "neo4j"
PASSWORD = "hide"
AUTH = (USERNAME, PASSWORD)
INDEX_NAME = "legal_corpus"
DATABASE = "neo4j"

driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)

# Create an Embedder object to convert the user's question into a vector. The same embedding model used for knowledge graph creation is used here.
embedder = OpenAIEmbeddings(model="text-embedding-3-small", api_key="hide")

# Initialize the retriever
retriever = VectorRetriever(driver, INDEX_NAME, embedder)

# Set up the LLM for generating answers based on the retrieved knowledge graph leal texts.
llm = OpenAILLM(
    api_key="hide",
    model_name="gpt-4o-mini",
    model_params={
        "temperature": 0,
    },
)

# Initialize the RAG pipeline
rag = GraphRAG(retriever=retriever, llm=llm)


In [23]:
# Query the graph
#query_text = "Tell me something about risks"
#query_text = "What is the legal form of FinaStream Fund?"
#query_text = "What assets can the company invest in?"
#query_text = "Can you give the different types of shares?"
query_text = "Give all you know about the Investment Risk Disclosures?"
#query_text = "How are the Tolaxis Single Asset Funds legally organised?"
#query_text = "Provide the paragraph that explains under what conditions an investor’s shares can be forcibly redeemed by the company."
#query_text = "Provide and show the paragraph that defines the different share classes and their specific rights. Dont forget to give the provision number."
#query_text = "What can you say about the fith article of the Shareholders from the FinaStream Fund in the constitution? Is it related to anything in the offering document?"
#query_text = "Can you say about the FinaStream Fund? Give many details related in documents as possible and give the answer only from those"
response = rag.search(query_text=query_text, retriever_config={"top_k": 5})
print(response.answer)

2025-06-10 22:33:50,486 - INFO - HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
2025-06-10 22:33:58,094 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


Investment Risk Disclosures are essential documents or statements provided by investment firms to inform investors about the potential risks associated with various investment products. These disclosures aim to ensure that investors are aware of the risks they may face, allowing them to make informed decisions. Key aspects of Investment Risk Disclosures include:

1. **Types of Risks**: They typically outline various types of risks, such as market risk, credit risk, liquidity risk, operational risk, and specific risks related to particular investment products (e.g., stocks, bonds, mutual funds).

2. **Regulatory Requirements**: Many jurisdictions require investment firms to provide risk disclosures as part of their compliance with financial regulations. This is to protect investors and promote transparency in the financial markets.

3. **Clear Language**: Disclosures should be written in clear and understandable language, avoiding jargon that may confuse investors. The goal is to ensure