# Index MS Marco dateset with docT5query

* https://github.com/castorini/docTTTTTquery#Predicting-Queries-from-Passages-T5-Inference-with-PyTorch

In [3]:
from elasticsearch import Elasticsearch
import logging
from typing import Dict, List, Optional
import json
#from trec_tools.trec_car.read_data import *

import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration



In [4]:
INDEX_SETTINGS = {
    "mappings": {  
    "properties": {
        "passage": {"type": "text", "term_vector": "yes", "analyzer": "english"},
        "origin": {"type": "text"}
    }
    }
}

In [5]:
INDEX_NAME = "cast_d2q"

In [6]:
es = Elasticsearch()
es_logger = logging.getLogger('elasticsearch')
es_logger.setLevel(logging.WARNING)

In [7]:
if es.indices.exists(INDEX_NAME):
    es.indices.delete(index=INDEX_NAME)
es.indices.create(index=INDEX_NAME, body=INDEX_SETTINGS)



{'acknowledged': True, 'shards_acknowledged': True, 'index': 'cast_doct5query'}

In [13]:
def index_msmarco_doct5query(filepath: str, es: Elasticsearch, index: str, car:bool = False) -> None:
    """Indexes documents from JSONL file."""
    # Init docT5query
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = T5Tokenizer.from_pretrained('castorini/doc2query-t5-base-msmarco')
    model = T5ForConditionalGeneration.from_pretrained('castorini/doc2query-t5-base-msmarco')
    model.to(device) 


    # Index and process
    bulk_data = []
    with open(filepath, "r",encoding="utf-8") as docs:
        for i,line in enumerate(docs):
            print("number: {}".format(i))
            s = line.split("\t")
            id = s[0]
            passage = s[1].strip()

            # docT5query
            input_ids = tokenizer.encode(passage, return_tensors='pt').to(device)
            outputs = model.generate(
                input_ids=input_ids,
                max_length=64,
                do_sample=True,
                top_k=10,
                num_return_sequences=3)
            
            passage += " ?".join([tokenizer.decode(outputs[i], skip_special_tokens=True) for i in range(3)])
                
            bulk_data.append(
                {"index":{"_index": index, "_id": id}}
            )
            bulk_data.append({"passage":passage,"origin":"msmarco"})
            if i%2000==0:
                es.bulk(index=index, body=bulk_data, refresh=True)
                bulk_data = []
        es.bulk(index=index,body=bulk_data, refresh=True)
        bulk_data = []
        para_cnt = 0
    if car:
        for para in iter_paragraphs(open('../data/dedup.articles-paragraphs.cbor', 'rb')):
            para_cnt += 1
            bulk_data.append({"index":{"_index": index, "_id": para.para_id}})
            bulk_data.append({"passage":para.get_text(),"origin":"car"})
            if para_cnt%2000==0:
                es.bulk(index=index, body=bulk_data, refresh=True)
                bulk_data = []
    return "finished"

In [14]:
index_msmarco_doct5query(filepath = "../data/passages1000.tsv", es = es, index = INDEX_NAME, car=False)

number: 0
number: 1
number: 2
number: 3
number: 4
number: 5
number: 6
number: 7
number: 8
number: 9
number: 10
number: 11
number: 12
number: 13
number: 14
number: 15
number: 16
number: 17
number: 18
number: 19
number: 20
number: 21
number: 22
number: 23
number: 24
number: 25
number: 26
number: 27


In [9]:
es.get(index=INDEX_NAME, id="8840824")

{'_index': 'cast_base',
 '_type': '_doc',
 '_id': '8840824',
 '_version': 1,
 '_seq_no': 1,
 '_primary_term': 1,
 'found': True,
 '_source': {'passage': 'H 2 and C are the reducing agents(note: H 2 is oxidized, so CuO is an oxidizing agent) Oxidation-Reduction Reactions In all cases: â\x80¢ If something is oxidized, something must be reduced. â\x80¢ Redox reactions move e - .2Ag + (aq) + Cu(s) 2 Ag(s) + Cu 2+ (aq) â\x80¢ OOxxidation-reduction = redox.Here: â\x80¢ Cu changes to Cu 2+ .',
  'origin': 'msmarco'}}