### Imports

In [None]:
import os
import dill
import pickle
import pandas as pd
from typing import List

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option("display.max_colwidth", None)

from langchain_core.output_parsers import StrOutputParser,JsonOutputParser
from langchain_core.runnables import RunnableLambda

# Data

### Question

In [None]:
df = pd.read_excel("../evaluation/xxxxx.xlsx")
df = df[["index","problem"]]

### Novel data

In [None]:
from components.data import get_novel_data, get_novel_summary, get_novel_docs
texts_dic = get_novel_data()
documents_dic = get_novel_docs(texts_dic,chunk_size = 512, chunk_overlap = 128)
novels_dic = get_novel_summary()

# LLMs

In [None]:
from langchain_openai import ChatOpenAI
llm         = ChatOpenAI(model="gpt-4o-2024-08-06",temperature=0)
llm_4o_mini = ChatOpenAI(model="gpt-4o-mini",temperature=0)
llm_rag     = ChatOpenAI(model="gpt-4o-2024-08-06",temperature=0)

In [None]:
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()

from langchain_huggingface.embeddings import HuggingFaceEmbeddings
embeddings_e5 = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")

# Knowledge DBs

### FAISS for title classification

In [None]:
from langchain_community.vectorstores import FAISS
faiss_list_e5_ = []
for k in range(1,7+1):
    faiss_db_ = FAISS.load_local(
        folder_path = "./dbs/faiss_db_e5",
        index_name = f"faiss_{k}",
        embeddings = embeddings_e5,
        allow_dangerous_deserialization = True
    )
    faiss_list_e5_.append(faiss_db_)

# 統合
faiss_db_e5 = None
for faiss in faiss_list_e5_:
    if faiss_db_e5:
        faiss_db_e5.merge_from(faiss)
    else:
        faiss_db_e5 = faiss

faiss_retriever_e5_for_title_classification = faiss_db_e5.as_retriever(
    search_kwargs = {"k":1}
)

### BM25 for title classification

In [None]:
import requests
from janome.tokenizer import Tokenizer
tokenizer = Tokenizer()

# # 日本語ストップワード辞書
url = "https://raw.githubusercontent.com/stopwords-iso/stopwords-ja/master/stopwords-ja.txt"
stopwords_jp = set(requests.get(url).text.split("\n"))

def remove_stopwords(tokens: List[str]) -> List[str]:
    """指定した単語リストからストップワードを除去した結果を返す"""
    tokens = [token for token in tokens if token not in stopwords_jp]
    return tokens

In [None]:
# BM25
with open('./dbs/bm25.pickle', mode='rb') as f:
    bm25_retriever_for_title_classification = dill.load(f)

### FAISS for RAG

In [None]:
def insert_entity_in_metadata(faiss_db,entity_list):
    doc_ids = list(faiss_db.index_to_docstore_id.values())
    assert len(entity_list) == faiss_db.index.ntotal
    for i in range(len(doc_ids)):
        faiss_db.docstore._dict[doc_ids[i]].metadata["entity"] = entity_list[i]
    return faiss_db

In [None]:
# ベクトル検索
from langchain_community.vectorstores import FAISS
faiss_list_e5 = []
for k in range(1,7+1):
    faiss_db = FAISS.load_local(
        folder_path = "./dbs/faiss_db_e5_v2",
        index_name = f"faiss_{k}",
        embeddings = embeddings_e5,
        allow_dangerous_deserialization = True
    )

    if k in [3,5]:
        save_filename = f"./dbs/entity/{k}.pickle"
        assert os.path.isfile(save_filename)
        
        with open(save_filename, 'rb') as file:
            entity_list = pickle.load(file)

        # insert
        faiss_db = insert_entity_in_metadata(faiss_db,entity_list)
                
    faiss_list_e5.append(faiss_db)

In [None]:
faiss_retriever_e5_list = [
    faiss_db.as_retriever(search_kwargs = {"k":50}) for faiss_db in faiss_list_e5
]

### BM25 for RAG

In [None]:
import dill
save_folder = "./dbs/bm25_db_v2"

bm25_retriever_list = []
for k in range(1,7+1):
    with open(f'{save_folder}/bm25_{k}.pickle', mode='rb') as f:
        bm25_retriever = dill.load(f)

    if k in [3,5]:
        save_filename = f"./dbs/entity/{k}.pickle"
        assert os.path.isfile(save_filename)
        
        with open(save_filename, 'rb') as file:
            entity_list = pickle.load(file)

        for i in range(len(bm25_retriever.docs)):
            bm25_retriever.docs[i].metadata["entity"] = entity_list[i]
            
    try:
        bm25_retriever.k = 50
    except:
        pass
    bm25_retriever_list.append(bm25_retriever)

# Runnables

### Identify Novel Title Chain

