In [1]:
from time import time
from datetime import timedelta
from copy import deepcopy

import random
import numpy as np
import pandas as pd
from ml_metrics import mapk

import torch
from torch.optim import AdamW
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForMultipleChoice

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
# Random seed
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# CUDA device
use_cuda_device = 0
torch.cuda.set_device(use_cuda_device)
print("Using CUDA device: %d" % torch.cuda.current_device())

Using CUDA device: 0


# Setting

In [3]:
# Input files
document_csv_path = './documents.csv'
training_csv_path = './train_queries.csv'
testing_csv_path = './test_queries.csv'

# Input limitation
max_query_length = 64
max_input_length = 512
num_negatives = 3   # num. of negative documents to pair with a positive document

# Model finetuning
model_name_or_path = "bert-base-uncased"
max_epochs = 1
learning_rate = 3e-5
dev_set_ratio = 0.2   # make a ratio of training set as development set for rescoring weight sniffing
max_patience = 0      # earlystop if avg. loss on development set doesn't decrease for num. of epochs
batch_size = 2    # num. of inputs = 8 requires ~9200 MB VRAM (num. of inputs = batch_size * (num_negatives + 1))
num_workers = 2   # num. of jobs for pytorch dataloader

# Save paths
save_model_path = "models/bert_base_uncased"  # assign `None` for not saving the model
save_submission_path = "bm25_bert_rescoring.csv"
K = 1000   # for MAP@K


# Preparing

In [4]:
# Build and save BERT tokenizer
tokenizer = BertTokenizerFast.from_pretrained(model_name_or_path)
if save_model_path is not None:
    save_tokenizer_path = "%s/tokenizer" % (save_model_path)
    tokenizer.save_pretrained(save_tokenizer_path)

# Collect mapping of all document id and text
doc_id_to_text = {}
doc_df = pd.read_csv(document_csv_path)
doc_df.fillna("<Empty Document>", inplace=True)
id_text_pair = zip(doc_df["doc_id"], doc_df["doc_text"])
for i, pair in enumerate(id_text_pair, start=1):
    doc_id, doc_text = pair
    doc_id_to_text[doc_id] = doc_text
    
    print("Progress: %d/%d\r" % (i, len(doc_df)), end='')
    
doc_df.tail()

Progress: 1/100000Progress: 2/100000Progress: 3/100000Progress: 4/100000Progress: 5/100000Progress: 6/100000Progress: 7/100000Progress: 8/100000Progress: 9/100000Progress: 10/100000Progress: 11/100000Progress: 12/100000Progress: 13/100000Progress: 14/100000Progress: 15/100000Progress: 16/100000Progress: 17/100000Progress: 18/100000Progress: 19/100000Progress: 20/100000Progress: 21/100000Progress: 22/100000Progress: 23/100000Progress: 24/100000Progress: 25/100000Progress: 26/100000Progress: 27/100000Progress: 28/100000Progress: 29/100000Progress: 30/100000Progress: 31/100000Progress: 32/100000Progress: 33/100000Progress: 34/100000Progress: 35/100000Progress: 36/100000Progress: 37/100000Progress: 38/100000Progress: 39/100000Progress: 40/100000Progress: 41/100000Progress: 42/100000Progress: 43/100000Progress: 44/100000Progress: 45/100000Progress: 46/100000Progress: 47/100000Progress: 48/100000Progress: 49/100000Progress: 50/100000Progress:

Progress: 4570/100000Progress: 4571/100000Progress: 4572/100000Progress: 4573/100000Progress: 4574/100000Progress: 4575/100000Progress: 4576/100000Progress: 4577/100000Progress: 4578/100000Progress: 4579/100000Progress: 4580/100000Progress: 4581/100000Progress: 4582/100000Progress: 4583/100000Progress: 4584/100000Progress: 4585/100000Progress: 4586/100000Progress: 4587/100000Progress: 4588/100000Progress: 4589/100000Progress: 4590/100000Progress: 4591/100000Progress: 4592/100000Progress: 4593/100000Progress: 4594/100000Progress: 4595/100000Progress: 4596/100000Progress: 4597/100000Progress: 4598/100000Progress: 4599/100000Progress: 4600/100000Progress: 4601/100000Progress: 4602/100000Progress: 4603/100000Progress: 4604/100000Progress: 4605/100000Progress: 4606/100000Progress: 4607/100000Progress: 4608/100000Progress: 4609/100000Progress: 4610/100000Progress: 4611/100000Progress: 4612/100000Progress: 4613/100000Progress: 4614/100000Progress: 

Progress: 8965/100000Progress: 8966/100000Progress: 8967/100000Progress: 8968/100000Progress: 8969/100000Progress: 8970/100000Progress: 8971/100000Progress: 8972/100000Progress: 8973/100000Progress: 8974/100000Progress: 8975/100000Progress: 8976/100000Progress: 8977/100000Progress: 8978/100000Progress: 8979/100000Progress: 8980/100000Progress: 8981/100000Progress: 8982/100000Progress: 8983/100000Progress: 8984/100000Progress: 8985/100000Progress: 8986/100000Progress: 8987/100000Progress: 8988/100000Progress: 8989/100000Progress: 8990/100000Progress: 8991/100000Progress: 8992/100000Progress: 8993/100000Progress: 8994/100000Progress: 8995/100000Progress: 8996/100000Progress: 8997/100000Progress: 8998/100000Progress: 8999/100000Progress: 9000/100000Progress: 9001/100000Progress: 9002/100000Progress: 9003/100000Progress: 9004/100000Progress: 9005/100000Progress: 9006/100000Progress: 9007/100000Progress: 9008/100000Progress: 9009/100000Progress: 

