In [None]:
!pip install sentence-transformers faiss-cpu transformers torch accelerate safetensors psutil
!pip install faiss-gpu


In [None]:

import re
import time
import math
import psutil
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Tuple

import numpy as np

try:
    import faiss
except Exception as e:
    raise ImportError("Faiss required (faiss-cpu or faiss-gpu). Install with pip. " + str(e))

try:
    import torch
    from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
except Exception as e:
    raise ImportError("torch + transformers required. Install with pip. " + str(e))



In [None]:
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
EMBED_DIM = 384                    
MISTRAL_MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
MAX_CHARS_PER_CHUNK = 1500
MIN_CHARS_PER_CHUNK = 400
MAX_CHARS_PER_SUBUNIT = 700        
FAISS_TARGET_NLIST = 256
FAISS_USE_GPU_IF_AVAILABLE = True
RAG_MAX_CONTEXT_CHARS = 8000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
@dataclass
class Unit:
    type: str                 
    text: str
    level: Optional[int] = None


@dataclass
class Chunk:
    doc_id: str
    chunk_id: str
    section_path: List[str]
    units: List[Unit] = field(default_factory=list)
    char_len: int = 0
    position: int = 0
    embedding: Optional[np.ndarray] = None  

    @property
    def text(self) -> str:
        return "\n\n".join(u.text for u in self.units)


In [None]:
HEADING_RE_NUM = re.compile(r"^\s*\d+(\.\d+)*\s+")
HEADING_RE_HASH = re.compile(r"^\s*#{1,6}\s+")
HEADING_RE_ALLCAPS = re.compile(r"^[A-Z0-9 ,;:\-]{8,}$")
SENTENCE_SPLIT_RE = re.compile(r'(?<=[.!?])\s+')


def detect_heading_level(line: str) -> Optional[int]:
    if HEADING_RE_HASH.match(line):
        m = re.match(r"^\s*(#+)", line)
        if m:
            return len(m.group(1))
    if HEADING_RE_NUM.match(line):
        num_part = line.strip().split()[0]
        return num_part.count(".") + 1
    if HEADING_RE_ALLCAPS.match(line.strip()):
        return 2
    return None


def segment_into_units(text: str) -> List[Unit]:
    units: List[Unit] = []
    current_para: List[str] = []
    lines = text.splitlines()

    def flush_para():
        nonlocal current_para
        if current_para:
            txt = " ".join(l.strip() for l in current_para).strip()
            if txt:
                units.append(Unit(type="paragraph", text=txt))
            current_para = []

    for raw in lines:
        line = raw.rstrip("\n")
        if not line.strip():
            flush_para()
            continue

        if re.match(r"^\s*[-*â€¢]\s+", line) or re.match(r"^\s*\d+\.\s+", line):
            flush_para()
            units.append(Unit(type="list", text=line.strip()))
            continue

        level = detect_heading_level(line.strip())
        if level is not None:
            flush_para()
            units.append(Unit(type="heading", text=line.strip(), level=level))
            continue

        current_para.append(line)

    flush_para()
    return units


def split_into_sentences(text: str) -> List[str]:
    parts = SENTENCE_SPLIT_RE.split(text.strip())
    return [p.strip() for p in parts if p.strip()]


def split_long_paragraph_unit(unit: Unit, max_chars_per_subunit: int) -> List[Unit]:
    if unit.type != "paragraph" or len(unit.text) <= max_chars_per_subunit:
        return [unit]

    sentences = split_into_sentences(unit.text)
    out: List[Unit] = []
    cur_sentences: List[str] = []
    cur_len = 0
    for s in sentences:
        sl = len(s)
        if cur_sentences and (cur_len + 1 + sl > max_chars_per_subunit):
            out.append(Unit(type="paragraph", text=" ".join(cur_sentences).strip()))
            cur_sentences = []
            cur_len = 0
        cur_sentences.append(s)
        cur_len += sl + 1
    if cur_sentences:
        out.append(Unit(type="paragraph", text=" ".join(cur_sentences).strip()))
    return out


def normalize_units_by_sentence(units: List[Unit], max_chars_per_subunit: int) -> List[Unit]:
    out: List[Unit] = []
    for u in units:
        if u.type == "paragraph" and len(u.text) > max_chars_per_subunit:
            out.extend(split_long_paragraph_unit(u, max_chars_per_subunit))
        else:
            out.append(u)
    return out