In [None]:
# keyword extraction from title
from components.novel_classify import novel_find_keyword_prompt,find_most_common,pred_title_by_counts,match_title,majority_vote
keyword_extraction_for_identify_novel_runnable = novel_find_keyword_prompt | llm | JsonOutputParser()

In [None]:
# Router
from components.router import RouteQuery,router_prompt
structured_llm_router = llm_4o_mini.with_structured_output(RouteQuery)
router_runnable = router_prompt | structured_llm_router

In [None]:
# Keyword Counter in novels
from components.keyword_counter import make_keyword_prompt,func_count_string_all_text
make_keyword_runnable = make_keyword_prompt | llm_4o_mini | StrOutputParser()

### Self-Query

In [None]:
from components.metadata_filter import entity_filter_prompt,has_common_element,entiry_filter
entity_filter_runnable = entity_filter_prompt | llm_4o_mini | JsonOutputParser()

### Reranking

In [None]:
from langchain_cohere import CohereRerank
cohere_reranker = CohereRerank(model="rerank-multilingual-v3.0", top_n = 1000)

### Python Agent

In [None]:
from components.python_repl import python_prompt,create_python_repl_llm,get_python_repl_executed_results

### Summary

In [None]:
from components.summary import summary_answer_prompt
summary_answer_runnable = summary_answer_prompt | llm | JsonOutputParser()

### RAG

In [None]:
# RAG chain
from components.utils import format_docs,reciprocal_rank_fusion
from components.rag import rag_prompt
rag_chain = rag_prompt | llm_rag | StrOutputParser()

### Hyde

In [None]:
from components.hyde import hyde_prompt, get_novel_info
hyde_chain = hyde_prompt | llm_4o_mini | StrOutputParser()

### Document Sort and Refinement

In [None]:
from components.utils import sort_sequential_chunks,group_continuous_positions,merge_document_contents
refine_docs_runnable = (
    RunnableLambda(sort_sequential_chunks) | RunnableLambda(group_continuous_positions) | RunnableLambda(merge_document_contents)
)

### Finalize Answer

In [None]:
from components.finalize_answer import finalize_prompt
finalize_chain = finalize_prompt | llm | StrOutputParser()

## Construct Graph

In [None]:
title2name = {
    '流行暗殺節': '1.txt', '不如帰': '2.txt', 'カインの末裔': '3.txt', '競漕': '4.txt', '芽生': '5.txt', 'サーカスの怪人': '6.txt', '死生に関するいくつかの断想': '7.txt'
}
novel_titles = list(title2name.keys())
    
def identify_novel_title(query):

    # クエリの中にタイトルが含まれているか
    title_in_query = match_title(query)
    
    if title_in_query == "該当なし":
    
        # FaissでTop1
        docs_faiss = faiss_retriever_e5_for_title_classification.invoke("query: "+query)
        docs_faiss = docs_faiss[:1]
        title_faiss = find_most_common(docs_faiss)
    
        # BM25でTop1
        docs_bm25 = bm25_retriever_for_title_classification.invoke(query)
        docs_bm25 = docs_bm25[:1]
        title_bm25 = find_most_common(docs_bm25)
    
        # テキスト検索でTop1
        keywords = keyword_extraction_for_identify_novel_runnable.invoke(query)
        title_keyword = pred_title_by_counts(keywords)
    
        # 多数決で決定
        title = majority_vote([title_faiss,title_bm25,title_keyword])
        
        return title

    else:
        return title_in_query
        
        
