In [None]:
import re
import os
import json
import math
import random
import multiprocessing
import pandas as pd
from collections import defaultdict
from tqdm import tqdm
from nltk.corpus import stopwords
from fuzzywuzzy import fuzz
from seal import FMIndex, SEALSearcher
from transformers import AutoTokenizer
from seal.evaluate import evaluator

### Get Corpus

In [None]:
docmap = {}

with open("../../../data/nq-data/nq-docs-sents.top.320k.json", "r") as fin:
    for i, line in enumerate(tqdm(fin)):
        line = json.loads(line)
        docid, document_text, url = line['docid'], line['body'], line['url']
        
        pattern = re.compile(r'title=.*&amp')
        title = pattern.findall(url)[0]
        title = ' '.join(title[6:-4].split("_"))
        
        if document_text.find('<P>') != -1:
            abs_start = document_text.index('<P>')
            abs_end = document_text.index('</P>')
            abs = document_text[abs_start+3:abs_end]
        else:
            abs = ''
        
        if document_text.rfind('</Ul>') != -1:
            final = document_text.rindex('</Ul>')
            document_text = document_text[:final]
            if document_text.rfind('</Ul>') != -1:
                final = document_text.rindex('</Ul>')
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
            else:
                content = document_text[abs_end+4:final]
                content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
                content = re.sub(' +', ' ', content)
        else:
            content = document_text[abs_end+4:]
            content = re.sub('<[^<]+?>', '', content).replace('\n', '').strip()
            content = re.sub(' +', ' ', content)
        
        body = abs + " " + content
        body = body.split()[:128]
        body = " ".join(body)
        docmap[docid] = {'docid' : docid, 'url' : url, 'title' : title, 'body' : body}

    
with open("NQ_320k/corpus.json", "w") as fout:
    for k, v in docmap.items():
        line = json.dumps(v)
        fout.write(line + "\n")

### Get Training/Dev Set

In [None]:
docmap = {}
with open("NQ_320k/corpus.json", "r") as fin:
    for i, line in enumerate(tqdm(fin)):
        line = json.loads(line)
        docmap[line['docid']] = line
        
print("load corpus succees")

# dev
dev_querymap = {}
with open('../../../data/nq-data/nq-docdev-queries.tsv','r') as f:
    for line in tqdm(enumerate(f)):
        line = line[1].split("\t")
        dev_querymap[line[0]] = line[1].strip()

devset = []
dev = []
with open("../../../data/nq-data/nq-docdev-qrels.tsv", "r") as f:
    for line in tqdm(enumerate(f)):
        qid, _, docid, _ = line[1].strip().split("\t")
        query = dev_querymap[qid]
        devset.append({'qid' : qid, 'docid' : docid , 'query' : query})
        dev.append({'query' : query, 'docid' : docid, 'title' : docmap[docid]['title'], 'body' : docmap[docid]['body']})
        
with open("NQ_320k/dev4retrieval.json", "w") as f:
    for it in devset:
        line = json.dumps(it)
        f.write(line + "\n")

with open("NQ_320k/dev.json", "w") as f:
    for it in tqdm(dev):
        line = json.dumps(it)
        f.write(line + "\n")

print("save dev set succees")

#train 
train_querymap = {}
with open('../../../data/nq-data/nq-doctrain-queries.tsv','r') as f:
    for line in tqdm(enumerate(f)):
        line = line[1].split("\t")
        train_querymap[line[0]] = line[1].strip()

train = []
with open("../../../data/nq-data/nq-doctrain-qrels.tsv", "r") as f:
    for line in tqdm(enumerate(f)):
        qid, _, docid, _ = line[1].strip().split("\t")
        query = train_querymap[qid]
        train.append({'query' : query, 'docid' : docid, 'title' : docmap[docid]['title'], 'body' : docmap[docid]['body']})

with open("NQ_320k/train.json", "w") as f:
    for it in tqdm(train):
        line = json.dumps(it)
        f.write(line + "\n")

print("save train set succees")

In [None]:
jobs = 50
n_samples = 10
min_length = 10
max_length = 10
temperature = 1.0
banned = set(stopwords.words('english'))

def iterator_title(train_or_dev):
    with open(f"NQ_320k/{train_or_dev}.json", "r") as fin:
        for line in tqdm(fin):
            sample = json.loads(line)
            yield sample['query'].strip() + " || title || +", sample['title'].strip() + " @@"

def _iterator_span_get_arguments(train_or_dev):
    with open(f"NQ_320k/{train_or_dev}.json", "r") as fin:
        for line in tqdm(fin):
            sample = json.loads(line)
            yield sample['body'].strip(), sample['query'].strip() + " || body || +"

