In [2]:
from collections import defaultdict
from itertools import *
import json
from operator import *
import pickle
import re

import matplotlib.pylab as pl
import nltk
import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from scipy.sparse import csr_matrix

In [3]:
input_dir = '/data/yu_gai/cfq'
output_dir = '/work/yu_gai/cfq/data/cfq'

df = sqlCtx.read.parquet(f'{input_dir}/dataset.parquet').sort('index').persist()
df.columns

['complexityMeasures',
 'expectedResponse',
 'expectedResponseWithMids',
 'index',
 'question',
 'questionPatternModEntities',
 'questionTemplate',
 'questionWithBrackets',
 'questionWithMids',
 'ruleIds',
 'ruleTree',
 'sparql',
 'sparqlPattern',
 'sparqlPatternModEntities']

In [4]:
splits = {}
split_ids = !ls {input_dir}/splits | grep json
for split_id in [s.replace('.json', '') for s in split_ids]:
    split = splits[split_id] = json.load(open(f'{input_dir}/splits/{split_id}.json'))
    np.savez(f'{output_dir}/splits/{split_id}', **{k : np.array(v) for k, v in split.items()})
    print(split_id, len(split['trainIdxs']), len(split['devIdxs']), len(split['testIdxs']), df.count())

mcd1 95743 11968 11968 239357
mcd2 95743 11968 11968 239357
mcd3 95743 11968 11968 239357
query_complexity_split 100654 9512 9512 239357
query_pattern_split 94600 12489 12589 239357
question_complexity_split 98999 10339 10340 239357
question_pattern_split 95654 12115 11909 239357
random_split 95744 11967 11967 239357


In [5]:
def replace(q):
    for s in [
        'art director',
        'country of nationality',
        'costume designer',
        'executive producer',
        'executive produce',
        'executive produced',
        'film director',
        'film distributor',
        'film editor',
        'film producer',
        'production company',
    ]:
        q = q.replace(s, s.replace(' ', ''))
    return q

df = df.withColumn('questionPatternModEntities', udf(replace, StringType())('questionPatternModEntities')).persist()
df.rdd.map(lambda r: len(r['questionPatternModEntities'].split(' ')) == len(r['questionTemplate'].split(' '))).reduce(and_)

True

In [6]:
at = lambda i: (lambda x: x[i])
k1 = lambda r: [r, 1]
unique = lambda rdd: sorted(rdd.distinct().collect())
count = lambda rdd: dict(rdd.map(k1).reduceByKey(add).collect())
collect = lambda rdd: np.array(rdd.collect())
flat_collect = lambda rdd: np.array(rdd.flatMap(lambda r: r).collect())

In [7]:
SEP, NIL = '{SEP}', '{NIL}'

In [8]:
idx2tok, tok2idx = pickle.load(open(f'{output_dir}/tok-vocab.pickle', 'rb'))
idx2tag, tag2idx = pickle.load(open(f'{output_dir}/tag-vocab.pickle', 'rb'))
idx2typ, typ2idx = pickle.load(open(f'{output_dir}/typ-vocab.pickle', 'rb'))
idx2attr, _ = pickle.load(open(f'{output_dir}/attr-vocab.pickle', 'rb'))
roles, _ = pickle.load(open(f'{output_dir}/role-vocab.pickle', 'rb'))

In [9]:
def find_rel(line):
    if 'FILTER' in line:
        [[src, dst, *_]] = re.findall(r'^FILTER \( ([^ ]+) != ([^ ]+) \)( .)?$', line)
        return src, '!=', dst
    else:
        [[src, typ, dst, *_]] = re.findall(r'^([^ ]+) ([^ ]+) ([^ ]+)( .)?$', line)
        return src, typ, dst

r = '(?:%s)' % '|'.join(fr'\[{role[1 : -1]}\]' for role in roles)  # TODO
# p = re.compile(fr'{r} and {r}|(?:{r} , )+and {r}')
p = re.compile('|'.join(fr'{r} and {r}|(?:{r} , )+and {r}' for r in [fr'\[{role[1 : -1]}\]' for role in roles]))
def grp_by_tag(tags):
    lens = np.array(list(map(len, tags)))
    ends = np.cumsum(lens) + np.arange(len(tags))
    starts = ends - lens

    t = ' '.join(tags)
    homo = lambda s: sum(role in s for role in roles) == 1
    matches = [m for m in re.finditer(p, t) if homo(m.group())]
    if not matches:
        grps = [[i] for i in range(len(tags))]
        return grps
    
    m_start, m_end = zip(*([m.start(), m.end()] for m in matches))
    hit = False
    grps = []
    for idx, [start, end] in enumerate(zip(starts, ends)):
        if start in m_start:
            hit = True
            grps.append([])
        if hit:
            grps[-1].append(idx)
        else:
            grps.append([idx])
        if end in m_end:
            hit = False
    
    for start, end, grp in zip(m_start, m_end, (grp for grp in grps if len(grp) > 1)):
        assert t[start : end] == ' '.join(tags[idx] for idx in grp)

    return grps

