In [8]:
# %pdb on
import json
import re
import numpy as np
import copy
from tqdm import tqdm_notebook 
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from redis import StrictRedis

In [9]:
db = StrictRedis(host='localhost', port=6379, db=0)

In [3]:
with open('./hotpot_train_v1.1.json', 'r') as fin:
    train_set = json.load(fin)
print('Finish Reading! len = ', len(train_set))

Finish Reading! len =  90447


In [6]:
with open('./hotpot_train_v1.1_small.json', 'w') as fout:
    json.dump(train_set[:5000], fout)

In [17]:
GENERAL_WD = ['is', 'are', 'am', 'was', 'were', 'have', 'has', 'had', 'can', 'could', 
              'shall', 'will', 'should', 'would', 'do', 'does', 'did', 'may', 'might', 'must', 'ought', 'need', 'dare']
GENERAL_WD += [x.capitalize() for x in GENERAL_WD]
GENERAL_WD = re.compile(' |'.join(GENERAL_WD))

def judge_question_type(q : str, G = GENERAL_WD) -> int:
    if q.find(' or ') >= 0:
        return 2 
    elif G.match(q):
        return 1
    else:
        return 0

In [12]:
# print(judge_question_type('Who has a longer middle name, Alice Walker or Michael Herr?'))
print(db.lrange('Miami Gardens, Florida', 0, -1)[0].decode())

Miami Gardens is a suburban city located in north-central Miami-Dade County, Florida.


In [None]:
BERT_MODEL = 'bert-base-uncased'
# tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True)

In [48]:
# print(tokenizer.tokenize('Brian \"Boosh\" Boucher (pronounced \"Boo-shay\") (born January 2, 1977) is a retired American professional ice hockey goaltender, who played 13 seasons in the National Hockey League (NHL) for the Philadelphia Flyers, Phoenix Coyotes, Calgary Flames, Chicago Blackhawks, Columbus Blue Jackets, San Jose Sharks, and Carolina Hurricanes.'))

In [77]:
from hotpot_evaluate_v1 import normalize_answer, f1_score
from fuzzywuzzy import fuzz, process as fuzzy_process

def fuzzy_retrive(entity, pool):
    if len(pool) > 100:
        # fullwiki, exact match
        # TODO: test ``entity (annotation)'' and find the most like one
        if pool.get(entity):
            return entity
        else:
            return None
    else:
        # distractor mode or use link in original wiki, no need to consider ``entity (annotation)''
        pool = pool if isinstance(pool, list) else pool.keys()
        f1max, ret = 0, None
        for t in pool:
            f1, precision, recall = f1_score(entity, t)
            if f1 > f1max:
                f1max, ret = f1, t
        return ret

def find_near_matches(w, sentence):
    ret = []
    max_ratio = 0
    t = 0
    for word in sentence.split():
        while sentence[t] != word[0]:
            t += 1
        score = (fuzz.ratio(w, word) + fuzz.partial_ratio(w, word)) / 2
        if score > max_ratio:
            max_ratio = score
            ret = [(t, t + len(word))]
        elif score == max_ratio:
            ret.append((t, t + len(word)))
        else:
            pass
        t += len(word)
    return ret if max_ratio > 85 else []     

def dp(a, b): # a source, b long text
    f, start = np.zeros((len(a), len(b))), np.zeros((len(a), len(b)), dtype = np.int)
    for j in range(len(b)):
        f[0, j] = int(a[0] != b[j])
        if j > 0 and b[j - 1].isalnum():
            f[0, j] += 10
        start[0, j] = j
    for i in range(1, len(a)):        
        for j in range(len(b)):
            # (0, i-1) + del(i) ~ (start[j], j)
            f[i, j] = f[i - 1, j] + 1
            start[i, j] = start[i - 1, j]
            if j == 0:
                continue
            if f[i, j] > f[i - 1, j - 1] + int(a[i] != b[j]):
                f[i, j] = f[i - 1, j - 1] + int(a[i] != b[j])
                start[i, j] = start[i-1, j - 1]

            if f[i, j] > f[i, j - 1] + 0.5:
                f[i, j] = f[i, j - 1] + 0.5
                start[i, j] = start[i, j - 1]
#     print(f[len(a) - 1])
    r = np.argmin(f[len(a) - 1])
    ret = [start[len(a) - 1, r], r + 1]
#     print(b[ret[0]:ret[1]])
    score = f[len(a) - 1, r] / len(a)
    return (ret, score)

