In [2]:
from collections import defaultdict
from itertools import *
import json
from operator import *
import pickle
import re

import matplotlib.pylab as pl
import nltk
import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from scipy.sparse import csr_matrix

In [3]:
input_dir = '/data/yu_gai/cfq'
output_dir = '/work/yu_gai/cfq/data/cfq'

df = sqlCtx.read.parquet(f'{input_dir}/dataset.parquet').sort('index').persist()
df.columns

['complexityMeasures',
 'expectedResponse',
 'expectedResponseWithMids',
 'index',
 'question',
 'questionPatternModEntities',
 'questionTemplate',
 'questionWithBrackets',
 'questionWithMids',
 'ruleIds',
 'ruleTree',
 'sparql',
 'sparqlPattern',
 'sparqlPatternModEntities']

In [4]:
splits = {}
split_ids = !ls {input_dir}/splits | grep json
for split_id in [s.replace('.json', '') for s in split_ids]:
    split = splits[split_id] = json.load(open(f'{input_dir}/splits/{split_id}.json'))
    np.savez(f'{output_dir}/splits/{split_id}', **{k : np.array(v) for k, v in split.items()})
    print(split_id, len(split['trainIdxs']), len(split['devIdxs']), len(split['testIdxs']), df.count())

mcd1 95743 11968 11968 239357
mcd2 95743 11968 11968 239357
mcd3 95743 11968 11968 239357
query_complexity_split 100654 9512 9512 239357
query_pattern_split 94600 12489 12589 239357
question_complexity_split 98999 10339 10340 239357
question_pattern_split 95654 12115 11909 239357
random_split 95744 11967 11967 239357


In [5]:
def replace(q):
    for s in [
        'art director',
        'country of nationality',
        'costume designer',
        'executive producer',
        'executive produce',
        'executive produced',
        'film director',
        'film distributor',
        'film editor',
        'film producer',
        'production company',
    ]:
        q = q.replace(s, s.replace(' ', ''))
    return q

df = df.withColumn('questionPatternModEntities', udf(replace, StringType())('questionPatternModEntities')).persist()
df.rdd.map(lambda r: len(r['questionPatternModEntities'].split(' ')) == len(r['questionTemplate'].split(' '))).reduce(and_)

True

In [6]:
at = lambda i: (lambda x: x[i])
k1 = lambda r: [r, 1]
unique = lambda rdd: sorted(rdd.distinct().collect())
count = lambda rdd: dict(rdd.map(k1).reduceByKey(add).collect())
collect = lambda rdd: np.array(rdd.collect())
flat_collect = lambda rdd: np.array(rdd.flatMap(lambda r: r).collect())

In [7]:
SEP, NIL = '{SEP}', '{NIL}'

In [8]:
tok_vocab = idx2tok, tok2idx = pickle.load(open(f'{output_dir}/tok-vocab.pickle', 'rb'))
tag_vocab = idx2tag, tag2idx = pickle.load(open(f'{output_dir}/tag-vocab.pickle', 'rb'))
typ_vocab = idx2typ, typ2idx = pickle.load(open(f'{output_dir}/typ-vocab.pickle', 'rb'))
idx2attr, _ = pickle.load(open(f'{output_dir}/attr-vocab.pickle', 'rb'))
roles, _ = pickle.load(open(f'{output_dir}/role-vocab.pickle', 'rb'))

In [9]:
def find_rel(line):
    if 'FILTER' in line:
        [[src, dst, *_]] = re.findall(r'^FILTER \( ([^ ]+) != ([^ ]+) \)( .)?$', line)
        return src, '!=', dst
    else:
        [[src, typ, dst, *_]] = re.findall(r'^([^ ]+) ([^ ]+) ([^ ]+)( .)?$', line)
        return src, typ, dst

r = '(?:%s)' % '|'.join(fr'\[{role[1 : -1]}\]' for role in roles)  # TODO
# p = re.compile(fr'{r} and {r}|(?:{r} , )+and {r}')
p = re.compile('|'.join(fr'{r} and {r}|(?:{r} , )+and {r}' for r in [fr'\[{role[1 : -1]}\]' for role in roles]))
def grp_by_tag(tags):
    lens = np.array(list(map(len, tags)))
    ends = np.cumsum(lens) + np.arange(len(tags))
    starts = ends - lens

    t = ' '.join(tags)
    homo = lambda s: sum(role in s for role in roles) == 1
    matches = [m for m in re.finditer(p, t) if homo(m.group())]
    if not matches:
        grps = [[i] for i in range(len(tags))]
        return grps
    
    m_start, m_end = zip(*([m.start(), m.end()] for m in matches))
    hit = False
    grps = []
    for idx, [start, end] in enumerate(zip(starts, ends)):
        if start in m_start:
            hit = True
            grps.append([])
        if hit:
            grps[-1].append(idx)
        else:
            grps.append([idx])
        if end in m_end:
            hit = False
    
    for start, end, grp in zip(m_start, m_end, (grp for grp in grps if len(grp) > 1)):
        assert t[start : end] == ' '.join(tags[idx] for idx in grp)

    return grps

