In [None]:
import streamlit as st
from st_multimodal_chatinput import multimodal_chatinput
import openai
from bokeh.models.widgets import Button
from bokeh.models import CustomJS
from streamlit_bokeh_events import streamlit_bokeh_events
from gtts import gTTS
from langchain.chains import history_aware_retriever
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.vectorstores import Chroma
from getpass import getpass
from langchain_openai import OpenAIEmbeddings

import os 
import io
import base64
import sys 

OPEN_AI_KEY = "YOUR_OPENAI_API_KEY"
LANGCHAIN_API_KEY = "YOUR_LANGCHAIN_API_KEY"
LANGCHAIN_PROJECT = "YOUR_PROJECTNAME"
LANGCHAIN_TRACING_V2 = "true"
#os.environ["OPENAI_API_KEY"] = OPEN_AI_KEY
# vector DB path
vector_path = "./capstone/capstone-design/산학캡스톤/index_store"
embeddings = OpenAIEmbeddings()


vector_index = Chroma(persist_directory = vector_path,embedding_function = embeddings)

retriever = vector_index.as_retriever(search_type = "similarity",
                                      search_kwargs={ "k": 3 }
                                      )


st.set_page_config(
    page_title="농업용 코파일럿",
    #page_icon="🧊",
    initial_sidebar_state="expanded",
)

client = openai
client.api_key = st.secrets['openai_api_key']

#응답 요청 함수
def get_completion(prompt):
    messages = []
    for i in st.session_state['chat']:
        messages.append({'content':i.msg, 'role':i.sender})    
    
    user_input = prompt
    input_data = {"history": messages, 'query':user_input}    
    # response = client.chat.completions.create(
    #     model=model,
    #     messages=messages,
    #     temperature=temperature,
    # )
    
    ##### response = qa_chain.run(input_data)

    return input_data ## response

#tts 요청함수
def text_speech(text):
    tts = gTTS(text=text, lang='ko')

    # Save speech to a BytesIO object
    speech_bytes = io.BytesIO()
    tts.write_to_fp(speech_bytes)
    speech_bytes.seek(0)

    # Convert speech to base64 encoding
    b64 = base64.b64encode(speech_bytes.read()).decode('utf-8')
    md = f"""
            <audio id="audioTag" controls autoplay>
            <source src="data:audio/mp3;base64,{b64}"  type="audio/mpeg" format="audio/mpeg">
            </audio>
            """
    st.markdown(
        md,
        unsafe_allow_html=True,
    )
    
#side bar
sidebar = st.sidebar

sidebar.header("Chatbot")
sidebar.text("copilot")

#st.session_state['api_key'] = True
if not st.secrets['openai_api_key']:
    sidebar.error(":x: API 인증 안됨")
else :
    sidebar.success(":white_check_mark: API 인증 완료")

sidebar.subheader("Models and parameters")

model = sidebar.selectbox(
    label="모델 선택",
    options=["gpt-3.5-turbo", "gpt-4-turbo", "모델3"]
)
                    

params = sidebar.expander("Parameters")

#temperature
temperature = params.slider(
    label="temperature",
    min_value=0.01,
    max_value=5.00,
    step=0.01
)

#top_p
top_p = params.slider(
    label="top_p",
    min_value=0.01,
    max_value=1.00,
    step=0.01,
    value=0.90
)

#max_length
max_length = params.slider(
    label= "max_length",
    min_value=32,
    max_value=128,
    step = 1,
    value=120
)

# sidebar.button(
#     label= "Clear Chat History"
# )  

# model setting

llm = ChatOpenAI(temperature=temperature,
                 model_name = model,
                 openai_api_key = st.secrets['openai_api_key']
                 )

qa_chain = RetrievalQA.from_llm(
    llm=llm,
    retriever=retriever,
)

weather_chain = model.bind_tools

#chat
class chat:
    img = None
    msg: str = None
    sender: str = None
    isTTS = None
    def __init__(self, img = None, msg = None, sender = None):
         self.msg = msg
         self.sender = sender
         self.img = img
        

if 'chat' not in st.session_state:
    st.session_state['chat'] = []
    st.session_state['chat'].append(chat(msg = "무엇을 도와드릴까요?", sender='assistant')) ##첫 채팅
        
chatContainer = st.container(height=450)
userInput = multimodal_chatinput()

for i in st.session_state['chat']:
    with chatContainer:
        with st.chat_message(i.sender):
            if i.img:
                st.image(i.img)
            st.write(i.msg)

if "userinput_check" not in st.session_state: #이전에 썼는지 체크
    st.session_state['userinput_check'] = None

