In [17]:
import json
import lucene
import shelve
import random
import itertools
import jsonlines

from tqdm import tqdm
from glob import glob
from argparse import ArgumentParser

from tqdm import tqdm
from typing import List
from java.nio.file import Paths
from org.apache.lucene.store import SimpleFSDirectory
from org.apache.lucene.index import Term, DirectoryReader, IndexReader
from org.apache.lucene.search import IndexSearcher, Query, BoostQuery, BooleanQuery, BooleanClause, TermQuery

In [2]:
def get_query(field: str, tokens: List[str]) -> Query:
    query_builder = BooleanQuery.Builder()
    for token in tokens:
        query_builder.add(TermQuery(Term(field, token)), BooleanClause.Occur.SHOULD)
    return query_builder.build()


def get_combined_query(fields: List[str], tokens_list: List[List[str]],
                       weights: List[float], is_mandatory: List[bool]) -> Query:
    assert len(fields) == len(tokens_list) == len(weights) == len(is_mandatory)
    query_builder = BooleanQuery.Builder()
    for field, tokens, weight, flag in zip(fields, tokens_list, weights, is_mandatory):
        occur = BooleanClause.Occur.MUST if flag else BooleanClause.Occur.SHOULD
        query_builder.add(BoostQuery(get_query(field, tokens), weight), occur)
    return query_builder.build()

# Index CGW texts (for debugging only)

In [20]:
with shelve.open('../data/cgw/shelve/cgw.shelve') as db:
    for split_name in ['train', 'test', 'dev']:
        with jsonlines.open(f'../data/cgw/jsonl/{split_name}.jsonl') as reader:
            for doc in tqdm(reader):
                db[doc['id']] = doc

1454306it [02:32, 9540.44it/s] 
127505it [00:13, 9251.69it/s]
163083it [00:17, 9526.34it/s]


# Query schemas

In [3]:
# Command-line args

index_dir = '../data/cgw/preds_args/lucene_index'
schema_path = '../data/schemas/descrs/election.txt'
p_w = 1.0
a_w = 1.0
pa_w = 5.0
p_must = True
a_must = True
pa_must = False

In [4]:
# Initialize lucene and the JVM
lucene.initVM()
directory = SimpleFSDirectory.open(Paths.get(index_dir))
searcher = IndexSearcher(DirectoryReader.open(directory))

In [5]:
print(f'Index contains {searcher.getIndexReader().numDocs()} docs')

Index contains 1744894 docs


In [6]:
# Retrieve schema-related docs

all_relevant_docs_ids = set()
schemas_relevant_docs_ids = {}
schemas_related_docs_ids = {}

for schema_name in ['ce_040_arrest', 'ce_001_protest', 'ce_005_disease_outbreak',
                    'plane_crash', 'election']:
    schema_path = f'../data/schemas/descrs/{schema_name}.txt'
    # Read the schema
    combined_tokens = [set(), # preds
                       set(), # args
                       set(), # preds_args
                      ]
    with open(schema_path) as fin:
        for line in fin:
            arg0, arg1, pred_idx = line.strip().split()
            pred_idx = int(pred_idx)
            assert pred_idx in range(2)
            # Wikidata refs are ignored for now
            arg0 = arg0.split(':')[0]
            arg1 = arg1.split(':')[0]
            pred = [arg0, arg1][pred_idx]
            arg = [arg0, arg1][1 - pred_idx]
            combined_tokens[0].add(pred)
            combined_tokens[1].add(arg)
            combined_tokens[2].add(f'{arg0}_{arg1}_{pred_idx}')
            
    # Query the docs
    query = get_combined_query(['preds', 'args', 'preds_args'], combined_tokens,
                               [p_w, a_w, pa_w], [p_must, a_must, pa_must])
    num_relevant_docs = searcher.count(query)
    num_related_docs = int(0.80 * num_relevant_docs if num_relevant_docs < 3000 else 0.5 * num_relevant_docs)
    print(f'Schema "{schema_name}" has {num_relevant_docs} relevant docs and {num_related_docs} related docs')
    relevant_docs_ids = [searcher.doc(d.doc).get('filename') for d in searcher.search(query, num_relevant_docs).scoreDocs]
    assert all(x.endswith('.comm') for x in relevant_docs_ids)
    relevant_docs_ids = [x[:-5] for x in relevant_docs_ids]
    related_docs_ids = relevant_docs_ids[:num_related_docs]
    schemas_relevant_docs_ids[schema_name] = relevant_docs_ids
    schemas_related_docs_ids[schema_name] = related_docs_ids
    all_relevant_docs_ids.update(relevant_docs_ids)

Schema "ce_040_arrest" has 4979 relevant docs and 2489 related docs
Schema "ce_001_protest" has 1876 relevant docs and 1500 related docs
Schema "ce_005_disease_outbreak" has 3965 relevant docs and 1982 related docs
Schema "plane_crash" has 1974 relevant docs and 1579 related docs
Schema "election" has 25521 relevant docs and 12760 related docs


In [9]:
%%time

# Sample schema-unrelated docs ids
# Retrieve schema-related and unrelated docs

schemas_related_docs = {}
schemas_unrelated_docs = {}

rnd = random.Random(0)

with shelve.open('../data/cgw/shelve/cgw.shelve') as db:
    all_docs_ids = list(db.keys())

    for schema_name, related_docs_ids in schemas_related_docs_ids.items():
        # Sample schema-unrelated docs ids
        unrelated_docs_ids = set()
        while len(unrelated_docs_ids) != len(related_docs_ids):
            sample_doc_id = rnd.choice(all_docs_ids)
            if sample_doc_id not in all_relevant_docs_ids:
                unrelated_docs_ids.add(sample_doc_id)

        # Retrieve schema-related and unrelated docs
        schemas_related_docs[schema_name] = [db[doc_id] for doc_id in related_docs_ids]
        schemas_unrelated_docs[schema_name] = [db[doc_id] for doc_id in unrelated_docs_ids]

CPU times: user 22.4 s, sys: 1.24 s, total: 23.7 s
Wall time: 28.6 s


In [10]:
# Write schema-related docs
for schema_name, docs in schemas_related_docs.items():
    with jsonlines.open(f'../data/cgw/schema_related/pos/{schema_name}.jsonl', 'w') as writer:
        writer.write_all(docs)

# Write schema-unrelated docs
for schema_name, docs in schemas_unrelated_docs.items():
    with jsonlines.open(f'../data/cgw/schema_related/neg/{schema_name}.jsonl', 'w') as writer:
        writer.write_all(docs)

In [29]:
# Write schema-related docs ids

all_schema_related_docs_ids = list(itertools.chain.from_iterable((doc['id'] for doc in docs) for docs in schemas_related_docs.values()))
with open(f'../data/cgw/schema_related/pos/docs_ids.txt', 'w') as fout:
    fout.write('\n'.join(all_schema_related_docs_ids))

# Write schema-unrelated docs ids
all_schema_unrelated_docs_ids = list(itertools.chain.from_iterable((doc['id'] for doc in docs) for docs in schemas_unrelated_docs.values()))
with open(f'../data/cgw/schema_related/neg/docs_ids.txt', 'w') as fout:
    fout.write('\n'.join(all_schema_unrelated_docs_ids))

---