In [1]:
import os
from rank_bm25 import BM25Okapi
import numpy as np
import pickle
from tqdm import tqdm
from konlpy.tag import Mecab
import faiss

import torch
from torch import nn, optim
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mecab = Mecab("C:\mecab\mecab-ko-dic")
def mecab_tokenizer(sent):
    return mecab.morphs(sent)

In [3]:
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)

file_path = os.path.join(parent_directory, 'aihub')

doc_id = []
corpus = []     # 모든 문단 

with open(os.path.join(file_path, 'collection.tsv'), 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.split('||', 1)
        docid = parts[0].strip()
        corpus_ = parts[1].strip()
        doc_id.append(docid)
        corpus.append(corpus_)

questions = {}

with open(os.path.join(file_path, 'questions.tsv'), 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.split('\t', 1)
        qid = parts[0].strip()
        question = parts[1].strip()
        if qid in questions.keys():
            print('1')
        else:
            questions[qid] = question   ### string으로 들어감

qids = []
docid = []
qid_docid = {}

with open(os.path.join(file_path, 'test/qrels_test.tsv'), 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.split('\t', 1)
        qid = parts[0].strip()
        docid_ = parts[1].strip()
        qids.append(qid)
        docid.append(docid_)
        qid_docid[qid] = docid_

ctxt = []   # test 질문 
for qid in qids:
    ctxt.append(questions[qid])

In [4]:
len(ctxt), len(corpus)

(3000, 124535)

In [5]:
tokenized_corpus = [mecab_tokenizer(doc) for doc in tqdm(corpus)]
k1 = 0.9
b = 0.4
bm25 = BM25Okapi(tokenized_corpus, k1=k1, b = b)

100%|██████████| 124535/124535 [00:51<00:00, 2426.60it/s]


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [7]:
p_model = BertModel.from_pretrained("kykim/bert-kor-base")
q_model = BertModel.from_pretrained("kykim/bert-kor-base")

p_model.load_state_dict(torch.load('DPR_p_small.pth'))
q_model.load_state_dict(torch.load('DPR_q_small.pth'))

p_model.to(device)
q_model.to(device)

tokenizer = BertTokenizer.from_pretrained("kykim/bert-kor-base")

In [8]:
question_embs = []

q_model.eval()

with torch.no_grad():
    for que in tqdm(ctxt):
        q_input = tokenizer(que, padding=True, truncation=True, return_tensors="pt").to(device)
        question_emb = q_model(**q_input).pooler_output
        question_embs.append(question_emb)

question_embs = torch.cat(question_embs, dim=0)
question_embs = question_embs.cpu().numpy()
print()
print(question_embs.shape)  #(3000, 768)

100%|██████████| 3000/3000 [00:30<00:00, 97.70it/s] 


(3000, 768)





In [14]:
collection_embs = []

p_model.eval()

with torch.no_grad():
    for p in tqdm(corpus):
        p_input = tokenizer(p, padding=True, truncation=True, return_tensors="pt").to(device)
        p_dv = p_model(**p_input).pooler_output
        collection_embs.append(p_dv)

collection_embs = torch.cat(collection_embs, dim=0)
collection_embs = collection_embs.cpu().numpy()
print()
print(collection_embs.shape)  # (124535, 768)

100%|██████████| 124535/124535 [30:04<00:00, 69.01it/s]



(124535, 768)


In [15]:
dimension = 768

def normalize_vectors(vectors):
    norms = np.linalg.norm(vectors, axis=1, keepdims=True)
    vectors_normalized = vectors / norms
    return vectors_normalized

In [18]:
n_top = 100
k = 20

recall1, recall2, recall5, recall10, recall20 = 0,0,0,0,0

for idx, qid in tqdm(enumerate(qids), total = len(qids)):
    answer = qid_docid[qid]

    cur_question = questions[qid]

    cur_question_emb = question_embs[idx]

    top_collection_embs = []

    tokenized_question = mecab_tokenizer(cur_question)
    doc_scores = bm25.get_scores(tokenized_question)

    top_indices = np.argsort(doc_scores)[-n_top:][::-1] # 상위 n_top개 문단의 index 뽑음

    for i in top_indices:
      top_collection_embs.append(collection_embs[i])

    index = faiss.IndexFlatIP(dimension)
    top_collection_embs = normalize_vectors(top_collection_embs)
    index.add(top_collection_embs)

    _, indices = index.search(cur_question_emb.reshape(1, -1), k)

    for k_idx, i  in enumerate(indices[0]):
        if doc_id[top_indices[i]] == answer:
            if k_idx<1:
                recall1+=1
                recall2+=1
                recall5+=1
                recall10+=1
                recall20+=1
                break
            elif k_idx<2:
                recall2+=1
                recall5+=1
                recall10+=1
                recall20+=1
                break
            elif k_idx<5:
                recall5+=1
                recall10+=1
                recall20+=1
                break
            elif k_idx<10:
                recall10+=1
                recall20+=1
                break
            else:
                recall20+=1
                break


100%|██████████| 3000/3000 [42:46<00:00,  1.17it/s]


In [19]:
print(f'recall@1 : {recall1/len(qids)}')
print(f'recall@2 : {recall2/len(qids)}')
print(f'recall@5 : {recall5/len(qids)}')
print(f'recall@10 : {recall10/len(qids)}')
print(f'recall@20 : {recall20/len(qids)}')


recall@1 : 0.20033333333333334
recall@2 : 0.303
recall@5 : 0.4836666666666667
recall@10 : 0.6263333333333333
recall@20 : 0.7763333333333333


In [17]:
# np.save('question_embs.npy', question_embs)
# np.save('collection_embs.npy', collection_embs)

# loaded_question_embs = np.load('question_embs.npy')
# loaded_collection_embs = np.load('collection_embs.npy')
