In [3]:
from dataclasses import dataclass
from typing import Dict, List, Optional
import re

In [4]:
@dataclass
class CitationSpan:
    chunk_id: str
    source: Optional[str]
    page: Optional[int]
    start_char: int
    end_char: int
    section: Optional[str]
    clause: Optional[str]

In [None]:
### Section Detector - Hierarchical section detection

SECTION_PATTERNS = [
    r"^SECTION\s+\d+",
    r"^ARTICLE\s+\d+",
    r"^\d+\.",
    r"^[A-Z][A-Z\s]{4,}$"
]

section_regex = re.compile("|".join(SECTION_PATTERNS))

def detect_sections(text: str):
    current = "INTRO"
    mapping = []

    for line in text.split("\n"):
        if section_regex.match(line.strip()):
            current = line.strip()

        mapping.append((line, current))

    return mapping

In [None]:
### Clause Detector

CLAUSE_RE = re.compile(
    r"(Section|Clause|Article|Topic)\s+\d+(\.\d+)*(\([a-z]\))*",
    re.I
)

def extract_clauses(text: str):
    matches = CLAUSE_RE.findall(text)
    return ["".join(m) for m in matches]

In [10]:
### Chunk Enricher - Add inheritance + section + clause_metadata

def enrich_chunk_metadata(parent_meta, extra_meta, chunk_text):
    meta = deepcopy(parent_meta or {})
    meta.update(extra_meta)

    clauses = extract_clauses(chunk_text)
    if clauses:
        meta["clauses"] = clauses

    return meta

In [11]:
### Citation Merge

def merge_adjacent_spans(spans, gap=80):
    spans = sorted(spans, key=lambda s: (s.page, s.start_char))
    merged = []

    for s in spans:
        if not merged:
            merged.append(s)
            continue

        last = merged[-1]

        same_page = last.page == s.page
        close = s.start_char <= last.end_char + gap

        if same_page and close:
            last.end_char = max(last.end_char, s.end_char)
        else:
            merged.append(s)

    return merged

In [None]:
### Highlight

def highlight_text(text, start, end):
    return text[:start] + "<mark>" + text[start:end] + "</mark>" + text[end:]


def extract_span(text, start, end):
    return text[start:end]

In [None]:
### Formatter - Frontend-ready answer formatter

def format_answer(answer_text, citation_spans):
    citations = []

    for s in citation_spans:
        citations.append({
            "chunk_id": s.chunk_id,
            "source": s.source,
            "page": s.page,
            "section": s.section,
            "clause": s.clause,
            "start_char": s.start_char,
            "end_char": s.end_char
        })

    return {
        "answer": answer_text,
        "citations": citations
    }

In [14]:
### RAG Citation Pipeline - This is the glue layer you call after retrieval

def build_citation_spans(retrieved_docs):
    spans = []

    for d in retrieved_docs:
        m = d.metadata

        spans.append(CitationSpan(
            chunk_id=m["chunk_id"],
            source=m.get("source"),
            page=m.get("page"),
            start_char=m["start_char"],
            end_char=m["end_char"],
            section=m.get("section_title"),
            clause=(m.get("clauses") or [None])[0]
        ))

    return spans


def rag_citation_pipeline(answer_text, retrieved_docs):
    spans = build_citation_spans(retrieved_docs)

    merged = merge_adjacent_spans(spans)

    return format_answer(answer_text, merged)

In [None]:
### Integration Example

import os
from dotenv import load_dotenv
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate 
from langchain_openai import AzureChatOpenAI


load_dotenv()

llm = AzureChatOpenAI(
    azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
    api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
)

prompt = PromptTemplate.from_template(
    "What is good name for a company that makes {product}"
)

chain = LLMChain(llm=llm, prompt=prompt)

result = chain.run(product="colorful socks")
print(result)

retrieved = retriever.invoke(query)

result = rag_citation_pipeline(
    answer_text=llm_answer,
    retrieved_docs=retrieved
)