In [8]:
import json
import os
import datasets
import tqdm
import MeCab
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction.text import CountVectorizer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import random

In [20]:
# query and positives
full_ds = datasets.load_dataset(
    "miracl/miracl", "ja", use_auth_token=os.environ["HF_ACCESS_TOKEN"], split="dev"
)

# all corpus texts
full_corpus = datasets.load_dataset("miracl/miracl-corpus", "ja")



In [7]:
mecab = MeCab.Tagger("-Owakati")

def tokenize_jp(text: str) -> str:
    tokens = mecab.parse(text).split()
    return tokens

In [17]:
len(full_corpus["train"])

6953614

In [24]:
def get_corpus(query_size, corpus_size):
    # ランダムにクエリを選択
    ds = full_ds.select(random.sample(range(len(full_ds)), query_size))

    positive_corpus_json = []
    query_texts = []
    done_docids = set()
    for item in ds:
        query_texts.append(tokenize_jp(item["query"]))
        for pp in item["positive_passages"]:
            if pp["docid"] in done_docids:
                continue
            positive_corpus_json.append({
                "text": tokenize_jp(pp["text"]),
                "docid": pp["docid"]
            })
    positive_docids = set([x["docid"] for x in positive_corpus_json])

    # ランダムにコーパスを選択
    max_corpus_size = corpus_size*2 + len(positive_docids)
    corpus_without_positive = full_corpus["train"].select(random.sample(range(len(full_corpus["train"])), max_corpus_size)).filter(lambda x: x["docid"] not in positive_docids)
    corpus_without_positive_json = [{"docid": doc["docid"], "text": tokenize_jp(doc["text"])} for doc in corpus_without_positive]

    train_corpus = corpus_without_positive_json[:corpus_size]
    test_corpus = positive_corpus_json + corpus_without_positive_json[corpus_size:corpus_size*2-len(positive_corpus_json)]
    assert len(test_corpus) == corpus_size
    assert len(train_corpus) == corpus_size
    return ds, query_texts, train_corpus, test_corpus


In [26]:
CORPUS_SIZE = 500
QUERY_SIZE = 30
train_recalls = []
test_recalls = []
for _ in tqdm.tqdm(range(100)):
    ds, query_texts, train_corpus, test_corpus = get_corpus(QUERY_SIZE, CORPUS_SIZE)
    
    def calc_result(test_corpus, vectorizer, n = 5):
        global query_texts, ds
        
        test_matrix = vectorizer.transform([doc["text"] for doc in test_corpus])
        query_matrix = vectorizer.transform(query_texts)

        # 類似度行列を計算し、queryのdocidのランクを取得
        similarity_matrix = cosine_similarity(query_matrix, test_matrix)
        ranking_matrix = np.argsort(similarity_matrix, axis=1)[:, ::-1]

        test_docid2indice = {item["docid"]: i for i, item in enumerate(test_corpus)}    

        query_result = []
        for item, ranking in zip(ds, ranking_matrix):
            # rankingの何番目にdocidがあるかを取得
            docids = [pp["docid"] for pp in item["positive_passages"]]
            docid_indices = [test_docid2indice[docid] for docid in docids if docid in test_docid2indice]
            ranks = [list(ranking).index(docid_index) for docid_index in docid_indices]
            query_result.append({
                "query_id": item["query_id"],
                "ranks": ranks
            })
        
        # recall@nを計算
        recall_at_n = np.mean([np.mean([1 if rank < n else 0 for rank in item["ranks"]]) for item in query_result])
        return recall_at_n
    full_vocabulary = set()

    for query_text in query_texts:
        full_vocabulary.update(query_text)
    for doc in train_corpus + test_corpus:
        full_vocabulary.update(doc["text"])

    train_vectorizer = TfidfVectorizer(analyzer=lambda x: x, vocabulary=full_vocabulary)
    train_vectorizer.fit([x["text"] for x in train_corpus])

    test_vectorizer = TfidfVectorizer(analyzer=lambda x: x, vocabulary=full_vocabulary)
    test_vectorizer.fit([x["text"] for x in test_corpus])
    
    train_recall = calc_result(test_corpus, train_vectorizer, 5)
    test_recall = calc_result(test_corpus, test_vectorizer, 5)

    train_recalls.append(train_recall)
    test_recalls.append(test_recall)

