In [None]:
from jax import numpy as jnp
import jax
import transformers
from transformers import TrainingArguments, Trainer

import data
import modeling_bart
import arguments
import datasets

import os

In [None]:
transformers.__version__

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

In [None]:
TRAIN_DIR = '/home/arthur/Workplace/URA/jimmy/mores_plus/training'
TRAIN_DATA = '/home/arthur/Workplace/URA/jimmy/mores_plus/training/train_dataset'
DEV_DATA = '/home/arthur/Workplace/URA/jimmy/mores_plus/training/dev_dataset'
RANK_SCORE_PATH = '/home/arthur/Workplace/URA/jimmy/mores_plus/training'

PATH_TO_TSV = '/home/arthur/Workplace/URA/jimmy/mores_plus/training/msmarco-docs.tsv'

data_args = arguments.DataArguments(train_dir=TRAIN_DIR,train_path=TRAIN_DATA,dev_path=DEV_DATA,rank_score_path=RANK_SCORE_PATH)
reranker_args = arguments.RerankerTrainingArguments(output_dir=os.path.join(TRAIN_DIR,'output'))

In [None]:
config = transformers.BartConfig()
tokenizer = transformers.BartTokenizer.from_pretrained("facebook/bart-base")
model = modeling_bart.FlaxBartMoresRanker(config=config)

In [None]:
train_dataset = data.GroupedTrainDataset(args=data_args,path_to_tsv=PATH_TO_TSV,tokenizer=tokenizer,train_args=reranker_args)

In [None]:
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=train_dataset
    # eval_dataset=small_eval_dataset,
    # compute_metrics=compute_metrics,
)

In [None]:
datasets.load_dataset(data_files='training/msmarco-docs.tsv')

In [1]:
from transformers import AutoTokenizer
import json
import os
from collections import defaultdict
import datasets
import random
from tqdm import tqdm

args = {'tokenizer_name':'facebook/bart-base', 'rank_file':'training/run.msmarco-passage.bm25.train.tsv', 'json_dir':'training/json_files',
'n_sample':10,'sample_from_top':100,'qrel':'training/document/msmarco-doctrain-qrels.tsv.gz','query_collection': 'training/document/msmarco-doctrain-queries.tsv',
'doc_collection':'training/document/msmarco-docs.tsv'}

def read_qrel():
    import gzip, csv
    qrel = {}
    with gzip.open(args['qrel'], 'rt', encoding='utf8') as f:
        tsvreader = csv.reader(f, delimiter=" ")
        for [topicid, _, docid, rel] in tsvreader:
            assert rel == "1"
            if topicid in qrel:
                qrel[topicid].append(docid)
            else:
                qrel[topicid] = [docid]
    return qrel


qrel = read_qrel()
rankings = defaultdict(list)
no_judge = set()
with open(args['rank_file']) as f:
    for l in f:
        qid, pid, rank = l.split()
        if qid not in qrel:
            no_judge.add(qid)
            continue
        if pid in qrel[qid]:
            continue
        # append passage if & only if it is not juddged relevant but ranks high
        rankings[qid].append(pid)

print(f'{len(no_judge)} queries not judged and skipped', flush=True)

columns = ['did', 'url', 'title', 'body']
collection = args['doc_collection']
collection = datasets.load_dataset(
    'csv',
    data_files=collection,
    column_names=['did', 'url', 'title', 'body'],
    delimiter='\t',
    ignore_verifications=True,
)['train']
qry_collection = args['query_collection']
qry_collection = datasets.load_dataset(
    'csv',
    data_files=qry_collection,
    column_names=['qid', 'qry'],
    delimiter='\t',
    ignore_verifications=True,
)['train']

doc_map = {x['did']: idx for idx, x in enumerate(collection)}
qry_map = {str(x['qid']): idx for idx, x in enumerate(qry_collection)}

tokenizer = AutoTokenizer.from_pretrained(args['tokenizer_name'], use_fast=True)

out_file = args['rank_file']
if out_file.endswith('.tsv') or out_file.endswith('.txt'):
    out_file = out_file[:-4]
out_file = os.path.join(args['json_dir'], os.path.split(out_file)[1])
out_file = out_file + '.group.json'

queries = list(rankings.keys())

  from .autonotebook import tqdm as notebook_tqdm


441704 queries not judged and skipped


Using custom data configuration default-862e81ca74ac2f89
Found cached dataset csv (/home/arthur/.cache/huggingface/datasets/csv/default-862e81ca74ac2f89/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)
100%|██████████| 1/1 [00:00<00:00, 20.35it/s]
Using custom data configuration default-dab58c8e68de5015
Found cached dataset csv (/home/arthur/.cache/huggingface/datasets/csv/default-dab58c8e68de5015/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)
100%|██████████| 1/1 [00:00<00:00, 1184.83it/s]


In [30]:
with open(out_file, 'w') as f:
    for qid in tqdm(queries):
        # pick from top of the full initial ranking
        negs = rankings[qid][:args['sample_from_top']]
        # shuffle if random flag is on
        random.shuffle(negs)
        # pick n samples
        negs = negs[:args['n_sample']]

        neg_encoded = []
        for neg in negs:
            idx = doc_map['D'+neg]
            print(idx)
            item = collection[idx]
            did, url, title, body = (item[k] for k in columns)
            url, title, body = map(lambda v: v if v else '', [url, title, body])
            encoded_neg = tokenizer.encode(
                url + tokenizer.sep_token + title + tokenizer.sep_token + body,
                add_special_tokens=False,
                max_length=args.truncate,
                truncation=True
            )
            neg_encoded.append({
                'passage': encoded_neg,
                'pid': neg,
            })
        pos_encoded = []
        for pos in qrel[qid]:
            idx = doc_map[pos]
            item = collection[idx]
            did, url, title, body = (item[k] for k in columns)
            url, title, body = map(lambda v: v if v else '', [url, title, body])
            encoded_pos = tokenizer.encode(
                url + tokenizer.sep_token + title + tokenizer.sep_token + body,
                add_special_tokens=False,
                max_length=args.truncate,
                truncation=True
            )
            pos_encoded.append({
                'passage': encoded_pos,
                'pid': pos,
            })
        q_idx = qry_map[qid]
        query_dict = {
            'qid': qid,
            'query': tokenizer.encode(
                qry_collection[q_idx]['qry'],
                add_special_tokens=False,
                max_length=args.truncate,
                truncation=True),
        }
        item_set = {
            'qry': query_dict,
            'pos': pos_encoded,
            'neg': neg_encoded,
        }
        f.write(json.dumps(item_set) + '\n')

  0%|          | 0/367007 [00:00<?, ?it/s]


KeyError: 'D7282917'

3213834