def fuzzy_find(entities, sentence):
    ret = []
    for entity in entities:
        item = re.sub(r'\(.*\)$', '', entity).strip()
        r, score = dp(item, sentence)
        if score < 0.5:
            matched = sentence[r[0]: r[1]].lower()
            final_word = item.split()[-1]
            # from end
            retry = False
            while fuzz.partial_ratio(final_word.lower(), matched) < 80:
                retry = True
                end = len(item) - len(final_word)
                while end > 0 and item[end - 1].isspace():
                    end -= 1
                if end == 0:
                    retry = False
                    score = 1
                    break
                item = item[:end]
                final_word = item.split()[-1]
            if retry:
#                 print(entity + ' ### ' + sentence[r[0]: r[1]] + ' ### ' + item)
                r, score = dp(item, sentence)
                score += 0.1

            if score >= 0.5:
#                 print(entity + ' ### ' + sentence[r[0]: r[1]] + ' ### ' + item)
                continue
            del final_word
            # from start
            retry = False
            first_word = item.split()[0]
            while fuzz.partial_ratio(first_word.lower(), matched) < 80:
                retry = True
                start = len(first_word)
                while start < len(item) and item[start].isspace():
                    start += 1
                if start == len(item):
                    retry = False
                    score = 1
                    break
                item = item[start:]
                first_word = item.split()[0]
            if retry:
#                 print(entity + ' ### ' + sentence[r[0]: r[1]] + ' ### ' + item)
                r, score = dp(item, sentence)
                score = max(score, 1 - ((r[1] - r[0]) / len(entity)))
                score += 0.1
#             if score > 0.5:
#                 print(entity + ' ### ' + sentence[r[0]: r[1]] + ' ### ' + item)
            if score < 0.5:
                if item.isdigit() and sentence[r[0]: r[1]] != item:
                    continue
                ret.append((entity, sentence[r[0]: r[1]], int(r[0]), int(r[1]), score))
    non_intersection = []
    for i in range(len(ret)):
        ok = True
        for j in range(len(ret)):
            if j != i:
                if not (ret[i][2] >= ret[j][3] or ret[j][2] >= ret[i][3]) and ret[j][4] < ret[i][4]:
                    ok = False
                    break
                if ret[i][4] > 0.2 and ret[j][4] < 0.1 and not ret[i][1][0].isupper() and len(ret[i][1].split()) <= 3:
                    ok = False
                    print(ret[i])
                    break
        if ok:
            non_intersection.append(ret[i][:4])
    return non_intersection

# print(dp('Skiffle', 'Die Rh\u00f6ner S\u00e4uw\u00e4ntzt are a Skif, dm , fle-Bluesband from Eichenzell-L\u00fctter in Hessen, Germany.'))
# def fuzzy_find(entities, sentence):
#     items = fuzzy_process.extract(sentence, entities, scorer=fuzz.partial_token_set_ratio)
#     items = [x for x, y in items if y > 85]
#     items_matched = []
#     for item in items:
#         positions = []
#         for w in re.split('[\s,.?!]', item):
#             r = find_near_matches(w, sentence)
#             if len(r) > 0:
#                 # assume by default sorted by starts
#                 positions.append(r)
#         # To find an interval, which length is minimized
#         print(item, positions)
#         assert len(positions) > 0
#         min_len, s_min, e_min = len(sentence), -1, -1
#         while s_min < 0:
#             if len(positions) == 1:
#                 s_min, e_min = positions[0][0]
#                 break
#             for s0, e0 in positions[0]:
#                 for s_1, e_1 in positions[-1]:
#                     if s_1 <= e0:
#                         continue
#                     if e_1 - s0 >= min_len:
#                         break
#                     ok = True
# #                 last = e0
# #                 for k in range(1, len(positions) - 1):
# #                     ok = False
# #                     for s_k, e_k in positions[k]:
# #                         if last < s_k and e_k < s_1:
# #                             last = e_k
# #                             ok = True
# #                             break
# #                     if not ok:
# #                         break
#                     if ok:
#                         min_len, s_min, e_min = e_1 - s0, s0, e_1
#             if min_len > 2 * len(item): # invalid, too long
#                 positions.pop()
#                 s_min, e_min = -1, -1
#         items_matched.append(sentence[s_min: e_min])
#     return list(zip(items, items_matched))   
print(list(fuzzy_find(['Miami Gardens, Florida', 'WSCV', 'Hard Rock Stadium'], r"Hard Rock Stadium is a multipurpose football stadium located in Miami Gardens, a city north of Miami. It is the home stadium of the Miami Dolphins of the National Football League (NFL).")))
print(fuzzy_find(["19 Kids and Counting", "nine girls and 10 boys"], r" A spin-off show of \"19 kids ande counting\", it features the Duggar family: Jill Dillard, Jessa Seewald, sixteen of their seventeen siblings, and parents Jim Bob and Michelle Duggar."))
print(fuzzy_retrive('Joshua Aaron Charles', ['Jawahar Navodaya Vidyalaya Kanpur', 'Dead Poets Society', 'Josh Charles', 'Aaron1', 'josh charles']))

