## Library

In [None]:
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
import os
import torch
from langchain.vectorstores import FAISS
import pandas as pd
from tqdm import tqdm
from pydantic import BaseModel, Field
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from typing import List, Optional, Callable
from typing_extensions import TypedDict, Annotated
from langgraph.graph import END, StateGraph, START
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from typing import Any, Dict, List, Callable


from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report,
)

# API KEY 정보로드
load_dotenv()

In [None]:
seed = 1
retrieve_k = 3

## Data

In [None]:
train = pd.read_csv(f"../seed{seed}/seed{seed}_train_쇼핑.csv")
validation = pd.read_csv(f"../seed{seed}/seed{seed}_validation_쇼핑.csv")
test = pd.read_csv(f"../seed{seed}/seed{seed}_test_쇼핑.csv")

## Vectorstore

In [None]:
device = torch.device("cpu")


class KUREEmbedding:
    def __init__(self, model_name="nlpai-lab/KURE-v1"):
        self.model = SentenceTransformer(model_name, trust_remote_code=True).to(device)

    def embed_documents(self, texts):
        embeddings = self.model.encode(texts, convert_to_numpy=True)
        return embeddings

    def embed_query(self, text):
        return self.embed_documents([text])[0]


class KoE5Embedding(KUREEmbedding):
    def __init__(self, model_name="nlpai-lab/KoE5"):
        super().__init__(model_name)

In [None]:
vectorstore_path = f"../seed{seed}/faiss_index_seed{seed}_koe5"
if os.path.exists(vectorstore_path):
    embeddings = KoE5Embedding()
    vectorstore = FAISS.load_local(
        vectorstore_path,
        embeddings.embed_query,
        allow_dangerous_deserialization=True,
    )
    retriever = vectorstore.as_retriever(search_kwargs={"k": retrieve_k})

## Classifier

In [None]:
class CategoryClassification(BaseModel):
    prediction: str = Field(description="Predicted category of the user query")


llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
structured_llm_grader = llm.with_structured_output(CategoryClassification)

## Prompt & Chain

In [None]:
classification_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Classify the query into one of: 제품, 배송, 교환/반품/환불, 행사, AS, 포장, 구매, 웹사이트.\n"
            "Return only one category exactly as listed above. No other categories or explanations.\n"
            'Return in JSON: {{"prediction": "category"}}',
        ),
        ("human", "Query: {question}\nRelevant cases: {documents}"),
    ]
)


classification_grader = classification_prompt | structured_llm_grader

## Heuristic Filtering

In [None]:
class GradeDocuments(BaseModel):
    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )


structured_llm_grader = llm.with_structured_output(GradeDocuments)

## Heuristic Filtering Prompt & Chain