In [None]:
def update_section_path(section_path: List[str], level: int, heading_text: str) -> List[str]:
    new = section_path[: level - 1]
    new.append(heading_text)
    return new


def chunk_document(
    units: List[Unit],
    doc_id: str,
    max_chars_per_chunk: int = MAX_CHARS_PER_CHUNK,
    min_chars_per_chunk: int = MIN_CHARS_PER_CHUNK,
) -> List[Chunk]:
    chunks: List[Chunk] = []
    section_path: List[str] = []

    def new_chunk(pos: int) -> Chunk:
        return Chunk(doc_id=doc_id, chunk_id=f"{doc_id}_chunk_{pos}", section_path=list(section_path), position=pos)

    current_chunk = new_chunk(0)

    for unit in units:
        if unit.type == "heading" and unit.level is not None:
            if current_chunk.units:
                chunks.append(current_chunk)
                current_chunk = new_chunk(len(chunks))
            section_path = update_section_path(section_path, unit.level, unit.text)
            continue

        unit_len = len(unit.text)
        if current_chunk.char_len + unit_len > max_chars_per_chunk:
            if not current_chunk.units:
                current_chunk.units.append(unit)
                current_chunk.char_len += unit_len
                continue
            if current_chunk.char_len >= min_chars_per_chunk:
                chunks.append(current_chunk)
                current_chunk = new_chunk(len(chunks))
                current_chunk.units.append(unit)
                current_chunk.char_len += unit_len
                continue
            current_chunk.units.append(unit)
            current_chunk.char_len += unit_len
            continue

        current_chunk.units.append(unit)
        current_chunk.char_len += unit_len

    if current_chunk.units:
        chunks.append(current_chunk)
    return chunks


In [None]:
class Embedder:
    def __init__(self, model_name: str = EMBED_MODEL_NAME, device: str = DEVICE):
        print(f"[embedder] Loading model {model_name} on {device} ...")
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)


    def encode_batch(self, texts: List[str], batch_size: int = 64) -> np.ndarray:
        all_vecs = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            enc = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device)
            with torch.no_grad():
                out = self.model(**enc)
            last = out.last_hidden_state  
            mask = enc["attention_mask"].unsqueeze(-1)
            summed = (last * mask).sum(dim=1)
            lens = mask.sum(dim=1)
            vecs = (summed / lens).cpu().numpy()
            norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-10
            vecs = vecs / norms
            all_vecs.append(vecs.astype("float32"))
        return np.vstack(all_vecs)


def embed_chunks(chunks: List[Chunk], embedder: Embedder, batch_size: int = 64):
    texts = [c.text for c in chunks]
    vecs = embedder.encode_batch(texts, batch_size=batch_size)
    for i, c in enumerate(chunks):
        c.embedding = vecs[i]