In [10]:
def _mapper(r):
    rels = list(map(find_rel, r['sparqlPatternModEntities'].split('\n')[1 : -1]))
    srcs, typs, dsts = zip(*rels)
    ents = sorted({x for x in chain(srcs, dsts) if re.match('M\d', x) or re.match('\?x\d', x)})

    tail = [SEP] + sorted(ent for ent in ents if ent.startswith('?x')) + [NIL]
    toks = r['questionPatternModEntities'].split(' ') + tail
    tags = r['questionTemplate'].split(' ') + tail
    grps = grp_by_tag(tags)

    seq = [tag2idx[tags[idx]] for idx, *_ in grps]
    mem = [[tok2idx[toks[grp[0]]]] if len(grp) == 1 else
           [tok2idx[toks[idx]] for idx in grp if tags[idx] in roles] for grp in grps]
#     print([[toks[idx] for idx in grp] for grp in grps])
#     print([[tags[idx] for idx in grp] for grp in grps])
#     print(mem)
    
    ent2grp = {}
    for idx, tok in zip(chain(*(len(grp) * [idx] for idx, grp in enumerate(grps))), toks):
        if tok in ents:
            ent2grp[tok] = idx
    idx2grp = sorted(set(ent2grp.values()))
    ent2idx = {ent: idx2grp.index(ent2grp[ent]) for ent in ents}
    _, idx2ent = zip(*sorted(ent2idx.items(), key=at(1)))
    
    # filters
    filters = [[ent2idx[src], ent2idx[dst]] for src, typ, dst in rels if typ == '!=']
    
    # attributes
    ent2attr = np.zeros([len(idx2grp), len(idx2attr)])
    for src, _, dst in rels:
        if dst in idx2attr:
            ent2attr[ent2idx[src], idx2attr.index(dst)] = 1

    # groundable relations
    gr_rels = [[ent2idx[src], ent2idx[dst], idx2typ.index(typ)] for src, typ, dst in rels if typ in idx2typ]

    return filters, ent2attr, gr_rels, seq, mem, idx2grp

rdd = df.rdd.map(_mapper).cache()

dat = {}
dat['n_filter'] = collect(rdd.map(at(0)).map(len))
dat['filter'] = collect(rdd.flatMap(at(0)))
dat['attr'] = np.vstack(rdd.map(at(1)).collect())
dat['n_rel'] = collect(rdd.map(at(2)).map(len))
dat['src'] = collect(rdd.flatMap(at(2)).map(at(0)))
dat['dst'] = collect(rdd.flatMap(at(2)).map(at(1)))
dat['typ'] = collect(rdd.flatMap(at(2)).map(at(2)))
dat['seq'] = collect(rdd.flatMap(at(3)))
dat['n_grp'] = collect(rdd.map(at(4)).map(len))
dat['n_mem'] = collect(rdd.flatMap(at(4)).map(len))
dat['mem'] = collect(rdd.flatMap(at(4)).flatMap(lambda r: r))
dat['n'] = collect(rdd.map(at(5)).map(len))
dat['idx2grp'] = collect(rdd.flatMap(at(5)))

In [11]:
grammar = nltk.CFG.fromstring("""
    S -> NP
    DT -> 'a' | 'an'
    JJ -> 'ADJECTIVE_SIMPLE'
    N -> 'NP_SIMPLE' | 'ROLE_SIMPLE' | JJ N
    NP -> 'entity' | DT N | 'What' | 'What' N | 'Which' N | 'Who' | NP POS N | NP PP NP | NP 'whose' N
    POS -> "'s"
    PP -> 'of'
""")
parser = nltk.ChartParser(grammar)

def find_noun_phrases(tags):
    hits = []
    start = 0
    while start < len(tags):
        for end in range(len(tags), start, -1):
            try:
                trees = list(parser.parse(tags[start : end]))
            except ValueError:
                trees = []
            if len(trees) > 0:
                hits.append([start, end, trees])
                start = end - 1
                break
        start += 1
    return hits
            
for start, end, trees in find_noun_phrases("What".split(' ')):
# for start, end, trees in find_noun_phrases("entity 's ADJECTIVE_SIMPLE ADJECTIVE_SIMPLE ADJECTIVE_SIMPLE ROLE_SIMPLE".split(' ')):
    for tree in trees:
        tree.pretty_print()

 S  
 |   
 NP 
 |   
What



