In [1]:
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os, io, uuid, re
import chromadb
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
from pypdf import PdfReader
import requests
import nest_asyncio, uvicorn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TINYLLAMA_API = os.environ.get("TINYLLAMA_API", "http://localhost:8011/generate")
PERSIST_DIR = os.environ.get("RAG_STORE", "./rag_store")
COLLECTION_NAME = os.environ.get("RAG_COLLECTION", "tinyllama_docs")
EMB_MODEL_NAME = os.environ.get("EMB_MODEL", "sentence-transformers/all-MiniLM-L6-v2")


In [3]:
app = FastAPI(title="TinyLlama RAG Service", version="1.0.0")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

In [4]:
emb_model = SentenceTransformer(EMB_MODEL_NAME)

def embed_texts(texts: List[str]) -> List[List[float]]:
    return emb_model.encode(texts, normalize_embeddings=True).tolist()


In [5]:
client = chromadb.PersistentClient(path=PERSIST_DIR, settings=Settings(anonymized_telemetry=False))
try:
    collection = client.get_collection(COLLECTION_NAME)
except Exception:
    collection = client.create_collection(COLLECTION_NAME, metadata={"hnsw:space": "cosine"})


In [6]:
SENT_SPLIT = re.compile(r"(?<=[.!?])\s+")

def chunk_text(text: str, chunk_size: int = 700, overlap: int = 120) -> List[str]:
    # sentence-aware greedy chunking
    sents = re.split(SENT_SPLIT, text.strip()) if text else []
    chunks, cur = [], ""
    for s in sents:
        if len(cur) + len(s) + 1 <= chunk_size:
            cur = (cur + " " + s).strip()
        else:
            if cur:
                chunks.append(cur)
            cur = s
    if cur:
        chunks.append(cur)
    # add overlap
    with_overlap = []
    prev_tail = ""
    for ch in chunks:
        piece = (prev_tail + " " + ch).strip() if prev_tail else ch
        with_overlap.append(piece)
        prev_tail = ch[-overlap:]
    return [c for c in with_overlap if c.strip()]


def read_file_bytes(file: UploadFile) -> str:
    if file.filename.lower().endswith(".pdf"):
        data = io.BytesIO(file.file.read())
        pdf = PdfReader(data)
        text = []
        for page in pdf.pages:
            try:
                text.append(page.extract_text() or "")
            except Exception:
                pass
        return "\n".join(text)
    else:
        # .txt, .md, etc.
        return file.file.read().decode("utf-8", errors="ignore")



In [7]:
class GenerateIn(BaseModel):
    prompt: str
    use_rag: bool = True
    top_k: int = 4
    system_prompt: Optional[str] = None

class GenerateOut(BaseModel):
    response: str
    sources: List[Dict[str, Any]]
    context_tokens: int


In [8]:
@app.post("/ingest")
async def ingest(files: List[UploadFile] = File(...), namespace: str = Form("default")):
    added = 0
    for f in files:
        raw = read_file_bytes(f)
        chunks = chunk_text(raw)
        ids = [str(uuid.uuid4()) for _ in chunks]
        metadata = [{"filename": f.filename, "namespace": namespace, "chunk_index": i} for i, _ in enumerate(chunks)]
        embeddings = embed_texts(chunks)
        collection.add(ids=ids, documents=chunks, embeddings=embeddings, metadatas=metadata)
        added += len(chunks)
    return {"status": "ok", "chunks_added": added, "collection": COLLECTION_NAME}




In [9]:
def build_rag_prompt(query: str, contexts: List[str], system_prompt: Optional[str]) -> str:
    system = system_prompt or (
        "You are a precise assistant. Answer using ONLY the provided context. If the answer isn't in the context, say you don't know."
    )
    joined = "\n\n".join([f"[Doc {i+1}]\n" + c for i, c in enumerate(contexts)])
    return (
        f"<|system|>\n{system}\n</|system|>\n"
        f"<|context|>\n{joined}\n</|context|>\n"
        f"<|user|>\n{query}\n</|user|>\n"
        f"<|assistant|>"
    )



In [10]:
@app.post("/generate", response_model=GenerateOut)
async def generate(body: GenerateIn):
    sources: List[Dict[str, Any]] = []
    contexts: List[str] = []

    if body.use_rag:
        res = collection.query(
            query_embeddings=embed_texts([body.prompt]),
            n_results=max(1, min(10, body.top_k)),
            include=["documents", "metadatas", "distances"],
        )
        docs = res.get("documents", [[]])[0]
        mets = res.get("metadatas", [[]])[0]
        dists = res.get("distances", [[]])[0]
        for doc, meta, dist in zip(docs, mets, dists):
            contexts.append(doc)
            sources.append({
                "filename": meta.get("filename"),
                "namespace": meta.get("namespace"),
                "chunk_index": meta.get("chunk_index"),
                "score": float(1 - dist)  # cosine similarity approx
            })

    final_prompt = build_rag_prompt(body.prompt, contexts, body.system_prompt) if body.use_rag else body.prompt

    # Forward to your existing TinyLlama API
    try:
        resp = requests.post(TINYLLAMA_API, json={"prompt": final_prompt}, timeout=120)
        resp.raise_for_status()
        model_text = resp.json().get("response", "")
    except Exception as e:
        model_text = f"[RAG] Error calling TinyLlama API: {e}"

    context_tokens = sum(len(c.split()) for c in contexts)
    return GenerateOut(response=model_text, sources=sources, context_tokens=context_tokens)


In [11]:
@app.get("/")
def root():
    return {"service": "tinyllama-rag", "collection": COLLECTION_NAME, "persist_dir": PERSIST_DIR}

# Run: uvicorn rag_backend:app --port 8010 --reload

In [None]:
nest_asyncio.apply()  # allow running inside Jupyter

uvicorn.run(app, host="0.0.0.0", port=8010)

INFO:     Started server process [16564]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8010 (Press CTRL+C to quit)


INFO:     127.0.0.1:57907 - "POST /ingest HTTP/1.1" 200 OK
INFO:     127.0.0.1:57920 - "POST /generate HTTP/1.1" 200 OK
INFO:     127.0.0.1:58624 - "POST /ingest HTTP/1.1" 200 OK
INFO:     127.0.0.1:58632 - "POST /generate HTTP/1.1" 200 OK
