# BERT를 활용한 Dense Passage Retrieval 실습

## Requirements

In [None]:
!pip install datasets
!pip install transformers

## 데이터셋 로딩

* KorQuAD 다운로드

In [None]:
from datasets import load_dataset

dataset = load_dataset("squad_kor_v1")

In [None]:
corpus = list(set([example['context'] for example in dataset['train']]))
len(corpus)

## 토크나이저 준비 - Huggingface 제공 tokenizer 이용

* bert multilingual model 사용

In [None]:
from transformers import AutoTokenizer
import numpy as np

model_checkpoint = "bert-base-multilingual-cased"

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


In [None]:
tokenizer

* input을 tokenize 및 decoding 하기
  * `truncation=True` : 너무 길면 자름

In [None]:
print(corpus[0])
tokenized_input = tokenizer(corpus[0], padding="max_length", truncation=True)
tokenizer.decode(tokenized_input['input_ids'])

## Dense encoder (BERT) 학습

* package 가져오고 seed 지정

In [None]:
from tqdm import tqdm, trange
import argparse
import random
import torch
import torch.nn.functional as F
from transformers import BertModel, BertPreTrainedModel, AdamW, TrainingArguments, get_linear_schedule_with_warmup

torch.manual_seed(3532812018032770127)
torch.cuda.manual_seed(3532812018032770127)
np.random.seed(324)
random.seed(2021)

* 학습 데이터 준비
  * 128개를 sample 함(총 train 데이터 길이에서 128개 숫자(index)를 sample한 것)
  * index로 data를 가져와서 training dataset을 만듬

In [None]:
# Use subset (128 example) of original training dataset 
sample_idx = np.random.choice(range(len(dataset['train'])), 128)
training_dataset = dataset['train'][sample_idx]

* tokenization 하기

In [None]:
from torch.utils.data import (DataLoader, RandomSampler, TensorDataset)

q_seqs = tokenizer(training_dataset['question'], padding="max_length", truncation=True, return_tensors='pt')
p_seqs = tokenizer(training_dataset['context'], padding="max_length", truncation=True, return_tensors='pt')


* dataset을 학습하기 위해 tenser dataset으로 변경
  * q_seqs와 p_seqs를 합쳐주는 것
  * 학습할 때 용이하도록(access가 편리함) 형태를 바꾸는 것

In [None]:
train_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
                        q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids'])

* BERT encoder 학습
  * BERT encoder를 직접 구현
    * [CLS] token에 해당되는 embedding만 가져오면 됨

In [None]:
class BertEncoder(BertPreTrainedModel):
  def __init__(self, config):
    super(BertEncoder, self).__init__(config)

    self.bert = BertModel(config)
    self.init_weights()
      
  def forward(self, input_ids, 
              attention_mask=None, token_type_ids=None): 
  
      # vanilla bert 적용
      outputs = self.bert(input_ids,
                          attention_mask=attention_mask,
                          token_type_ids=token_type_ids)
      
      pooled_output = outputs[1] # [CLS] token에 해당하는 embedding

      return pooled_output


* model을 instantiate하기(시작점 정의)
  * model 가져오기

In [None]:
# load pre-trained model on cuda (if available)
p_encoder = BertEncoder.from_pretrained(model_checkpoint)
q_encoder = BertEncoder.from_pretrained(model_checkpoint)

if torch.cuda.is_available(): # GPU 사용
  p_encoder.cuda()
  q_encoder.cuda()

* train function 구현

