In [1]:
import re
import tiktoken
from typing import List, Dict, Tuple
from copy import deepcopy
from langchain_core.documents import Document

In [2]:
# -------------------------
# CONFIG
# -------------------------

MAX_TOKENS = 500
OVERLAP_TOKENS = 80

SECTION_PATTERNS = [
    r"^SECTION\s+\d+",
    r"^ARTICLE\s+\d+",
    r"^\d+\.",
    r"^[A-Z][A-Z\s]{5,}$"
]

CLAUSE_RE = re.compile(
    r"(Section|Clause|Article)\s+\d+(\.\d+)*(\([a-z]\))*",
    re.I
)

enc = tiktoken.get_encoding("cl100k_base")

In [3]:
# -------------------------
# HELPERS
# -------------------------

def tok_len(text):
    return len(enc.encode(text))


def split_sentences(text):
    return re.split(r'(?<=[.!?])\s+', text)


def detect_sections(text):
    sec_re = re.compile("|".join(SECTION_PATTERNS))
    current = "INTRO"
    sections = []

    for line in text.split("\n"):
        if sec_re.match(line.strip()):
            current = line.strip()

        sections.append((line, current))

    return sections


def extract_clauses(text):
    return ["".join(m) for m in CLAUSE_RE.findall(text)]

In [4]:
# -------------------------
# CORE SPLITTER
# -------------------------

def context_aware_split(doc: Document) -> Tuple[List[Document], Dict]:

    text = doc.page_content
    parent_meta = doc.metadata or {}

    sections = detect_sections(text)

    # rebuild grouped by section
    section_blocks: Dict[str, List[str]] = {}

    for line, sec in sections:
        section_blocks.setdefault(sec, []).append(line)

    chunks = []
    reverse_map = {}
    cursor = 0
    chunk_idx = 0

    for sec_title, lines in section_blocks.items():

        section_text = "\n".join(lines)
        paragraphs = section_text.split("\n\n")

        for para in paragraphs:

            sentences = split_sentences(para)

            buf = []
            buf_tokens = 0

            for sent in sentences:

                t = tok_len(sent)

                if buf_tokens + t > MAX_TOKENS and buf:

                    chunk_text = " ".join(buf)
                    start = text.find(chunk_text, cursor)
                    if start == -1:
                        start = cursor
                    end = start + len(chunk_text)

                    meta = deepcopy(parent_meta)
                    meta.update({
                        "chunk_id": f"{chunk_idx}",
                        "section_title": sec_title,
                        "start_char": start,
                        "end_char": end,
                        "token_count": tok_len(chunk_text),
                        "clauses": extract_clauses(chunk_text)
                    })

                    chunks.append(Document(
                        page_content=chunk_text,
                        metadata=meta
                    ))

                    reverse_map[str(chunk_idx)] = meta
                    chunk_idx += 1
                    cursor = end

                    # overlap window
                    overlap_buf = buf[-2:]  # last sentences
                    buf = overlap_buf.copy()
                    buf_tokens = tok_len(" ".join(buf))

                buf.append(sent)
                buf_tokens += t

            # flush remainder
            if buf:
                chunk_text = " ".join(buf)
                start = text.find(chunk_text, cursor)
                if start == -1:
                    start = cursor
                end = start + len(chunk_text)

                meta = deepcopy(parent_meta)
                meta.update({
                    "chunk_id": f"{chunk_idx}",
                    "section_title": sec_title,
                    "start_char": start,
                    "end_char": end,
                    "token_count": tok_len(chunk_text),
                    "clauses": extract_clauses(chunk_text)
                })

                chunks.append(Document(
                    page_content=chunk_text,
                    metadata=meta
                ))

                reverse_map[str(chunk_idx)] = meta
                chunk_idx += 1
                cursor = end

    stats = {
        "total_chunks": len(chunks),
        "sections": len(section_blocks)
    }

    return chunks, reverse_map, stats


In [None]:
# chunks, reverse_map, stats = context_aware_split(document)
# vector_db.add_documents(chunks)
