In [1]:
import gradio as gr
# httpx==0.24.1

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from bigdl.llm.transformers import AutoModelForCausalLM
from transformers import pipeline, AutoTokenizer
# from langchain.llms.huggingface_pipeline import HuggingFacePipeline 
from bigdl.llm.langchain.embeddings import TransformersEmbeddings
from bigdl.llm.langchain.llms import TransformersLLM
import PyPDF2 # pdf reader
from pypdf import PdfReader
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory
from langchain.text_splitter import CharacterTextSplitter,RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate

In [3]:
embeddings = TransformersEmbeddings.from_model_id(model_id="/data/vicuna-7b-v1.5")

Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.86s/it]
2024-01-20 02:15:07,852 - INFO - Converting the current model to sym_int4 format......


In [6]:
def load_llm():
    # Loads the  DeciLM-7B-instruct llm when called
    llm = TransformersLLM.from_model_id(
        model_id="/data/vicuna-7b-v1.5",
        model_kwargs={"temperature": 0, "max_length": 1024, "trust_remote_code": True},
    )
    
    return llm

In [9]:
def add_text(history, text):
  # Adding user query to the chatbot and chain
  # use history with curent user question
  if not text:
      raise gr.Error('Enter text')
  history = history + [(text, '')]
  return history

In [8]:
def upload_file(files):
  # Loads files when the file upload button is clicked
  # Displays them on the File window
  # print(type(file))
  return files

In [5]:
def process_file(files):

    """Function reads each loaded file, and extracts text from each of their pages
    The extracted text is store in the 'text variable which is the passed to the splitter
    to make smaller chunks necessary for easier information retrieval and adhere to max-tokens(4096) of DeciLM-7B-instruct"""

    pdf_text = ""
    for file in files:
      pdf = PyPDF2.PdfReader(file.name)
      for page in pdf.pages:
          pdf_text += page.extract_text()


    # split into smaller chunks
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=200)
    splits = text_splitter.create_documents([pdf_text])

    # create a FAISS vector store db
    # embedd the chunks and store in the db
    vectorstore_db = FAISS.from_documents(splits, embeddings)

    #create a custom prompt
    custom_prompt_template = """You have been given the following documents to answer the user's question.
    If you do not have information from the files given to answer the questions just say I don't have information from the given files to answer. Do not try to make up an answer.
    Context: {context}
    History: {history}
    Question: {question}

    Helpful answer:
    """
    prompt = PromptTemplate(template=custom_prompt_template, input_variables=["question", "context", "history"])

    # set QA chain with memory
    qa_chain_with_memory = RetrievalQA.from_chain_type(llm=load_llm(),
                                                       chain_type='stuff',
                                                       return_source_documents=True,
                                                       retriever=vectorstore_db.as_retriever(),
                                                       chain_type_kwargs={"verbose": True,
                                                                          "prompt": prompt,
                                                                          "memory": ConversationBufferMemory(
                                                                              input_key="question",
                                                                              memory_key="history",
                                                                              return_messages=True) })
    # get answers
    return qa_chain_with_memory

In [11]:
import time
def generate_bot_response(history,query, btn):
  """Fiunction takes the query, history and inputs from the qa chain when the submit button is clicked
  to generate a response to the query"""

  if not btn:
      raise gr.Error(message='Upload a PDF')

  qa_chain_with_memory = process_file(btn) # run the qa chain with files from upload
  bot_response = qa_chain_with_memory({"query": query})
  # simulate streaming
  for char in bot_response['result']:
          history[-1][-1] += char
          time.sleep(0.05)
          yield history,''

# The GRADIO Interface
with gr.Blocks() as demo:
    with gr.Row():
            with gr.Row():
              # Chatbot interface
              chatbot = gr.Chatbot(label="DeciLM-7B-instruct bot",
                                   value=[],
                                   elem_id='chatbot')
            with gr.Row():
              # Uploaded PDFs window
              file_output = gr.File(label="Your PDFs")

              with gr.Column():
                # PDF upload button
                btn = gr.UploadButton("📁 Upload a PDF(s)",
                                      file_types=[".pdf"],
                                      file_count="multiple")

    with gr.Column():
        with gr.Column():
          # Ask question input field
          txt = gr.Text(show_label=False, placeholder="Enter question")

        with gr.Column():
          # button to submit question to the bot
          submit_btn = gr.Button('Ask')

    # Event handler for uploading a PDF
    btn.upload(fn=upload_file, inputs=[btn], outputs=[file_output])

    # Event handler for submitting text question and generating response
    submit_btn.click(
        fn= add_text,
        inputs=[chatbot, txt],
        outputs=[chatbot],
        queue=False
        ).success(
          fn=generate_bot_response,
          inputs=[chatbot, txt, btn],
          outputs=[chatbot, txt]
        ).success(
          fn=upload_file,
          inputs=[btn],
          outputs=[file_output]
        )

In [25]:
demo.queue().launch(share=True, inbrowser=True,server_name="0.0.0.0",server_port=7860)

Rerunning server... use `close()` to stop if you need to change `launch()` parameters.
----


KeyboardInterrupt: 

In [24]:
demo.close()

KeyboardInterrupt: 