In [1]:
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AdamW
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import SubsetRandomSampler
from torch.cuda.amp import autocast, GradScaler

from time import time
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)

file_path = os.path.join(parent_directory, 'aihub', 'dataset.tsv')
testfile_path = os.path.join(parent_directory, 'aihub/test', 'test_dataset.tsv')

with open(file_path, 'r', encoding='utf-8') as f:
    data = [line.strip().split('|', 1) for line in f]
data = data[1:]

with open(testfile_path, 'r', encoding='utf-8') as f:
    testdata = [line.strip().split('|', 1) for line in f]
testdata = testdata[1:]

In [3]:
len(data), len(testdata)

(249689, 3000)

In [4]:
class DPRDataset(Dataset):
    def __init__(self, questions, passages, tokenizer):
        self.passages = passages
        self.questions = questions
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.passages)

    def __getitem__(self, index):
        passage = self.passages[index]
        question = self.questions[index]
        return question, passage

    def collate_fn(self, batch):
        passages, questions = zip(*batch)
        passage_inputs = self.tokenizer.batch_encode_plus(passages, padding=True, truncation=True, return_tensors="pt")
        question_inputs = self.tokenizer.batch_encode_plus(questions, padding=True, truncation=True, return_tensors="pt")
        return question_inputs, passage_inputs

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [6]:
q_model = BertModel.from_pretrained("kykim/bert-kor-base")

q_model.to(device)

tokenizer = BertTokenizer.from_pretrained("kykim/bert-kor-base")

In [7]:
batch_size = 8

In [8]:
questions = [row[0] for row in data]
passages = [row[1] for row in data]

train_questions, valid_questions, train_passages, valid_passages = train_test_split(
    questions, passages, test_size=0.1
)

train_dataset = DPRDataset(train_questions, train_passages, tokenizer)
valid_dataset = DPRDataset(valid_questions, valid_passages, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate_fn)
# input_ids : [bs, 512],    attention_mask : [bs, 512]

In [9]:
len(train_dataloader), len(valid_dataloader)

(28090, 3122)

In [10]:
q_optimizer = AdamW(q_model.parameters(), lr=2e-5, eps=1e-8)

epochs = 100

total_steps = len(train_dataloader) * epochs
q_scheduler = get_linear_schedule_with_warmup(q_optimizer,
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)



In [11]:
scaler = GradScaler()

inf_loss = float('inf')
cnt = 0

for epoch in range(epochs):
    q_model.train()
    total_loss = 0

    for step, batch in tqdm(enumerate(train_dataloader), total = len(train_dataloader), desc="training", leave = False):
        batch = tuple(t.to(device) for t in batch)
        b_question, b_passage = batch

        q_optimizer.zero_grad()

        with autocast():
            question_v = q_model(**b_question).pooler_output
            passage_v = q_model(**b_passage).pooler_output

            cosine = torch.matmul(question_v, torch.transpose(passage_v, 0, 1))
            cosine = torch.nn.functional.log_softmax(cosine, dim=1)

            targets = torch.arange(0, question_v.shape[0]).long().to(device)

            loss = torch.nn.functional.nll_loss(cosine, targets)

        scaler.scale(loss).backward()
        scaler.step(q_optimizer)
        scaler.update()

        q_scheduler.step()
        total_loss += loss.item()
    print(f'epoch : {epoch+1}/{epochs}, train loss : {total_loss / len(train_dataloader)}')


    q_model.eval()
    with torch.no_grad():
        valid_loss = 0.0
        for step, batch in tqdm(enumerate(valid_dataloader), total = len(valid_dataloader), desc = "validing", leave = False):
            batch = tuple(t.to(device) for t in batch)
            b_question, b_passage = batch

            question_v = q_model(**b_question).pooler_output
            passage_v = q_model(**b_passage).pooler_output

            cosine = torch.matmul(question_v, torch.transpose(passage_v, 0, 1))
            cosine = torch.nn.functional.log_softmax(cosine, dim=1)

            targets = torch.arange(0, question_v.shape[0]).long().to(device)

            loss = torch.nn.functional.nll_loss(cosine, targets)
            valid_loss += loss
        valid_loss_result = valid_loss / len(valid_dataloader)
        print(f'epoch : {epoch+1}/{epochs}, valid loss : {valid_loss_result}')
        if inf_loss > valid_loss_result:
            cnt = 0
            inf_loss = valid_loss_result
            torch.save(q_model.state_dict(), 'DPR_single.pth')
        # else:
        #     cnt += 1
        #     if cnt > 3:
        #         break

