In [1]:
!pip install gradio PyPDF2 python-docx transformers faiss-cpu



In [3]:
!pip install datasets



In [21]:
import gradio as gr
import PyPDF2
import docx
import os
import pandas as pd
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, DPRQuestionEncoderTokenizer, DPRQuestionEncoder
from datasets import Dataset, Features, Value, Sequence
import faiss
import numpy as np
import torch

# Constants
SUPPORTED_FILE_TYPES = [".txt", ".pdf", ".docx"]
MAX_FILE_SIZE = 10 * 1024 * 1024  # 10 MB

# Initialize the correct tokenizer and model
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
question_encoder_model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")

dimension = 768  # Dimension of embeddings

# FAISS index for document embeddings
index = faiss.IndexFlatL2(dimension)

# Document storage
document_texts = []
embeddings_list = []

def parse_document(file_path):
    """Parse the uploaded document based on its file type."""
    if file_path.endswith(".txt"):
        with open(file_path, "r", encoding="utf-8") as f:
            text = f.read()
    elif file_path.endswith(".pdf"):
        with open(file_path, "rb") as f:
            reader = PyPDF2.PdfReader(f)
            text = "".join(page.extract_text() for page in reader.pages)
    elif file_path.endswith(".docx"):
        doc = docx.Document(file_path)
        text = "\n".join([para.text for para in doc.paragraphs])
    else:
        raise ValueError("Unsupported file type.")
    return text

def preprocess_text(text):
    """Preprocess the document text (cleaning, tokenization, etc.)."""
    # Basic cleaning (remove extra spaces, etc.)
    text = " ".join(text.split())
    return text

def add_to_faiss_index(text):
    """Add document embeddings to the FAISS index."""
    inputs = question_encoder_tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    embeddings = question_encoder_model(**inputs).pooler_output.detach().numpy()
    index.add(np.array(embeddings))
    document_texts.append(text)
    embeddings_list.append(embeddings)

def generate_response(query):
    """Generate a response using the RAG model."""
    if not embeddings_list:
        return "Please upload a document first."

    # Encode the query
    query_inputs = question_encoder_tokenizer(query, return_tensors="pt")
    query_embeddings = question_encoder_model(**query_inputs).pooler_output.detach().numpy()

    # Perform FAISS search
    D, I = index.search(query_embeddings, k=1)  # k is the number of nearest neighbors to search for
    retrieved_text = document_texts[I[0][0]]

    # Generate response using RAG model with retrieved context
    inputs = tokenizer(query, return_tensors="pt")
    outputs = model.generate(input_ids=inputs["input_ids"], context_input_ids=tokenizer(retrieved_text, return_tensors="pt")["input_ids"])
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def handle_upload(file):
    """Handle document upload and processing."""
    if not file:
        return "Please upload a document."

    file_size = os.path.getsize(file.name)
    if file_size > MAX_FILE_SIZE:
        return f"File size exceeds the limit of {MAX_FILE_SIZE / (1024 * 1024)} MB."

    file_ext = os.path.splitext(file.name)[1].lower()
    if file_ext not in SUPPORTED_FILE_TYPES:
        return f"Unsupported file type. Supported types: {', '.join(SUPPORTED_FILE_TYPES)}."

    try:
        text = parse_document(file.name)
        text = preprocess_text(text)
        add_to_faiss_index(text)
        return "Document uploaded and processed successfully!"
    except Exception as e:
        return f"Error processing document: {str(e)}"

def chat(query, history):
    """Handle chat interactions."""
    response = generate_response(query)
    history.append((query, response))
    return history

# Gradio Interface
def create_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# Document-Based Chatbot")

        with gr.Row():
            file_input = gr.File(label="Upload Document")
            upload_status = gr.Textbox(label="Upload Status", interactive=False)

        with gr.Row():
            chat_history = gr.Chatbot(label="Chat History")
            message = gr.Textbox(label="Your Question")

        upload_button = gr.Button("Upload Document")
        upload_button.click(handle_upload, inputs=file_input, outputs=upload_status)

        send_button = gr.Button("Send")
        send_button.click(chat, inputs=[message, chat_history], outputs=chat_history)

    return demo

# Launch the app
demo = create_interface()
demo.launch(share=True)  # Use `share=True` to generate a public link

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'DPRQuestionEncoderTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'RagTokenizer'. 
The class this function is called from is 'BartTokenizer'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called fr

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://601a4f030e03997eea.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