In [10]:
def _mapper(r):
    rels = list(map(find_rel, r['sparqlPatternModEntities'].split('\n')[1 : -1]))
    srcs, typs, dsts = zip(*rels)
    ents = sorted({x for x in chain(srcs, dsts) if re.match('M\d', x)})

    toks = r['questionPatternModEntities'].split(' ')
    tags = r['questionTemplate'].split(' ')
    grps = grp_by_tag(tags)

    return ' '.join(tags[idx] for idx, *_ in grps)

df.rdd.map(_mapper).distinct().count()

18378

def _mapper(r):
    rels = list(map(find_rel, r['sparqlPatternModEntities'].split('\n')[1 : -1]))
    srcs, typs, dsts = zip(*rels)
    ents = sorted({x for x in chain(srcs, dsts) if re.match('M\d', x)})

    toks = r['questionPatternModEntities'].split(' ')
    tags = r['questionTemplate'].split(' ')
    grps = grp_by_tag(tags)
    
    templ = [tags[idx].replace('[', '').replace(']', '') for idx, *_ in grps]
    starts, ends, _ = zip(*find_noun_phrases(templ))
    idx, seq = 0, []
    while idx < len(grps):
        if idx in starts:
            seq.append('NP_COMPLEX')
            idx = ends[starts.index(idx)]
        else:
            seq.append(templ[idx])
            idx += 1
    
    return r['index'], ' '.join(seq)

templs = df.rdd.map(_mapper).cache()
uniq_templs1, uniq_templs2, uniq_templs3 = [{k : set(templs.filter(lambda r: r[0] in v).map(at(1)).collect())
                                             for k, v in splits[f'mcd{i}'].items()} for i in [1, 2, 3]]

In [12]:
def _mapper(r):
    dat = {}
    
    rels = list(map(find_rel, r['sparqlPatternModEntities'].split('\n')[1 : -1]))
    srcs, typs, dsts = zip(*rels)
    ents = sorted({x for x in chain(srcs, dsts) if re.match('M\d', x) or re.match('\?x\d', x)})
    seq_var = sorted(ent for ent in ents if ent.startswith('?x'))

    toks = r['questionPatternModEntities'].split(' ') + [NIL] + seq_var
    tags = r['questionTemplate'].split(' ') + [NIL] + seq_var
    grps = grp_by_tag(tags)

    dat['seq'] = [tag2idx[tags[idx]] for idx, *_ in grps]
    dat['mem'] = [[tok2idx[toks[grp[0]]]] if len(grp) == 1 else
                  [tok2idx[toks[idx]] for idx in grp if tags[idx] in roles] for grp in grps if not toks[grp[0]].startswith('?x')]
    
    ent2grp = {}
    for idx, tok in zip(chain(*(len(grp) * [idx] for idx, grp in enumerate(grps))), toks):
        if tok in ents:
            ent2grp[tok] = idx
    idx2grp = dat['idx2grp'] = sorted(set(ent2grp.values()))
    ent2idx = {ent : idx2grp.index(ent2grp[ent]) for ent in ents}
    
    dat['rel'] = [[ent2idx[src], ent2idx[dst], idx2typ.index(typ)] for src, typ, dst in rels if typ in idx2typ]

    return dat

dat, rdd = {}, df.rdd.map(_mapper).cache()
dat['n_rel'] = collect(rdd.map(at('rel')).map(len))
rdd_rel = rdd.flatMap(at('rel')).cache()
dat['src'], dat['dst'], dat['typ'] = [collect(rdd_rel.map(at(i))) for i in range(3)]
dat['n_grp'] = collect(rdd.map(at('mem')).map(len))
rdd_mem = rdd.flatMap(at('mem')).cache()
dat['n_mem'], dat['mem'] = collect(rdd_mem.map(len)), flat_collect(rdd_mem)
dat['n'] = collect(rdd.map(at('idx2grp')).map(len))
dat['idx2grp'] = collect(rdd.flatMap(at('idx2grp')))

