In [1]:
"""
RAG Pharma Chatbot - Gradio app (single-file)

Features:
- Ask user to upload a PDF (pharma doc)
- Extract text, chunk with overlap
- Compute embeddings using SentenceTransformers
- Store embeddings in FAISS (on-disk optional)
- Retrieve top-k relevant chunks
- Call Groq LLM (via groq-python or HTTP) to generate final answer using retrieved context
- Simple Gradio GUI

Environment variables expected:
- GROQ_API_KEY : your Groq API key
- GROQ_MODEL : model id (e.g. "compound-beta" or "llama-3.1-8b-instant")
- SENTENCE_TRANSFORMER_MODEL : (optional) default: "all-MiniLM-L6-v2"

"""


'\nRAG Pharma Chatbot - Gradio app (single-file)\n\nFeatures:\n- Ask user to upload a PDF (pharma doc)\n- Extract text, chunk with overlap\n- Compute embeddings using SentenceTransformers\n- Store embeddings in FAISS (on-disk optional)\n- Retrieve top-k relevant chunks\n- Call Groq LLM (via groq-python or HTTP) to generate final answer using retrieved context\n- Simple Gradio GUI\n\nEnvironment variables expected:\n- GROQ_API_KEY : your Groq API key\n- GROQ_MODEL : model id (e.g. "compound-beta" or "llama-3.1-8b-instant")\n- SENTENCE_TRANSFORMER_MODEL : (optional) default: "all-MiniLM-L6-v2"\n\n'

In [2]:
!pip install pypdf2 sentence-transformers faiss-cpu groq python-dotenv
!pip install -U gradio websockets
!pip install PyPDF2 pymupdf



In [3]:

import os
import tempfile
import uuid
import json
from typing import List, Tuple
import fitz

import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import gradio as gr
from PyPDF2 import PdfReader
from dataclasses import dataclass
from dotenv import load_dotenv
import nltk

nltk.download("punkt")

def semantic_chunk_text(text: str, chunk_size: int = 5, overlap: int = 1) -> List[str]:
    """Split text into sentence-based chunks for better semantic coherence."""
    sentences = nltk.sent_tokenize(text)
    chunks = []
    for i in range(0, len(sentences), chunk_size - overlap):
        chunk = " ".join(sentences[i:i + chunk_size])
        chunks.append(chunk.strip())
    return chunks


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:

# Optional: try to use groq python client if installed, otherwise fallback to simple http call
try:
    from groq import GroqClient
    _HAS_GROQ = True
except Exception:
    import requests
    _HAS_GROQ = False


In [5]:
# Colab: Read environment directly from notebook secrets
from google.colab import userdata


GROQ_API_KEY = userdata.get('KEY_GEN_AI_HEC_GROQ')
GROQ_MODEL = os.getenv("GROQ_MODEL", "compound-beta")
SENT_MODEL = os.getenv("SENTENCE_TRANSFORMER_MODEL", "all-MiniLM-L6-v2")

In [6]:

# Simple dataclass to hold chunks and metadata
@dataclass
class DocChunk:
    id: str
    text: str
    metadata: dict



In [7]:

# ---------- PDF extraction ----------

def extract_text_from_pdf(file_path):
    text = ""
    doc = fitz.open(file_path)
    for page in doc:
        text += page.get_text("text")
    doc.close()
    return text.strip()

In [8]:

# ---------- Chunking utilities ----------

def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
    """Naive chunker based on characters (safe for PDFs). Adjust chunk_size to tokens for opt.
    Returns chunks with overlap.
    """
    if not text:
        return []
    chunks = []
    start = 0
    text_len = len(text)
    while start < text_len:
        end = min(start + chunk_size, text_len)
        chunk = text[start:end].strip()
        if chunk:
            chunks.append(chunk)
        start = end - overlap
        if start < 0:
            start = 0
    return chunks


In [9]:


# ---------- Embeddings and FAISS index ----------

