In [1]:
import pandas as pd
import numpy as np
import random
import os
import json

from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import transformers
from transformers import AdamW
from transformers import AutoModel, AutoTokenizer, AutoConfig

from model import DPR

from datasets import load_dataset
from haystack.utils import print_answers
from haystack.nodes import ElasticsearchRetriever, FARMReader
from haystack.document_stores import ElasticsearchDocumentStore
from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline



INFO - haystack.document_stores.base -  Numba not found, replacing njit() with no-op implementation. Enable it with 'pip install numba'.
INFO - haystack.modeling.model.optimization -  apex not found, won't use it. See https://nvidia.github.io/apex/


In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def load_data(path):
    data= []
    with open(path) as f:
        for line in f:
            data.append(json.loads(line))

    return data

In [3]:
seed_everything(42)

train_data= load_data('./data/specific_train.jsonl')
valid_data= load_data('./data/specific_val.jsonl')


train_data[0].keys()

dict_keys(['query', 'passage', 'target', 'relevant_span', 'context_idx'])

In [4]:
from tqdm import tqdm

train_list= []
for i in tqdm(range(len(train_data))):
    query, answer= train_data[i]['query'], train_data[i]['relevant_span']
    train_list.append(dict({'question': query, 'answers': [answer]}))

valid_list= []
for i in tqdm(range(len(valid_data))):
    query, answer= valid_data[i]['query'], valid_data[i]['relevant_span']
    valid_list.append(dict({'question': query, 'answers': [answer]}))

with open('./data/dpr_train.json', 'w') as f:
    json.dump(train_list, f)

with open('./data/dpr_dev.json', 'w') as f:
    json.dump(valid_list, f)


100%|██████████| 1095/1095 [00:00<00:00, 590253.55it/s]
100%|██████████| 237/237 [00:00<00:00, 493325.09it/s]


In [5]:
from haystack.nodes import DensePassageRetriever
from haystack.utils import fetch_archive_from_http
from haystack.document_stores import InMemoryDocumentStore


query_model = "facebook/dpr-question_encoder-single-nq-base"
passage_model = "facebook/dpr-ctx_encoder-single-nq-base"

retriever = DensePassageRetriever(
    document_store=InMemoryDocumentStore(),
    query_embedding_model=query_model,
    passage_embedding_model=passage_model,
    max_seq_len_query=64,
    max_seq_len_passage=256,
)
print('fin')

INFO - haystack.modeling.utils -  Using devices: CUDA:0
INFO - haystack.modeling.utils -  Number of GPUs: 1
INFO - haystack.modeling.utils -  Using devices: CUDA:0
INFO - haystack.modeling.utils -  Number of GPUs: 1
INFO - haystack.modeling.model.language_model -  LOADING MODEL
INFO - haystack.modeling.model.language_model -  Could not find facebook/dpr-question_encoder-single-nq-base locally.
INFO - haystack.modeling.model.language_model -  Looking on Transformers Model Hub (in local cache and online)...
INFO - haystack.modeling.model.language_model -  Loaded facebook/dpr-question_encoder-single-nq-base
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizerFast'.
INFO - haystack.modeling.model.language_model -  LOADING MODEL
INFO - 

fin


In [6]:
doc_dir = "./data/"
train_filename = "dpr_train.json"
dev_filename = "dpr_dev.json"
save_dir= './save_model/'

retriever.train(
    data_dir=doc_dir,
    train_filename=train_filename,
    dev_filename=dev_filename,
    test_filename=dev_filename,
    n_epochs=3,
    batch_size=4,
    grad_acc_steps=32,
    save_dir=save_dir,
    evaluate_every=16,
    # embed_title=True,
    # num_positives=1,
    # num_hard_negatives=1,
)

INFO - haystack.modeling.data_handler.data_silo -  
Loading data into the data silo ... 
              ______
               |o  |   !
   __          |:`_|---'-.
  |__|______.-/ _ \-----.|       
 (o)(o)------'\ _ /     ( )      
 
INFO - haystack.modeling.data_handler.data_silo -  LOADING TRAIN DATA
INFO - haystack.modeling.data_handler.data_silo -  Loading train set from: data/dpr_train.json 
INFO - haystack.modeling.data_handler.data_silo -  Got ya 7 parallel workers to convert 1095 dictionaries to pytorch datasets (chunksize = 32)...
INFO - haystack.modeling.data_handler.data_silo -   0    0    0    0    0    0    0 
INFO - haystack.modeling.data_handler.data_silo -  /w\  /w\  /w\  /w\  /w\  /w\  /w\
INFO - haystack.modeling.data_handler.data_silo -  /'\  / \  /'\  /'\  / \  / \  /'\
Preprocessing Dataset data/dpr_train.json: 100%|██████████| 1095/1095 [00:01<00:00, 1058.20 Dicts/s]
INFO - haystack.modeling.data_handler.data_silo -  
INFO - haystack.modeling.data_handler.data_silo 