In [4]:
import datasets
import json
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gc
import csv
import tqdm
from multiprocessing import Pool, cpu_count
from functools import partial

prompt = """# Instruction: I will give you a conversation between a user and a system. You should rewrite the last question of the user into a self-contained query. 
# Example 1: 
# Context: 
user: Tell me about the benefit of Yoga? 
system: Increased flexibility, muscle strength. 
# Please rewrite the following user question: 
Does it help in reducing stress? 
# Re-written query: Does Yoga help in reducing stress? 
# Example 2: 
# Context: 
{ctx}
# Please rewrite the following user question: 
{question}
# Re-written query:
"""

def _process_passage_chunk(rows):
    passages_chunk = {}
    for row in rows:
        if not row:
            continue
        doc_id = row[0]
        passage_text = "\t".join(row[1:])
        passages_chunk[doc_id] = passage_text
    return passages_chunk

def _load_passages_multiprocess(passage_loc, n_workers=None, chunk_size=1000, max_rows=10):
    if n_workers is None:
        n_workers = max(1, cpu_count() - 1)

    passages = {}
    with open(passage_loc, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter='\t')
        chunk = []
        with Pool(processes=n_workers) as pool:
            results = []
            # tqdm progress bar over rows (streaming, capped by max_rows)
            for i, row in enumerate(tqdm.tqdm(reader, desc="Loading passages", unit="row")):
                if i >= max_rows:
                    break
                chunk.append(row)
                if len(chunk) >= chunk_size:
                    results.append(pool.apply_async(_process_passage_chunk, (chunk,)))
                    chunk = []
            if chunk:
                results.append(pool.apply_async(_process_passage_chunk, (chunk,)))

            for r in tqdm.tqdm(results, desc="Merging chunks", unit="chunk") :
                passages.update(r.get())
    return passages

def generate_rewrites(questions_loc, qrels_loc, passage_loc, model_name, output_name):
    # questions file has two columns: ["0" (row id), question text]
    entries = pd.read_csv(questions_loc, sep='\t', header=None, names=["row_id", "question"])
    qrels = pd.read_json(qrels_loc).to_dict(orient='index')

    print('loading')
    # load passages with multiprocessing to speed up large collections (only first 10k rows)
    passages = _load_passages_multiprocess(passage_loc, max_rows=10000)

    gc.collect()

    # model = AutoModelForCausalLM.from_pretrained(model_name)
    # tokenizer = AutoTokenizer.from_pretrained(model_name)
    # if tokenizer.pad_token is None:
        # tokenizer.pad_token = tokenizer.eos_token
    # gen_pipeline = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)

    all_prompts = []
    ids = []

    rewrites = {}

    # align each TSV row with qrels by position, not by key
    qrels_items = list(qrels.items())

    for row_idx, row in entries.iterrows():
        question = row["question"]
        # get the qrels entry by position
        _, qrel_for_row = qrels_items[row_idx]
        # qrel_for_row looks like {"5498209": 1} or {5498209: 1}
        doc_id = next(iter(qrel_for_row.keys()))
        # robustly handle string/int ids
        doc_id_str = str(doc_id)
        if doc_id_str not in passages:
            try:
                doc_id_str = str(int(doc_id))
            except Exception:
                pass
        passage = passages.get(doc_id_str, "")
        print("QUESTION:", question)
        print("PASSAGE:", passage)
        break

In [None]:
generate_rewrites('~/disco-conv-splade/DATA/topiocqa_subset/queries_rowid_train_all.tsv', '~/disco-conv-splade/DATA/topiocqa_subset/qrel_rowid_train.json', '/home/scur1719/disco-conv-splade/DATA/topiocqa_subset/full_wiki_segments_topiocqa.tsv', 'google-t5/t5-base', '~/disco-conv-splade/DATA/topiocqa_topiocqa/topiocqa_t5_rewrites.json')

loading