In [None]:
def build_faiss_index(chunks: List[Chunk], use_gpu: bool = FAISS_USE_GPU_IF_AVAILABLE, target_nlist: int = FAISS_TARGET_NLIST):
    vectors = np.vstack([c.embedding for c in chunks]).astype("float32")
    N, dim = vectors.shape
    print(f"[faiss] N={N}, dim={dim}")

    if N == 0:
        index = faiss.IndexFlatIP(dim)
        return index, {}

    faiss.normalize_L2(vectors)

    if N < 64:
        print("[faiss] Too few vectors -> using IndexFlatIP (no train)")
        index = faiss.IndexFlatIP(dim)
        if use_gpu and faiss.get_num_gpus() > 0:
            res = faiss.StandardGpuResources()
            index = faiss.index_cpu_to_gpu(res, 0, index)
        index.add(vectors)
        label_map = {i: chunks[i] for i in range(N)}
        return index, label_map

    nlist = min(target_nlist, max(8, N // 10))
    print(f"[faiss] Using IVF with nlist={nlist} (target {target_nlist})")
    quantizer = faiss.IndexFlatIP(dim)
    index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss.METRIC_INNER_PRODUCT)

    train_size = max(nlist, min(N, 50000))
    idx = np.random.choice(N, train_size, replace=False)
    train_sample = vectors[idx]
    print(f"[faiss] Training with {train_size} samples...")
    index.train(train_sample)

    print("[faiss] Adding vectors...")
    index.add(vectors)
    print(f"[faiss] ntotal={index.ntotal}")

    if use_gpu and faiss.get_num_gpus() > 0:
        print("[faiss] Moving index to GPU (single GPU)...")
        res = faiss.StandardGpuResources()
        co = faiss.GpuClonerOptions()
        co.useFloat16 = True
        index = faiss.index_cpu_to_gpu(res, 0, index, co)

    try:
        index.nprobe = max(1, min(FAISS_TARGET_NLIST, int(math.sqrt(nlist))))
    except Exception:
        pass

    label_map = {i: chunks[i] for i in range(N)}
    return index, label_map


In [None]:
def retrieve_relevant_chunks(question: str, index, label_map: Dict[int, Chunk], top_k: int = 6):
    if index is None or index.ntotal == 0:
        return []

    q_vec = embedder.encode_batch([question], batch_size=1)
    faiss.normalize_L2(q_vec)
    k = min(top_k, index.ntotal)
    D, I = index.search(q_vec, k)
    D = D[0]
    I = I[0]
    out = []
    for idx, score in zip(I, D):
        if idx < 0:
            continue
        out.append((label_map[int(idx)], float(score)))
    return out

In [None]:
_mistral_model = None
_mistral_tokenizer = None


def load_mistral(model_name: str = MISTRAL_MODEL):
    global _mistral_model, _mistral_tokenizer
    if _mistral_model is not None:
        return _mistral_model, _mistral_tokenizer
    print(f"[mistral] loading {model_name} on {DEVICE} ...")
    _mistral_tokenizer = AutoTokenizer.from_pretrained(model_name)
    _mistral_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
        device_map="auto" if DEVICE == "cuda" else None,
    )
    return _mistral_model, _mistral_tokenizer


def call_mistral(system_prompt: str, user_prompt: str, max_new_tokens: int = 512):
    model, tokenizer = load_mistral()
    full = f"<s>[INST] {system_prompt}\n\n{user_prompt} [/INST]"
    inputs = tokenizer(full, return_tensors="pt", truncation=True, max_length=4096).to(DEVICE)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.2,
            top_p=0.9,
        )
    gen = tokenizer.decode(out[0], skip_special_tokens=True)
    if "[/INST]" in gen:
        gen = gen.split("[/INST]", 1)[1].strip()
    return gen.strip()

In [None]:
def generate_answer_rag(question: str, retrieved: List[Tuple[Chunk, float]], max_context_chars: int = RAG_MAX_CONTEXT_CHARS):
    parts = []
    used = 0
    for chunk, score in retrieved:
        block = f"[Doc: {chunk.doc_id} | Chunk: {chunk.chunk_id} | Sec: {' > '.join(chunk.section_path)} | Score: {score:.4f}]\n{chunk.text}\n"
        if used + len(block) > max_context_chars:
            break
        parts.append(block)
        used += len(block)
    context = "\n\n".join(parts)
    system_prompt = "You are a helpful assistant. Answer using ONLY the provided excerpts. If not supported, say you don't know."
    user_prompt = f"Question:\n{question}\n\nExcerpts:\n{context}\n\nAnswer:"
    return call_mistral(system_prompt, user_prompt)


class RAGEngine:
    def __init__(self, embedder: Embedder):
        self.chunks: List[Chunk] = []
        self.index = None
        self.label_map = {}
        self.embedder = embedder

    def build_from_docs(self, docs: Dict[str, str],
                        max_chars_per_chunk: int = MAX_CHARS_PER_CHUNK,
                        min_chars_per_chunk: int = MIN_CHARS_PER_CHUNK,
                        max_chars_per_subunit: int = MAX_CHARS_PER_SUBUNIT):
        all_chunks: List[Chunk] = []
        for doc_id, text in docs.items():
            units = segment_into_units(text)
            units = normalize_units_by_sentence(units, max_chars_per_subunit)
            doc_chunks = chunk_document(units, doc_id, max_chars_per_chunk, min_chars_per_chunk)
            all_chunks.extend(doc_chunks)

        print(f"[rag] Created {len(all_chunks)} chunks from {len(docs)} docs.")
        self.chunks = all_chunks

        embed_chunks(self.chunks, self.embedder, batch_size=64)

        self.index, self.label_map = build_faiss_index(self.chunks, use_gpu=FAISS_USE_GPU_IF_AVAILABLE, target_nlist=FAISS_TARGET_NLIST)

    def answer(self, question: str, top_k: int = 6) -> str:
        if self.index is None:
            raise RuntimeError("Index not built")
        retrieved = retrieve_relevant_chunks(question, self.index, self.label_map, top_k=top_k)
        return generate_answer_rag(question, retrieved)