def span_iterator(tokens, ngrams=3, banned=banned):
    for i in range(len(tokens)):
        if tokens[i] not in banned:
            yield (i, i+ngrams)

def extract_spans(text, source, n_samples, min_length, max_length, temperature=1.0):
    source = source.split("||", 1)[0]
    query_tokens = source.split()
    query_tokens_lower = [t.lower() for t in query_tokens]
    passage_tokens = text.split()
    passage_tokens_lower = [t.lower() for t in passage_tokens]
    matches = defaultdict(int)
    for i1, _ in enumerate(query_tokens_lower):
        j1 = i1+3
        str_1 = " ".join(query_tokens_lower[i1:j1])
        for (i2, j2) in span_iterator(passage_tokens_lower, 3):
            str_2 = " ".join(passage_tokens_lower[i2:j2])
            ratio = fuzz.ratio(str_1, str_2) / 100.0
            matches[i2] += ratio

    if not matches:
        indices = [0]
    else:
        indices, weights = zip(*sorted(matches.items(), key=lambda x: -(x[1])))
        weights = list(weights)
        sum_weights = float(sum([0] + weights))
        if sum_weights == 0.0 or not weights:
            indices = [0]
            weights = [1.0]
        else:
            weights = [math.exp(float(w) / temperature) for w in weights]
            Z = sum(weights)
            weights = [w / Z for w in weights]
        indices = random.choices(indices, weights=weights, k=n_samples)

    for i in indices:
        subspan_size = random.randint(min_length, max_length)
        span = " ".join(passage_tokens[i:i+subspan_size])
        yield span

def extract_spans_wrapper(args):
    return args[1], list(extract_spans(*args))

def iterator_span(train_or_dev):
    arg_it = _iterator_span_get_arguments(train_or_dev)
    arg_it = ((text, source, n_samples, min_length, max_length, temperature) for text, source in arg_it)
    with multiprocessing.Pool(jobs) as pool:
        for source, spans in pool.imap(extract_spans_wrapper, arg_it):
            for target in spans:
                yield source, target    

def gen_title_ex(train_or_dev):
    with open(f"NQ_320k/{train_or_dev}.source", 'w') as src, open(f"NQ_320k/{train_or_dev}.target", 'w') as tgt:
        for source, target in iterator_title(train_or_dev):
            src.write(" " + source.strip() + "\n")
            tgt.write(" " + target.strip() + "\n")
        
def get_span_ex(train_or_dev):
    with open(f"NQ_320k/{train_or_dev}.source", 'a') as src, open(f"NQ_320k/{train_or_dev}.target", 'a') as tgt:
        for source, target in iterator_span(train_or_dev):
            src.write(" " + source.strip() + "\n")
            tgt.write(" " + target.strip() + "\n")

In [None]:
gen_title_ex("dev")
get_span_ex("dev")
gen_title_ex("train")
get_span_ex("train")

In [None]:
os.system("wc -l NQ_320k/train.source")
os.system("wc -l NQ_320k/dev.source")

In [None]:
banned = {
    "the", "The",
    "to", 
    "a", "A", "an", "An", 
    "he", "He", "his", "His", "him", "He's",  
    "she", "She", "her", "Her", "she's", "She's", 
    "it", "It", "its", "Its",  "it's", "It's",
    "and", "And",
    "or", "Or",
    "this", "This",
    "that", "That",
    "those", "Those",
    "these", "These",
    '"', '""', "'", "''",
}

def is_good(token):
    if token in banned:
        return False
    elif token[-1] in '?.!':
        return False
    elif token[0] in '([':
        return False
    return True

def preprocess_file(
    input_path,
    num_samples=1,
    num_title_samples=1,
    delimiter='@@', 
    min_length_input=1,
    max_length_input=15,
    min_length_output=10, 
    max_length_output=10,
    full_doc_n=0,
    ):
    
    with open(input_path, 'r', 2 ** 20) as f:
        for line in tqdm(f):
            line=json.loads(line)
            text = line['body']
            title = line['title']

            if text == title:
                continue

            tokens = text.split()

            for _ in range(full_doc_n):
                a = text.strip() + " || title || p"
                b = title.strip() + " " + delimiter
                yield a, b

            sampled = 0
            failures = 0
            while sampled < num_title_samples and failures < 10:

                if random.random() > 0.5:
                    len_a = random.randint(min_length_input, max_length_input)
                    idx_a = random.randint(0, max(0, len(tokens)-len_a))
                    a = ' '.join(tokens[idx_a:idx_a+len_a]).strip() + " || title || p"
                    b = title.strip() + " " + delimiter
                    
                else:

                    len_b = random.randint(min_length_output, max_length_output)
                    idx_b = random.randint(0, max(0, len(tokens)-len_b))
                    
                    if idx_b >= len(tokens):
                        failures += 1
                        continue
                    
                    if not is_good(tokens[idx_b]):
                        failures += 1
                        continue

                    b = ' '.join(tokens[idx_b:idx_b+len_b]).strip()
                    a = title.strip() + ' || body || p'
                    
                yield a, b
                sampled += 1

            sampled = 0
            failures = 0
            while sampled < num_samples and failures < 10:
                len_a = random.randint(min_length_input, max_length_input)
                len_b = random.randint(min_length_output, max_length_output)
                idx_a = random.randint(0, max(0, len(tokens)-len_a))
                idx_b = random.randint(0, max(0, len(tokens)-len_b))

                if idx_a == idx_b or (not is_good(tokens[idx_b])):
                    failures += 1
                    continue

                a = ' '.join(tokens[idx_a:idx_a+len_a]).strip() + ' || body || p'
                b = ' '.join(tokens[idx_b:idx_b+len_b]).strip()
                yield a, b
                sampled += 1


