In [1]:
!pip install faiss-cpu pyngrok
!pip install -q transformers accelerate bitsandbytes sentencepiece
!pip -q install --upgrade transformers accelerate bitsandbytes huggingface_hub sentencepiece
!pip -q install flask pyngrok

In [2]:
import transformers
import torch
import sentence_transformers
import faiss
import flask
import pyngrok
import requests
import re
import os, time, threading, json
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer, CrossEncoder
from typing import List, Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from flask import Flask, request, jsonify
import yaml

In [3]:
RAW_URL = "https://github.com/AntoineAbouJanab/Livedrops---AntoineAouJanab/tree/main/docs/prompting/knowledge-base.md"
RAW_URL = "https://raw.githubusercontent.com/AntoineAbouJanab/Livedrops---AntoineAouJanab/main/docs/prompting/knowledge-base.md"
PROMPTS_URL = "https://github.com/AntoineAbouJanab/Livedrops---AntoineAouJanab/blob/main/docs/prompting/assistant-prompts.yml"

In [4]:
headers = {}
resp = requests.get(RAW_URL, headers=headers)
resp.raise_for_status()
markdown_text = resp.text
print("✅ Fetched markdown, length:", len(markdown_text))
print(markdown_text[:500], "...")

In [5]:
def parse_docs(md_text: str):
    pattern = re.compile(r'(?m)^##\s*Document\s+(\d+)\s*:\s*(.+?)\s*$')
    matches = list(pattern.finditer(md_text))
    docs = []
    for i, m in enumerate(matches):
        num = m.group(1).strip()
        title = m.group(2).strip()
        start = m.end()
        end = matches[i+1].start() if i+1 < len(matches) else len(md_text)
        chunk = md_text[start:end].strip()
        chunk = re.sub(r'(?m)^\s*---\s*$', '', chunk).strip()
        docs.append({"title": title, "content": chunk, "id": f"doc{num}"})
    return docs

KNOWLEDGE_BASE = parse_docs(markdown_text)
print(f"✅ Parsed {len(KNOWLEDGE_BASE)} documents")
print(KNOWLEDGE_BASE)

In [6]:
EMBED_MODEL_NAME = "BAAI/bge-base-en-v1.5"
embedder = SentenceTransformer(EMBED_MODEL_NAME)
print("Embedder ready →", EMBED_MODEL_NAME)

texts = [f"{d['title']}\n\n{d['content']}" for d in KNOWLEDGE_BASE]
ids = [d["id"] for d in KNOWLEDGE_BASE]
titles = [d["title"] for d in KNOWLEDGE_BASE]

embs = embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True).astype("float32")
embs /= (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12)

d = embs.shape[1]
faiss_index = faiss.IndexFlatIP(d)
faiss_index.add(embs)

In [7]:
def mmr(query_vec, cand_vecs, lambda_mult=0.7, top_k=5):
    selected = []
    candidate_idxs = list(range(cand_vecs.shape[0]))
    q_sims = cand_vecs @ query_vec
    while candidate_idxs and len(selected) < top_k:
        if not selected:
            best = max(candidate_idxs, key=lambda i: q_sims[i])
            selected.append(best)
            candidate_idxs.remove(best)
        else:
            sims_to_selected = np.max(cand_vecs[candidate_idxs] @ cand_vecs[selected].T, axis=1)
            scores = lambda_mult * q_sims[candidate_idxs] - (1 - lambda_mult) * sims_to_selected
            best_idx_local = int(np.argmax(scores))
            best = candidate_idxs[best_idx_local]
            selected.append(best)
            candidate_idxs.remove(best)
    return selected

RERANKER_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
cross_encoder = CrossEncoder(RERANKER_NAME)

def rerank_with_cross_encoder(query, cand_texts):
    pairs = [(query, t) for t in cand_texts]
    scores = cross_encoder.predict(pairs)
    order = np.argsort(-scores)
    return order, scores

def decide_k_from_scores(scores, max_k=3, margin=4, ratio=0.85):
    if len(scores) == 0 or scores[0] < -5:
        return 0
    if len(scores) == 1:
        return 1
    s1, s2 = float(scores[0]), float(scores[1])
    if (s1 - s2) >= margin:
        return 1
    return min(2, max_k)

