In [None]:
import tiktoken
import re
from copy import deepcopy
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_openai import AzureChatOpenAI
from langchain_community.chat_models import ChatOllama

In [3]:
# =========================
# CONFIG
# =========================

USE_AZURE = True   # switch provider here
OLLAMA_MODEL = "llama3"
EMBED_MODEL = "nomic-embed-text"

CHUNK_SIZE = 1200
CHUNK_OVERLAP = 150


# =========================
# TOKENIZER
# =========================

enc = tiktoken.get_encoding("cl100k_base")
token_len = lambda t: len(enc.encode(t))


# =========================
# CLAUSE DETECTOR
# =========================

CLAUSE_RE = re.compile(
    r"(Section|Clause|Article)\s+\d+(\.\d+)*(\([a-z]\))*",
    re.I
)

def extract_clauses(text):
    return ["".join(m) for m in CLAUSE_RE.findall(text)]

In [4]:
# =========================
# SPLIT + ENRICH
# =========================

def find_offsets(full, chunks):
    pos = 0
    out = []
    for c in chunks:
        s = full.find(c, pos)
        if s == -1:
            s = pos
        e = s + len(c)
        out.append((s, e))
        pos = e
    return out


def split_and_enrich(docs):

    splitter = RecursiveCharacterTextSplitter(
        chunk_size=CHUNK_SIZE,
        chunk_overlap=CHUNK_OVERLAP,
        separators=["\n\n","\n",". "," ",""]
    )

    all_chunks = []
    reverse_map = {}

    for page_idx, doc in enumerate(docs):

        splits = splitter.split_documents([doc])
        texts = [d.page_content for d in splits]
        offsets = find_offsets(doc.page_content, texts)

        for chunk_idx, (sd, (start,end)) in enumerate(zip(splits, offsets)):

            chunk_id = f"{page_idx}_{chunk_idx}"

            meta = deepcopy(doc.metadata or {})
            meta.update({
                "chunk_id": chunk_id,
                "page": meta.get("page", page_idx),
                "start_char": start,
                "end_char": end,
                "token_count": token_len(sd.page_content),
                "clauses": extract_clauses(sd.page_content)
            })

            cdoc = Document(page_content=sd.page_content, metadata=meta)
            all_chunks.append(cdoc)

            reverse_map[chunk_id] = meta

    return all_chunks, reverse_map

In [5]:
# =========================
# CITATION MERGE
# =========================

def merge_spans(spans, gap=80):
    spans = sorted(spans, key=lambda s: (s["page"], s["start_char"]))
    merged = []

    for s in spans:
        if not merged:
            merged.append(s)
            continue

        last = merged[-1]

        if s["page"] == last["page"] and s["start_char"] <= last["end_char"] + gap:
            last["end_char"] = max(last["end_char"], s["end_char"])
        else:
            merged.append(s)

    return merged

In [6]:
# =========================
# ANSWER FORMATTER
# =========================

def format_answer(answer, docs):
    spans = []

    for d in docs:
        m = d.metadata
        spans.append({
            "chunk_id": m["chunk_id"],
            "source": m.get("source"),
            "page": m.get("page"),
            "start_char": m["start_char"],
            "end_char": m["end_char"],
            "clause": (m.get("clauses") or [None])[0]
        })

    spans = merge_spans(spans)

    return {
        "answer": answer,
        "citations": spans
    }

In [7]:
# =========================
# BUILD VECTOR STORE
# =========================

def build_vector_store(chunks):

    embeddings = OllamaEmbeddings(model=EMBED_MODEL)

    db = Chroma.from_documents(
        chunks,
        embedding=embeddings,
        collection_name="rag_citations"
    )

    return db

In [8]:
# =========================
# LLM SELECTOR
# =========================

def get_llm():

    if USE_AZURE:
        return AzureChatOpenAI(
            azure_deployment="gpt-4o",
            api_version="2024-02-01",
            temperature=0
        )
    else:
        return ChatOllama(
            model=OLLAMA_MODEL,
            temperature=0
        )

In [9]:
# =========================
# RAG PIPELINE
# =========================

def rag_query(db, query):

    retriever = db.as_retriever(search_kwargs={"k":4})
    docs = retriever.invoke(query)

    context = "\n\n".join(d.page_content for d in docs)

    llm = get_llm()

    prompt = f"""
        Answer using ONLY the context.
        Cite supporting facts.

        Context:
        {context}

        Question:
        {query}
    """

    answer = llm.invoke(prompt).content

    return format_answer(answer, docs)

In [14]:
# =========================
# DEMO RUN
# =========================

documents = [
    Document(
        page_content=open("/home/jeremy/Documents/Work/Learning/fastapi/llm/4-document_loader/sample_data/shopping_behavior_updated.txt").read(),
        metadata={"source":"/home/jeremy/Documents/Work/Learning/fastapi/llm/4-document_loader/sample_data/shopping_behavior_updated.txt", "page":1}
    )
]

chunks, reverse_map = split_and_enrich(documents)
db = build_vector_store(chunks)

result = rag_query(db, "What does Section 4.2 say about termination?")

print(result)

ValueError: Error raised by inference endpoint: HTTPConnectionPool(host='localhost', port=11434): Max retries exceeded with url: /api/embeddings (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x78fbd29ffcb0>: Failed to establish a new connection: [Errno 111] Connection refused'))