In [13]:
grammar = nltk.CFG.fromstring("""
    S -> NP
    DT -> 'a' | 'an'
    JJ -> 'ADJECTIVE_SIMPLE'
    N -> 'NP_SIMPLE' | 'ROLE_SIMPLE' | JJ N
    NP -> 'entity' | DT N | 'What' | 'What' N | 'Which' N | 'Who' | NP POS N | NP PP NP | NP 'whose' N
    POS -> "'s"
    PP -> 'of'
""")
parser = nltk.ChartParser(grammar)

def find_noun_phrases(tags):
    hits = []
    start = 0
    while start < len(tags):
        for end in range(len(tags), start, -1):
            try:
                trees = list(parser.parse(tags[start : end]))
            except ValueError:
                trees = []
            if len(trees) > 0:
                hits.append([start, end, trees])
                start = end - 1
                break
        start += 1
    return hits
            
for start, end, trees in find_noun_phrases("a ADJECTIVE_SIMPLE NP_SIMPLE of entity".split(' ')):
    print(start, end)
    for tree in trees:
        tree.pretty_print()

0 5
                             S               
                             |                
                             NP              
            _________________|___________     
           NP                       |    |   
  _________|__________              |    |    
 |                    N             |    |   
 |          __________|______       |    |    
 DT        JJ                N      PP   NP  
 |         |                 |      |    |    
 a  ADJECTIVE_SIMPLE     NP_SIMPLE  of entity



In [21]:
traverse(tree)

(['S',
  'NP',
  'NP',
  'DT',
  'a',
  'N',
  'JJ',
  'ADJECTIVE_SIMPLE',
  'N',
  'NP_SIMPLE',
  'PP',
  'of',
  'NP',
  'entity'],
 [False,
  False,
  False,
  False,
  True,
  False,
  False,
  True,
  False,
  True,
  False,
  True,
  False,
  True],
 [[0, 1],
  [1, 2],
  [2, 3],
  [3, 4],
  [2, 5],
  [5, 6],
  [6, 7],
  [5, 8],
  [8, 9],
  [1, 10],
  [10, 11],
  [1, 12],
  [12, 13]])

In [17]:
def _mapper(r):
    dat = {}
    
    idx2tag_ = [tag.replace('[', '').replace(']', '') for tag in idx2tag]
    templ = [idx2tag_[idx] for idx in takewhile(lambda idx: not idx2tag[idx].startswith('?x'), r['seq'])]
    hits = dat['hits'] = find_noun_phrases(templ)

    dat['seq_noun'] = [r['seq'][start : end] for start, end, _ in hits]
    
    idx = 0
    starts, ends, trees = zip(*hits)
    seq_tag = dat['seq_tag'] = []
    pos_np = dat['pos_np'] = []
    pos_all = dat['pos_all'] = []
    istag, isnoun, isvar = dat['istag'], dat['isnoun'], dat['isvar'] = [], [], []
    while idx < len(templ):
        if idx in starts:
            start, end, _ = hits[starts.index(idx)]
            disp = end - start

            pos_np.append(len(seq_tag))
            seq_tag.append(len(idx2tag_))
            
            pos_all.extend(range(disp))
            istag.extend(disp * [False])
            isnoun.extend(disp * [True])
            isvar.extend(disp * [False])
            
            idx = end
        else:
            seq_tag.append(idx2tag_.index(templ[idx]))
            
            pos_all.append(len(seq_tag) - 1)
            istag.append(True)
            isnoun.append(False)
            isvar.append(False)
            
            idx += 1
            
    seq_var = dat['seq_var'] = r['seq'][len(templ):] + [idx2tag.index('{NIL}')]   
    pos_all.extend(range(len(seq_var)))
    istag.extend(len(seq_var) * [False])
    isnoun.extend(len(seq_var) * [False])
    isvar.extend(len(seq_var) * [True])

    return dat

rdd_np = rdd.map(_mapper).cache()
dat['len_tag'] = collect(rdd_np.map(at('seq_tag')).map(len))
dat['seq_tag'] = collect(rdd_np.flatMap(at('seq_tag')))
rdd_noun = rdd_np.flatMap(at('seq_noun')).cache()
dat['len_noun'] = collect(rdd_noun.map(len))
dat['seq_noun'] = flat_collect(rdd_noun)
dat['len_var'] = collect(rdd_np.map(at('seq_var')).map(len))
dat['seq_var'] = flat_collect(rdd_np.map(at('seq_var')))
dat['n_np'] = collect(rdd_np.map(at('pos_np')).map(len))
dat['pos_np'] = collect(rdd_np.flatMap(at('pos_np')))
dat['n_all'] = collect(rdd_np.map(at('pos_all')).map(len))
dat['pos_all'] = collect(rdd_np.flatMap(at('pos_all')))
dat['istag'] = collect(rdd_np.flatMap(at('istag')))
dat['isnoun'] = collect(rdd_np.flatMap(at('isnoun')))
dat['isvar'] = collect(rdd_np.flatMap(at('isvar')))
np.savez(f'{output_dir}/data', **dat)

