In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, sys
import pandas as pd
import re
import numpy as np
import ujson as json
import jsonlines
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from IPython.core.display import display, HTML, Markdown
from bootleg.symbols.entity_symbols import EntitySymbols
from bootleg.symbols.type_symbols import TypeSymbols
from bootleg.symbols.kg_symbols import KGSymbols
from nltk.stem import PorterStemmer
ps = PorterStemmer()
def printmd(string):
    display(Markdown(string))
tqdm.pandas()
display(HTML("<style>.container { width:90% !important; }</style>"))
pd.options.display.max_colwidth = 500
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 5000)

In [None]:
input_dir = Path("/dfs/scratch0/lorr1/projects/bootleg-data/data/korealiases_title_0122")
output_dir = input_dir / "resliced"
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
a2q = json.load(open(input_dir / "entity_db/entity_mappings/alias2qids.json"))
entity_dump = EntitySymbols(load_dir=input_dir / "entity_db/entity_mappings")
emb_dir = Path('/dfs/scratch0/lorr1/projects/bootleg-data/embs')
types_hy = TypeSymbols(entity_dump, emb_dir, max_types=3, type_vocab_file="hyena_vocab.json", type_file="hyena_types_1229.json")
types_wd = TypeSymbols(entity_dump, emb_dir, max_types=3, type_vocab_file="wikidatatitle_to_typeid_1229.json", type_file="wikidata_types_1229.json")
types_rel = TypeSymbols(entity_dump, emb_dir, max_types=50, type_vocab_file="relation_to_typeid_1229.json", type_file="kg_relation_types_1229.json")
kg_syms = KGSymbols(entity_dump, emb_dir, "kg_adj_1229.txt")
q2title = json.load(open(input_dir / "entity_db/entity_mappings/qid2title.json"))
title2q = {v:k for k,v in q2title.items()}

In [None]:
def any_word_contained_in_type(regexes, type_name):
    return any([re.search(w, type_name) is not None for w in regexes])

def any_word_in_any_type_set(regexes, type_names):
    return any([any_word_contained_in_type(regexes, type_name) for type_name in type_names])

def cand_idx_has_types_with_regexes(regexes, cand_types):
    return_cand_idx = []
    for cand_idx, type_names in enumerate(cand_types):
        if any_word_in_any_type_set(regexes, type_names):
            return_cand_idx.append(cand_idx)
    return return_cand_idx

def num_words_in_sentence(words, sentence):
    cnt = 0
    for w in sentence.lower().split():
        if ps.stem(w) in words:
            cnt += 1
    return cnt

In [None]:
#SLICE FUNCTIONS
team_wd_type_res = [r"club(?! season)", r"team(?! season)"]
location_wd_type_res = [r"^country$", r"city", r"town"]
country_wd_type_res = [r"^country$"]
person_wd_type_res = [r"^human$"]
sport_words = [ps.stem(w) for w in ["played", "match", "team", "club", "matches", "cricket", "soccer", "league",
                                    "cup", "football", "play", "teams", "champoinship", "series", "goal", "scored",
                                    "score", "win", "winner", "defense", "offense", "coach", "penalty", "tournament",
                                    "fifa", "forward", "defender", "faced", "faces"]]
months = [m.lower() for m in ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']]

def is_in_airport(al_idx, sent_idx, sentence, qid, alias, span, title, hy_types, wd_types, es):
    return len(set(['international airport', 'airport']).intersection(wd_types)) > 0

def is_in_historical(al_idx, sent_idx, sentence, qid, alias, span, title, hy_types, wd_types, es):
    return any(["historical" in ty for ty in wd_types])

def is_in_ethnic(al_idx, sent_idx, sentence, qid, alias, span, title, hy_types, wd_types, es):
    return any(["ethnic" in ty for ty in wd_types])

def is_in_tournament(al_idx, sent_idx, sentence, qid, alias, span, title, hy_types, wd_types, es):
    return any(["tournament" in ty for ty in wd_types])

def is_in_team(al_idx, sent_idx, sentence, qid, alias, span, title, hy_types, wd_types, es):
    return any_word_in_any_type_set(team_wd_type_res, wd_types)

def is_in_natsoccersports(al_idx, sent_idx, sentence, qid, alias, span, title, hy_types, wd_types, es):
    return len(set(['national association football team']).intersection(wd_types)) > 0

def is_in_football_type(al_idx, sent_idx, sentence, qid, alias, span, title, hy_types, wd_types, es):
    return any(["American football team" in ty for ty in wd_types])

In [None]:
def get_slices(line, slice_names, es, types_hy, types_wd, types_rel):
    sentence = line['sentence']
    all_spans = line["spans"]
    all_titles = [es.get_title(q) for q in line['qids']]
    all_types_hy = [types_hy.get_types(q) for q in line['qids']]
    all_types_wd = [types_wd.get_types(q) for q in line['qids']]

    new_slices = {}
    for s in slice_names:
        new_slices[s] = {}
        func_name = f"is_in_{s}"
        func = globals()[func_name]
        for al_idx in range(len(line['qids'])):
            if func(al_idx, line["sent_idx_unq"], sentence, line["qids"][al_idx], line["aliases"][al_idx], all_spans[al_idx], all_titles[al_idx], all_types_hy[al_idx], all_types_wd[al_idx], es):
                new_slices[s][str(al_idx)] = 1.0
            else:
                new_slices[s][str(al_idx)] = 0.0
    return new_slices

In [None]:
in_files = [input_dir / "dev.jsonl", input_dir / "test.jsonl"]
slice_names = [
    "airport",
    "historical",
    "ethnic",
    "tournament",
    "team",
    "natsoccersports",
    "football_type"
]

In [None]:
def get_num_lines(in_file):
    count = 0
    with open(in_file, "r") as f:
        for line in f:
            count += 1
    return count

for in_f_name in in_files:
    out_f_name = f"{os.path.splitext(in_f_name)[0]}_sliced.jsonl"
    num_lines = get_num_lines(os.path.join(input_dir, in_f_name))
    print(f"Reading in {in_f_name}")
    total_mens = 0
    with jsonlines.open(os.path.join(output_dir, out_f_name), "w") as out_f, jsonlines.open(os.path.join(input_dir, in_f_name), "r") as in_f:
        slice_totals = defaultdict(int)
        slice_overlaps = defaultdict(lambda: defaultdict(int))
        for line in tqdm(in_f, total=num_lines):
            old_slices = line["slices"]
            new_slices = get_slices(line, slice_names, entity_dump, types_hy, types_wd, types_rel)
            for s in new_slices:
                old_slices[s] = new_slices[s]
            line["slices"] = old_slices
            total_mens += len(line["aliases"])
            for s in new_slices:
                slice_totals[s] += sum(new_slices[s].values())
                for s2 in new_slices:
                    if s2 == s:
                        continue
                    for al_idx_str in new_slices[s2]:
                        if new_slices[s2][al_idx_str] > 0.5 and new_slices[s][al_idx_str] > 0.5:
                            slice_overlaps[s][s2] += 1
#             if sum(new_slices["airport"].values()) > 0:
#                 print(line["aliases"], line["qids"], line["sentence"])
            out_f.write(line)
    print(f"Wrote out to {out_f_name} with {total_mens} mentions and slice totals {json.dumps(slice_totals, indent=4)} and overlaps of {json.dumps(slice_overlaps, indent=4)}")