In [None]:
from typing import List
from collections import defaultdict
from copy import deepcopy
import numpy as np
from langchain_core.documents import Document
from nltk.tokenize import sent_tokenize
import tiktoken
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# -----------------------
# CONFIG
# -----------------------
MODEL_NAME = "all-MiniLM-L6-v2"
MAX_TOKENS = 400
OVERLAP_SENTENCES = 1
SIMILARITY_THRESHOLD = 0.75

CLAUSE_RE = re.compile(
    r"(Section|Clause|Article)\s+\d+(\.\d+)*(\([a-z]\))*",
    re.I
)

enc = tiktoken.get_encoding("cl100k_base")
tok_len = lambda t: len(enc.encode(t))


# -----------------------
# HELPERS
# -----------------------
def extract_clauses(text):
    return ["".join(m) for m in CLAUSE_RE.findall(text)]

def find_offset(full, chunk, cursor):
    pos = full.find(chunk, cursor)
    return pos if pos != -1 else cursor

# -----------------------
# SEMANTIC CHUNKER
# -----------------------
def semantic_chunking(documents: List[Document],
                      max_tokens: int = MAX_TOKENS,
                      similarity_threshold: float = SIMILARITY_THRESHOLD
                     ) -> List[Document]:

    model = SentenceTransformer(MODEL_NAME)
    all_chunks = []

    for doc_idx, doc in enumerate(documents):

        text = doc.page_content
        sentences = sent_tokenize(text)

        if len(sentences) < 3:
            all_chunks.append(doc)
            continue

        # encode all sentences
        embeddings = model.encode(sentences, normalize_embeddings=True)

        # compute pairwise cosine similarity
        sim_matrix = cosine_similarity(embeddings)

        # group sentences by semantic similarity
        # simple greedy grouping
        visited = set()
        groups = []

        for i, sent in enumerate(sentences):
            if i in visited:
                continue

            group = [i]
            visited.add(i)

            for j in range(i + 1, len(sentences)):
                if j in visited:
                    continue
                if sim_matrix[i, j] >= similarity_threshold:
                    group.append(j)
                    visited.add(j)

            groups.append(sorted(group))

        # --- create token-bounded chunks ---
        cursor = 0
        chunk_idx = 0

        for group in groups:
            buf = []
            buf_tokens = 0

            for i in group:
                sent = sentences[i]
                t = tok_len(sent)

                if buf_tokens + t > max_tokens and buf:
                    # flush chunk
                    chunk_text = " ".join(buf)
                    start = find_offset(text, chunk_text, cursor)
                    end = start + len(chunk_text)

                    meta = deepcopy(doc.metadata or {})
                    meta.update({
                        "chunk_id": f"{doc_idx}_{chunk_idx}",
                        "start_char": start,
                        "end_char": end,
                        "token_count": tok_len(chunk_text),
                        "clauses": extract_clauses(chunk_text)
                    })

                    all_chunks.append(Document(
                        page_content=chunk_text,
                        metadata=meta
                    ))

                    cursor = end
                    chunk_idx += 1
                    buf = buf[-OVERLAP_SENTENCES:]
                    buf_tokens = tok_len(" ".join(buf))

                buf.append(sent)
                buf_tokens += t

            # flush remainder
            if buf:
                chunk_text = " ".join(buf)
                start = find_offset(text, chunk_text, cursor)
                end = start + len(chunk_text)

                meta = deepcopy(doc.metadata or {})
                meta.update({
                    "chunk_id": f"{doc_idx}_{chunk_idx}",
                    "start_char": start,
                    "end_char": end,
                    "token_count": tok_len(chunk_text),
                    "clauses": extract_clauses(chunk_text)
                })

                all_chunks.append(Document(
                    page_content=chunk_text,
                    metadata=meta
                ))

                cursor = end
                chunk_idx += 1

    return all_chunks