# タイトル分類と適切なリトリーバーの選択
def create_chain(inputs):

    # 質問(クエリ)
    query = inputs["query"]

    # タイトルの特定
    title = identify_novel_title(query)

    # ルーティング
    route = router_runnable.invoke(query)

    # 要約文で回答(token節約)
    summary_answer = summary_answer_runnable.invoke({
        "title":title,
        "summary":novels_dic[title2name[title]]["summary"],
        "question":query
    })

    # python replを使うかのflag
    python_agent_search = False

    # 要約で回答できない場合
    if "回答不可" == summary_answer["answer"]:

        ##### Retrieve #####
        if route.datasource == "retrieve":
            
            # Novel index
            index = novel_titles.index(title)
            
            ##### Retrieve Faiss 
            faiss_retriever = faiss_retriever_e5_list[index]
            faiss_docs = faiss_retriever.invoke("query: "+query)
    
            ##### Retrieve BM25
            bm25_retriever = bm25_retriever_list[index]
            bm25_docs = bm25_retriever.invoke(query)
    
            ##### Retrieve Faiss with Hyde 
            title_,author,summary,example_text = get_novel_info(index+1)
            hyde_query = hyde_chain.invoke({
                "title":title_, "author":author, "summary":summary, "example":example_text, "question":query
            })
            hyde_faiss_docs = faiss_retriever.invoke("query: "+hyde_query)


            # 特定の小説だけentity filterを利用
            if title in ['カインの末裔','芽生']:
                entity_filter_out = entity_filter_runnable.invoke({
                    "question":query
                })

                # エンティティ検索
                if entity_filter_out["use_filter"]:
                    entity = entity_filter_out["entity_filter"]
                    
                    if isinstance(entity,str):
                        entity = [entity]
                    elif isinstance(entity,list):
                        pass
                    else:
                        raise ValueError

                    entity = [e for e in entity if e in ["COUNTRY","GEONAME"]]

                    # エンティティ検索を行う場合、今回はpython replの利用をトリガー
                    if len(entity)>0:
                        python_agent_search=True
                else:
                    pass
                
            # python replの利用
            if python_agent_search:
                docs_ = refine_docs_runnable.invoke(sum([faiss_docs,bm25_docs,hyde_faiss_docs],[]))
                repl_tool,python_repl_runnable = create_python_repl_llm(local_val=format_docs(docs_),llm=llm)
                tool_call_response = python_repl_runnable.invoke({"query":query})
                merge_context = get_python_repl_executed_results(tool_call_response,repl_tool)
            # Retrieve
            else:
                # Reranking
                faiss_docs         = cohere_reranker.compress_documents(faiss_docs[:20],query)
                bm25_docs          = cohere_reranker.compress_documents(bm25_docs[:20],query)
                hyde_faiss_docs    = cohere_reranker.compress_documents(hyde_faiss_docs[:20],query)

                # 各上位10件でRRF
                merge_docs = reciprocal_rank_fusion([faiss_docs[:10],bm25_docs[:10],hyde_faiss_docs[:10]])
                # merge_docs = refine_docs_runnable.invoke(merge_docs)
                merge_docs = merge_docs[:10]
                
                # formatting for input to rag prompt
                merge_context = format_docs(merge_docs)
                
            # generate answer
            merge_answer = rag_chain.invoke({"context": merge_context, "question": query})
            
            # 参照文書の保存(ログ用)
            references = [merge_answer+"<><><>"+merge_context]

            # Finalize Answerに渡すための回答を整形
            answers = "\n".join([
                f"- Agent{i+1}: {ans}" for i,ans in enumerate([
                    merge_answer #, faiss_answer, bm25_answer, profile_faiss_answer, hyde_faiss_answer
                ])])
    
        ##### Keyword Count #####
        elif route.datasource == "count_string":
            # 検索キーワードの作成
            keyword = make_keyword_runnable.invoke(query)
            # 検索対象となるテキストの取り出し
            text = texts_dic[title2name[title]]["content"]
            # 検索ヒット数
            n_keyword = func_count_string_all_text(keyword,text)
            # Finalize Answerに渡すための回答を整形
            answers = "\n".join([f"- Agent{i+1}: {ans}" for i,ans in enumerate([n_keyword])])
            # 参照文書(ログ用)
            references = [n_keyword]
        else:
            raise NotImplementedError

    else:
        # Finalize Answerに渡すための回答を整形
        answers = f"- Agent1: {summary_answer['answer']}"
        # 参照文書(ログ用)
        references = summary_answer["reason"]

    
    # 最終回答の作成
    final_answer = finalize_chain.invoke({
        "question": query, "answers": answers, 
    })
      
    return {
        "query": query,
        "title": title,
        "answers": answers,
        "final_answer": final_answer,
        "references": references
    }

In [None]:
# RAG全体のRunnableを作成
main_chain = RunnableLambda(create_chain)

# debug用
title_chain = RunnableLambda(identify_novel_title)

In [None]:
# from IPython.display import Image, display
# display(Image(main_chain.get_graph().draw_mermaid_png()))

## Execute the whole RAG Runnable

In [None]:
# batch processing
answers = main_chain.batch([{"query":df.iloc[i]["problem"]]} for i in range(len(df))])

In [None]:
# 最終結果
first_answers = [ans["answers"] for ans in answers]
final_answers = [ans["final_answer"] for ans in answers]
references    = [str(ans["references"]).replace("\n","").replace("\\n","") for ans in answers]

In [None]:
import tiktoken
def get_num_tokens(text):
    """You can get the number of tokens in a text."""
    encoding = tiktoken.encoding_for_model('gpt-4o-2024-08-06')
    num_tokens = len(encoding.encode(text))
    return num_tokens

sorted([get_num_tokens(ans) for ans in final_answers])[::-1][:10]

In [None]:
df["予測"] = final_answers
df["コンテキスト"] = references

In [None]:
save_file = "../evaluation/submit/2024xxxx-xx:xx.csv"
assert not os.path.exists(save_file)
df[["index","予測","コンテキスト"]].to_csv(save_file,index=False,header=False)