[('Miami Gardens, Florida', 'Miami Gardens,', 64, 78), ('Hard Rock Stadium', 'Hard Rock Stadium', 0, 17)]


In [None]:
# construct cognitive graph in training data    
def find_fact_content(bundle, title, sen_num):
    for x in bundle['context']:
        if x[0] == title:
            return x[1][sen_num]
test = copy.deepcopy(train_set)
for bundle in tqdm_notebook(test):
    entities = set([title for title, sen_num in bundle['supporting_facts']])
    bundle['Q_edge'] = fuzzy_find(entities, bundle['question'])
    for fact in bundle['supporting_facts']:
        try:
            title, sen_num = fact
            pool = set()
            for i in range(sen_num + 1):
                name = 'edges:###{}###{}'.format(i, title)
                tmp = set([x.decode().split('###')[0] for x in db.lrange(name, 0, -1)])
                pool |= tmp
            pool &= entities
            pool.add(bundle['answer'])
            pool.discard(title)
            r = fuzzy_find(pool, find_fact_content(bundle, title, sen_num))
            fact.append(r)
        except IndexError as e:
            print(bundle['_id'])
with open('./hotpot_train_v1.1_refined.json', 'w') as fout:
    json.dump(test, fout)

HBox(children=(IntProgress(value=0, max=90447), HTML(value='')))

('2000–01 Utah Jazz season', '2000–01 NBA season', 4, 22, 0.375)
('Operation Paperclip', 'operation', 123, 132, 0.2111111111111111)
('John of Bohemia', 'of Bohemia', 59, 69, 0.43333333333333335)
('Tyndall Air Force Base', 'air force base', 6, 20, 0.4636363636363636)
('Hurricane Ivan', 'hurricane', 16, 25, 0.2111111111111111)
5a7b23ca554299042af8f703
('Clark County, Nevada', 'county in Nevada', 5, 21, 0.3142857142857143)
('Early flying machines', 'flying machines', 32, 47, 0.3857142857142857)
('Us Weekly', 'weekly', 47, 53, 0.43333333333333335)
('the quiet Beatle', 'the Beatle', 55, 65, 0.375)
('Texas country music', 'country music', 40, 53, 0.4157894736842105)
('The Bends', 'the', 89, 92, 0.43333333333333335)
('Football in Munich', 'footballin', 6, 16, 0.28181818181818186)
('the physical universe', 'the Vienna Univers', 185, 203, 0.42857142857142855)
('Spectre (2015 film)', 'actre', 24, 29, 0.42857142857142855)
('County Cork', 'county', 36, 42, 0.26666666666666666)
('Robins Air Force B

('jazz tenor saxophonist', 'tenor saxophonist', 106, 123, 0.32727272727272727)
('Olathe School District', 'school district', 48, 63, 0.4181818181818182)
('Australia national soccer team', 'national soccer team', 51, 71, 0.43333333333333335)
('Dean of Lincoln', 'of Lincoln', 43, 53, 0.43333333333333335)
('Special Boat Service', 'special', 34, 41, 0.24285714285714285)
('Computer security', 'computer', 67, 75, 0.225)
5a8d6138554299585d9e37c7
('The Guard (2011 film)', 'the', 41, 44, 0.43333333333333335)
('Franco Malerba', 'doFranco Malerba', 5, 21, 0.21428571428571427)
('Please (Pet Shop Boys album)', 'release', 48, 55, 0.4166666666666667)
('Eraring Power Station', 'power station', 13, 26, 0.4809523809523809)
('M1 carbine', 'carbine', 52, 59, 0.4)
('Journal of Applied Physics', 'journal', 5, 12, 0.24285714285714285)
('NXP Semiconductors', 'semiconductor', 59, 72, 0.37777777777777777)
('12 million', 'million', 114, 121, 0.4)
('About Time (2013 film)', 'about', 57, 62, 0.30000000000000004)
(