In [15]:
uniq_lhs = set(str(p.lhs()) for p in grammar.productions())
uniq_rhs = set(chain(*(map(str, p.rhs()) for p in grammar.productions())))
idx2symb = sorted(uniq_lhs.union(uniq_rhs))
symb2idx = {symb : idx for idx, symb in enumerate(idx2symb)}
pickle.dump([idx2symb, symb2idx], open(f'{output_dir}/symb-vocab.pickle', 'wb'))
idx2symb, symb2idx

(["'s",
  'ADJECTIVE_SIMPLE',
  'DT',
  'JJ',
  'N',
  'NP',
  'NP_SIMPLE',
  'POS',
  'PP',
  'ROLE_SIMPLE',
  'S',
  'What',
  'Which',
  'Who',
  'a',
  'an',
  'entity',
  'of',
  'whose'],
 {"'s": 0,
  'ADJECTIVE_SIMPLE': 1,
  'DT': 2,
  'JJ': 3,
  'N': 4,
  'NP': 5,
  'NP_SIMPLE': 6,
  'POS': 7,
  'PP': 8,
  'ROLE_SIMPLE': 9,
  'S': 10,
  'What': 11,
  'Which': 12,
  'Who': 13,
  'a': 14,
  'an': 15,
  'entity': 16,
  'of': 17,
  'whose': 18})

In [33]:
def traverse(tree, nid=0):
    labels, isleaf, edges = [tree.label()], [False], []
    subtree_nid = nid + 1
    for subtree in tree:
        edges.append([nid, subtree_nid])
        if isinstance(subtree, nltk.tree.Tree):
            subtree_labels, subtree_isleaf, subtree_edges = traverse(subtree, subtree_nid)
            subtree_nid += len(subtree_labels)
            labels.extend(subtree_labels)
            isleaf.extend(subtree_isleaf)
            edges.extend(subtree_edges)
        elif isinstance(subtree, str):
            subtree_nid += 1
            labels.append(subtree)
            isleaf.append(True)
        else:
            raise TypeError()
            
    return labels, isleaf, edges

def _mapper(r):
    dat = {}
    
    dat['symb'], dat['isleaf'], dat['src'], dat['dst'] = [], [], [], []
    for _, _, trees in r['hits']:
        labels, isleaf, edges = traverse(trees[0])
        dat['symb'].append([symb2idx[symb] for symb in labels])
        dat['isleaf'].append(isleaf)
        src, dst = zip(*edges)
        dat['src'].append(src)
        dat['dst'].append(dst)
        
    return dat

rdd_tree = rdd_np.map(_mapper).cache()
dat['n_tree'] = collect(rdd_tree.flatMap(at('symb')).map(len))
dat['symb'] = flat_collect(rdd_tree.flatMap(at('symb')))
dat['isleaf'] = flat_collect(rdd_tree.flatMap(at('isleaf')))
dat['m_tree'] = collect(rdd_tree.flatMap(at('src')).map(len))
dat['src_tree'] = flat_collect(rdd_tree.flatMap(at('src')))
dat['dst_tree'] = flat_collect(rdd_tree.flatMap(at('dst')))
np.savez(f'{output_dir}/data', **dat)

In [34]:
rdd_tree.zip(rdd_np).filter(lambda r: any(sum(isleaf) != len(seq_noun) for isleaf, seq_noun in zip(r[0]['isleaf'], r[1]['seq_noun']))).take(1)

[]

In [28]:
count(rdd_tree.flatMap(lambda r: [len(symb) - max(dst) for symb, dst in zip(r['symb'], r['dst'])]))

{1: 624095}

In [29]:
count(rdd_tree.flatMap(lambda r: [len(symb) - max(src) for symb, src in zip(r['symb'], r['src'])]))

{2: 624095}

In [22]:
dat['m_tree'].shape

(239357,)

In [76]:
dat

