In [3]:
from app.core.config import settings
from app.database.mysql import get_mysql_connection
mysql_connection = get_mysql_connection()

2024-12-21 12:10:35,252 - fastapi_project - INFO - Connected to MySQL database: mydatabase on my-database.cbuom6aeon9v.ap-northeast-2.rds.amazonaws.com


In [6]:
import json

class MySQLMemorySaver:
    def __init__(self, mysql_connection):
        self.connection = mysql_connection
        
    def save(self, session_id, node_key, node_value):
        # node_value를 JSON 직렬화
        value_str = json.dumps(node_value, ensure_ascii=False)
        with self.connection.cursor() as cursor:
            # 동일 session_id, node_key 있으면 업데이트, 없으면 삽입
            sql = """
            INSERT INTO chatbot_sessions (session_id, node_key, node_value) 
            VALUES (%s, %s, %s)
            ON DUPLICATE KEY UPDATE node_value = VALUES(node_value), timestamp = CURRENT_TIMESTAMP
            """
            cursor.execute(sql, (session_id, node_key, value_str))
        self.connection.commit()
        
    def load(self, session_id, node_key):
        with self.connection.cursor() as cursor:
            sql = "SELECT node_value FROM chatbot_sessions WHERE session_id = %s AND node_key = %s"
            cursor.execute(sql, (session_id, node_key))
            result = cursor.fetchone()
            if result:
                return json.loads(result["node_value"])
        return None
        
    def load_all(self, session_id):
        # 세션의 모든 노드 데이터 로드
        with self.connection.cursor() as cursor:
            sql = "SELECT node_key, node_value FROM chatbot_sessions WHERE session_id = %s"
            cursor.execute(sql, (session_id,))
            rows = cursor.fetchall()
            data = {}
            for row in rows:
                data[row["node_key"]] = json.loads(row["node_value"])
            return data

In [8]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, AIMessage
from typing import List, Dict

In [None]:
# MySQLMemorySaver 초기화
memory = MySQLMemorySaver(mysql_connection)

app = FastAPI()

conversation_cache = {}  # 서버 실행 중 세션별 대화 기록 캐시 (옵션)

class UserInput(BaseModel):
    session_id: str
    message: str

def get_conversation(session_id: str):
    # MySQL에서 conversation 로드
    conv = memory.load(session_id, "conversation")
    if conv is None:
        return {"history": []}
    return conv

def update_conversation(session_id: str, user_msg: str, ai_msg: str):
    conv = get_conversation(session_id)
    conv["history"].append({"role": "사용자", "message": user_msg})
    conv["history"].append({"role": "AI", "message": ai_msg})
    memory.save(session_id, "conversation", conv)


def process_user_input(user_input: UserInput):
    session_id = user_input.session_id
    user_msg = user_input.message

    # 현재까지의 대화 불러오기
    conv = get_conversation(session_id)
    # conv -> HumanMessage, AIMessage 리스트로 변환
    messages = []
    for turn in conv["history"]:
        if turn["role"] == "사용자":
            messages.append(HumanMessage(content=turn["message"]))
        else:
            messages.append(AIMessage(content=turn["message"]))

    # 새로운 사용자 메시지 추가
    messages.append(HumanMessage(content=user_msg))

    # LangGraph 실행
    try:
        result = graph.run(messages=messages)
        ai_response = result["output"]
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

    # AI 응답 대화 기록 업데이트
    update_conversation(session_id, user_msg, ai_response)

    # 노드별 데이터 (예: node_x)를 저장 (선택사항)
    # node_key 예: node_(대화횟수)
    node_key = f"node_{len(conv['history'])//2}"  # 사용자+AI 한 쌍 = 대화 1회
    memory.save(session_id, node_key, {
        "input": {"messages": [{"content": user_msg, "type": "HumanMessage"}]},
        "output": {"response": ai_response}
    })

    # 변경된 대화 기록 반환
    return get_conversation(session_id)