In [12]:
def _mapper(r):
    _, _, rels, seq, _, idx2grp = r
    idx2tag_ = [tag.replace('[', '').replace(']', '') for tag in idx2tag]
    templ = [idx2tag_[idx] for idx in seq]
    hits = find_noun_phrases(templ)
    idx_sep, idx_nil = templ.index(SEP), templ.index(NIL)
    if idx_sep + 1 < idx_nil:
        hits.append([idx_sep + 1, idx_nil, None])

    idx = idx_tag = 0
    starts, ends, trees = zip(*hits)
    tags = []
    pos_noun = []
    pos_tag = []
    pos_nps = []
    issep = []
    isnp = []
    while idx < len(templ):
        if idx in starts:
            where = starts.index(idx)
            tags.append(len(idx2tag_) + (trees[where] is None))
            end = ends[where]
            k = end - idx
            pos_noun.extend(list(range(k)) + [0])
            pos_tag.extend(k * [0])
            pos_nps.extend(range(len(issep), len(issep) + k))
            issep.extend(k * [False] + [True])
            isnp.extend(k * [True])
            idx = end
        else:
            tags.append(idx2tag_.index(templ[idx]))
            isnp.append(False)
            pos_nps.append(0)
            pos_tag.append(idx_tag)
            idx += 1
            
        idx_tag += 1
    
    noun_phrases = [[idx2tag_.index(tag) for tag in templ[start : end]] for start, end, _ in hits]
    positions = [idx for idx, tag in enumerate(tags) if tag >= len(idx2tag_)]

    return {'tags' : tags,
            'noun_phrases' : noun_phrases,
            'positions' : positions,
            'pos_noun' : pos_noun,
            'pos_tag' : pos_tag,
            'pos_nps' : pos_nps,
            'issep' : issep,
            'isnp' : isnp}

rdd_np = rdd.map(_mapper).cache()
dat['len_tag'] = collect(rdd_np.map(at('tags')).map(len))
dat['seq_tag'] = collect(rdd_np.flatMap(at('tags')))
dat['len_np'] = collect(rdd_np.map(at('noun_phrases')).map(len))
rdd_np_ = rdd_np.flatMap(at('noun_phrases')).cache()
dat['len_noun'] = collect(rdd_np_.map(len))
dat['seq_noun'] = flat_collect(rdd_np_)
dat['pos_noun'] = collect(rdd_np.flatMap(at('pos_noun')))
dat['pos_tag'] = collect(rdd_np.flatMap(at('pos_tag')))
dat['pos_np'] = collect(rdd_np.flatMap(at('positions')))
dat['pos_nps'] = collect(rdd_np.flatMap(at('pos_nps')))
dat['issep'] = collect(rdd_np.flatMap(at('issep')))
dat['isnp'] = collect(rdd_np.flatMap(at('isnp')))
dat['len_nps'] = collect(rdd_np.map(at('issep')).map(len))
np.savez(f'{output_dir}/data', **dat)

In [13]:
seq2templ = lambda seq: ' '.join(idx2tag[idx] for idx in takewhile(lambda idx: idx2tag[idx] != '{SEP}', seq))
def _mapper(r):
    tags = seq2templ(r).replace('[', '').replace(']', '').split(' ')
    hits = find_noun_phrases(tags)
    return all(any(start <= idx < end for start, end, _ in hits)
               for idx, tag in enumerate(tags) if tag in ['entity', 'NP_SIMPLE', 'ROLE_SIMPLE'])

rdd.map(at(3)).map(_mapper).reduce(add)

239357

In [71]:
dat['seq_tag'].max()

31

In [72]:
len(idx2tag)

30

In [65]:
dat['len_noun'].sum()

1779289

In [63]:
dat['pos_np'].shape

(627405,)

In [59]:
dat['seq_tag'].max()

31

In [60]:
len(idx2tag)

30

In [47]:
tags, noun_phrases, positions, pos_noun, pos_tag, pos_nps, issep, isnp = _mapper(rdd.take(1)[0])
tags, noun_phrases, positions, pos_noun, pos_tag, pos_nps, issep, isnp

([10, 30, 19, 30, 0, 30, 1],
 [[20, 8, 16, 18], [20], [2]],
 [1, 3, 5],
 [0, 1, 2, 3, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 2, 0, 4, 0, 6],
 [0, 0, 1, 2, 3, 0, 5, 0, 7, 0],
 [False, False, False, False, True, False, True, False, True],
 [False, True, True, True, True, False, True, False, True, False])

In [48]:
len(isnp)

10

In [50]:
len(issep)

9

In [49]:
len(pos_nps)

10

In [12]:
_mapper(rdd.take(1)[0])

here
0
1
5
6
7
8
9


([10, 30, 19, 30, 0, 2, 1],
 [[20, 8, 16, 18], [20]],
 [1, 3],
 [False, True, True, True, True, False, True, False, False, False],
 [0, 0, 1, 2, 3, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 5, 0, 7, 8, 9])

In [None]:
dat_np