In [1]:
%%writefile rag.py

from langchain.prompts import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOllama
from langchain_community.document_loaders import PyPDFLoader

from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.vectorstores.utils import filter_complex_metadata

class ChatPDF:
    vector_store = None
    retriever = None
    chain = None

    def __init__(self):
        self.model = ChatOllama(model="llama3")
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=100)
        self.prompt = PromptTemplate.from_template(
            """
            <s> [INST] You are an assistant for question-answering tasks. Use the following pieces of retrieved context
            to answer the question. If you don't know the answer, just say that you don't know. Use three sentences
             maximum and keep the answer concise. [/INST] </s>
            [INST] Question: {question}
            Context: {context}
            Answer: [/INST]
            """
        )
        # HuggingFaceEmbeddings 신규 추가
        self.embeddings = HuggingFaceEmbeddings(
          model_name='BAAI/bge-m3',
          model_kwargs={'device':'cuda'},
          encode_kwargs={'normalize_embeddings':True},
        )

    def ingest(self, pdf_file_path: str):
        docs = PyPDFLoader(file_path=pdf_file_path).load()
        chunks = self.text_splitter.split_documents(docs)
        chunks = filter_complex_metadata(chunks)

        # FastEmbedEmbeddings -> HuggingFaceEmbeddings 로 변경
        #vector_store = Chroma.from_documents(documents=chunks, embedding=FastEmbedEmbeddings())
        vector_store = Chroma.from_documents(documents=chunks, embedding=self.embeddings)
        self.retriever = vector_store.as_retriever(
            search_type="similarity_score_threshold",
            search_kwargs={
                "k": 3,
                "score_threshold": 0.5,
            },
        )

        self.chain = ({"context": self.retriever, "question": RunnablePassthrough()}
                      | self.prompt
                      | self.model
                      | StrOutputParser())

    def ask(self, query: str):
        if not self.chain:
            return "Please, add a PDF document first."

        return self.chain.invoke(query)

    def clear(self):
        self.vector_store = None
        self.retriever = None
        self.chain = None

Overwriting rag.py


In [4]:
%%writefile app.py

import os
import tempfile
import streamlit as st
from streamlit_chat import message
from rag import ChatPDF

st.set_page_config(page_title="ChatPDF")


def display_messages():
    st.subheader("Chat")
    for i, (msg, is_user) in enumerate(st.session_state["messages"]):
        message(msg, is_user=is_user, key=str(i))
    st.session_state["thinking_spinner"] = st.empty()


def process_input():
    if st.session_state["user_input"] and len(st.session_state["user_input"].strip()) > 0:
        user_text = st.session_state["user_input"].strip()
        with st.session_state["thinking_spinner"], st.spinner(f"Thinking"):
            agent_text = st.session_state["assistant"].ask(user_text)

        st.session_state["messages"].append((user_text, True))
        st.session_state["messages"].append((agent_text, False))


def read_and_save_file():
    st.session_state["assistant"].clear() # 객체 초기화
    st.session_state["messages"] = [] # 이전 대화 메세지 초기화
    st.session_state["user_input"] = "" # 사용자 입력 필드 초기화

    for file in st.session_state["file_uploader"]:
        # 임시 파일 생성
        with tempfile.NamedTemporaryFile(delete=False) as tf: # delete=False는 프로세스 종료 후에도 파일이 유지되도록 설정
            tf.write(file.getbuffer()) # 업로드된 파일의 바이너리 데이터를 임시 파일에 씀
            file_path = tf.name # 임시 파일의 경로를 저장

        with st.session_state["ingestion_spinner"], st.spinner(f"Ingesting {file.name}"):
            st.session_state["assistant"].ingest(file_path)
        os.remove(file_path)


def page():
      if len(st.session_state) == 0:
          st.session_state["messages"] = []
          st.session_state["assistant"] = ChatPDF()

      st.header("ChatPDF")

      st.subheader("Upload a document")
      st.file_uploader(
          "Upload document",
          type=["pdf"],
          key="file_uploader",
          on_change=read_and_save_file,
          label_visibility="collapsed",
          accept_multiple_files=True,
      )

      st.session_state["ingestion_spinner"] = st.empty()

      display_messages()
      st.text_input("Message", key="user_input", on_change=process_input)


if __name__ == "__main__":
    page()
    print("a")

Overwriting app.py