def search_enhanced(query, ann_k=50, mmr_top=10, lambda_mult=0.7, use_reranker=True, max_return_k=3):
    q = embedder.encode([query], convert_to_numpy=True)[0].astype("float32")
    q /= (np.linalg.norm(q) + 1e-12)
    scores, idxs = faiss_index.search(np.array([q], dtype="float32"), min(ann_k, len(ids)))
    idxs = idxs[0]
    ann_scores = scores[0]
    cand_vecs = embs[idxs]
    sel = mmr(q, cand_vecs, lambda_mult=lambda_mult, top_k=min(mmr_top, len(idxs)))
    mmr_idxs = idxs[sel]
    mmr_texts = [texts[i] for i in mmr_idxs]
    mmr_ids = [ids[i] for i in mmr_idxs]
    mmr_titles = [titles[i] for i in mmr_idxs]
    print("MMR selected", mmr_titles)
    if use_reranker:
        order, ce_scores = rerank_with_cross_encoder(query, mmr_texts)
        final_idxs = [mmr_idxs[i] for i in order]
        final_scores = [float(ce_scores[i]) for i in order]
        print(f"final idxs {final_idxs}")
        print(f"final scores {final_scores}")
    else:
        final_scores = (embs[mmr_idxs] @ q).tolist()
        order = np.argsort(-np.array(final_scores))
        final_idxs = [mmr_idxs[i] for i in order]
        final_scores = [final_scores[i] for i in order]
    k = decide_k_from_scores(final_scores, max_k=max_return_k, margin=4, ratio=0.85)
    results = []
    for i in range(min(k, len(final_idxs))):
        idx = final_idxs[i]
        results.append({
            "id": ids[idx],
            "title": titles[idx],
            "score": float(final_scores[i]),
            "preview": texts[idx][:200].replace("\n", " ") + "..."
        })
    return results

for result in (search_enhanced("How do I synchronize inventory via API with my warehouse system?")):
    print(result["title"])
print(len(search_enhanced("How do I synchronize inventory via API with my warehouse system?")))

In [8]:
def to_raw_github(url: str) -> str:
    if "raw.githubusercontent.com" in url:
        return url
    m = re.match(r"https?://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.+)", url)
    if not m:
        raise ValueError("Provide a valid GitHub file URL or raw URL.")
    user, repo, branch, path = m.groups()
    return f"https://raw.githubusercontent.com/{user}/{repo}/{branch}/{path}"

def download_prompts(src_url: str, save_path: str = "/content/assistant-prompts.yml") -> str:
    raw_url = to_raw_github(src_url)
    headers = {}
    token = os.environ.get("GITHUB_TOKEN")
    if token:
        headers["Authorization"] = f"token {token}"
    r = requests.get(raw_url, headers=headers, timeout=30)
    r.raise_for_status()
    Path(save_path).write_text(r.text, encoding="utf-8")
    return save_path

In [9]:
REQUIRED_KEYS = ["base_retrieval_prompt", "multi_document_synthesis", "refusal_no_context"]

def load_prompts(yaml_path: str):
    p = Path(yaml_path)
    if not p.exists():
        raise FileNotFoundError(f"Prompts file not found: {yaml_path}")
    data = yaml.safe_load(p.read_text(encoding="utf-8")) or {}
    missing = [k for k in REQUIRED_KEYS if k not in data]
    if missing:
        raise KeyError(f"Missing keys in prompts YAML: {missing}")
    return data

In [10]:
def choose_prompt(retrieved_docs: List[Dict[str, Any]], prompts: Dict[str, str]) -> str:
    n = len(retrieved_docs)
    if n == 0:
        return prompts["refusal_no_context"]
    elif n == 1:
        return prompts["base_retrieval_prompt"]
    else:
        return prompts["multi_document_synthesis"]