# 平均と標準偏差を計算
train_recall_mean = np.mean(train_recalls)
train_recall_std = np.std(train_recalls)
test_recall_mean = np.mean(test_recalls)
test_recall_std = np.std(test_recalls)
# 表示
print(f"train recall@5: {train_recall_mean} ± {train_recall_std}")  
print(f"test recall@5: {test_recall_mean} ± {test_recall_std}")   

  0%|          | 0/100 [00:00<?, ?it/s]

Filter:   0%|          | 0/1073 [00:00<?, ? examples/s]

  1%|          | 1/100 [00:01<02:51,  1.73s/it]

Filter:   0%|          | 0/1055 [00:00<?, ? examples/s]

  2%|▏         | 2/100 [00:02<02:16,  1.39s/it]

Filter:   0%|          | 0/1079 [00:00<?, ? examples/s]

  3%|▎         | 3/100 [00:04<02:08,  1.33s/it]

Filter:   0%|          | 0/1049 [00:00<?, ? examples/s]

  4%|▍         | 4/100 [00:05<01:56,  1.21s/it]

Filter:   0%|          | 0/1065 [00:00<?, ? examples/s]

  5%|▌         | 5/100 [00:06<01:48,  1.15s/it]

Filter:   0%|          | 0/1062 [00:00<?, ? examples/s]

  6%|▌         | 6/100 [00:07<01:48,  1.15s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

  7%|▋         | 7/100 [00:08<01:45,  1.13s/it]

Filter:   0%|          | 0/1053 [00:00<?, ? examples/s]

  8%|▊         | 8/100 [00:09<01:44,  1.14s/it]

Filter:   0%|          | 0/1065 [00:00<?, ? examples/s]

  9%|▉         | 9/100 [00:10<01:39,  1.09s/it]

Filter:   0%|          | 0/1060 [00:00<?, ? examples/s]

 10%|█         | 10/100 [00:11<01:35,  1.06s/it]

Filter:   0%|          | 0/1054 [00:00<?, ? examples/s]

 11%|█         | 11/100 [00:12<01:33,  1.05s/it]

Filter:   0%|          | 0/1065 [00:00<?, ? examples/s]

 12%|█▏        | 12/100 [00:13<01:31,  1.05s/it]

Filter:   0%|          | 0/1060 [00:00<?, ? examples/s]

 13%|█▎        | 13/100 [00:14<01:27,  1.01s/it]

Filter:   0%|          | 0/1068 [00:00<?, ? examples/s]

 14%|█▍        | 14/100 [00:15<01:28,  1.03s/it]

Filter:   0%|          | 0/1058 [00:00<?, ? examples/s]

 15%|█▌        | 15/100 [00:16<01:25,  1.00s/it]

Filter:   0%|          | 0/1060 [00:00<?, ? examples/s]

 16%|█▌        | 16/100 [00:17<01:25,  1.02s/it]

Filter:   0%|          | 0/1063 [00:00<?, ? examples/s]

 17%|█▋        | 17/100 [00:18<01:26,  1.04s/it]

Filter:   0%|          | 0/1057 [00:00<?, ? examples/s]

 18%|█▊        | 18/100 [00:19<01:23,  1.02s/it]

Filter:   0%|          | 0/1067 [00:00<?, ? examples/s]

 19%|█▉        | 19/100 [00:20<01:28,  1.09s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

 20%|██        | 20/100 [00:22<01:33,  1.17s/it]

Filter:   0%|          | 0/1086 [00:00<?, ? examples/s]

 21%|██        | 21/100 [00:23<01:32,  1.17s/it]

Filter:   0%|          | 0/1065 [00:00<?, ? examples/s]

 22%|██▏       | 22/100 [00:24<01:38,  1.26s/it]

Filter:   0%|          | 0/1066 [00:00<?, ? examples/s]

 23%|██▎       | 23/100 [00:26<01:44,  1.36s/it]

Filter:   0%|          | 0/1060 [00:00<?, ? examples/s]

 24%|██▍       | 24/100 [00:28<01:47,  1.42s/it]

Filter:   0%|          | 0/1064 [00:00<?, ? examples/s]

 25%|██▌       | 25/100 [00:29<01:42,  1.37s/it]

Filter:   0%|          | 0/1067 [00:00<?, ? examples/s]

 26%|██▌       | 26/100 [00:30<01:36,  1.30s/it]

Filter:   0%|          | 0/1068 [00:00<?, ? examples/s]

 27%|██▋       | 27/100 [00:31<01:37,  1.34s/it]

Filter:   0%|          | 0/1062 [00:00<?, ? examples/s]

 28%|██▊       | 28/100 [00:33<01:32,  1.29s/it]

Filter:   0%|          | 0/1069 [00:00<?, ? examples/s]

 29%|██▉       | 29/100 [00:34<01:28,  1.25s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

 30%|███       | 30/100 [00:35<01:29,  1.28s/it]

Filter:   0%|          | 0/1058 [00:00<?, ? examples/s]

 31%|███       | 31/100 [00:36<01:26,  1.25s/it]

Filter:   0%|          | 0/1077 [00:00<?, ? examples/s]

 32%|███▏      | 32/100 [00:38<01:25,  1.26s/it]

Filter:   0%|          | 0/1066 [00:00<?, ? examples/s]

 33%|███▎      | 33/100 [00:39<01:18,  1.17s/it]

Filter:   0%|          | 0/1047 [00:00<?, ? examples/s]

 34%|███▍      | 34/100 [00:40<01:13,  1.12s/it]

Filter:   0%|          | 0/1080 [00:00<?, ? examples/s]

 35%|███▌      | 35/100 [00:41<01:15,  1.17s/it]

Filter:   0%|          | 0/1062 [00:00<?, ? examples/s]

 36%|███▌      | 36/100 [00:42<01:11,  1.12s/it]

Filter:   0%|          | 0/1060 [00:00<?, ? examples/s]

 37%|███▋      | 37/100 [00:43<01:14,  1.19s/it]

Filter:   0%|          | 0/1082 [00:00<?, ? examples/s]

 38%|███▊      | 38/100 [00:44<01:14,  1.20s/it]

Filter:   0%|          | 0/1062 [00:00<?, ? examples/s]

 39%|███▉      | 39/100 [00:45<01:09,  1.14s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

 40%|████      | 40/100 [00:47<01:08,  1.14s/it]

Filter:   0%|          | 0/1052 [00:00<?, ? examples/s]

 41%|████      | 41/100 [00:48<01:06,  1.13s/it]

Filter:   0%|          | 0/1063 [00:00<?, ? examples/s]

 42%|████▏     | 42/100 [00:49<01:04,  1.11s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

 43%|████▎     | 43/100 [00:50<01:04,  1.13s/it]

Filter:   0%|          | 0/1076 [00:00<?, ? examples/s]

 44%|████▍     | 44/100 [00:51<01:01,  1.10s/it]

Filter:   0%|          | 0/1056 [00:00<?, ? examples/s]

 45%|████▌     | 45/100 [00:52<01:02,  1.13s/it]

Filter:   0%|          | 0/1073 [00:00<?, ? examples/s]

 46%|████▌     | 46/100 [00:53<01:05,  1.21s/it]

Filter:   0%|          | 0/1068 [00:00<?, ? examples/s]

 47%|████▋     | 47/100 [00:55<01:01,  1.17s/it]

Filter:   0%|          | 0/1062 [00:00<?, ? examples/s]

 48%|████▊     | 48/100 [00:56<00:57,  1.11s/it]

Filter:   0%|          | 0/1077 [00:00<?, ? examples/s]

 49%|████▉     | 49/100 [00:57<00:58,  1.15s/it]

Filter:   0%|          | 0/1066 [00:00<?, ? examples/s]

 50%|█████     | 50/100 [00:58<00:55,  1.12s/it]

Filter:   0%|          | 0/1060 [00:00<?, ? examples/s]

 51%|█████     | 51/100 [00:59<00:52,  1.08s/it]

Filter:   0%|          | 0/1057 [00:00<?, ? examples/s]

 52%|█████▏    | 52/100 [01:00<00:52,  1.08s/it]

Filter:   0%|          | 0/1052 [00:00<?, ? examples/s]

 53%|█████▎    | 53/100 [01:01<00:50,  1.08s/it]

Filter:   0%|          | 0/1077 [00:00<?, ? examples/s]

 54%|█████▍    | 54/100 [01:02<00:48,  1.05s/it]

Filter:   0%|          | 0/1066 [00:00<?, ? examples/s]

 55%|█████▌    | 55/100 [01:03<00:48,  1.08s/it]

Filter:   0%|          | 0/1064 [00:00<?, ? examples/s]

 56%|█████▌    | 56/100 [01:04<00:47,  1.07s/it]

Filter:   0%|          | 0/1051 [00:00<?, ? examples/s]

 57%|█████▋    | 57/100 [01:05<00:45,  1.05s/it]

Filter:   0%|          | 0/1052 [00:00<?, ? examples/s]

 58%|█████▊    | 58/100 [01:06<00:47,  1.13s/it]

Filter:   0%|          | 0/1055 [00:00<?, ? examples/s]

 59%|█████▉    | 59/100 [01:07<00:44,  1.09s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

 60%|██████    | 60/100 [01:08<00:42,  1.05s/it]

Filter:   0%|          | 0/1072 [00:00<?, ? examples/s]

 61%|██████    | 61/100 [01:10<00:42,  1.08s/it]

Filter:   0%|          | 0/1056 [00:00<?, ? examples/s]

 62%|██████▏   | 62/100 [01:11<00:40,  1.06s/it]

Filter:   0%|          | 0/1058 [00:00<?, ? examples/s]

 63%|██████▎   | 63/100 [01:12<00:39,  1.08s/it]

Filter:   0%|          | 0/1050 [00:00<?, ? examples/s]

 64%|██████▍   | 64/100 [01:13<00:39,  1.11s/it]

Filter:   0%|          | 0/1054 [00:00<?, ? examples/s]

 65%|██████▌   | 65/100 [01:14<00:37,  1.06s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

 66%|██████▌   | 66/100 [01:15<00:36,  1.07s/it]

Filter:   0%|          | 0/1066 [00:00<?, ? examples/s]

 67%|██████▋   | 67/100 [01:16<00:36,  1.10s/it]

Filter:   0%|          | 0/1077 [00:00<?, ? examples/s]

 68%|██████▊   | 68/100 [01:17<00:34,  1.08s/it]

Filter:   0%|          | 0/1058 [00:00<?, ? examples/s]

 69%|██████▉   | 69/100 [01:18<00:33,  1.09s/it]

Filter:   0%|          | 0/1081 [00:00<?, ? examples/s]

 70%|███████   | 70/100 [01:19<00:33,  1.12s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

 71%|███████   | 71/100 [01:20<00:31,  1.08s/it]

Filter:   0%|          | 0/1061 [00:00<?, ? examples/s]

 72%|███████▏  | 72/100 [01:21<00:29,  1.05s/it]

Filter:   0%|          | 0/1054 [00:00<?, ? examples/s]

 73%|███████▎  | 73/100 [01:23<00:29,  1.09s/it]

Filter:   0%|          | 0/1051 [00:00<?, ? examples/s]

 74%|███████▍  | 74/100 [01:24<00:28,  1.08s/it]

Filter:   0%|          | 0/1063 [00:00<?, ? examples/s]

 75%|███████▌  | 75/100 [01:25<00:26,  1.05s/it]

Filter:   0%|          | 0/1057 [00:00<?, ? examples/s]

 76%|███████▌  | 76/100 [01:26<00:25,  1.07s/it]

Filter:   0%|          | 0/1061 [00:00<?, ? examples/s]

 77%|███████▋  | 77/100 [01:27<00:24,  1.06s/it]

Filter:   0%|          | 0/1057 [00:00<?, ? examples/s]

 78%|███████▊  | 78/100 [01:28<00:23,  1.08s/it]

Filter:   0%|          | 0/1060 [00:00<?, ? examples/s]

 79%|███████▉  | 79/100 [01:29<00:23,  1.13s/it]

Filter:   0%|          | 0/1067 [00:00<?, ? examples/s]

 80%|████████  | 80/100 [01:30<00:22,  1.10s/it]

Filter:   0%|          | 0/1071 [00:00<?, ? examples/s]

 81%|████████  | 81/100 [01:31<00:20,  1.10s/it]

Filter:   0%|          | 0/1055 [00:00<?, ? examples/s]

 82%|████████▏ | 82/100 [01:32<00:19,  1.07s/it]

Filter:   0%|          | 0/1055 [00:00<?, ? examples/s]

 83%|████████▎ | 83/100 [01:33<00:17,  1.05s/it]

Filter:   0%|          | 0/1057 [00:00<?, ? examples/s]

 84%|████████▍ | 84/100 [01:34<00:17,  1.09s/it]

Filter:   0%|          | 0/1057 [00:00<?, ? examples/s]

 85%|████████▌ | 85/100 [01:36<00:16,  1.08s/it]

Filter:   0%|          | 0/1062 [00:00<?, ? examples/s]

 86%|████████▌ | 86/100 [01:37<00:14,  1.05s/it]

Filter:   0%|          | 0/1071 [00:00<?, ? examples/s]

 87%|████████▋ | 87/100 [01:38<00:15,  1.21s/it]

Filter:   0%|          | 0/1065 [00:00<?, ? examples/s]

 88%|████████▊ | 88/100 [01:40<00:17,  1.42s/it]

Filter:   0%|          | 0/1062 [00:00<?, ? examples/s]

 89%|████████▉ | 89/100 [01:42<00:16,  1.51s/it]

Filter:   0%|          | 0/1053 [00:00<?, ? examples/s]

 90%|█████████ | 90/100 [01:43<00:15,  1.54s/it]

Filter:   0%|          | 0/1067 [00:00<?, ? examples/s]

 91%|█████████ | 91/100 [01:45<00:13,  1.55s/it]

Filter:   0%|          | 0/1052 [00:00<?, ? examples/s]

 92%|█████████▏| 92/100 [01:46<00:12,  1.55s/it]

Filter:   0%|          | 0/1069 [00:00<?, ? examples/s]

 93%|█████████▎| 93/100 [01:48<00:10,  1.54s/it]

Filter:   0%|          | 0/1073 [00:00<?, ? examples/s]

 94%|█████████▍| 94/100 [01:49<00:08,  1.37s/it]

Filter:   0%|          | 0/1059 [00:00<?, ? examples/s]

 95%|█████████▌| 95/100 [01:50<00:06,  1.26s/it]

Filter:   0%|          | 0/1076 [00:00<?, ? examples/s]

 96%|█████████▌| 96/100 [01:51<00:04,  1.23s/it]

Filter:   0%|          | 0/1065 [00:00<?, ? examples/s]

 97%|█████████▋| 97/100 [01:52<00:03,  1.17s/it]

Filter:   0%|          | 0/1055 [00:00<?, ? examples/s]

 98%|█████████▊| 98/100 [01:53<00:02,  1.11s/it]

Filter:   0%|          | 0/1080 [00:00<?, ? examples/s]

 99%|█████████▉| 99/100 [01:54<00:01,  1.11s/it]

Filter:   0%|          | 0/1065 [00:00<?, ? examples/s]

100%|██████████| 100/100 [01:55<00:00,  1.16s/it]

train recall@5: 0.8925901996151994 ± 0.04391918850477444
test recall@5: 0.8848538119288117 ± 0.046527230079738975