class FaissStore:
    def __init__(self, dim: int, index_path: str = None):
        self.dim = dim
        self.index = faiss.IndexFlatIP(self.dim)  # use inner product with normalized vectors for cosine
        self.id_map = {}  # mapping from int index -> metadata
        self.next_id = 0
        self.index_path = index_path

    def add(self, vectors: np.ndarray, metadatas: List[dict]):
        # vectors should be L2-normalized if using IndexFlatIP for cosine
        n_before = self.index.ntotal
        self.index.add(vectors)
        for i, meta in enumerate(metadatas):
            self.id_map[n_before + i] = meta

    def search(self, q_vector: np.ndarray, top_k: int = 5, min_score: float = 0.3):
        if self.index.ntotal == 0:
            return []
        D, I = self.index.search(q_vector, top_k)
        results = []
        for score_list, idx_list in zip(D, I):
            for score, idx in zip(score_list, idx_list):
                if idx < 0 or score < min_score:
                    continue
                meta = self.id_map.get(int(idx), {})
                results.append((meta, float(score)))
        return results


    def save(self, path_prefix: str):
        faiss.write_index(self.index, path_prefix + ".index")
        with open(path_prefix + ".meta.json", "w", encoding="utf-8") as f:
            json.dump(self.id_map, f)

    def load(self, path_prefix: str):
        self.index = faiss.read_index(path_prefix + ".index")
        with open(path_prefix + ".meta.json", "r", encoding="utf-8") as f:
            self.id_map = json.load(f)



In [10]:

# ---------- LLM (Groq) wrapper ----------

class GroqLLM:
    def __init__(self, api_key: str, model: str = "compound-beta"):
        self.api_key = api_key
        self.model = model
        if _HAS_GROQ:
            self.client = GroqClient(api_key=api_key)

    def chat_completion(self, system_prompt: str, user_prompt: str, max_tokens: int = 512):
        full_prompt = f"{system_prompt}\n\nUser: {user_prompt}\n\nAssistant:"
        if _HAS_GROQ:
            # Example using groq-python client (wrapper)
            response = self.client.chat.create(model=self.model, messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ], max_tokens=max_tokens)
            return response.choices[0].message.content
        else:
            # fallback raw HTTP compatible with OpenAI-like endpoint
            url = "https://api.groq.com/openai/v1/chat/completions"
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json",
            }
            body = {
                "model": self.model,
                "messages": [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ],
                "max_tokens": max_tokens,
            }
            r = requests.post(url, headers=headers, json=body)
            r.raise_for_status()
            data = r.json()
            # compatible parsing for OpenAI-style response
            return data["choices"][0]["message"]["content"]



In [11]:

# ---------- Main RAG pipeline helpers ----------