{'len_var': array([2, 2, 3, ..., 2, 2, 3]),
 'seq_var': array([2, 1, 2, ..., 2, 3, 1]),
 'n_rel': array([ 3,  7,  2, ..., 12, 12,  6]),
 'src': array([2, 2, 2, ..., 1, 1, 2]),
 'dst': array([0, 1, 1, ..., 2, 2, 0]),
 'typ': array([ 3,  7, 24, ..., 24, 25, 17]),
 'n_grp': array([ 9,  8, 11, ...,  5,  5,  9]),
 'n_mem': array([1, 1, 1, ..., 1, 1, 1]),
 'mem': array([14, 21,  8, ...,  1,  2,  3]),
 'n': array([3, 3, 3, ..., 2, 2, 3]),
 'idx2grp': array([1, 6, 8, ..., 5, 7, 8]),
 'len_tag': array([5, 5, 5, ..., 4, 4, 4]),
 'seq_tag': array([10, 30, 19, ..., 19, 30,  1]),
 'len_noun': array([4, 1, 3, ..., 1, 1, 4]),
 'seq_noun': array([20,  8, 16, ..., 18, 25, 20]),
 'n_np': array([2, 2, 2, ..., 2, 2, 2]),
 'pos_np': array([1, 3, 1, ..., 2, 0, 2]),
 'n_all': array([ 9,  8, 11, ...,  5,  5,  9]),
 'pos_all': array([0, 0, 1, ..., 6, 0, 1]),
 'istag': array([ True, False, False, ...,  True, False, False]),
 'isnoun': array([False,  True,  True, ..., False, False, False]),
 'isvar': array([Fals

In [13]:
seq2templ = lambda seq: ' '.join(idx2tag[idx] for idx in takewhile(lambda idx: idx2tag[idx] != '{SEP}', seq))
def _mapper(r):
    tags = seq2templ(r).replace('[', '').replace(']', '').split(' ')
    hits = find_noun_phrases(tags)
    return all(any(start <= idx < end for start, end, _ in hits)
               for idx, tag in enumerate(tags) if tag in ['entity', 'NP_SIMPLE', 'ROLE_SIMPLE'])

rdd.map(at(3)).map(_mapper).reduce(add)

239357

In [79]:
import sys
sys.path.append('../cfq')

from data import RaggedArray, CFQDataset

In [82]:
dataset = CFQDataset(splits['mcd1']['trainIdxs'], dat, tok_vocab, tag_vocab, typ_vocab)

In [89]:
dat = dataset[0]
' '.join(idx2tag[idx] if idx < len(idx2tag) else '[NP_COMPLEX]' for idx in dat['seq_tag'])

'Did [NP_COMPLEX] [VP_SIMPLE] [NP_COMPLEX] , [VP_SIMPLE] [NP_COMPLEX] , and [VP_SIMPLE] [NP_COMPLEX] {NIL}'

In [98]:
dat['pos_all'][dat['isvar']]

array([0, 1, 2])

In [100]:
dat['pos_all'][dat['isnoun']]

array([0, 1, 0, 1, 0, 0])

In [99]:
dat['pos_all'][dat['istag']]

array([ 0,  2,  4,  5,  7,  8,  9, 11])

In [105]:
seq = []
for pos, istag, isnoun, isvar in zip(dat['pos_all'], dat['istag'], dat['isnoun'], dat['isvar']):
    if istag:
        seq.append(idx2tag[dat['seq_tag'][pos]])
    elif isnoun:
        seq.append('')
    elif isvar:
        seq.append(idx2tag[dat['seq_var'][pos]])
    else:
        raise RuntimeError()
        
' '.join(seq)

'Did   [VP_SIMPLE]   , [VP_SIMPLE]  , and [VP_SIMPLE]  {NIL} ?x0 ?x1 {NIL}'

In [104]:
dat['seq_noun']

[array([21, 17]), array([21, 17]), array([20]), array([20])]

In [106]:
dat['pos_all'][dat['isnoun']]

array([0, 1, 0, 1, 0, 0])

In [92]:
dat['seq_noun']

[array([21, 17]), array([21, 17]), array([20]), array([20])]

In [90]:
templs = df.rdd.map(lambda r: r['questionTemplate']).collect()

In [91]:
templs[dat['index']]

'Did a [NP_SIMPLE] [VP_SIMPLE] a [NP_SIMPLE] , [VP_SIMPLE] [entity] , and [VP_SIMPLE] [entity]'

In [93]:
dat['src']

array([ 8, 12, 14, 15,  8, 12, 14, 15,  8, 12, 14, 15,  8, 12, 14, 15])

In [96]:
dat['']

array([False,  True,  True, False,  True,  True, False, False,  True,
       False, False, False,  True, False, False, False, False])