Progress: 13964/100000Progress: 13965/100000Progress: 13966/100000Progress: 13967/100000Progress: 13968/100000Progress: 13969/100000Progress: 13970/100000Progress: 13971/100000Progress: 13972/100000Progress: 13973/100000Progress: 13974/100000Progress: 13975/100000Progress: 13976/100000Progress: 13977/100000Progress: 13978/100000Progress: 13979/100000Progress: 13980/100000Progress: 13981/100000Progress: 13982/100000Progress: 13983/100000Progress: 13984/100000Progress: 13985/100000Progress: 13986/100000Progress: 13987/100000Progress: 13988/100000Progress: 13989/100000Progress: 13990/100000Progress: 13991/100000Progress: 13992/100000Progress: 13993/100000Progress: 13994/100000Progress: 13995/100000Progress: 13996/100000Progress: 13997/100000Progress: 13998/100000Progress: 13999/100000Progress: 14000/100000Progress: 14001/100000Progress: 14002/100000Progress: 14003/100000Progress: 14004/100000Progress: 14005/100000Progress: 14006/100000Progress: 1

Progress: 18770/100000Progress: 18771/100000Progress: 18772/100000Progress: 18773/100000Progress: 18774/100000Progress: 18775/100000Progress: 18776/100000Progress: 18777/100000Progress: 18778/100000Progress: 18779/100000Progress: 18780/100000Progress: 18781/100000Progress: 18782/100000Progress: 18783/100000Progress: 18784/100000Progress: 18785/100000Progress: 18786/100000Progress: 18787/100000Progress: 18788/100000Progress: 18789/100000Progress: 18790/100000Progress: 18791/100000Progress: 18792/100000Progress: 18793/100000Progress: 18794/100000Progress: 18795/100000Progress: 18796/100000Progress: 18797/100000Progress: 18798/100000Progress: 18799/100000Progress: 18800/100000Progress: 18801/100000Progress: 18802/100000Progress: 18803/100000Progress: 18804/100000Progress: 18805/100000Progress: 18806/100000Progress: 18807/100000Progress: 18808/100000Progress: 18809/100000Progress: 18810/100000Progress: 18811/100000Progress: 18812/100000Progress: 1

Progress: 23463/100000Progress: 23464/100000Progress: 23465/100000Progress: 23466/100000Progress: 23467/100000Progress: 23468/100000Progress: 23469/100000Progress: 23470/100000Progress: 23471/100000Progress: 23472/100000Progress: 23473/100000Progress: 23474/100000Progress: 23475/100000Progress: 23476/100000Progress: 23477/100000Progress: 23478/100000Progress: 23479/100000Progress: 23480/100000Progress: 23481/100000Progress: 23482/100000Progress: 23483/100000Progress: 23484/100000Progress: 23485/100000Progress: 23486/100000Progress: 23487/100000Progress: 23488/100000Progress: 23489/100000Progress: 23490/100000Progress: 23491/100000Progress: 23492/100000Progress: 23493/100000Progress: 23494/100000Progress: 23495/100000Progress: 23496/100000Progress: 23497/100000Progress: 23498/100000Progress: 23499/100000Progress: 23500/100000Progress: 23501/100000Progress: 23502/100000Progress: 23503/100000Progress: 23504/100000Progress: 23505/100000Progress: 2

Progress: 28019/100000Progress: 28020/100000Progress: 28021/100000Progress: 28022/100000Progress: 28023/100000Progress: 28024/100000Progress: 28025/100000Progress: 28026/100000Progress: 28027/100000Progress: 28028/100000Progress: 28029/100000Progress: 28030/100000Progress: 28031/100000Progress: 28032/100000Progress: 28033/100000Progress: 28034/100000Progress: 28035/100000Progress: 28036/100000Progress: 28037/100000Progress: 28038/100000Progress: 28039/100000Progress: 28040/100000Progress: 28041/100000Progress: 28042/100000Progress: 28043/100000Progress: 28044/100000Progress: 28045/100000Progress: 28046/100000Progress: 28047/100000Progress: 28048/100000Progress: 28049/100000Progress: 28050/100000Progress: 28051/100000Progress: 28052/100000Progress: 28053/100000Progress: 28054/100000Progress: 28055/100000Progress: 28056/100000Progress: 28057/100000Progress: 28058/100000Progress: 28059/100000Progress: 28060/100000Progress: 28061/100000Progress: 2

Progress: 32704/100000Progress: 32705/100000Progress: 32706/100000Progress: 32707/100000Progress: 32708/100000Progress: 32709/100000Progress: 32710/100000Progress: 32711/100000Progress: 32712/100000Progress: 32713/100000Progress: 32714/100000Progress: 32715/100000Progress: 32716/100000Progress: 32717/100000Progress: 32718/100000Progress: 32719/100000Progress: 32720/100000Progress: 32721/100000Progress: 32722/100000Progress: 32723/100000Progress: 32724/100000Progress: 32725/100000Progress: 32726/100000Progress: 32727/100000Progress: 32728/100000Progress: 32729/100000Progress: 32730/100000Progress: 32731/100000Progress: 32732/100000Progress: 32733/100000Progress: 32734/100000Progress: 32735/100000Progress: 32736/100000Progress: 32737/100000Progress: 32738/100000Progress: 32739/100000Progress: 32740/100000Progress: 32741/100000Progress: 32742/100000Progress: 32743/100000Progress: 32744/100000Progress: 32745/100000Progress: 32746/100000Progress: 3

