In [1]:
%%writefile ko_paper_streamlit.py
import streamlit as st
import pandas as pd
from loguru import logger
from langchain.chains import ConversationalRetrievalChain
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from torch import torch
from torch import cuda, bfloat16
import transformers
from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments,
    GenerationConfig,
    pipeline,
)
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.llms import HuggingFacePipeline
from langchain_community.document_loaders import JSONLoader
import json
from pathlib import Path
from pprint import pprint
from langchain_community.document_loaders import DataFrameLoader
from langchain.memory import ConversationBufferMemory
from langchain.vectorstores import FAISS
from torch import cuda, bfloat16
from langchain.callbacks import get_openai_callback
from langchain.memory import StreamlitChatMessageHistory
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import torch
from transformers import GenerationConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel,AutoPeftModelForCausalLM

num=2
adapter_dir = f'./results{num}/final_checkpoint'
output_dir = f'./merged_peft{num}'
output_merged_dir = os.path.join(output_dir, "final_merged_checkpoint")

DEV = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model_name = output_merged_dir
adapter_path = f'./results{num}/final_checkpoint'

def main():
    st.set_page_config(
    page_title="DirChat",
    page_icon=":books:")
    st.title("한국논문검색 챗봇 :red[한국어 전용] :books:")
    if "conversation" not in st.session_state:
        st.session_state.conversation = None

    if "chat_history" not in st.session_state:
        st.session_state.chat_history = None

    if "processComplete" not in st.session_state:
        st.session_state.processComplete = None
    
    if "model_loaded" not in st.session_state:
        model_name = 'intfloat/multilingual-e5-large'
        model_kwargs = {"device": "cuda"}
        encode_kwargs = {'normalize_embeddings': True}
        embeddings = HuggingFaceEmbeddings(
            model_name=model_name,
            model_kwargs=model_kwargs,
            encode_kwargs=encode_kwargs
        )
        vector_db = Chroma(persist_directory="./version_10", embedding_function=embeddings)

        model = AutoPeftModelForCausalLM.from_pretrained(
            adapter_path,
            torch_dtype=torch.float16,
            quantization_config=bnb_config,
            low_cpu_mem_usage=True
        )
        tokenizer = AutoTokenizer.from_pretrained(output_merged_dir)
        tokenizer.eos_token = '</s>'
        pipe = pipeline("text-generation", model=model, max_new_tokens=512, tokenizer=tokenizer, device_map="auto")
        llm = HuggingFacePipeline(pipeline=pipe)

        st.session_state.model_loaded = True
        st.session_state.vector_db = vector_db
        st.session_state.llm = llm
    else:
        vector_db = st.session_state.vector_db
        llm = st.session_state.llm

    st.session_state.conversation = get_conversation_chain(vector_db, llm)
    st.session_state.processComplete = True

    if 'messages' not in st.session_state:
        st.session_state['messages'] = [{"role": "assistant",
                                         "content": "안녕하세요! 찾고 싶은 논문이 있으면 ~~논문 찾아줘라고 요청해보세요!"}]
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    history = StreamlitChatMessageHistory(key="chat_messages")
    if query := st.chat_input("질문을 입력해주세요."):
        st.session_state.messages.append({"role": "user", "content": query})
        with st.chat_message("user"):
            st.markdown(query)
        with st.chat_message("assistant"):
            chain = st.session_state.conversation
            with st.spinner("요청한 논문을 찾는 중입니다..."):
                result = chain(query)
                response = result['answer']
                source_documents = result['source_documents']

                st.markdown(response)

        st.session_state.messages.append({"role": "assistant", "content": response})

def get_conversation_chain(vector_db, llm):
    conversation_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        chain_type="stuff",
        retriever=vector_db.as_retriever(search_type='mmr', search_kwargs={'k': 3, 'fetch_k': 3}, verbose=True),
        memory=ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer'),
        get_chat_history=lambda h: h,
        return_source_documents=True,
        verbose=True
    )
    conversation_chain.combine_docs_chain.llm_chain.prompt.template = """너는 논문 검색 도우미야. 맥락을 참고해서 답해줘. 답변의 시작은 "관련 논문은"으로 시작해줘 .말투는 "~~입니다."로 해.

        {context}

        질문: {question}

        유용한 답변:"""

    return conversation_chain

if __name__ == '__main__':
    main()


Overwriting streamlit_please.py


In [None]:
import streamlit as st
import pandas as pd

In [None]:
!streamlit run ko_paper_streamlit.py --server.port {'your_port'}