In [37]:
!pip install -q faiss-cpu gradio wikipedia pypdf sentence-transformers transformers accelerate

In [38]:
import numpy as np
import faiss
import wikipedia
from pypdf import PdfReader
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer
from transformers import pipeline

In [39]:
device = 0 if torch.cuda.is_available() else -1
embedder = SentenceTransformer("all-MiniLM-L6-v2", device=("cuda" if device==0 else "cpu"))
generator = pipeline("text2text-generation", model="google/flan-t5-base", device=device)

Device set to use cuda:0


In [40]:
INDEX = None           # FAISS index
DOCS = []              # list[str] chunks
CHUNK_SIZE = 450       # words per chunk (keep modest for prompt size)
CHUNK_OVERLAP = 60     # words overlap
TOP_K = 4              # retrieved chunks per question
HISTORY_TURNS = 4      # how many turns of chat history to feed into prompt

In [41]:
def split_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
    words = (text or "").split()
    chunks = []
    step = max(1, chunk_size - overlap)
    for i in range(0, len(words), step):
        chunk = " ".join(words[i:i+chunk_size])
        if chunk.strip():
            chunks.append(chunk)
    return chunks

def build_faiss_index(chunks):
    # SentenceTransformer returns float32 by default; ensure np.float32 for FAISS
    embeddings = embedder.encode(chunks, convert_to_numpy=True, normalize_embeddings=False)
    embeddings = embeddings.astype(np.float32)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(embeddings)
    return index

def retrieve(question, k=TOP_K):
    global INDEX, DOCS
    if INDEX is None or not DOCS:
        return []
    q_emb = embedder.encode([question], convert_to_numpy=True, normalize_embeddings=False).astype(np.float32)
    D, I = INDEX.search(q_emb, k)
    return [DOCS[i] for i in I[0] if 0 <= i < len(DOCS)]

def format_prompt(history, retrieved_chunks, question):
    # keep last few turns to help follow-ups
    history_str = ""
    if history:
        trimmed = history[-HISTORY_TURNS:]
        for u, a in trimmed:
            history_str += f"User: {u}\nAssistant: {a}\n"

    context = "\n\n---\n".join(retrieved_chunks)

    conversation_part = f"Conversation so far:\n{history_str}" if history_str else ""

    prompt = (
        "You are a helpful assistant that must answer ONLY using the provided context. "
        "If the answer cannot be found in the context, say you don't know.\n\n"
        f"{conversation_part}"
        f"Context:\n{context}\n\n"
        f"User question: {question}\n\n"
        "Give a concise, accurate answer grounded strictly in the context."
    )
    return prompt



In [42]:
def load_pdf(file):
    """Load a PDF file, build chunks+index."""
    global INDEX, DOCS
    if file is None:
        return "⚠️ Please upload a PDF first."
    text = ""
    reader = PdfReader(file.name)
    for page in reader.pages:
        page_text = page.extract_text() or ""
        text += page_text + "\n"
    DOCS = split_text(text)
    if not DOCS:
        INDEX = None
        return "⚠️ Could not extract text from the PDF."
    INDEX = build_faiss_index(DOCS)
    return f"✅ Loaded PDF with {len(DOCS)} chunks."

def load_wikipedia(topic):
    """Load a Wikipedia topic, build chunks+index."""
    global INDEX, DOCS
    topic = (topic or "").strip()
    if not topic:
        return "⚠️ Enter a Wikipedia topic."
    try:
        page = wikipedia.page(topic)
        text = page.content
    except wikipedia.DisambiguationError as e:
        return f"⚠️ Multiple pages found. Try a more specific title. Examples: {e.options[:8]}"
    except wikipedia.PageError:
        return "⚠️ Page not found. Try another title."
    DOCS = split_text(text)
    if not DOCS:
        INDEX = None
        return "⚠️ No content found on that page."
    INDEX = build_faiss_index(DOCS)
    return f"✅ Loaded Wikipedia article with {len(DOCS)} chunks."


In [43]:
def chat_respond(message, history):
    if INDEX is None or not DOCS:
        return "⚠️ First load a PDF or a Wikipedia article (left panel)."
    retrieved = retrieve(message, k=TOP_K)
    if not retrieved:
        return "I don't know based on the available context."
    prompt = format_prompt(history, retrieved, message)
    out = generator(prompt, max_length=256, do_sample=False)[0]["generated_text"].strip()
    return out

In [44]:
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 📚 RAG Chatbot — PDF / Wikipedia\nAnswers are grounded **only** in the loaded content.")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Load Source")

            pdf = gr.File(label="Upload a PDF", file_types=[".pdf"])
            pdf_status = gr.Textbox(label="PDF Status", interactive=False)
            pdf.upload(load_pdf, inputs=pdf, outputs=pdf_status)

            gr.Markdown("**— or —**")

            wiki_box = gr.Textbox(label="Wikipedia Topic (e.g., 'Python (programming language)')")
            wiki_status = gr.Textbox(label="Wikipedia Status", interactive=False)
            wiki_box.submit(load_wikipedia, inputs=wiki_box, outputs=wiki_status)

            gr.Markdown("Tip: After loading, switch to the chat on the right.")

        with gr.Column(scale=2):
            gr.Markdown("### Chat")
            chatbot = gr.Chatbot(height=420)
            msg = gr.Textbox(placeholder="Ask a question about the loaded content…")
            clear_btn = gr.Button("Clear Chat")

            def user_submit(user_message, chat_history):
                # Append user message; assistant reply computed next
                return "", chat_history + [[user_message, None]]

            def bot_reply(chat_history):
                user_message = chat_history[-1][0]
                answer = chat_respond(user_message, chat_history[:-1])
                chat_history[-1][1] = answer
                return chat_history

            # Wire events
            msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
                bot_reply, chatbot, chatbot
            )
            clear_btn.click(lambda: None, None, chatbot, queue=False)

    gr.Markdown(
        "—\n**Notes**: Uses `all-MiniLM-L6-v2` for embeddings + `flan-t5-base` for answers. "
        "If the answer isn't in the context, the assistant will say it doesn't know."
    )


  chatbot = gr.Chatbot(height=420)


In [45]:
demo.launch()

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://3594bce44ee6654a90.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