Progress: 37287/100000Progress: 37288/100000Progress: 37289/100000Progress: 37290/100000Progress: 37291/100000Progress: 37292/100000Progress: 37293/100000Progress: 37294/100000Progress: 37295/100000Progress: 37296/100000Progress: 37297/100000Progress: 37298/100000Progress: 37299/100000Progress: 37300/100000Progress: 37301/100000Progress: 37302/100000Progress: 37303/100000Progress: 37304/100000Progress: 37305/100000Progress: 37306/100000Progress: 37307/100000Progress: 37308/100000Progress: 37309/100000Progress: 37310/100000Progress: 37311/100000Progress: 37312/100000Progress: 37313/100000Progress: 37314/100000Progress: 37315/100000Progress: 37316/100000Progress: 37317/100000Progress: 37318/100000Progress: 37319/100000Progress: 37320/100000Progress: 37321/100000Progress: 37322/100000Progress: 37323/100000Progress: 37324/100000Progress: 37325/100000Progress: 37326/100000Progress: 37327/100000Progress: 37328/100000Progress: 37329/100000Progress: 3

Progress: 41961/100000Progress: 41962/100000Progress: 41963/100000Progress: 41964/100000Progress: 41965/100000Progress: 41966/100000Progress: 41967/100000Progress: 41968/100000Progress: 41969/100000Progress: 41970/100000Progress: 41971/100000Progress: 41972/100000Progress: 41973/100000Progress: 41974/100000Progress: 41975/100000Progress: 41976/100000Progress: 41977/100000Progress: 41978/100000Progress: 41979/100000Progress: 41980/100000Progress: 41981/100000Progress: 41982/100000Progress: 41983/100000Progress: 41984/100000Progress: 41985/100000Progress: 41986/100000Progress: 41987/100000Progress: 41988/100000Progress: 41989/100000Progress: 41990/100000Progress: 41991/100000Progress: 41992/100000Progress: 41993/100000Progress: 41994/100000Progress: 41995/100000Progress: 41996/100000Progress: 41997/100000Progress: 41998/100000Progress: 41999/100000Progress: 42000/100000Progress: 42001/100000Progress: 42002/100000Progress: 42003/100000Progress: 4

Progress: 44974/100000Progress: 44975/100000Progress: 44976/100000Progress: 44977/100000Progress: 44978/100000Progress: 44979/100000Progress: 44980/100000Progress: 44981/100000Progress: 44982/100000Progress: 44983/100000Progress: 44984/100000Progress: 44985/100000Progress: 44986/100000Progress: 44987/100000Progress: 44988/100000Progress: 44989/100000Progress: 44990/100000Progress: 44991/100000Progress: 44992/100000Progress: 44993/100000Progress: 44994/100000Progress: 44995/100000Progress: 44996/100000Progress: 44997/100000Progress: 44998/100000Progress: 44999/100000Progress: 45000/100000Progress: 45001/100000Progress: 45002/100000Progress: 45003/100000Progress: 45004/100000Progress: 45005/100000Progress: 45006/100000Progress: 45007/100000Progress: 45008/100000Progress: 45009/100000Progress: 45010/100000Progress: 45011/100000Progress: 45012/100000Progress: 45013/100000Progress: 45014/100000Progress: 45015/100000Progress: 45016/100000Progress: 4

Progress: 49909/100000Progress: 49910/100000Progress: 49911/100000Progress: 49912/100000Progress: 49913/100000Progress: 49914/100000Progress: 49915/100000Progress: 49916/100000Progress: 49917/100000Progress: 49918/100000Progress: 49919/100000Progress: 49920/100000Progress: 49921/100000Progress: 49922/100000Progress: 49923/100000Progress: 49924/100000Progress: 49925/100000Progress: 49926/100000Progress: 49927/100000Progress: 49928/100000Progress: 49929/100000Progress: 49930/100000Progress: 49931/100000Progress: 49932/100000Progress: 49933/100000Progress: 49934/100000Progress: 49935/100000Progress: 49936/100000Progress: 49937/100000Progress: 49938/100000Progress: 49939/100000Progress: 49940/100000Progress: 49941/100000Progress: 49942/100000Progress: 49943/100000Progress: 49944/100000Progress: 49945/100000Progress: 49946/100000Progress: 49947/100000Progress: 49948/100000Progress: 49949/100000Progress: 49950/100000Progress: 49951/100000Progress: 4

Progress: 53847/100000Progress: 53848/100000Progress: 53849/100000Progress: 53850/100000Progress: 53851/100000Progress: 53852/100000Progress: 53853/100000Progress: 53854/100000Progress: 53855/100000Progress: 53856/100000Progress: 53857/100000Progress: 53858/100000Progress: 53859/100000Progress: 53860/100000Progress: 53861/100000Progress: 53862/100000Progress: 53863/100000Progress: 53864/100000Progress: 53865/100000Progress: 53866/100000Progress: 53867/100000Progress: 53868/100000Progress: 53869/100000Progress: 53870/100000Progress: 53871/100000Progress: 53872/100000Progress: 53873/100000Progress: 53874/100000Progress: 53875/100000Progress: 53876/100000Progress: 53877/100000Progress: 53878/100000Progress: 53879/100000Progress: 53880/100000Progress: 53881/100000Progress: 53882/100000Progress: 53883/100000Progress: 53884/100000Progress: 53885/100000Progress: 53886/100000Progress: 53887/100000Progress: 53888/100000Progress: 53889/100000Progress: 5