In [None]:
system = """You are a grader assessing relevance of a retrieved document to a user question. \n 
    It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""

grade_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Retrieved document: \n\n {documents} \n\n User question: {question}",
        ),
    ]
)

retrieval_grader = grade_prompt | structured_llm_grader

## Summary

In [None]:
system = """
    - Label 값이 배송이면, 해당 문의가 배송 관련 문의에 포함된다고 요약하세요.
    - Label 값이 웹사이트이면, 해당 문의가 웹사이트 관련 문의에 포함된다고 요약하세요.
    - Label 값이 행사이면, 해당 문의가 행사 관련 문의에 포함된다고 요약하세요.
    - Label 값이 구매이면, 해당 문의가 구매 관련 문의에 포함된다고 요약하세요.
    - Label 값이 AS이면, 해당 문의가 AS 관련 문의에 포함된다고 요약하세요.
    - Label 값이 교환/반품/환불이면, 해당 문의가 교환/반품/환불 관련 문의에 포함된다고 요약하세요.
    - Label 값이 제품이면, 해당 문의가 제품 관련 문의에 포함된다고 요약하세요.
    - Label 값이 포장이면, 해당 문의가 포장 관련 문의에 포함된다고 요약하세요.  
    다음 문장들을 위 기준에 맞춰 요약하세요.
     """

summary_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        (
            "human",
            "Here is the initial question: \n\n {documents} \n Formulate an improved question.",
        ),
    ]
)

summary_rewriter = summary_prompt | llm | StrOutputParser()

## LangGraph State

In [None]:
class GraphState(TypedDict):
    question: Annotated[str, "Question"]
    generation: Annotated[str, "LLM Generation"]
    documents: Annotated[Optional[List[str]], "Retrieved Documents"]
    prediction: Annotated[str, "prediction result"]

## Node

In [None]:
def retrieve(state):
    question = state["question"]
    documents = retriever.invoke(question)
    return {"documents": documents}


def generate(state):
    question = state["question"]
    documents = state["documents"]
    generation = classification_grader.invoke(
        {"question": question, "documents": documents}
    )
    prediction = generation.prediction
    return {"generation": generation, "prediction": prediction}


def grade_documents(state):
    question = state["question"]
    documents = state["documents"]
    filtered_docs = []
    for d in documents:
        score = retrieval_grader.invoke(
            {"question": question, "documents": d.page_content}
        )
        grade = score.binary_score
        if grade == "yes":
            filtered_docs.append(d)
        else:
            continue
    return {"documents": filtered_docs}


def summary(state):
    documents = state["documents"]
    doc_summary = summary_rewriter.invoke({"documents": documents})
    return {"documents": doc_summary}

## Conditional Edge

In [None]:
def decide_to_generate(state):
    state["question"]
    filtered_documents = state["documents"]
    if not filtered_documents:
        return "not relevant"
    else:
        return "relevant"

## Graph Node Generation

In [None]:
workflow = StateGraph(GraphState)

workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("summary", summary)

workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_edge("summary", "generate")

workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "relevant": "summary",
        "not relevant": "generate",
    },
)

workflow.add_edge("generate", END)

app = workflow.compile()

In [None]:
def invoke_graph_to_dataframe(
    graph: CompiledStateGraph,
    inputs: dict,
    config: RunnableConfig,
    node_names: List[str] = [],
    callback: Callable = None,
):
    result_data = []

    def format_namespace(namespace):
        return namespace[-1].split(":")[0] if len(namespace) > 0 else "root graph"

    for namespace, chunk in graph.stream(
        inputs, config, stream_mode="updates", subgraphs=True
    ):
        for node_name, node_chunk in chunk.items():
            if len(node_names) > 0 and node_name not in node_names:
                continue
            if callback is not None:
                callback({"node": node_name, "content": node_chunk})
            formatted_namespace = format_namespace(namespace)

            if isinstance(node_chunk, dict):
                for k, v in node_chunk.items():
                    result_data.append(
                        {
                            "namespace": formatted_namespace,
                            "node_name": node_name,
                            "key": k,
                            "value": v if not isinstance(v, list) else str(v),
                        }
                    )
            else:
                if node_chunk is not None:
                    for item in node_chunk:
                        result_data.append(
                            {
                                "namespace": formatted_namespace,
                                "node_name": node_name,
                                "key": None,
                                "value": item,
                            }
                        )
    df = pd.DataFrame(result_data)
    result_df = pd.DataFrame(
        columns=[
            "prediction",
        ]
    )

    result_df["prediction"] = df[df["key"] == "prediction"]["value"].values
    return result_df

## Prediction

In [None]:
results = []
config = RunnableConfig(recursion_limit=10, configurable={"thread_id": 1234})

for idx, row in tqdm(test.iterrows(), total=len(test), desc="Processing"):
    question = row["text"]
    answer = row["category"]
    inputs = {"question": question}
    df_results = invoke_graph_to_dataframe(
        app,
        inputs,
        config,
    )
    df_results["category"] = answer
    results.append(df_results)

final_df = pd.concat(results, ignore_index=True)

## Evaluation

In [None]:
y_true = final_df["category"]
y_pred = final_df["prediction"]

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average="macro")
recall = recall_score(y_true, y_pred, average="macro")
f1_macro = f1_score(y_true, y_pred, average="macro")
f1_weighted = f1_score(y_true, y_pred, average="weighted")

conf_matrix = confusion_matrix(y_true, y_pred)

print("\n===== Classification Performance Results =====")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-score (Macro): {f1_macro:.4f}")
print(f"F1-score (Weighted): {f1_weighted:.4f}")

print("\n===== Classification Confusion Matrix =====")
print(conf_matrix)

print("\n===== Detailed Classification Report =====")
print(classification_report(y_true, y_pred))