<a href="https://colab.research.google.com/github/Chuck2Win/temp/blob/main/dense_passage_retrieval_by_myself.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
! pip install sentencepiece
! pip install transformers

Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[?25l[K     |▎                               | 10 kB 26.8 MB/s eta 0:00:01[K     |▌                               | 20 kB 25.0 MB/s eta 0:00:01[K     |▉                               | 30 kB 19.9 MB/s eta 0:00:01[K     |█                               | 40 kB 16.2 MB/s eta 0:00:01[K     |█▍                              | 51 kB 7.7 MB/s eta 0:00:01[K     |█▋                              | 61 kB 8.9 MB/s eta 0:00:01[K     |██                              | 71 kB 8.6 MB/s eta 0:00:01[K     |██▏                             | 81 kB 9.6 MB/s eta 0:00:01[K     |██▍                             | 92 kB 10.0 MB/s eta 0:00:01[K     |██▊                             | 102 kB 7.6 MB/s eta 0:00:01[K     |███                             | 112 kB 7.6 MB/s eta 0:00:01[K     |███▎                            | 122 kB 7.6 MB/s eta 0:00:01[K     |███▌     

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
cd /content/gdrive/MyDrive/dense_passage_retriever

/content/gdrive/MyDrive/dense_passage_retriever


In [29]:
import os
import json
import torch
import pickle
from tqdm import tqdm
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader,DistributedSampler,RandomSampler,SequentialSampler
import torch.nn.functional as F
import logging
    
from transformers import DistilBertModel
import argparse
from datetime import datetime

from retrieval.dense_retrieval.model import *
from utils.data_utils import load_jsonl,DPR_Train_Dataset, DPR_Context_Dataset
from utils.tools import str2bool
from retrieval.dense_retrieval.dense_retrieval import DPR_Retrieval
from utils.metrics import compute_topk_accuracy
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
from utils.tokenization_kobert import KoBertTokenizer

In [30]:
# -*- coding: utf-8 -*-

# parser
parser = argparse.ArgumentParser()

# test name
parser.add_argument('--test_name', type=str,default =  'ok')

# data
parser.add_argument('--data_path', type=str, default =r'./data/book/도서_valid/contexts.json')
parser.add_argument('--train_data', type=str,default =  r'./data/book/도서_valid/train_data_hard_ctxs_add.jsonl')
parser.add_argument('--val_data', type=str,default =  r'./data/book/도서_valid/val_data.jsonl')
parser.add_argument('--n_hard_negative_ctxs', type=int,default =  1)
parser.add_argument('--output_dir', type=str, default = './output/1210')
parser.add_argument('--top_n', type=int, default = 20)
parser.add_argument('--passage_embedding_path', type=str, default = './output/1210/passage_embedding')

# logging 관련
parser.add_argument('--logging_term', type=int, default = 1000)

# 학습 관련
parser.add_argument('--epochs', type=int, default = 3)
parser.add_argument('--batch_size', type=int, default = 8)
parser.add_argument('--lr', type=float, default = 1e-4)
parser.add_argument('--warmup', type=int, default = 4000)
parser.add_argument('--fp16', type=str2bool, default = True)

# 데이터 관련
parser.add_argument('--context_max_length',type= int, default = 512)
parser.add_argument('--question_max_length', type=int, default = 128)

# distributed 관련
parser.add_argument('--local_rank', type=int, default = -1)
parser.add_argument('--distributed', type=bool, default = False)


_StoreAction(option_strings=['--distributed'], dest='distributed', nargs=None, const=None, default=False, type=<class 'bool'>, choices=None, help=None, metavar=None)

In [31]:
def get_log(args):
    global logger1, logger2
    logger1 = logging.getLogger('file') # 적지 않으면 root로 생성
    logger2 = logging.getLogger('stream') # 적지 않으면 root로 생성

    # 2. logging level 지정 - 기본 level Warning
    logger1.setLevel(logging.INFO)
    logger2.setLevel(logging.INFO)

    # 3. logging formatting 설정 - 문자열 format과 유사 - 시간, logging 이름, level - messages
    formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] >> %(message)s')

    # 4. handler : log message를 지정된 대상으로 전달하는 역할.
    # SteamHandler : steam(terminal 같은 console 창)에 log message를 보냄
    # FileHandler : 특정 file에 log message를 보내 저장시킴.
    # handler 정의
    stream_handler = logging.StreamHandler()
    # handler에 format 지정
    stream_handler.setFormatter(formatter)
    # logger instance에 handler 삽입
    logger2.addHandler(stream_handler)
    os.makedirs(args.output_dir,exist_ok=True)
    file_handler = logging.FileHandler(os.path.join(args.output_dir,'%s_%s.txt'%(datetime.today(),args.test_name)), encoding='utf-8')
    file_handler.setFormatter(formatter)
    logger1.addHandler(file_handler)


In [32]:
def val_collate_fn(batch):
    questions = []
    positive_ctxs_ids = []
    for data in batch:
        positive_ctxs_ids.extend(data['positive_ctxs_ids'])
        questions.append(data['question'])
    return questions, positive_ctxs_ids

In [33]:
def validate(args, tokenizer, passage_encoder, question_encoder, dataloader, epoch):
    dpr_retrieval = DPR_Retrieval(args, tokenizer, passage_encoder, question_encoder)
    tmp = args.passage_embedding_path 
    args.passage_embedding_path = args.passage_embedding_path + '_%d'%epoch
    answers = []
    predicts = []
    for data in tqdm(dataloader,desc='encoding'):
        question, label = data
        answers.extend(label)
        top_n=dpr_retrieval.retrieve(question, args.top_n)
        predicts.extend(top_n)
    args.passage_embedding_path = tmp
    return answers, predicts

In [34]:
args,_ = parser.parse_known_args()
get_log(args)
tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert')
distilbert_model = DistilBertModel.from_pretrained('monologg/distilkobert')

# passage encoder
encoder_p = encoder(distilbert_model)
# question encoder
encoder_q = encoder(distilbert_model)
# dense passage retriever model   
model = dpr_encoder(encoder_p, encoder_q)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'KoBertTokenizer'.
Some weights of the model checkpoint at monologg/distilkobert were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClass

In [35]:
# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# data
train_data = load_jsonl(args.train_data)
val_data = load_jsonl(args.val_data)

# sanity check
os.makedirs(args.output_dir, exist_ok = True)


43920it [00:02, 15299.56it/s]
1000it [00:00, 27304.53it/s]


In [36]:
# distributed 관련
if args.distributed:
    assert torch.cuda.is_avaiable()
    assert torch.cuda.device_count()>1
    # 이 프로세스가 어느 gpu에 할당되는지 명시
    torch.cuda.set_device(args.local_rank)
    # 통신을 위한 초기화
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    model.cuda()
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],output_device = args.local_rank)
    
else:
    if torch.cuda.is_available():
        model.cuda()
# train
train_dataset=DPR_Train_Dataset(args, train_data, tokenizer)
train_sampler = DistributedSampler(train_dataset) if args.distributed else RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset,batch_size = args.batch_size, sampler = train_sampler, collate_fn = train_dataset._collate_fn)
# val
val_dataloader = DataLoader(val_data,batch_size = args.batch_size, collate_fn=val_collate_fn)


optimizer = torch.optim.Adam(model.parameters(),args.lr)
linear_scheduler = lambda step: min(1/args.warmup*step,1.)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = linear_scheduler)
criterion = criterion_sim

if args.fp16:
    scaler = GradScaler()

In [37]:
# train
global_step = 0
best_score = -float('inf')
for epoch in range(1, args.epochs+1):
    # BERT
    if args.distributed:
        train_sampler.set_epoch(epoch)
    model.train()
    Loss_t = 0.
    iter_bar = tqdm(train_dataloader, desc='step')
    for data in iter_bar:
        optimizer.zero_grad()
        if torch.cuda.is_available():
            question_input_ids, question_attention_masks, context_input_ids, context_attention_masks, context_indice = [i.cuda() for i in data]
        else:
            question_input_ids, question_attention_masks, context_input_ids, context_attention_masks, context_indice = data
        if args.fp16:
            with autocast():
                P,Q = model.forward(context_input_ids, context_attention_masks, question_input_ids, question_attention_masks)
                loss,_ = criterion(Q,P,context_indice)
                Loss_t+=loss.item()
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                
        else:
            P,Q = model.forward(context_input_ids, context_attention_masks, question_input_ids, question_attention_masks)
            loss,_ = criterion(Q,P,context_indice)
            Loss_t+=loss.item()
            loss.backward()
            optimizer.step()
            
            
        global_step+=1           
        scheduler.step()
        torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1.0)          
        
        iter_bar.set_postfix({'epoch':epoch, 'global_step':global_step, 'lr':f"{scheduler.get_last_lr()[0]:.10f}", 'last_loss':f'{loss.item():.5f}','epoch_loss':f'{Loss_t/len(train_dataloader):.5f}'})
        if global_step%args.logging_term == 0:
            logger2.info(iter_bar)
        # validation
    actual, predict = validate(args, tokenizer, model.passage_encoder, model.question_encoder, val_dataloader, epoch)
    acc=compute_topk_accuracy(actual,predict)
    logger1.info(f'epoch : {epoch} ----- Val_acc : {acc[-1]:.5f}')
    logger2.info(f'epoch : {epoch} ----- Val_acc : {acc[-1]:.5f}')
    if acc[-1]>best_score:
        best_score = acc[-1]
        torch.save(model,os.path.join(args.output_dir,'best_model'))
    logger1.info(f'epoch : {epoch} ----- Train_Loss : {Loss_t/len(train_dataloader):.5f}')
    logger2.info(f'epoch : {epoch} ----- Train_Loss : {Loss_t/len(train_dataloader):.5f}')
    # 저장시 - gpu 0번 것만 저장 - barrier 필수
    if args.local_rank in [-1,0]:
        torch.save(model,os.path.join(args.output_dir,'model_epoch_%d'%epoch))
    # torch.distributed.barrier()
logger1.info('train_end')
logger2.info('train end')

step:  18%|█▊        | 999/5490 [04:36<20:39,  3.62it/s, epoch=1, global_step=1000, lr=0.0000250000, last_loss=1.11632, epoch_loss=0.31606][2021-12-10 02:24:32,732][stream][INFO] >> step:  18%|█▊        | 999/5490 [04:36<20:39,  3.62it/s, epoch=1, global_step=1000, lr=0.0000250000, last_loss=1.11632, epoch_loss=0.31606]
step:  36%|███▋      | 1999/5490 [09:12<16:01,  3.63it/s, epoch=1, global_step=2000, lr=0.0000500000, last_loss=0.54974, epoch_loss=0.46894][2021-12-10 02:29:08,723][stream][INFO] >> step:  36%|███▋      | 1999/5490 [09:12<16:01,  3.63it/s, epoch=1, global_step=2000, lr=0.0000500000, last_loss=0.54974, epoch_loss=0.46894]
step:  55%|█████▍    | 2999/5490 [13:48<11:28,  3.62it/s, epoch=1, global_step=3000, lr=0.0000750000, last_loss=0.27976, epoch_loss=0.60001][2021-12-10 02:33:44,769][stream][INFO] >> step:  55%|█████▍    | 2999/5490 [13:48<11:28,  3.62it/s, epoch=1, global_step=3000, lr=0.0000750000, last_loss=0.27976, epoch_loss=0.60001]
step:  73%|███████▎  | 3999/54

KeyboardInterrupt: ignored