Progress: 58582/100000Progress: 58583/100000Progress: 58584/100000Progress: 58585/100000Progress: 58586/100000Progress: 58587/100000Progress: 58588/100000Progress: 58589/100000Progress: 58590/100000Progress: 58591/100000Progress: 58592/100000Progress: 58593/100000Progress: 58594/100000Progress: 58595/100000Progress: 58596/100000Progress: 58597/100000Progress: 58598/100000Progress: 58599/100000Progress: 58600/100000Progress: 58601/100000Progress: 58602/100000Progress: 58603/100000Progress: 58604/100000Progress: 58605/100000Progress: 58606/100000Progress: 58607/100000Progress: 58608/100000Progress: 58609/100000Progress: 58610/100000Progress: 58611/100000Progress: 58612/100000Progress: 58613/100000Progress: 58614/100000Progress: 58615/100000Progress: 58616/100000Progress: 58617/100000Progress: 58618/100000Progress: 58619/100000Progress: 58620/100000Progress: 58621/100000Progress: 58622/100000Progress: 58623/100000Progress: 58624/100000Progress: 5

Progress: 63190/100000Progress: 63191/100000Progress: 63192/100000Progress: 63193/100000Progress: 63194/100000Progress: 63195/100000Progress: 63196/100000Progress: 63197/100000Progress: 63198/100000Progress: 63199/100000Progress: 63200/100000Progress: 63201/100000Progress: 63202/100000Progress: 63203/100000Progress: 63204/100000Progress: 63205/100000Progress: 63206/100000Progress: 63207/100000Progress: 63208/100000Progress: 63209/100000Progress: 63210/100000Progress: 63211/100000Progress: 63212/100000Progress: 63213/100000Progress: 63214/100000Progress: 63215/100000Progress: 63216/100000Progress: 63217/100000Progress: 63218/100000Progress: 63219/100000Progress: 63220/100000Progress: 63221/100000Progress: 63222/100000Progress: 63223/100000Progress: 63224/100000Progress: 63225/100000Progress: 63226/100000Progress: 63227/100000Progress: 63228/100000Progress: 63229/100000Progress: 63230/100000Progress: 63231/100000Progress: 63232/100000Progress: 6

Progress: 68434/100000Progress: 68435/100000Progress: 68436/100000Progress: 68437/100000Progress: 68438/100000Progress: 68439/100000Progress: 68440/100000Progress: 68441/100000Progress: 68442/100000Progress: 68443/100000Progress: 68444/100000Progress: 68445/100000Progress: 68446/100000Progress: 68447/100000Progress: 68448/100000Progress: 68449/100000Progress: 68450/100000Progress: 68451/100000Progress: 68452/100000Progress: 68453/100000Progress: 68454/100000Progress: 68455/100000Progress: 68456/100000Progress: 68457/100000Progress: 68458/100000Progress: 68459/100000Progress: 68460/100000Progress: 68461/100000Progress: 68462/100000Progress: 68463/100000Progress: 68464/100000Progress: 68465/100000Progress: 68466/100000Progress: 68467/100000Progress: 68468/100000Progress: 68469/100000Progress: 68470/100000Progress: 68471/100000Progress: 68472/100000Progress: 68473/100000Progress: 68474/100000Progress: 68475/100000Progress: 68476/100000Progress: 6

Progress: 72285/100000Progress: 72286/100000Progress: 72287/100000Progress: 72288/100000Progress: 72289/100000Progress: 72290/100000Progress: 72291/100000Progress: 72292/100000Progress: 72293/100000Progress: 72294/100000Progress: 72295/100000Progress: 72296/100000Progress: 72297/100000Progress: 72298/100000Progress: 72299/100000Progress: 72300/100000Progress: 72301/100000Progress: 72302/100000Progress: 72303/100000Progress: 72304/100000Progress: 72305/100000Progress: 72306/100000Progress: 72307/100000Progress: 72308/100000Progress: 72309/100000Progress: 72310/100000Progress: 72311/100000Progress: 72312/100000Progress: 72313/100000Progress: 72314/100000Progress: 72315/100000Progress: 72316/100000Progress: 72317/100000Progress: 72318/100000Progress: 72319/100000Progress: 72320/100000Progress: 72321/100000Progress: 72322/100000Progress: 72323/100000Progress: 72324/100000Progress: 72325/100000Progress: 72326/100000Progress: 72327/100000Progress: 7