if userInput and userInput['text'] != st.session_state['userinput_check']:
    #유저 입력
    chatting = chat()
    if userInput['images']:
        chatting.img = userInput['images']
    chatting.msg = userInput['text']
    chatting.sender = 'user'
    st.session_state['chat'].append(chatting)
    st.session_state['userinput_check'] = userInput['text']
    #메시지 출력
    with chatContainer:
        with st.chat_message('user'):
            if userInput['images']:
                st.image(userInput['images'])
            st.write(userInput['text'])
        # for i in st.session_state['chat']:
        #     with st.chat_message(i.sender):
        #         if i.img:
        #             st.image(i.img)
        #         st.write(i.msg)
    #챗봇

    generation = get_completion(userInput['text'])
    response_message = generation#['generation_result']
    #source_doc = generation['source_doc']

    response = chat()

    response.msg = response_message
    response.sender = 'assistant'
    st.session_state['chat'].append(response)
    #메시지 출력
    with chatContainer:
        with st.chat_message('assistant'):
            st.write(response_message)
    # with chatContainer:
    #     for i in st.session_state['chat']:
    #         if i.sender is 'ai':
    #             with st.chat_message(i.sender):
    #                st.write(i.msg)
    userInput = None


# stt 사용
if "tts_check" not in st.session_state: #이전에 썼는지 체크
    st.session_state['tts_check'] = None


stt_button = Button(label="말하기", width=100, button_type="success")
stt_button.js_on_event("button_click", CustomJS(code="""
    var recognition = new webkitSpeechRecognition();
    recognition.continuous = true;
    recognition.interimResults = true;

    recognition.onresult = function (e) {
        var value = "";
        for (var i = e.resultIndex; i < e.results.length; ++i) {
            if (e.results[i].isFinal) {
                value += e.results[i][0].transcript;
            }
        }
        if ( value != "") {
            document.dispatchEvent(new CustomEvent("GET_TEXT", {detail: value}));
        }
    }
    recognition.start();
    """))


with sidebar:
    result = streamlit_bokeh_events(
        stt_button,
        events="GET_TEXT",
        key="listen",
        refresh_on_update=False,
        override_height=40,
        debounce_time=0,)
        

if result :
    if "GET_TEXT" in result and result.get("GET_TEXT") != st.session_state['tts_check']:
        speech = chat()
        #if result.get("GET_TEXT") != st.session_state['chat'][-1].msg:
        speech.msg = result.get("GET_TEXT")
        speech.sender = 'user'
        st.session_state['chat'].append(speech)
        #유저 메시지 출력
        with chatContainer:
            with st.chat_message('user'):
                st.write(result.get("GET_TEXT"))
        st.session_state['tts_check'] = result.get("GET_TEXT")

        #챗봇
        generation = get_completion(result.get("GET_TEXT"))

        response_message = generation#['generation_result']

        response = chat()
        response.msg = response_message
        response.sender = 'assistant'
        response.isTTS = True
        st.session_state['chat'].append(response)
        
        #챗봇 메시지 출력
        with chatContainer:
            with st.chat_message('assistant'):
                st.write(response_message)
                text_speech(response_message)
        

In [None]:
import json
from langchain import Prompt, Model, Chain
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from weather import (get_weather_forecast)
from langchain_core.output_parsers import StrOutputParser

# 예제 함수 정의
location = location_info
#get_weather_forecast(location)

# Chroma와 retriever 설정
embedding_function = OpenAIEmbeddings()
vectorstore = Chroma(embedding_function)
retriever = vectorstore.as_retriever()

# 일반 정보를 검색하는 함수 정의
def retrieve_general_information(query):
    # 실제 구현에서는 retriever를 사용하여 검색
    results = retriever.retrieve(query)
    return f"Retrieving information for query: {query}, Results: {results}"

# 도구 목록 정의
functions = [
    {
        "name": "get_weather_forecast",
        "description": "Get weather information up to 5 hours in the future using a weather API in a given location. Also this can get current weather information.",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city or dong, e.g. Seoul, Gangnam",
                },
            },
            "required": ["location"],
        },
    },
    {
        "name": "retrieval",
        "description": "Use a retriever to get information from a vectorstore.",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "The query to retrieve information for.",
                },
            },
            "required": ["query"],
        },
    }
]

# 프롬프트와 모델 객체 생성
prompt = "Tell me about the history of Seoul."
llm = ChatOpenAI(temperature=temperature,
                 model_name = model,
                 openai_api_key = st.secrets['openai_api_key']
                 )

# 모델의 출력에서 함수 호출을 결정하고 적절한 함수를 호출하는 체인 정의
def tool_chain(prompt_text):
    output = llm.run(prompt_text)  # 모델 실행
    
    # 추가 kwargs에서 함수 호출 정보 확인
    if output.additional_kwargs.get("tool_calls"):
        available_functions = {
            "get_weather_forecast": get_weather_forecast,
            "retrieval": retrieve_general_information
        }
        
        function_name = output.additional_kwargs["tool_calls"][0]["function"]["name"]
        function_to_call = available_functions[function_name]
        function_args = json.loads(output.additional_kwargs["tool_calls"][0]["function"]["arguments"])
        
        if function_name == "get_weather_forecast":
            function_response = function_to_call(location=function_args.get("location"))
        else:
            function_response = function_to_call(query=function_args.get("query"))
        
        prompt = f"Function response: {function_response}"
        
        function_chain = prompt | model.with_retry() | StrOutputParser()
        output_with_function = function_chain.invoke({})
        
        # function 적용 후 output
        print(output_with_function)
    else:
        # functions를 사용하지 않을 때는 retrieval 사용
        result = retrieve_general_information(prompt_text)
        return result

# 체인 실행
result = tool_chain(prompt)
print(result)  # 예제 출력