def build_final_prompt(user_question: str, retrieved_docs: List[Dict[str, Any]], prompts: Dict[str, str]) -> str:
    selected = choose_prompt(retrieved_docs, prompts)

    def _sources_block(docs: List[Dict[str, Any]], limit: int = 5) -> str:
        if not docs:
            return ""
        items = sorted(docs, key=lambda d: d.get("score", 0.0), reverse=True)
        seen, names = set(), []
        for d in items:
            name = d.get("title") or d.get("id") or "Untitled"
            if name in seen:
                continue
            seen.add(name)
            names.append(name)
            if len(names) >= limit:
                break
        if not names:
            return ""
        return "### Sources\n" + "\n".join(f"- {n}" for n in names)

    if len(retrieved_docs) == 0:
        return f"{selected}\n\nUser question:\n{user_question}\n\n{_sources_block(retrieved_docs)}"

    if len(retrieved_docs) == 1:
        d = retrieved_docs[0]
        body = d.get('text') or d.get('content') or d.get('preview', '')
        context_block = (
            f"### Source\n"
            f"Title: {d.get('title','')}\n"
            f"ID: {d.get('id','')}\n\n"
            f"{body}"
        )
        return (
            f"{selected}\n\nUser question:\n{user_question}\n\n"
            f"{context_block}\n\n{_sources_block(retrieved_docs)}"
        )

    parts = []
    for j, d in enumerate(retrieved_docs, start=1):
        body = d.get('text') or d.get('content') or d.get('preview', '')
        parts.append(
            "---\n"
            f"[Doc {j}] Title: {d.get('title','')}\n"
            f"ID: {d.get('id','')}\n"
            f"Score: {float(d.get('score',0.0)):.4f}\n\n"
            f"{body}"
        )
    context_block = "\n".join(parts)

    return (
        f"{selected}\n\nUser question:\n{user_question}\n\n"
        f"### Retrieved Context ({len(retrieved_docs)} docs)\n"
        f"{context_block}\n\n{_sources_block(retrieved_docs)}"
    )

In [11]:
local_yaml_path = download_prompts(PROMPTS_URL)
prompts = load_prompts(local_yaml_path)
retrieved_docs = []
user_question = "what is my name"
final_prompt = build_final_prompt(user_question, search_enhanced(user_question), prompts)
print(final_prompt)

In [12]:
from huggingface_hub import login
login()

In [13]:
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
compute_dtype = torch.float16
bnb_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=compute_dtype,
)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
print(f"Loading {MODEL_ID} in 4-bit NF4...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    use_fast=True,
    token=HF_TOKEN,
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=compute_dtype,
    quantization_config=bnb_4bit,
    attn_implementation="sdpa",
    low_cpu_mem_usage=True,
    token=HF_TOKEN,
)
if model.generation_config.pad_token_id is None:
    model.generation_config.pad_token_id = tokenizer.eos_token_id
model.eval()
print("✅ Llama 3.1 8B Instruct is ready.")
if torch.cuda.is_available():
    torch.cuda.synchronize()
    print("GPU:", torch.cuda.get_device_name(0))
    print(f"VRAM in use: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")

In [14]:
def llama_generate(prompt, max_new_tokens=120, temperature=None, min_new_tokens=16):
    messages = [{"role": "user", "content": prompt}]
    input_ids = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)
    terminators = [tokenizer.eos_token_id]
    try:
        eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
        if eot_id is not None and eot_id != tokenizer.eos_token_id:
            terminators.append(eot_id)
    except Exception:
        pass
    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        min_new_tokens=min_new_tokens,
        use_cache=True,
        eos_token_id=terminators if len(terminators) > 1 else terminators[0],
        do_sample=False,
    )
    if temperature is not None and temperature > 0:
        gen_kwargs.update(do_sample=True, temperature=temperature, top_p=0.95)
    model.eval()
    with torch.inference_mode():
        out = model.generate(**gen_kwargs)
    gen_only = out[0, input_ids.shape[-1]:]
    return tokenizer.decode(gen_only, skip_special_tokens=True).strip()

import time
q = "Say hi in one short sentence."
t0 = time.time()
resp = llama_generate(q, max_new_tokens=60, min_new_tokens=12)
print(repr(resp))
print(f"\nElapsed: {time.time() - t0:.2f}s")

def generate_response(prompt, max_new_tokens=512, temperature=0.7):
    messages = [{"role": "user", "content": prompt}]
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    terminators = [tokenizer.eos_token_id]
    try:
        eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
        if eot_id is not None and eot_id != tokenizer.eos_token_id:
            terminators.append(eot_id)
    except Exception:
        pass
    gen_kwargs = dict(
        max_new_tokens=max_new_tokens,
        use_cache=True,
        eos_token_id=terminators if len(terminators) > 1 else terminators[0],
        do_sample=False,
    )
    if temperature is not None and temperature > 0:
        gen_kwargs.update(
            do_sample=True,
            temperature=float(temperature),
            top_p=0.95,
            repetition_penalty=1.1,
        )
    model.eval()
    with torch.inference_mode():
        outputs = model.generate(input_ids=input_ids, **gen_kwargs)
    gen_only = outputs[0, input_ids.shape[-1]:]
    return tokenizer.decode(gen_only, skip_special_tokens=True).strip()