Progress: 76794/100000Progress: 76795/100000Progress: 76796/100000Progress: 76797/100000Progress: 76798/100000Progress: 76799/100000Progress: 76800/100000Progress: 76801/100000Progress: 76802/100000Progress: 76803/100000Progress: 76804/100000Progress: 76805/100000Progress: 76806/100000Progress: 76807/100000Progress: 76808/100000Progress: 76809/100000Progress: 76810/100000Progress: 76811/100000Progress: 76812/100000Progress: 76813/100000Progress: 76814/100000Progress: 76815/100000Progress: 76816/100000Progress: 76817/100000Progress: 76818/100000Progress: 76819/100000Progress: 76820/100000Progress: 76821/100000Progress: 76822/100000Progress: 76823/100000Progress: 76824/100000Progress: 76825/100000Progress: 76826/100000Progress: 76827/100000Progress: 76828/100000Progress: 76829/100000Progress: 76830/100000Progress: 76831/100000Progress: 76832/100000Progress: 76833/100000Progress: 76834/100000Progress: 76835/100000Progress: 76836/100000Progress: 7

Progress: 81457/100000Progress: 81458/100000Progress: 81459/100000Progress: 81460/100000Progress: 81461/100000Progress: 81462/100000Progress: 81463/100000Progress: 81464/100000Progress: 81465/100000Progress: 81466/100000Progress: 81467/100000Progress: 81468/100000Progress: 81469/100000Progress: 81470/100000Progress: 81471/100000Progress: 81472/100000Progress: 81473/100000Progress: 81474/100000Progress: 81475/100000Progress: 81476/100000Progress: 81477/100000Progress: 81478/100000Progress: 81479/100000Progress: 81480/100000Progress: 81481/100000Progress: 81482/100000Progress: 81483/100000Progress: 81484/100000Progress: 81485/100000Progress: 81486/100000Progress: 81487/100000Progress: 81488/100000Progress: 81489/100000Progress: 81490/100000Progress: 81491/100000Progress: 81492/100000Progress: 81493/100000Progress: 81494/100000Progress: 81495/100000Progress: 81496/100000Progress: 81497/100000Progress: 81498/100000Progress: 81499/100000Progress: 8

Progress: 85340/100000Progress: 85341/100000Progress: 85342/100000Progress: 85343/100000Progress: 85344/100000Progress: 85345/100000Progress: 85346/100000Progress: 85347/100000Progress: 85348/100000Progress: 85349/100000Progress: 85350/100000Progress: 85351/100000Progress: 85352/100000Progress: 85353/100000Progress: 85354/100000Progress: 85355/100000Progress: 85356/100000Progress: 85357/100000Progress: 85358/100000Progress: 85359/100000Progress: 85360/100000Progress: 85361/100000Progress: 85362/100000Progress: 85363/100000Progress: 85364/100000Progress: 85365/100000Progress: 85366/100000Progress: 85367/100000Progress: 85368/100000Progress: 85369/100000Progress: 85370/100000Progress: 85371/100000Progress: 85372/100000Progress: 85373/100000Progress: 85374/100000Progress: 85375/100000Progress: 85376/100000Progress: 85377/100000Progress: 85378/100000Progress: 85379/100000Progress: 85380/100000Progress: 85381/100000Progress: 85382/100000Progress: 8

Progress: 89840/100000Progress: 89841/100000Progress: 89842/100000Progress: 89843/100000Progress: 89844/100000Progress: 89845/100000Progress: 89846/100000Progress: 89847/100000Progress: 89848/100000Progress: 89849/100000Progress: 89850/100000Progress: 89851/100000Progress: 89852/100000Progress: 89853/100000Progress: 89854/100000Progress: 89855/100000Progress: 89856/100000Progress: 89857/100000Progress: 89858/100000Progress: 89859/100000Progress: 89860/100000Progress: 89861/100000Progress: 89862/100000Progress: 89863/100000Progress: 89864/100000Progress: 89865/100000Progress: 89866/100000Progress: 89867/100000Progress: 89868/100000Progress: 89869/100000Progress: 89870/100000Progress: 89871/100000Progress: 89872/100000Progress: 89873/100000Progress: 89874/100000Progress: 89875/100000Progress: 89876/100000Progress: 89877/100000Progress: 89878/100000Progress: 89879/100000Progress: 89880/100000Progress: 89881/100000Progress: 89882/100000Progress: 8

Progress: 94491/100000Progress: 94492/100000Progress: 94493/100000Progress: 94494/100000Progress: 94495/100000Progress: 94496/100000Progress: 94497/100000Progress: 94498/100000Progress: 94499/100000Progress: 94500/100000Progress: 94501/100000Progress: 94502/100000Progress: 94503/100000Progress: 94504/100000Progress: 94505/100000Progress: 94506/100000Progress: 94507/100000Progress: 94508/100000Progress: 94509/100000Progress: 94510/100000Progress: 94511/100000Progress: 94512/100000Progress: 94513/100000Progress: 94514/100000Progress: 94515/100000Progress: 94516/100000Progress: 94517/100000Progress: 94518/100000Progress: 94519/100000Progress: 94520/100000Progress: 94521/100000Progress: 94522/100000Progress: 94523/100000Progress: 94524/100000Progress: 94525/100000Progress: 94526/100000Progress: 94527/100000Progress: 94528/100000Progress: 94529/100000Progress: 94530/100000Progress: 94531/100000Progress: 94532/100000Progress: 94533/100000Progress: 9

