## import

In [3]:
import os
os.environ["OPENAI_API_KEY"] = "sk-"

import warnings
warnings.simplefilter('ignore')

import re
import json
import random
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from functools import partial

# config
seed = 0
np.random.seed(seed)

warnings.simplefilter('ignore')
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option("display.max_colwidth", None)

In [None]:
# from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain_openai import OpenAIEmbeddings
embedding_model = OpenAIEmbeddings(
    model = 'text-embedding-ada-002',
    # chunk_size = 1000
)

## 前処理

In [124]:
from preprocess import preprocessing,remove_stopwords

In [125]:
from janome.tokenizer import Tokenizer
tokenizer = Tokenizer()
def preprocess_jp_with_morph(text: str) -> List[str]:
    """日本語テキストを前処理してトークンリストを返す"""

    # 改行コードの除去
    # text = text.replace("\n", "")

    # Janomeで特定の品詞のみのトークンリストを作成
    pos_list = ["名詞", "動詞", "形容詞"]
    tokens = tokenizer.tokenize(text)

    # 指定した品詞に一致する単語を抽出
    word_list = [token.surface for token in tokens if token.part_of_speech.split(',')[0] in pos_list]
    
    # ストップワードを削除
    word_list = remove_stopwords(word_list)

    return word_list

In [126]:
from langchain_core.documents import Document
def series2doc(data):
    document = Document(
        page_content=data["Reference"],
        metadata={"id": data["ID"]}
    )
    return document

## Retrieverの準備

In [209]:
from langchain.retrievers.bm25 import BM25Retriever
# 形態素解析(名詞,動詞,形容詞)
bm25_retriever = BM25Retriever.from_documents(
    documents,
    preprocess_func=preprocess_jp_with_morph,
)
bm25_retriever.k = 50

In [210]:
# query = 'パスワードを忘れた'
# docs = bm25_retriever.invoke(input=query)
# docs[0]

## データセット読み込み（動作確認としてBM25で実行）

In [None]:
from langchain.retrievers.bm25 import BM25Retriever
from evaluation import average_recall_at_k

In [236]:
# data_dir = "./RAG評価データセット/dataset"
data_dir = "./dataset"
data_list = [
    # "01_Abema",
    # "02_Tokyo",
    "03_Wikipedia",
    # "04_Amagasaki",
    # data_05 = "05_mr-tydi" # データが大き過ぎる
]

In [None]:
for data_path in data_list:
    print(f"TEST on {data_path} begins.")
    
    # データの読み込み
    print(f"1. Reading Data...")
    df_qa = pd.read_csv(os.path.join(data_dir,data_path,"QA.csv"),index_col=0)
    df_ref = pd.read_csv(os.path.join(data_dir,data_path,"Reference.csv"),index_col=0)

    # IDはstring型に変換
    if not isinstance(df_qa["ID"].iloc[0],str):
        df_qa["ID"] = df_qa["ID"].map(str)
    
    # 正解idが複数ある場合（リストとして読み込まれないためリスト型に変換）
    if "[" in str(df_qa["ID"].iloc[0]):
        # isinstance(df_qa["ID"].iloc[0],str) & (
        import ast
        df_qa["ID"] = df_qa["ID"].apply(ast.literal_eval)
        print("適用")
    
    # 参照ドキュメントの作成
    documents = df_ref.apply(series2doc,axis=1)
    print(f"   N of Docs: {len(documents)}")
    
    ##### 変更 ##############################################
    print(f"2. Building Retriever...")
    # 形態素解析(名詞,動詞,形容詞)
    retriever = BM25Retriever.from_documents(
        documents,
        preprocess_func=preprocess_jp_with_morph,
    )
    retriever.k = 50

    # Save BM25Retriever
    # import json
    # with open('BM25Retriever.json', 'w') as f:
    #     data = bm25_retriever.to_json()
    #     json.dump(data,f)
    ########################################################

    # 各QAごとに検索結果を取得
    print(f"3. Evaluating...")
    y_true_list,y_pred_list = [],[]
    N_ALL_DOCS = df_ref.nunique()
    for i in tqdm(range(len(df_qa))):
        
        # queryと正解refのidを取得
        data = df_qa.iloc[i]
        id_gt = data["ID"]
        query = data["Query"]
    
        # 検索
        docs = retriever.invoke(input=query)
        # 検索結果のidリストを作成
        y_pred = [str(doc.metadata['id']) for doc in docs]
    
        # 正解idをリストで追加
        if not isinstance(id_gt,list):
            id_gt = [id_gt]
        y_true_list.append(id_gt)
        
        # 検索結果のidを追加
        y_pred_list.append(y_pred)

        # recall@k
        k_list = [1,3,5,10,15,20,50]
        index_list,recall_list = [],[]
        for k in k_list:
            
            index = f"Average recall@{k}"
            avg_recall_at_k = average_recall_at_k(y_true_list, y_pred_list, k)
            avg_recall_at_k = round(avg_recall_at_k,3)
            
            index_list.append(index)
            recall_list.append(avg_recall_at_k)
        
    df_result = pd.DataFrame(recall_list,index=index_list,columns=["value"])
    display(df_result)

# memo

In [None]:
# from functools import partial
# from langchain.retrievers.bm25 import BM25Retriever

# # シノニム拡張なしa
# bm25_retriever = BM25Retriever.from_texts(documents,preprocess_func=partial(preprocess_jp, sysnonim_extention=False))
# bm25_retriever.k = 10
# docs_without_sysnonim_extention = bm25_retriever.get_relevant_documents(query)

# # シノニム拡張あり
# bm25_retriever = BM25Retriever.from_texts(documents,preprocess_func=partial(preprocess_jp, sysnonim_extention=False))
# # 検索用のシノニム拡張ありverの前処理を設定
# bm25_retriever.preprocess_func = partial(preprocess_jp, sysnonim_extention=True) 
# bm25_retriever.k = 10
# docs_with_sysnonim_extention = bm25_retriever.get_relevant_documents(query)

In [4]:
# from langchain.vectorstores import FAISS
# db = FAISS.from_documents(docs, embedding_model)
# db.save_local("./2030JapanDigitalInnovation_byMcKinsey")

In [10]:
# from langchain.vectorstores import FAISS
# db = FAISS.load_local("./2030JapanDigitalInnovation_byMcKinsey", embedding_model)
# faiss_retriever = db.as_retriever()

In [244]:
# faiss_retriever_small = db_512.as_retriever() # .as_retriever(search_kwargs={"k": 7})
# faiss_retriever_small.search_kwargs['k'] = 7

In [None]:
# import MeCab
# def preprocess_jp(text: str) -> List[str]:
#     """日本語テキストを前処理してトークンリストを返す"""

#     # 改行コードの除去
#     text = text.replace("\n", "")

#     # Mecabで特定の品詞のみのトークンリストを作成
#     pos_list = ["名詞", "動詞", "形容詞"]
#     tagger = MeCab.Tagger()
#     node = tagger.parseToNode(text)

#     word_list = []
#     while node:
#         pos = node.feature.split(",")[0]
#         if pos in pos_list:
#             word = node.surface
#             word_list.append(word)
#         node = node.next
#     return remove_stopwords(word_list)

In [103]:
# bm25_search = BM25Retriever(metadata=metadata,docs=docs,preprocess_func=preprocess_jp) # でもいい。
bm25_retriever = BM25Retriever_.from_texts(texts,metadata,preprocess_func=preprocess_jp)
bm25_retriever.k = 5

In [141]:
# bm25_retriever.get_relevant_documents("マッキンゼー社とはについて解説してください。")