training:   0%|          | 0/28090 [00:00<?, ?it/s]

training:   1%|          | 199/28090 [00:48<1:51:02,  4.19it/s]

# inference

In [None]:
import os
from transformers import BertTokenizer
from transformers import BertModel
import torch
import faiss
import numpy as np
from tqdm import tqdm

In [None]:
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)

file_path = os.path.join(parent_directory, 'aihub')

questions = {}
with open(os.path.join(file_path, 'questions.tsv'), 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.split('\t', 1)
        qid = parts[0].strip()
        question = parts[1].strip()
        questions[qid] = question

pids = []
passages = []
with open(os.path.join(file_path, 'collection.tsv'), 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.split('||', 1)
        pid = parts[0].strip()
        passage = parts[1].strip()
        pids.append(pid)
        passages.append(passage)

answers = {}
with open(os.path.join(file_path, 'test', 'qrels_test.tsv'), 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.split('\t', 1)
        qid = parts[0].strip()
        pid = parts[1].strip()
        answers[qid] = pid

In [None]:
len(questions), len(pids), len(passages), len(answers)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
test_q_model = BertModel.from_pretrained("kykim/bert-kor-base")
test_q_model.load_state_dict(torch.load('DPR_q.pth'))
test_q_model.to(device)

In [None]:
tokenizer = BertTokenizer.from_pretrained("kykim/bert-kor-base")

In [None]:
p_dvs = []

test_q_model.eval()

with torch.no_grad():
    for p in tqdm(passages):
        p_input = tokenizer(p, padding=True, truncation=True, return_tensors="pt").to(device)
        p_dv = test_q_model(**p_input).pooler_output
        p_dvs.append(p_dv)

p_dvs = torch.cat(p_dvs, dim=0)
p_dvs = p_dvs.cpu().numpy()
print()
print(p_dvs.shape)  # (124535, 768)

In [None]:
q_dvs = []

test_q_model.eval()

with torch.no_grad():
    for qid, pid in tqdm(answers.items()):
        q = questions[qid]
        q_input = tokenizer(q, padding=True, truncation=True, return_tensors="pt").to(device)
        q_dv = test_q_model(**q_input).pooler_output
        q_dvs.append(q_dv)

q_dvs = torch.cat(q_dvs, dim=0)
q_dvs = q_dvs.cpu().numpy()
print()
print(q_dvs.shape)  #(3000, 768)

In [None]:
dimension = 768

index = faiss.IndexFlatIP(dimension)
def normalize_vectors(vectors):
    norms = np.linalg.norm(vectors, axis=1, keepdims=True)
    vectors_normalized = vectors / norms
    return vectors_normalized

p_dvs = normalize_vectors(p_dvs)
index.add(p_dvs)

k=100
_, indices = index.search(q_dvs, k)

In [None]:
recall1 = 0
recall10 = 0
recall20 =0
recall100 = 0

for idx, (qid, pid) in tqdm(enumerate(answers.items()), total = len(answers), desc = 'testing', leave = False):
    for pred in range(k):
        if pids[indices[idx][pred]] == pid:
            if pred<1:
                recall1+=1
                recall10+=1
                recall20+=1
                recall100+=1
                break
            elif pred<10:
                recall10+=1
                recall20+=1
                recall100+=1
                break
            elif pred<20:
                recall20+=1
                recall100+=1
                break
            elif pred<100:
                recall100+=1
                break

print()
print(f'recall@1 : {recall1/len(answers)}')
print(f'recall@10 : {recall10/len(answers)}')
print(f'recall@20 : {recall20/len(answers)}')
print(f'recall@100 : {recall100/len(answers)}') 