Progress: 99268/100000Progress: 99269/100000Progress: 99270/100000Progress: 99271/100000Progress: 99272/100000Progress: 99273/100000Progress: 99274/100000Progress: 99275/100000Progress: 99276/100000Progress: 99277/100000Progress: 99278/100000Progress: 99279/100000Progress: 99280/100000Progress: 99281/100000Progress: 99282/100000Progress: 99283/100000Progress: 99284/100000Progress: 99285/100000Progress: 99286/100000Progress: 99287/100000Progress: 99288/100000Progress: 99289/100000Progress: 99290/100000Progress: 99291/100000Progress: 99292/100000Progress: 99293/100000Progress: 99294/100000Progress: 99295/100000Progress: 99296/100000Progress: 99297/100000Progress: 99298/100000Progress: 99299/100000Progress: 99300/100000Progress: 99301/100000Progress: 99302/100000Progress: 99303/100000Progress: 99304/100000Progress: 99305/100000Progress: 99306/100000Progress: 99307/100000Progress: 99308/100000Progress: 99309/100000Progress: 99310/100000Progress: 9

Unnamed: 0,doc_id,doc_text
99995,LA123190-0105,CLERKS AT 13 STORES ARRESTED AFTER MINORS BUY ...
99996,LA123190-0108,LOOKING TO 1991; \n THE NEW YEAR PROMISES TRE...
99997,LA123190-0117,"LOCAL; \n GIRL, 14, DIES IN DRIVE-BY INCIDENT..."
99998,LA123190-0119,"GREECE, ISRAEL HIT BY EXODUS FROM ALBANIA \n ..."
99999,LA123190-0124,<Empty Document>


# Train 

# Split a ratio of training set as development set

In [5]:
train_df = pd.read_csv(training_csv_path)
dev_df, train_df = np.split(train_df, [int(dev_set_ratio*len(train_df))])
dev_df.reset_index(drop=True, inplace=True)
train_df.reset_index(drop=True, inplace=True)

print("train_df shape:", train_df.shape)
print("dev_df shape:", dev_df.shape)
train_df.tail()

train_df shape: (96, 5)
dev_df shape: (24, 5)


Unnamed: 0,query_id,query_text,pos_doc_ids,bm25_top1000,bm25_top1000_scores
91,641,Valdez wildlife marine life,FT911-1460 FT931-15213 FT931-16010 FT933-7162 ...,LA120989-0014 LA032390-0003 LA040889-0009 LA03...,34.16304495 32.97577181 31.31040724 30.8527172...
92,642,Tiananmen Square protesters,FBIS3-1941 FBIS3-2223 FBIS3-2224 FBIS3-26281 F...,FT922-10319 FT931-8730 FT942-5501 FBIS4-24379 ...,32.38429409 30.71831856 29.63771818 29.4676000...
93,648,family leave law,FBIS3-43072 FBIS3-61562 FBIS4-25261 FR940323-0...,FR941202-0-00181 FR941202-0-00176 FR941202-0-0...,24.51293307 23.98772391 23.42756181 23.0616218...
94,649,computer viruses,FBIS3-40468 FBIS3-42979 FBIS3-43017 FBIS4-5044...,FT944-9024 FBIS4-50440 FT921-5724 FT941-13624 ...,27.84369436 27.24267123 26.98326939 26.9108106...
95,650,tax evasion indicted,LA011689-0065 LA012589-0008 LA012889-0016 LA02...,LA040889-0060 LA053189-0041 LA092590-0146 LA06...,29.72207523 27.98961258 27.73561512 27.3372072...


# Build instances for training/development set

In [6]:
%%time
doc_id_to_token_ids = {}
def preprocess_df(df):
    ''' Preprocess DataFrame into training instances for BERT. '''
    instances = []
    
    # Parse CSV
    for i, row in df.iterrows():
        query_id, query_text, pos_doc_ids, bm25_top1000, _ = row
        pos_doc_id_list = pos_doc_ids.split()
        pos_doc_id_set = set(pos_doc_id_list)
        bm25_top1000_list = bm25_top1000.split()
        bm25_top1000_set = set(bm25_top1000_list)

        # Pair BM25 neg. with pos. samples
        labeled_pos_neg_list = []
        for pos_doc_id in pos_doc_id_list:
            neg_doc_id_set = bm25_top1000_set - pos_doc_id_set
            neg_doc_ids = random.sample(neg_doc_id_set, num_negatives)
            pos_position = random.randint(0, num_negatives)
            pos_neg_doc_ids = neg_doc_ids
            pos_neg_doc_ids.insert(pos_position, pos_doc_id)
            labeled_sample = (pos_neg_doc_ids, pos_position)
            labeled_pos_neg_list.append(labeled_sample)
            
        # Make query tokens for BERT
        query_tokens = tokenizer.tokenize(query_text)
        if len(query_tokens) > max_query_length:  # truncation
            query_tokens = query_tokens[:max_query_length]
        query_token_ids = tokenizer.convert_tokens_to_ids(query_tokens)
        query_token_ids.insert(0, tokenizer.cls_token_id)
        query_token_ids.append(tokenizer.sep_token_id)

        # Make input instances for all query/doc pairs
        for doc_ids, label in labeled_pos_neg_list:
            paired_input_ids = []
            paired_attention_mask = []
            paired_token_type_ids = []
            
            # Merge all pos/neg inputs as a single sample
            for doc_id in doc_ids:
                if doc_id in doc_id_to_token_ids:
                    doc_token_ids = doc_id_to_token_ids[doc_id]
                else:
                    doc_text = doc_id_to_text[doc_id]
                    doc_tokens = tokenizer.tokenize(doc_text)
                    doc_token_ids = tokenizer.convert_tokens_to_ids(doc_tokens)
                    doc_id_to_token_ids[doc_id] = doc_token_ids
                doc_token_ids.append(tokenizer.sep_token_id)

                # make input sequences for BERT
                input_ids = query_token_ids + doc_token_ids
                token_type_ids = [0 for token_id in query_token_ids]
                token_type_ids.extend(1 for token_id in doc_token_ids)
                if len(input_ids) > max_input_length:  # truncation
                    input_ids = input_ids[:max_input_length]
                    token_type_ids = token_type_ids[:max_input_length]
                attention_mask = [1 for token_id in input_ids]
                
                # convert and collect inputs as tensors
                input_ids = torch.LongTensor(input_ids)
                attention_mask = torch.FloatTensor(attention_mask)
                token_type_ids = torch.LongTensor(token_type_ids)
                paired_input_ids.append(input_ids)
                paired_attention_mask.append(attention_mask)
                paired_token_type_ids.append(token_type_ids)
            label = torch.LongTensor([label]).squeeze()
            
            # Pre-pad tensor pairs for efficiency
            paired_input_ids = pad_sequence(paired_input_ids, batch_first=True)
            paired_attention_mask = pad_sequence(paired_attention_mask, batch_first=True)
            paired_token_type_ids = pad_sequence(paired_token_type_ids, batch_first=True)

            # collect all inputs as a dictionary
            instance = {}
            instance['input_ids'] = paired_input_ids.T  # transpose for code efficiency
            instance['attention_mask'] = paired_attention_mask.T
            instance['token_type_ids'] = paired_token_type_ids.T
            instance['label'] = label
            instances.append(instance)

        print("Progress: %d/%d\r" % (i+1, len(df)), end='')
    print()
    return instances

