In [1]:
from collections import defaultdict
import itertools
import json
from operator import *
import pickle
import re

import matplotlib.pylab as pl
import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from scipy.sparse import csr_matrix

In [2]:
input_dir = '/data/yu_gai/cfq'
output_dir = '/work/yu_gai/cfq'

In [3]:
df = sqlCtx.read.parquet(f'{input_dir}/dataset.parquet').sort('index').persist()

In [4]:
df.columns

['complexityMeasures',
 'expectedResponse',
 'expectedResponseWithMids',
 'index',
 'question',
 'questionPatternModEntities',
 'questionTemplate',
 'questionWithBrackets',
 'questionWithMids',
 'ruleIds',
 'ruleTree',
 'sparql',
 'sparqlPattern',
 'sparqlPatternModEntities']

In [6]:
split_ids = !ls {input_dir}/splits | grep json
for split_id in [s.replace('.json', '') for s in split_ids]:
    split = json.load(open(f'{input_dir}/splits/{split_id}.json'))
    np.savez(f'{output_dir}/splits/{split_id}', **{k : np.array(v) for k, v in split.items()})
    print(split_id, len(split['trainIdxs']), len(split['devIdxs']), len(split['testIdxs']), df.count())

mcd1 95743 11968 11968 239357
mcd2 95743 11968 11968 239357
mcd3 95743 11968 11968 239357
query_complexity_split 100654 9512 9512 239357
query_pattern_split 94600 12489 12589 239357
question_complexity_split 98999 10339 10340 239357
question_pattern_split 95654 12115 11909 239357
random_split 95744 11967 11967 239357


In [7]:
w = r"(:?[a-zA-Z]+|M[0-9]|'s|,)"
p = re.compile(fr'(:?{w} )+{w}')
df.rdd.map(lambda r: re.match(p, r['questionPatternModEntities']).string).zip(df.rdd.map(lambda r: r['questionPatternModEntities'])).map(lambda r: r[0] == r[1]).reduce(and_)

True

In [8]:
def replace(q):
    for s in [
        'art director',
        'country of nationality',
        'costume designer',
        'executive producer',
        'executive produce',
        'executive produced',
        'film director',
        'film distributor',
        'film editor',
        'film producer',
        'production company',
    ]:
        q = q.replace(s, s.replace(' ', ''))
    return q

df = df.withColumn('questionPatternModEntities', udf(replace, StringType())('questionPatternModEntities')).persist()
df.rdd.map(lambda r: len(r['questionPatternModEntities'].split(' ')) == len(r['questionTemplate'].split(' '))).reduce(and_)

True

In [9]:
roles = df.rdd.flatMap(lambda r: re.findall(r'\[[^\]]+\]', r['questionTemplate'])).distinct().collect()
_

True

In [10]:
r = '(?:%s)' % '|'.join(fr'\[{role[1 : -1]}\]' for role in roles)
p = re.compile(fr'{r} and {r}|(?:{r} , )+and {r}')

def find_grps(t):
    toks = t.split(' ')
    lens = np.array(list(map(len, toks)))
    ends = np.cumsum(lens) + np.arange(len(toks))
    starts = ends - lens

    if re.search(p, t) is None:
        grps = [[i] for i in range(len(toks))]
        return grps, grps
    
    m_start, m_end = zip(*([m.start(), m.end()] for m in re.finditer(p, t)))
    hit = False
    grps = []
    for idx, [start, end] in enumerate(zip(starts, ends)):
        if start in m_start:
            hit = True
            grps.append([])
        if hit:
            grps[-1].append(idx)
        else:
            grps.append([idx])
        if end in m_end:
            hit = False
    
    for start, end, grp in zip(m_start, m_end, (grp for grp in grps if len(grp) > 1)):
        assert t[start : end] == ' '.join(toks[idx] for idx in grp)
    
    grp2pos = [grp if len(grp) == 1 else [idx for idx in grp if toks[idx] in roles] for grp in grps]

    return grps, grp2pos

In [11]:
d = {}

In [12]:
get = lambda rdd, i: rdd.map(lambda r: r[i])
collect = lambda rdd: np.array(rdd.collect())
fcollect = lambda rdd: np.array(rdd.flatMap(lambda r: r).collect())

In [13]:
SEP = '[SEP]'
PAD = '[PAD]'
def mapper(r):
    toks = r['questionPatternModEntities'].split(' ')
    entities = sorted(set(re.findall(r'M[0-0]', r['questionPatternModEntities'])))
    variables = sorted(set(re.findall(r'\?x[0-9]', r['sparqlPatternModEntities'])))
    concepts = []
    for line in r['sparqlPatternModEntities'].split('\n')[1 : -1]:
        if 'FILTER' not in line:
            [[concept, *_]] = re.findall(r'^[^ ]+ [^ ]+ ([^ ]+)( .)?$', line)
            if concept.startswith('ns:'):
                concepts.append(concept)
    concepts = sorted(set(concepts))
    tail = [SEP] + concepts + [SEP] + variables
    seq = toks + tail
    isconcept = len(toks) * [False] + len(concepts) * [True] + len(variables) * [False]
    isvariable = len(toks) * [False] + len(concepts) * [False] + len(variables) * [True]    
    grps, grp2pos = find_grps(r['questionTemplate'] + ' ' + ' '.join(tail))
    pos2grp = list(itertools.chain(len(grp) * [idx] for idx, grp in enumerate(grps)))
    return seq, isconcept, isvariable, grp2pos, pos2grp