In [None]:

with open("NQ_320k/unsupervised.source", 'w', 2 ** 20) as src, open("NQ_320k/unsupervised.target", 'w', 2 ** 20) as tgt:
    for i, (s, t) in enumerate(preprocess_file(
        "NQ_320k/train.json",
        num_samples=3,
        num_title_samples=1,
        full_doc_n=1,
        delimiter="@@",
        min_length_input=10,
        max_length_input=10,
        min_length_output=10,
        max_length_output=10,       
    )):
        if random.random() < 0.1:
            s = s.lower()
        s = " " + s
        t = " " + t
        src.write(s + '\n')
        tgt.write(t + '\n')

os.system("cat NQ_320k/unsupervised.source >> NQ_320k/train.source")
os.system("cat NQ_320k/unsupervised.target >> NQ_320k/train.target")

In [None]:
os.system("wc -l NQ_320k/unsupervised.source")
os.system("wc -l NQ_320k/train.source")

In [None]:
# sh ./scripts/training/preprocess_fairseq_base.sh NQ_320k BART

### Building FM-Index

In [None]:
corpus = []
labels = []

with open("NQ_320k/corpus.json", "r") as fin:
    for i, line in enumerate(tqdm(fin)):
        line = json.loads(line)
        labels.append(line['docid'])
        corpus.append(line['title'] + " @@ " + line['body'])

tokenizer = AutoTokenizer.from_pretrained('../../../transformers_models/bart-base')
def preprocess(doc):
    doc = ' ' + doc
    doc = tokenizer(doc, add_special_tokens=False, truncation = True, max_length=767)['input_ids'] #767 for base; 1023 for large
    doc += [tokenizer.eos_token_id]
    return doc


corpus_tokenized = [preprocess(doc) for doc in tqdm(corpus)]

with open("NQ_320k/tokenized_corpus.json", "w") as fout:
    for i, text in tqdm(enumerate(corpus_tokenized)):
        it = {'label' : labels[i], 'doc' : text}
        it = json.dumps(it)
        fout.write(it+"\n")

In [None]:
corpus_tokenized = []
labels = []
with open("NQ_320k/tokenized_corpus.json", "r") as fin:
    for line in tqdm(fin):
        line = json.loads(line)
        labels.append(line['label'])
        corpus_tokenized.append(line['doc'])
        

index = FMIndex()
index.initialize(corpus_tokenized, in_memory=True)
index.labels = labels
index.save('NQ_320k/FM_Index/NQ_320k.base.fm_index')

### Training

In [None]:
# sh ./scripts/training/training_fairseq_base.sh NQ_320k BART

### Evaluate

In [None]:
searcher = SEALSearcher.load("NQ_320k/FM_Index/NQ_320k.base.fm_index", "checkpoints/checkpoint_best.pt", device='cuda:0' )
searcher.include_keys = True
myevaluator = evaluator()

query_list = []
result = []
truth = []

with open("NQ_320k/dev4retrieval.json", "r") as f:
    for i, line in enumerate(tqdm(f)):
        line = json.loads(line)
        tmp = []
        truth.append([line['docid']])
        for doc in searcher.search(line['query'], k=10):
            tmp.append(doc.docid)
        result.append(tmp)

res = myevaluator.evaluate_ranking(truth, result)
print(f"mrr@5:{res['mrr5']}, mrr@10:{res['mrr10']}, mrr:{res['mrr']}, p@1:{res['p1']}, p@10:{res['p10']}, p@20:{res['p20']}, p@100:{res['p100']}, r@1:{res['r1']}, r@5:{res['r5']}, r@10:{res['r10']}, r@100:{res['r100']}")