In [None]:
def train(args, num_neg, dataset, p_model, q_model):
  
  # Dataloader
  train_sampler = RandomSampler(dataset)
  ## 학습시 어떻게 학습할지 feeding을 결정함
  train_dataloader = DataLoader(dataset, sampler=train_sampler, batch_size=args.per_device_train_batch_size)

  # Optimizer
  ## optimizer 관련 parameter 설정
  no_decay = ['bias', 'LayerNorm.weight']
  optimizer_grouped_parameters = [
        {'params': [p for n, p in p_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in p_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
        {'params': [p for n, p in q_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in q_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
  ## optimizer 정의
  optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
  ## 얼마동안 학습할지에 대한 parameter 정의
  t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
  scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

  # Start training!
  global_step = 0
  
  p_model.zero_grad()
  q_model.zero_grad()
  torch.cuda.empty_cache()
  
  train_iterator = trange(int(args.num_train_epochs), desc="Epoch")

  for _ in train_iterator: ## iteration 시작
    epoch_iterator = tqdm(train_dataloader, desc="Iteration")

    for step, batch in enumerate(epoch_iterator):
      q_encoder.train()
      p_encoder.train()
      
      targets = torch.zeros(args.per_device_train_batch_size).long()
      if torch.cuda.is_available():
        batch = tuple(t.cuda() for t in batch)
        targets = targets.cuda()

      p_inputs = {'input_ids': batch[0], ## tensor_dataset을 이용하여 각 batch에 나눠진 값
                  'attention_mask': batch[1],
                  'token_type_ids': batch[2]
                  }
      
      q_inputs = {'input_ids': batch[3],
                  'attention_mask': batch[4],
                  'token_type_ids': batch[5]}
      
      p_outputs = p_model(**p_inputs)  #(batch_size, emb_dim) ## vector의 개수 : batch size
      q_outputs = q_model(**q_inputs)  #(batch_size, emb_dim)

      # Calculate similarity score & loss
      sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1))  # (batch_size, emb_dim) x (emb_dim, batch_size) = (batch_size, batch_size)

      # target: position of positive samples = diagonal element 
      targets = torch.arange(0, args.per_device_train_batch_size).long()
      if torch.cuda.is_available():
        targets = targets.to('cuda')

      sim_scores = F.log_softmax(sim_scores, dim=1)

      loss = F.nll_loss(sim_scores, targets)
      print(loss)

      loss.backward()
      optimizer.step()
      scheduler.step()
      q_model.zero_grad()
      p_model.zero_grad()
      global_step += 1
      
      torch.cuda.empty_cache()


    
  return p_model, q_model




In [None]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    weight_decay=0.01
)


In [None]:
p_encoder, q_encoder = train(args, num_neg, train_dataset, p_encoder, q_encoder)

## Dense Embedding을 활용하여 passage retrieval 실습

* validation set 사용
  * train 학습했기 때문

In [None]:

valid_corpus = list(set([example['context'] for example in dataset['validation']]))[:10]
sample_idx = random.choice(range(len(dataset['validation'])))
query = dataset['validation'][sample_idx]['question']
ground_truth = dataset['validation'][sample_idx]['context']

## corpus에 없는 경우 보완함
if not ground_truth in valid_corpus:
  valid_corpus.append(ground_truth)

print(query)
print(ground_truth, '\n\n')

# valid_corpus

In [None]:
def to_cuda(batch):
  return tuple(t.cuda() for t in batch)

* 각 passage에 대한 embedding 확보하기

In [None]:
with torch.no_grad():
  p_encoder.eval()
  q_encoder.eval()

  q_seqs_val = tokenizer([query], padding="max_length", truncation=True, return_tensors='pt').to('cuda')
  q_emb = q_encoder(**q_seqs_val).to('cpu')  #(num_query, emb_dim)

  p_embs = []
  for p in valid_corpus:
    p = tokenizer(p, padding="max_length", truncation=True, return_tensors='pt').to('cuda')
    p_emb = p_encoder(**p).to('cpu').numpy()
    p_embs.append(p_emb)

## 하나의 matrix로 변형
p_embs = torch.Tensor(p_embs).squeeze()  # (num_passage, emb_dim)

print(p_embs.size(), q_emb.size()) ## (11, 768) (1, 768) ## 11 : passage 개수, 768 : embedding size, 1 : question 개수

* similarity score 계산하기
  * 한 개의 query(question)에 대한 passage들의 유사도

In [None]:
dot_prod_scores = torch.matmul(q_emb, torch.transpose(p_embs, 0, 1))
print(dot_prod_scores.size())

rank = torch.argsort(dot_prod_scores, dim=1, descending=True).squeeze() ## 내림차순 정렬
print(dot_prod_scores)
print(rank)

In [None]:
k = 5
print("[Search query]\n", query, "\n")
print("[Ground truth passage]")
print(ground_truth, "\n")

for i in range(k):
  print("Top-%d passage with score %.4f" % (i+1, dot_prod_scores.squeeze()[rank[i]]))
  print(valid_corpus[rank[i]])