In [None]:
def mem_mb():
    return psutil.Process().memory_info().rss / (1024**2)


def synthetic_doc(size_chars: int) -> str:
    block = (
        "This is a synthetic sentence used for stress-testing the dynamic chunking "
        "and FAISS indexing pipeline. It contains a moderate number of characters "
        "to simulate realistic text distribution. "
    )
    out = []
    while sum(len(x) for x in out) < size_chars:
        out.append(block)
    return "".join(out)[:size_chars]


def stress_test(size_chars: int):
    print("\n" + "="*70)
    print(f"STRESS TEST: {size_chars:,} chars")
    print("="*70)
    start_mem = mem_mb()
    t0 = time.time()
    doc = synthetic_doc(size_chars)
    t1 = time.time()
    print(f"Generated doc in {t1-t0:.2f}s, mem delta: {mem_mb()-start_mem:.2f} MB")

    t0 = time.time()
    units = segment_into_units(doc)
    units = normalize_units_by_sentence(units, MAX_CHARS_PER_SUBUNIT)
    t1 = time.time()
    print(f"Segmented -> units: {len(units)}, time: {t1-t0:.2f}s, mem delta: {mem_mb()-start_mem:.2f} MB")

    t0 = time.time()
    chunks = chunk_document(units, doc_id="stress", max_chars_per_chunk=MAX_CHARS_PER_CHUNK, min_chars_per_chunk=MIN_CHARS_PER_CHUNK)
    t1 = time.time()
    print(f"Chunked -> chunks: {len(chunks)}, avg chars/chunk: {np.mean([c.char_len for c in chunks]):.1f}, time: {t1-t0:.2f}s, mem delta: {mem_mb()-start_mem:.2f} MB")

    print("[stress] Instantiating embedder for stress embedding (may be slow).")
    ed = Embedder()
    t0 = time.time()
    embed_chunks(chunks, ed, batch_size=128)
    t1 = time.time()
    print(f"Embedded {len(chunks)} chunks in {t1-t0:.2f}s, mem delta: {mem_mb()-start_mem:.2f} MB")

    t0 = time.time()
    try:
        idx, lbl_map = build_faiss_index(chunks, use_gpu=FAISS_USE_GPU_IF_AVAILABLE, target_nlist=FAISS_TARGET_NLIST)
        t1 = time.time()
        print(f"Built FAISS index in {t1-t0:.2f}s, mem delta: {mem_mb()-start_mem:.2f} MB")
    except Exception as e:
        print("FAISS build failed:", e)
        return

    q = "What does this synthetic document describe?"
    t0 = time.time()
    results = retrieve_relevant_chunks(q, idx, lbl_map, top_k=6)
    t1 = time.time()
    print(f"Retrieved {len(results)} chunks in {t1-t0:.2f}s, mem delta: {mem_mb()-start_mem:.2f} MB")
    print("Stress test done.")

In [None]:
embedder = Embedder()
rag = RAGEngine(embedder)
docs = {
    "doc_1": """
1. Introduction
This document describes the system we are building.

1.1 Background
The system handles streaming data from multiple sources.
We normalize and validate inputs before processing.

1.2 Missing Data
If values are missing, we apply simple imputation strategies
such as forward fill or mean imputation depending on the feature.

2. Methods
We use a transformer-based model with attention over time windows.
""",
    "doc_2": """
1. Overview
This document explains retry and backoff strategy.

2. Retries
We retry failed requests with exponential backoff,
capped at a maximum delay and limited number of attempts.

3. Circuit Breaking
If too many failures occur, the circuit opens and
we temporarily stop sending requests to the downstream service.
"""
}

print("[demo] Building index from demo docs...")
rag.build_from_docs(docs)
q = "How does the system handle missing data?"
print("\nQUESTION:", q)
ans = rag.answer(q, top_k=6)
print("\nANSWER:\n", ans)


In [None]:
sizes = [100_000, 1_000_000, 5_000_000, 10_000_000]
for s in sizes:
    stress_test(s)
