MIchael CHo final project


In [30]:
import os
import gradio as gr
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA

In [None]:
os.environ["OPENAI_API_KEY"] = "REMOVED FOR SECURITY PROTECTION"

In [32]:
pdf_loader = PyPDFLoader("C:/Users/micha/the_nestle_hr_policy_pdf_2012.pdf")
documents = pdf_loader.load()


In [33]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
chunks = text_splitter.split_documents(documents)


In [34]:
embeddings = OpenAIEmbeddings()
vectordb = Chroma.from_documents(chunks, embeddings, persist_directory="chroma_db")
retriever = vectordb.as_retriever(search_kwargs={"k": 3})


In [35]:
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)


In [36]:
template = """
You are an HR assistant chatbot trained on Nestlé’s HR policy.
Answer user questions based ONLY on the HR policy content provided in the context.
If the answer is not in the policy, say "I couldn’t find that information in the HR policy."

Context:
{context}

Question: {question}

Answer:"""

prompt = PromptTemplate(
    template=template, input_variables=["context", "question"]
)

qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=retriever,
    chain_type="stuff",
    chain_type_kwargs={"prompt": prompt},
    return_source_documents=True
)




In [37]:
def chatbot(query):
    result = qa_chain({"query": query})
    answer = result["result"]
    return answer

demo = gr.Interface(
    fn=chatbot,
    inputs=gr.Textbox(label="Ask about Nestlé HR Policy"),
    outputs=gr.Textbox(label="Answer"),
    title="Nestlé HR Policy Chatbot",
    description="Ask questions about Nestlé’s HR policy. Powered by GPT-3.5 and Chroma."
)

if __name__ == "__main__":
    demo.launch()

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


  result = qa_chain({"query": query})
