In [1]:
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AdamW
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import SubsetRandomSampler

from time import time
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)

file_path = os.path.join(parent_directory, 'aihub', 'dataset.tsv')
testfile_path = os.path.join(parent_directory, 'aihub/test', 'test_dataset.tsv')

with open(file_path, 'r', encoding='utf-8') as f:
    data = [line.strip().split('|', 1) for line in f]
data = data[1:]

with open(testfile_path, 'r', encoding='utf-8') as f:
    testdata = [line.strip().split('|', 1) for line in f]
testdata = testdata[1:]

In [3]:
len(data), len(testdata)

(249689, 3000)

In [4]:
class DPRDataset(Dataset):
    def __init__(self, questions, passages, tokenizer):
        self.passages = passages
        self.questions = questions
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.passages)

    def __getitem__(self, index):
        passage = self.passages[index]
        question = self.questions[index]
        return question, passage

    def collate_fn(self, batch):
        passages, questions = zip(*batch)
        passage_inputs = self.tokenizer.batch_encode_plus(passages, padding=True, truncation=True, return_tensors="pt")
        question_inputs = self.tokenizer.batch_encode_plus(questions, padding=True, truncation=True, return_tensors="pt")
        return question_inputs, passage_inputs

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [6]:
p_model = BertModel.from_pretrained("kykim/bert-kor-base")
q_model = BertModel.from_pretrained("kykim/bert-kor-base")

p_model.to(device)
q_model.to(device)

tokenizer = BertTokenizer.from_pretrained("kykim/bert-kor-base")

In [7]:
batch_size = 4

In [8]:
questions = [row[0] for row in data]
passages = [row[1] for row in data]

train_questions, valid_questions, train_passages, valid_passages = train_test_split(
    questions, passages, test_size=0.1
)

train_dataset = DPRDataset(train_questions, train_passages, tokenizer)
valid_dataset = DPRDataset(valid_questions, valid_passages, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=train_dataset.collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=valid_dataset.collate_fn)
# input_ids : [bs, 512],    attention_mask : [bs, 512]

In [9]:
len(train_dataloader), len(valid_dataloader)

(56180, 6243)

In [10]:
# AdamW 옵티마이저를 초기화하고 하나의 리스트로 합친 파라미터들을 전달
p_optimizer = AdamW(p_model.parameters(), lr=2e-5, eps=1e-8)
q_optimizer = AdamW(q_model.parameters(), lr=2e-5, eps=1e-8)

epochs = 10

total_steps = len(train_dataloader) * epochs
p_scheduler = get_linear_schedule_with_warmup(p_optimizer,
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)
q_scheduler = get_linear_schedule_with_warmup(q_optimizer,
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)



In [11]:
inf_loss = float('inf')
cnt = 0

for epoch in range(epochs):
    q_model.train()
    p_model.train()
    total_loss = 0

    for step, batch in tqdm(enumerate(train_dataloader), total = len(train_dataloader), desc="training", leave = False):
        batch = tuple(t.to(device) for t in batch)
        b_question, b_passage = batch

        p_optimizer.zero_grad()
        q_optimizer.zero_grad()

        question_v = q_model(**b_question).pooler_output
        passage_v = p_model(**b_passage).pooler_output

        cosine = torch.matmul(question_v, torch.transpose(passage_v, 0, 1))
        cosine = torch.nn.functional.log_softmax(cosine, dim=1)

        targets = torch.arange(0, question_v.shape[0]).long().to(device)

        loss = torch.nn.functional.nll_loss(cosine, targets)

        loss.backward()
        p_optimizer.step()
        q_optimizer.step()

        p_scheduler.step()
        q_scheduler.step()
        total_loss += loss.item()
    print(f'epoch : {epoch+1}/{epochs}, train loss : {total_loss / len(train_dataloader)}')


    q_model.eval()
    p_model.eval()
    with torch.no_grad():
        valid_loss = 0.0
        for step, batch in tqdm(enumerate(valid_dataloader), total = len(valid_dataloader), desc = "validing", leave = False):
            batch = tuple(t.to(device) for t in batch)
            b_question, b_passage = batch

            question_v = q_model(**b_question).pooler_output
            passage_v = p_model(**b_passage).pooler_output

            cosine = torch.matmul(question_v, torch.transpose(passage_v, 0, 1))
            cosine = torch.nn.functional.log_softmax(cosine, dim=1)

            targets = torch.arange(0, question_v.shape[0]).long().to(device)

            loss = torch.nn.functional.nll_loss(cosine, targets)
            valid_loss += loss
        print(f'epoch : {epoch+1}/{epochs}, valid loss : {valid_loss / len(valid_dataloader)}')
        if inf_loss < (valid_loss / len(valid_dataloader)):
            cnt = 0
            inf_loss = valid_loss / len(valid_dataloader)
            torch.save(q_model.state_dict(), 'DPR_q.pth')
            torch.save(p_model.state_dict(), 'DPR_p.pth')
        else:
            cnt += 1
            if cnt > 2:
                break

                                                               

KeyboardInterrupt: 