train_instances = preprocess_df(train_df)
dev_instances = preprocess_df(dev_df)

print("num. train_instances: %d" % len(train_instances))
print("num. dev_instances: %d" % len(dev_instances))
print("input_ids.T shape:", train_instances[0]['input_ids'].T.shape)
train_instances[0]['input_ids'].T

Token indices sequence length is longer than the specified maximum sequence length for this model (926 > 512). Running this sequence through the model will result in indexing errors


Progress: 96/96
Progress: 24/24
num. train_instances: 7679
num. dev_instances: 1677
input_ids.T shape: torch.Size([4, 512])
CPU times: user 1min 5s, sys: 377 ms, total: 1min 5s
Wall time: 1min 5s


tensor([[  101,  3199,  3036,  ...,     0,     0,     0],
        [  101,  3199,  3036,  ..., 29153,  1998, 17800],
        [  101,  3199,  3036,  ...,     0,     0,     0],
        [  101,  3199,  3036,  ...,     0,     0,     0]])

# Build dataset and dataloader for PyTorch

In [7]:
class TrainingDataset(Dataset):
    def __init__(self, instances):
        self.instances = instances
    
    def __len__(self):
        return len(self.instances)
        
    def __getitem__(self, i):
        instance = self.instances[i]
        input_ids = instance['input_ids']
        attention_mask = instance['attention_mask']
        token_type_ids = instance['token_type_ids']
        label = instance['label']
        return input_ids, attention_mask, token_type_ids, label
    
def get_train_dataloader(instances, batch_size=2, num_workers=4):
    def collate_fn(batch):
        input_ids, attention_mask, token_type_ids, labels = zip(*batch)
        input_ids = pad_sequence(input_ids, batch_first=True).transpose(1,2).contiguous()  # re-transpose
        attention_mask = pad_sequence(attention_mask, batch_first=True).transpose(1,2).contiguous()
        token_type_ids = pad_sequence(token_type_ids, batch_first=True).transpose(1,2).contiguous()
        labels = torch.stack(labels)
        return input_ids, attention_mask, token_type_ids, labels
    
    dataset = TrainingDataset(instances)
    dataloader = DataLoader(dataset, collate_fn=collate_fn, shuffle=True, \
                            batch_size=batch_size, num_workers=num_workers)
    return dataloader

# Demo
dataloader = get_train_dataloader(train_instances)
for batch in dataloader:
    input_ids, attention_mask, token_type_ids, labels = batch
    break
    
print(input_ids.shape)
input_ids

torch.Size([2, 4, 512])


tensor([[[  101,  9732,  2943,  ...,  1996,  4762, 23725],
         [  101,  9732,  2943,  ...,     0,     0,     0],
         [  101,  9732,  2943,  ...,     0,     0,     0],
         [  101,  9732,  2943,  ...,     0,     0,     0]],

        [[  101,  2307,  3725,  ...,     0,     0,     0],
         [  101,  2307,  3725,  ...,  2095,  2857,  2683],
         [  101,  2307,  3725,  ...,  1011,  1011,  1011],
         [  101,  2307,  3725,  ...,  1012,  2007,  1037]]])

# Initialize and finetune BERT

In [8]:
model = BertForMultipleChoice.from_pretrained(model_name_or_path)
model.cuda()

optimizer = AdamW(model.parameters(), lr=learning_rate)
optimizer.zero_grad()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultipleChoice: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly

In [9]:
def validate(model, instances):
    total_loss = 0
    model.eval()
    dataloader = get_train_dataloader(instances, batch_size=batch_size, num_workers=num_workers)
    for batch in dataloader:
        batch = (tensor.cuda() for tensor in batch)
        input_ids, attention_mask, token_type_ids, labels = batch
        
        ''' TO-DO: 
        1. Compute the cross-entropy loss (using built-in loss of BertForMultipleChoice)
          (Hint: You need to call a function of model which takes all the 4 tensors in the batch as inputs)
          
        2. Sum up the loss of all dev-set samples
          (Hint: The built-in loss is averaged, so you should multiply it with the batch size)
        '''
        with torch.no_grad():
            outputs = model(input_ids=input_ids, 
                        token_type_ids=token_type_ids, 
                        attention_mask=attention_mask, 
                        labels=labels)
            loss = outputs[0]
        total_loss += loss * batch_size
        
    avg_loss = total_loss / len(instances)
    return avg_loss

In [10]:
patience, best_dev_loss = 0, 1e10
best_state_dict = model.state_dict()

start_time = time()
dataloader = get_train_dataloader(train_instances, batch_size=batch_size, num_workers=num_workers)
for epoch in range(1, max_epochs+1):
    model.train()
    for i, batch in enumerate(dataloader, start=1):
        batch = (tensor.cuda() for tensor in batch)
        input_ids, attention_mask, token_type_ids, labels = batch
        
        # Backpropogation
        ''' TO-DO: 
        1. Compute the cross-entropy loss (using built-in loss of BertForMultipleChoice)
          (Hint: You need to call a function of model which takes all the 4 tensors in the batch as inputs)
         
        2. Perform backpropogation on the loss (i.e. compute gradients)
        3. Optimize the model.
          (Hint: These two lines of codes can be found in PyTorch tutorial)
        '''
        outputs = model(input_ids=input_ids, 
                        token_type_ids=token_type_ids, 
                        attention_mask=attention_mask, 
                        labels=labels)

        loss = outputs[0]
        
        loss.backward()
        optimizer.step()
        
        optimizer.zero_grad()
        
        # Progress bar with timer ;-)
        elapsed_time = time() - start_time
        elapsed_time = timedelta(seconds=int(elapsed_time))
        print("Epoch: %d/%d | Batch: %d/%d | loss=%.5f | %s      \r" \
              % (epoch, max_epochs, i, len(dataloader), loss, elapsed_time), end='')
        
    # Save parameters of each epoch
    if save_model_path is not None:
        save_checkpoint_path = "%s/epoch_%d" % (save_model_path, epoch)
        model.save_pretrained(save_checkpoint_path)
        
    # Get avg. loss on development set
    print("Epoch: %d/%d | Validating...                           \r" % (epoch, max_epochs), end='')
    dev_loss = validate(model, dev_instances)
    elapsed_time = time() - start_time
    elapsed_time = timedelta(seconds=int(elapsed_time))
    print("Epoch: %d/%d | dev_loss=%.5f | %s                      " \
          % (epoch, max_epochs, dev_loss, elapsed_time))
    
    # Track best checkpoint and earlystop patience
    if dev_loss < best_dev_loss:
        patience = 0
        best_dev_loss = dev_loss
        best_state_dict = deepcopy(model.state_dict())
        if save_model_path is not None:
            model.save_pretrained(save_model_path)
    else:
        patience += 1
    
    if patience > max_patience:
        print('Earlystop at epoch %d' % epoch)
        break
        
# Restore parameters with best loss on development set
model.load_state_dict(best_state_dict)

RuntimeError: CUDA out of memory. Tried to allocate 18.00 MiB (GPU 0; 7.79 GiB total capacity; 6.41 GiB already allocated; 32.00 MiB free; 6.45 GiB reserved in total by PyTorch)
Exception raised from malloc at /pytorch/c10/cuda/CUDACachingAllocator.cpp:272 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7fb521dab1e2 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1e64b (0x7fb52200164b in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x1f464 (0x7fb522002464 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x1faa1 (0x7fb522002aa1 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x22f8b17 (0x7fb52450cb17 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0x22e1e34 (0x7fb5244f5e34 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #6: at::native::sum_out(at::Tensor&, at::Tensor const&, c10::ArrayRef<long>, bool, c10::optional<c10::ScalarType>) + 0x108 (0x7fb55dbf6788 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #7: at::native::sum(at::Tensor const&, c10::ArrayRef<long>, bool, c10::optional<c10::ScalarType>) + 0x4b (0x7fb55dbf6d9b in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x129a2b4 (0x7fb55e0ae2b4 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0xa564fe (0x7fb55d86a4fe in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::sum(at::Tensor const&, c10::ArrayRef<long>, bool, c10::optional<c10::ScalarType>) + 0x123 (0x7fb55e005a53 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x2e5888b (0x7fb55fc6c88b in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0xa564fe (0x7fb55d86a4fe in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #13: at::Tensor::sum(c10::ArrayRef<long>, bool, c10::optional<c10::ScalarType>) const + 0x123 (0x7fb55e15eaf3 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x336aaaa (0x7fb56017eaaa in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #15: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x3fd (0x7fb5601843fd in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #16: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x7fb560185fa1 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #17: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x89 (0x7fb56017e119 in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #18: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4a (0x7fb56d91e4ba in /home/teddy/.local/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #19: <unknown function> + 0xbd6df (0x7fb5bec9f6df in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #20: <unknown function> + 0x76db (0x7fb5c488f6db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #21: clone + 0x3f (0x7fb5c4bc871f in /lib/x86_64-linux-gnu/libc.so.6)