print("✅ Generation function ready")

In [23]:
def extract_doc_titles(results, unique=True, limit=5, print_debug=False):
    if not results:
        return []
    titles_out, seen = [], set()
    for r in results:
        t = r.get("title") or r.get("id") or "Untitled"
        if unique and t in seen:
            continue
        seen.add(t)
        titles_out.append(t)
        if limit and len(titles_out) >= limit:
            break
    if print_debug:
        print("🔎 Sources:")
        for t in titles_out:
            print(f"- {t}")
    return titles_out

def rag_generate(user_question, max_tokens=512):
    retrieved = search_enhanced(user_question)
    final_prompt = build_final_prompt(user_question, retrieved, prompts)
    print(f"final prompt = {final_prompt}")
    print(f" Retrieved {len(retrieved)} documents")
    print(" Generating response...")
    answer = generate_response(final_prompt, max_new_tokens=max_tokens)
    sources = extract_doc_titles(retrieved, unique=True, limit=5)
    return answer , sources

user_question = "how to track my order?"
print(f"Question: {user_question}\n")
answer, sources = rag_generate(user_question)
print(f"Answer:\n{answer}\n")

In [25]:
import time, threading
from flask import Flask, request, jsonify
app = Flask(__name__)

def _safe_sources(retrieved, limit=5):
    try:
        return extract_doc_titles(retrieved, unique=True, limit=limit)
    except NameError:
        seen, out = set(), []
        for r in retrieved:
            t = r.get("title") or r.get("id") or "Untitled"
            if t in seen:
                continue
            seen.add(t)
            out.append(t)
            if len(out) >= limit:
                break
        return out

@app.get("/health")
def health():
    return jsonify(
        status="ok",
        model=str(getattr(model.config, "name_or_path", "llm")),
        device=str(model.device),
    ), 200

@app.post("/ping")
def ping():
    data = request.get_json(silent=True) or {}
    prompt = data.get("prompt") or data.get("question") or ""
    if not prompt:
        return jsonify(error="Missing 'prompt'"), 400
    t0 = time.time()
    out = generate_response(prompt, max_new_tokens=int(data.get("max_new_tokens", 160)))
    return jsonify(answer=out, latency_s=round(time.time()-t0, 3)), 200

@app.post("/chat")
def chat():
    data = request.get_json(silent=True) or {}
    question = data.get("question") or data.get("prompt") or ""
    if not question:
        return jsonify(error="Missing 'question'"), 400

    t0 = time.time()
    # rag_generate returns (answer_text, sources)
    answer_text, sources = rag_generate(
        question,
        max_tokens=int(data.get("max_new_tokens", 160))
    )
    latency = round(time.time() - t0, 3)

    return jsonify(
        answer=answer_text,            # <-- string
        sources=sources,               # <-- list[str]
        retrieved_count=len(sources),
        latency_s=latency
    ), 200

def _run():
    app.run(host="0.0.0.0", port=5000, debug=False, threaded=True)

thread = threading.Thread(target=_run, daemon=True)
thread.start()
print("✅ Flask started on http://127.0.0.1:5000")

import requests, json, time
BASE = "http://127.0.0.1:5000"
print("GET /health ->")
print(requests.get(f"{BASE}/health").json(), "\n")
print("POST /ping ->")
print(requests.post(f"{BASE}/ping", json={
    "prompt": "Say hi in one short sentence.",
    "max_new_tokens": 80
}).json(), "\n")
print("POST /chat ->")
t0 = time.time()
res = requests.post(f"{BASE}/chat", json={
    "question": "How do I create a Shoplite account and verify my email?",
    "max_new_tokens": 160
})
print(res.json())
print(f"\nElapsed: {time.time()-t0:.2f}s")

from pyngrok import ngrok
ngrok_token = input("Enter your ngrok token: ").strip()
ngrok.set_auth_token(ngrok_token)
for t in ngrok.get_tunnels():
    ngrok.disconnect(t.public_url)
public_url = ngrok.connect(5000, "http").public_url
print("🌐 Public URL:", public_url)
print("Try POST /chat and /ping against this URL.")