In [29]:
import os
import PyPDF2
import nltk
import torch
import gradio as gr
from nltk.tokenize import sent_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModelForCausalLM

In [35]:
nltk.download('punkt')

class SimpleRAG:
    def __init__(self):
        self.documents = []
        self.vectorizer = TfidfVectorizer()
        self.tfidf_matrix = None

        self.tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
        self.model = AutoModelForCausalLM.from_pretrained("NousResearch/Llama-2-7b-chat-hf", device_map="auto", torch_dtype=torch.float16)

    def extract_text_from_pdf(self, pdf_path):
        with open(pdf_path, 'rb') as file:
            reader = PyPDF2.PdfReader(file)
            text = ""
            for page in reader.pages:
                text += page.extract_text() + "\n"
        return text

    def process_uploaded_pdfs(self, pdf_files):
        self.documents = []
        for pdf_file in pdf_files:
            text = self.extract_text_from_pdf(pdf_file.name)
            sentences = sent_tokenize(text)
            self.documents.extend([(sent, pdf_file.name) for sent in sentences])

        texts = [doc for doc, _ in self.documents]
        self.tfidf_matrix = self.vectorizer.fit_transform(texts)

    def retrieve_relevant_context(self, query, top_k=3):
        query_vec = self.vectorizer.transform([query])
        similarities = cosine_similarity(query_vec, self.tfidf_matrix).flatten()
        top_indices = similarities.argsort()[-top_k:][::-1]
        return [self.documents[i] for i in top_indices]

    def generate_answer_from_llama(self, context, query):
        """
        Use Llama2 to generate an answer based on the retrieved context.
        """
        input_prompt = f"Do not use your own knowledge answer only from context provided to you Context: {context}\n\nQuestion: {query}\nAnswer:"

        inputs = self.tokenizer(input_prompt, return_tensors="pt").to("cuda")
        output = self.model.generate(
            **inputs,
            max_length=256,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )
        answer = self.tokenizer.decode(output[0], skip_special_tokens=True)

        if "Answer:" in answer:
            answer = answer.split("Answer:")[1].strip()

        return answer

    def answer_question(self, query):
        relevant_contexts = self.retrieve_relevant_context(query)

        combined_context = " ".join([context for context, _ in relevant_contexts])

        answer = self.generate_answer_from_llama(combined_context, query)

        return answer





[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [31]:
rag = SimpleRAG()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [32]:


def upload_pdfs(pdf_files):
    rag.process_uploaded_pdfs(pdf_files)
    return f"{len(pdf_files)} PDF(s) uploaded and processed successfully."

def ask_question(query):
    answer = rag.answer_question(query)
    return answer



In [33]:
with gr.Blocks() as ui:
    gr.Markdown("# Llama2 RAG Chatbot")

    with gr.Row():
        pdf_files = gr.File(label="Upload PDFs", file_types=[".pdf"], file_count="multiple")
        upload_btn = gr.Button("Upload PDFs")

    upload_output = gr.Textbox(label="PDF Upload Status")
    upload_btn.click(upload_pdfs, inputs=[pdf_files], outputs=[upload_output])

    with gr.Row():
        query = gr.Textbox(label="Ask a Question", placeholder="Type your question here...")
        ask_btn = gr.Button("Ask")

    answer_output = gr.Markdown()
    ask_btn.click(ask_question, inputs=[query], outputs=[answer_output])



In [34]:
if __name__ == "__main__":
    ui.launch()

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://17ff09e5fe84fba572.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