class RAGPipeline:
    def __init__(self, emb_model_name: str = SENT_MODEL):
        self.emb_model = SentenceTransformer(emb_model_name)
        # determine dim by encoding a dummy
        v = self.emb_model.encode(["hello world"], convert_to_numpy=True)
        self.dim = v.shape[1]
        self.store = FaissStore(dim=self.dim)
        self.llm = GroqLLM(api_key=GROQ_API_KEY, model=GROQ_MODEL)

    def process_document(self, file_path: str, batch_size: int = 64):
        """Reads, chunks, embeds and stores a PDF document."""
        text = extract_text_from_pdf(file_path)

        # --- Chunk into small manageable pieces (character-based for now) ---
        chunks = [text[i:i + 500] for i in range(0, len(text), 500)]
        print(f"Total chunks: {len(chunks)}")

        # --- Reset FAISS index and metadata store ---
        dim = self.emb_model.get_sentence_embedding_dimension()
        # Use cosine similarity: normalize embeddings + IndexFlatIP
        self.store.index = faiss.IndexFlatIP(dim)
        self.store.id_map = {}
        self.store.next_id = 0

        # --- Process in batches to prevent memory overflow ---
        for i in range(0, len(chunks), batch_size):
            batch = chunks[i:i + batch_size]
            embeddings = self.emb_model.encode(batch, convert_to_numpy=True)
            embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
            self.store.index.add(embeddings.astype("float32"))

            for chunk in batch:
                self.store.id_map[self.store.next_id] = {"text": chunk}
                self.store.next_id += 1

        print(f"Added {len(chunks)} chunks to FAISS index.")
        return len(chunks)


    PHARMA_KEYWORDS = {"drug", "dosage", "contraindication", "tablet", "capsule", "injection", "pharma", "prescription"}

    def query(self, query_text: str, top_k: int = 5, max_tokens: int = 512, min_score: float = 0.3):
        # --- Domain filter ---
        if not any(kw in query_text.lower() for kw in PHARMA_KEYWORDS):
            return "This question does not appear pharma-related. Please ask a domain-specific query.", []

        # --- Embed & normalize ---
        q_emb = self.emb_model.encode([query_text], convert_to_numpy=True)
        q_emb = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-9)

        # --- Search with threshold ---
        results = self.store.search(q_emb.astype('float32'), top_k)
        results = [(meta, score) for meta, score in results if score >= min_score]

        if not results:
            return "I don’t know. The document does not contain relevant information.", []

        # --- Build context ---
        context_pieces = [meta.get("text", "") for meta, score in results if meta.get("text", "")]
        context = "\n\n---\n\n".join(context_pieces)

        # --- Strict system prompt ---
        system_prompt = (
            "You are a pharma domain assistant. ONLY use the provided context. "
            "If the context does not contain the answer, respond strictly with: "
            "'I don’t know. Please check the source document.' "
            "Do not guess or provide information outside the context."
        )

        user_prompt = f"Context:\n{context}\n\nQuestion: {query_text}"

        answer = self.llm.chat_completion(system_prompt=system_prompt, user_prompt=user_prompt, max_tokens=max_tokens)

        # --- Confidence reporting ---
        top_score = max(score for _, score in results)
        confidence_msg = f"Top match score = {top_score:.2f} → {'high confidence' if top_score > 0.6 else 'low confidence'}"

        # --- Sources preview ---
        sources_text = "\n\nRetrieved chunks:\n"
        for i, (meta, score) in enumerate(results):
            snippet = meta.get("text", "")[:200].replace("\n", " ")
            sources_text += f"[{i+1}] score={score:.4f} snippet={snippet}...\n"

        return f"{answer}\n\n{confidence_msg}\n\n{sources_text}", results


In [None]:

# ---------- Gradio GUI ----------

rag = RAGPipeline()

def handle_upload(file_obj):
    if file_obj is None:
        return "", "No file uploaded"

    # get the file path from the NamedString object
    file_path = file_obj.name if hasattr(file_obj, "name") else file_obj

    # pass directly to RAG pipeline
    n_chunks = rag.process_document(file_path)

    return f"Processed PDF, created {n_chunks} chunks.", "Upload successful"


def handle_query(question: str):
    if not question:
        return "Please enter a question."
    answer, results = rag.query(question, top_k=5)
    # show sources briefly
    sources_text = "\n\nRetrieved chunks:\n"
    for i, (meta, score) in enumerate(results):
        sources_text += f"[{i+1}] score={score:.4f} snippet={meta.get('text','')[:200].replace('\n',' ')}...\n"
    return answer + "\n\n" + sources_text


with gr.Blocks(title="RAG Pharma Chatbot") as demo:
    gr.Markdown("# RAG Pharma Chatbot\nUpload a pharma PDF, then ask domain questions. Uses SentenceTransformer embeddings + FAISS + Groq LLM.")
    with gr.Row():
        pdf_in = gr.File(label="Upload PDF", file_types=[".pdf"])
        upload_btn = gr.Button("Process PDF")
    status_out = gr.Textbox(label="Status", lines=1)

    with gr.Row():
        query_in = gr.Textbox(label="Ask a question", placeholder="e.g. What is the recommended dosage for drug X?")
        query_btn = gr.Button("Ask")
    answer_out = gr.Textbox(label="Answer", lines=10)
    confidence_out = gr.Textbox(label="Confidence", lines=1)

    query_btn.click(fn=handle_query, inputs=[query_in], outputs=[answer_out, confidence_out])


    upload_btn.click(fn=handle_upload, inputs=[pdf_in], outputs=[status_out])
    query_btn.click(fn=handle_query, inputs=[query_in], outputs=[answer_out])


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", share=True, debug=True)


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://902c167c90f0d115c4.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