rdd = df.rdd.map(mapper).cache()
seq_rdd = get(rdd, 0).cache()
idx2tok = sorted(seq_rdd.flatMap(lambda r: r).distinct().collect() + [PAD])
tok2idx = dict(map(reversed, enumerate(idx2tok)))

d['n_tok'] = collect(seq_rdd.map(len))
d['seq'] = fcollect(seq_rdd.map(lambda r: [tok2idx[tok] for tok in r]))
d['isconcept'], d['isvariable'] = fcollect(get(rdd, 1)), fcollect(get(rdd, 2))
pos2grp = get(rdd, 3).cache()
d['n_grp'] = collect(grp2pos.map(len))
d['pos2grp'] = collect(get(rdd, 4))

In [14]:
entities = sorted(tok for tok in idx2tok if re.match(r'^M[0-9]$', tok))
concepts = sorted(tok for tok in idx2tok if tok.startswith('ns:'))
variables = sorted(tok for tok in idx2tok if re.match(r'^\?x[0-9]$', tok))
sp_toks = sc.broadcast(set(entities + concepts + variables))
def mapper(r):
    d = defaultdict(list)
    for i, tok in enumerate(r):
        if tok in sp_toks.value:
            d[tok].append(i)

    n = len(d)
    tok = [tok2idx[tok] for tok in sorted(d)]
    n_idx = [len(d[k]) for k in sorted(d)]
    idx = sum((d[k] for k in sorted(d)), [])

    return n, tok, n_idx, idx

rdd = seq_rdd.map(mapper).cache()
d['n'] = collect(get(rdd, 0))
d['tok'] = fcollect(get(rdd, 1))
d['n_idx'] = fcollect(get(rdd, 2))
d['idx'] = fcollect(get(rdd, 3))

In [15]:
def mapper_rel(r):
    src, rel, dst = [], [], []
    for line in r['sparqlPatternModEntities'].split('\n')[1 : -1]:
        if 'FILTER' in line:
            [[src_, dst_, *_]] = re.findall(r'^FILTER \( ([^ ]+) != ([^ ]+) \)( .)?$', line)
            src.append(src_)
            rel.append('!=')
            dst.append(dst_)
        else:
            [[src_, rel_, dst_, *_]] = re.findall(r'^([^ ]+) ([^ ]+) ([^ ]+)( .)?$', line)
            src.append(src_)
            rel.append(rel_)
            dst.append(dst_)

    u, inv = np.unique(src + dst, return_inverse=True)
    src, dst = np.split(np.arange(len(u))[inv], 2)
    return src, rel, dst

rdd = df.rdd.map(mapper_rel).cache()
d['src'], d['dst'] = fcollect(get(rdd, 0)), fcollect(get(rdd, 2))
rel_rdd = get(rdd, 1).cache()
d['m'] = collect(rel_rdd.map(len))
rel_rdd = rel_rdd.flatMap(lambda r: r).cache()
idx2rel = sorted(rel_rdd.distinct().collect())
rel2idx = {rel : idx for idx, rel in enumerate(idx2rel)}
d['rel'] = collect(rel_rdd.map(lambda r: rel2idx[r]))

In [16]:
pickle.dump([idx2tok, tok2idx], open(f'{output_dir}/tok-vocab.pickle', 'wb'))
pickle.dump([idx2rel, rel2idx], open(f'{output_dir}/rel-vocab.pickle', 'wb'))
np.savez(f'{output_dir}/data', **d)

### Maximum multiplicity

In [17]:
def mapper(r):
    src, _, dst = r
    _, c = np.unique(np.vstack([src, dst]), return_counts=True, axis=1)
    return c.max()

rdd.map(mapper).reduce(max)

9

## Variable prediction

In [18]:
d = {}
tok_rdd = df.rdd.map(lambda r: [tok2idx[tok] for tok in r['questionPatternModEntities'].split(' ')])
d['seq'] = fcollect(tok_rdd)
d['n_tok'] = collect(tok_rdd.map(len))
d['n_var'] = collect(df.rdd.map(lambda r: len(set(re.findall(r'\?x[0-9]', r['sparqlPatternModEntities'])))))
np.savez(f'{output_dir}/nvar', **d)

## Concept prediction

In [19]:
def mapper(r):
    concepts = []
    for line in r['sparqlPatternModEntities'].split('\n')[1 : -1]:
        if 'FILTER' not in line:
            [[concept, *_]] = re.findall(r'^[^ ]+ [^ ]+ ([^ ]+)( .)?$', line)
            if concept.startswith('ns:'):
                concepts.append(concept)
    return set(concepts)

In [20]:
d = {}
tok_rdd = df.rdd.map(lambda r: [tok2idx[tok] for tok in r['questionPatternModEntities'].split(' ')])
d['seq'] = fcollect(tok_rdd)
d['n_tok'] = collect(tok_rdd.map(len))

con = df.rdd.map(mapper).cache()
d['n_con'] = collect(con.map(len))
idx2con = sorted(con.flatMap(lambda r: r).distinct().collect())
con2idx = {con : idx for idx, con in enumerate(idx2con)}
d['con'] = fcollect(con.map(lambda r: sorted(con2idx[con] for con in r)))

np.savez(f'{output_dir}/concept', **d)