In [92]:
from collections import defaultdict
from itertools import *
import json
from operator import *
import pickle
import re

import matplotlib.pylab as pl
import numpy as np
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from scipy.sparse import csr_matrix

In [2]:
input_dir = '/data/yu_gai/cfq'
output_dir = '/work/yu_gai/cfq/data/cfq'

df = sqlCtx.read.parquet(f'{input_dir}/dataset.parquet').sort('index').persist()
df.columns

['complexityMeasures',
 'expectedResponse',
 'expectedResponseWithMids',
 'index',
 'question',
 'questionPatternModEntities',
 'questionTemplate',
 'questionWithBrackets',
 'questionWithMids',
 'ruleIds',
 'ruleTree',
 'sparql',
 'sparqlPattern',
 'sparqlPatternModEntities']

In [127]:
splits = {}
split_ids = !ls {input_dir}/splits | grep json
for split_id in [s.replace('.json', '') for s in split_ids]:
    split = splits[split_id] = json.load(open(f'{input_dir}/splits/{split_id}.json'))
    np.savez(f'{output_dir}/splits/{split_id}', **{k : np.array(v) for k, v in split.items()})
    print(split_id, len(split['trainIdxs']), len(split['devIdxs']), len(split['testIdxs']), df.count())

mcd1 95743 11968 11968 239357
mcd2 95743 11968 11968 239357
mcd3 95743 11968 11968 239357
query_complexity_split 100654 9512 9512 239357
query_pattern_split 94600 12489 12589 239357
question_complexity_split 98999 10339 10340 239357
question_pattern_split 95654 12115 11909 239357
random_split 95744 11967 11967 239357


In [4]:
w = r"(:?[a-zA-Z]+|M[0-9]|'s|,)"
p = re.compile(fr'(:?{w} )+{w}')
df.rdd.map(lambda r: re.match(p, r['questionPatternModEntities']).string).zip(df.rdd.map(lambda r: r['questionPatternModEntities'])).map(lambda r: r[0] == r[1]).reduce(and_)

True

In [5]:
def replace(q):
    for s in [
        'art director',
        'country of nationality',
        'costume designer',
        'executive producer',
        'executive produce',
        'executive produced',
        'film director',
        'film distributor',
        'film editor',
        'film producer',
        'production company',
    ]:
        q = q.replace(s, s.replace(' ', ''))
    return q

df = df.withColumn('questionPatternModEntities', udf(replace, StringType())('questionPatternModEntities')).persist()
df.rdd.map(lambda r: len(r['questionPatternModEntities'].split(' ')) == len(r['questionTemplate'].split(' '))).reduce(and_)

True

In [24]:
def find_rel(line):
    if 'FILTER' in line:
        [[src, dst, *_]] = re.findall(r'^FILTER \( ([^ ]+) != ([^ ]+) \)( .)?$', line)
        return src, '!=', dst
    else:
        [[src, typ, dst, *_]] = re.findall(r'^([^ ]+) ([^ ]+) ([^ ]+)( .)?$', line)
        return src, typ, dst

In [33]:
rels.map(at(0)).distinct().sorted().collect()

AttributeError: 'PipelinedRDD' object has no attribute 'sorted'

In [187]:
rels = df.rdd.flatMap(lambda r: r['sparqlPatternModEntities'].split('\n')[1 : -1]).map(find_rel).cache()
srcs, typs, dsts = map(unique, [rels.map(at(0)), rels.map(at(1)), rels.map(at(2))])
srcs, typs, dsts

(['?x0',
  '?x1',
  '?x2',
  '?x3',
  '?x4',
  '?x5',
  'M0',
  'M1',
  'M2',
  'M3',
  'M4',
  'M5',
  'M6',
  'M7',
  'M8',
  'M9'],
 ['!=',
  '^ns:people.person.gender',
  '^ns:people.person.nationality',
  'a',
  'ns:business.employer.employees/ns:business.employment_tenure.person',
  'ns:film.actor.film/ns:film.performance.character',
  'ns:film.actor.film/ns:film.performance.film',
  'ns:film.cinematographer.film',
  'ns:film.director.film',
  'ns:film.editor.film',
  'ns:film.film.cinematography',
  'ns:film.film.costume_design_by',
  'ns:film.film.directed_by',
  'ns:film.film.distributors/ns:film.film_film_distributor_relationship.distributor',
  'ns:film.film.edited_by',
  'ns:film.film.executive_produced_by',
  'ns:film.film.film_art_direction_by',
  'ns:film.film.prequel',
  'ns:film.film.produced_by|ns:film.film.production_companies',
  'ns:film.film.sequel',
  'ns:film.film.starring/ns:film.performance.actor',
  'ns:film.film.written_by',
  'ns:film.film_art_director.film

In [188]:
rels.filter(lambda r: r[1].startswith('^')).take(10)

[('?x0', '^ns:people.person.nationality', '?x1'),
 ('?x1', '^ns:people.person.nationality', '?x2'),
 ('?x1', '^ns:people.person.nationality', '?x2'),
 ('?x0', '^ns:people.person.nationality', '?x1'),
 ('?x1', '^ns:people.person.nationality', 'M1'),
 ('?x0', '^ns:people.person.nationality', '?x1'),
 ('?x1', '^ns:people.person.nationality', '?x2'),
 ('?x0', '^ns:people.person.nationality', 'M0'),
 ('?x0', '^ns:people.person.nationality', '?x1'),
 ('?x0', '^ns:people.person.nationality', 'M0')]

In [251]:
idx2typ = sorted(typ for typ in typs if typ not in ['a', '!=', 'ns:people.person.gender', 'ns:people.person.nationality'])
typ2idx = {typ : idx for idx, typ in enumerate(idx2typ)}
idx2typ

['^ns:people.person.gender',
 '^ns:people.person.nationality',
 'ns:business.employer.employees/ns:business.employment_tenure.person',
 'ns:film.actor.film/ns:film.performance.character',
 'ns:film.actor.film/ns:film.performance.film',
 'ns:film.cinematographer.film',
 'ns:film.director.film',
 'ns:film.editor.film',
 'ns:film.film.cinematography',
 'ns:film.film.costume_design_by',
 'ns:film.film.directed_by',
 'ns:film.film.distributors/ns:film.film_film_distributor_relationship.distributor',
 'ns:film.film.edited_by',
 'ns:film.film.executive_produced_by',
 'ns:film.film.film_art_direction_by',
 'ns:film.film.prequel',
 'ns:film.film.produced_by|ns:film.film.production_companies',
 'ns:film.film.sequel',
 'ns:film.film.starring/ns:film.performance.actor',
 'ns:film.film.written_by',
 'ns:film.film_art_director.films_art_directed',
 'ns:film.film_costumer_designer.costume_design_for_film',
 'ns:film.film_distributor.films_distributed/ns:film.film_film_distributor_relationship.film',


In [65]:
for q, r in df.rdd.map(lambda r: [r['questionPatternModEntities'], r['sparqlPatternModEntities']]).filter(lambda r: '!=' in r[1]).take(10):
    print(q)
    print(r)
    print()

Did M1 employ a spouse of a character
SELECT count(*) WHERE {
?x0 ns:people.person.spouse_s/ns:people.marriage.spouse|ns:fictional_universe.fictional_character.married_to/ns:fictional_universe.marriage_of_fictional_characters.spouses ?x1 .
?x1 a ns:fictional_universe.fictional_character .
FILTER ( ?x0 != ?x1 ) .
M1 ns:business.employer.employees/ns:business.employment_tenure.person ?x0
}

Did M0 's writer , editor , cinematographer , producer , and director marry and influence M1
SELECT count(*) WHERE {
?x0 ns:film.cinematographer.film M0 .
?x0 ns:film.director.film M0 .
?x0 ns:film.editor.film M0 .
?x0 ns:film.producer.film|ns:film.production_company.films M0 .
?x0 ns:film.writer.film M0 .
?x0 ns:influence.influence_node.influenced M1 .
?x0 ns:people.person.spouse_s/ns:people.marriage.spouse|ns:fictional_universe.fictional_character.married_to/ns:fictional_universe.marriage_of_fictional_characters.spouses M1 .
FILTER ( ?x0 != M1 )
}

Did M0 's male actor marry M2
SELECT count(*) WHERE

In [48]:
dsts

['?x0',
 '?x1',
 '?x2',
 '?x3',
 '?x4',
 '?x5',
 'M0',
 'M1',
 'M2',
 'M3',
 'M4',
 'M5',
 'M6',
 'M7',
 'M8',
 'M9',
 'ns:business.employer',
 'ns:fictional_universe.fictional_character',
 'ns:film.actor',
 'ns:film.cinematographer',
 'ns:film.director',
 'ns:film.editor',
 'ns:film.film',
 'ns:film.film_art_director',
 'ns:film.film_costumer_designer',
 'ns:film.film_distributor',
 'ns:film.producer',
 'ns:film.production_company',
 'ns:film.writer',
 'ns:m.02zsn',
 'ns:m.0345h',
 'ns:m.03_3d',
 'ns:m.03rjj',
 'ns:m.059j2',
 'ns:m.05zppz',
 'ns:m.06mkj',
 'ns:m.07ssc',
 'ns:m.09c7w0',
 'ns:m.0b90_r',
 'ns:m.0d05w3',
 'ns:m.0d060g',
 'ns:m.0d0vqn',
 'ns:m.0f8l9c',
 'ns:people.person']

In [159]:
# categories
a = rels.filter(lambda r: r[1] == 'a').cache()
cats = unique(a.map(at(2)))
unique(a.map(at(0))), cats

(['?x0',
  '?x1',
  '?x2',
  '?x3',
  '?x4',
  '?x5',
  'M0',
  'M1',
  'M2',
  'M3',
  'M4',
  'M5',
  'M6'],
 ['ns:business.employer',
  'ns:fictional_universe.fictional_character',
  'ns:film.actor',
  'ns:film.cinematographer',
  'ns:film.director',
  'ns:film.editor',
  'ns:film.film',
  'ns:film.film_art_director',
  'ns:film.film_costumer_designer',
  'ns:film.film_distributor',
  'ns:film.producer',
  'ns:film.production_company',
  'ns:film.writer',
  'ns:people.person'])

In [160]:
# gender and nationality
for dst in dsts:
    if dst.startswith('ns:') and dst not in cats:
        print(dst, unique(rels.filter(lambda r: r[2] == dst).map(at(1))))

ns:m.02zsn ['ns:people.person.gender']
ns:m.0345h ['ns:people.person.nationality']
ns:m.03_3d ['ns:people.person.nationality']
ns:m.03rjj ['ns:people.person.nationality']
ns:m.059j2 ['ns:people.person.nationality']
ns:m.05zppz ['ns:people.person.gender']
ns:m.06mkj ['ns:people.person.nationality']
ns:m.07ssc ['ns:people.person.nationality']
ns:m.09c7w0 ['ns:people.person.nationality']
ns:m.0b90_r ['ns:people.person.nationality']
ns:m.0d05w3 ['ns:people.person.nationality']
ns:m.0d060g ['ns:people.person.nationality']
ns:m.0d0vqn ['ns:people.person.nationality']
ns:m.0f8l9c ['ns:people.person.nationality']


In [158]:
idx2attr = sorted(dst for dst in dsts if dst.startswith('ns:'))
idx2attr

['ns:business.employer',
 'ns:fictional_universe.fictional_character',
 'ns:film.actor',
 'ns:film.cinematographer',
 'ns:film.director',
 'ns:film.editor',
 'ns:film.film',
 'ns:film.film_art_director',
 'ns:film.film_costumer_designer',
 'ns:film.film_distributor',
 'ns:film.producer',
 'ns:film.production_company',
 'ns:film.writer',
 'ns:m.02zsn',
 'ns:m.0345h',
 'ns:m.03_3d',
 'ns:m.03rjj',
 'ns:m.059j2',
 'ns:m.05zppz',
 'ns:m.06mkj',
 'ns:m.07ssc',
 'ns:m.09c7w0',
 'ns:m.0b90_r',
 'ns:m.0d05w3',
 'ns:m.0d060g',
 'ns:m.0d0vqn',
 'ns:m.0f8l9c',
 'ns:people.person']

In [55]:
ne = rels.filter(lambda r: r[1] == '!=').cache()
unique(ne.map(at(0))), unique(ne.map(at(2)))

(['?x0',
  '?x1',
  '?x2',
  '?x3',
  '?x4',
  'M0',
  'M1',
  'M2',
  'M3',
  'M4',
  'M5',
  'M6',
  'M7'],
 ['?x0',
  '?x1',
  '?x2',
  '?x3',
  '?x4',
  '?x5',
  'M0',
  'M1',
  'M2',
  'M3',
  'M4',
  'M5',
  'M6',
  'M7',
  'M8'])

In [116]:
tok_vocab = unique(df.rdd.flatMap(lambda r: r['questionPatternModEntities'].split(' ')))
tok_vocab

["'s",
 ',',
 'American',
 'British',
 'Canadian',
 'Chinese',
 'Did',
 'Dutch',
 'French',
 'German',
 'Italian',
 'Japanese',
 'M0',
 'M1',
 'M2',
 'M3',
 'M4',
 'M5',
 'M6',
 'M7',
 'M8',
 'M9',
 'Mexican',
 'Spanish',
 'Swedish',
 'Was',
 'Were',
 'What',
 'Which',
 'Who',
 'a',
 'acquire',
 'acquired',
 'actor',
 'and',
 'artdirector',
 'by',
 'character',
 'child',
 'cinematographer',
 'company',
 'costumedesigner',
 'countryofnationality',
 'did',
 'direct',
 'directed',
 'director',
 'distribute',
 'distributed',
 'distributor',
 'edit',
 'edited',
 'editor',
 'employ',
 'employed',
 'employee',
 'employer',
 'executiveproduce',
 'executiveproduced',
 'executiveproducer',
 'female',
 'film',
 'filmdirector',
 'filmdistributor',
 'filmeditor',
 'filmproducer',
 'found',
 'founded',
 'founder',
 'gender',
 'influence',
 'influenced',
 'male',
 'married',
 'marry',
 'of',
 'parent',
 'person',
 'play',
 'played',
 'prequel',
 'produce',
 'produced',
 'producer',
 'productioncompan

In [146]:
# repeated occurence of entities
def _mapper(r):
    c = defaultdict(lambda: 0)
    for tok in r['questionPatternModEntities'].split(' '):
        if re.match('M\d', tok) is not None:
            c[tok] += 1
    return c.items()

def _fn(rdd):
    for [tok, c], n in sorted(count(rdd.flatMap(_mapper)).items()):
        if c > 1:
            print(tok, c, n)

_fn(df.rdd)

for k, v in splits.items():
    print(k)
    indices = set(chain(*splits['mcd1'].values()))
    _fn(df.rdd.filter(lambda r: r['index'] in indices))

M0 2 20
M1 2 17
M2 2 28
M3 2 13
M4 2 5
M5 2 1
mcd1
M0 2 6
M1 2 12
M2 2 11
M3 2 5
M4 2 1
mcd2
M0 2 6
M1 2 12
M2 2 11
M3 2 5
M4 2 1
mcd3
M0 2 6
M1 2 12
M2 2 11
M3 2 5
M4 2 1
query_complexity_split
M0 2 6
M1 2 12
M2 2 11
M3 2 5
M4 2 1
query_pattern_split
M0 2 6
M1 2 12
M2 2 11
M3 2 5
M4 2 1
question_complexity_split
M0 2 6
M1 2 12
M2 2 11
M3 2 5
M4 2 1
question_pattern_split
M0 2 6
M1 2 12
M2 2 11
M3 2 5
M4 2 1
random_split
M0 2 6
M1 2 12
M2 2 11
M3 2 5
M4 2 1


In [70]:
def _mapper(r):
    rels = list(map(find_rel, r['sparqlPatternModEntities'].split('\n')[1 : -1]))
    ends = set((src, dst) for src, typ, dst in rels if typ != '!=')
    return all((src, dst) in ends for src, typ, dst in rels if typ == '!=')

df.rdd.map(_mapper).reduce(and_)

In [47]:
sorted(rels.map(at(1)).map(k1).reduceByKey(add).collect(), key=at(1))

[('ns:film.film.costume_design_by', 56),
 ('ns:film.film.cinematography', 99),
 ('ns:film.film.film_art_direction_by', 101),
 ('^ns:people.person.gender', 1088),
 ('^ns:people.person.nationality', 2171),
 ('ns:film.film_distributor.films_distributed/ns:film.film_film_distributor_relationship.film',
  5210),
 ('ns:organization.organization.companies_acquired/ns:business.acquisition.company_acquired',
  5814),
 ('ns:film.film.prequel', 6337),
 ('ns:film.film.sequel', 6906),
 ('ns:organization.organization.acquired_by/ns:business.acquisition.acquiring_company',
  7584),
 ('ns:film.film.distributors/ns:film.film_film_distributor_relationship.distributor',
  9712),
 ('ns:film.film.starring/ns:film.performance.actor', 11668),
 ('ns:people.person.parents|ns:fictional_universe.fictional_character.parents|ns:organization.organization.parent/ns:organization.organization_relationship.parent',
  15451),
 ('ns:people.person.children|ns:fictional_universe.fictional_character.children|ns:organization

In [31]:
dsts

['?x0',
 '?x1',
 '?x2',
 '?x3',
 '?x4',
 '?x5',
 'M0',
 'M1',
 'M2',
 'M3',
 'M4',
 'M5',
 'M6',
 'M7',
 'M8',
 'M9',
 'ns:business.employer',
 'ns:fictional_universe.fictional_character',
 'ns:film.actor',
 'ns:film.cinematographer',
 'ns:film.director',
 'ns:film.editor',
 'ns:film.film',
 'ns:film.film_art_director',
 'ns:film.film_costumer_designer',
 'ns:film.film_distributor',
 'ns:film.producer',
 'ns:film.production_company',
 'ns:film.writer',
 'ns:m.02zsn',
 'ns:m.0345h',
 'ns:m.03_3d',
 'ns:m.03rjj',
 'ns:m.059j2',
 'ns:m.05zppz',
 'ns:m.06mkj',
 'ns:m.07ssc',
 'ns:m.09c7w0',
 'ns:m.0b90_r',
 'ns:m.0d05w3',
 'ns:m.0d060g',
 'ns:m.0d0vqn',
 'ns:m.0f8l9c',
 'ns:people.person']

In [141]:
at = lambda i: (lambda x: x[i])
k1 = lambda r: [r, 1]
unique = lambda rdd: sorted(rdd.distinct().collect())
count = lambda rdd: dict(rdd.map(k1).reduceByKey(add).collect())

In [81]:
roles = df.rdd.flatMap(lambda r: re.findall(r'\[[^\]]+\]', r['questionTemplate'])).distinct().collect()
roles

['[NP_SIMPLE]',
 '[entity]',
 '[ADJECTIVE_SIMPLE]',
 '[VP_SIMPLE]',
 '[ROLE_SIMPLE]']

In [84]:
def _mapper(r):
    d = defaultdict(set)
    for x, y in zip(r['questionPatternModEntities'].split(' '), r['questionTemplate'].split(' ')):
        if y in roles:
            d[y].add(x)
    return [[role, d[role]] for role in roles]

role2toks = dict(df.rdd.flatMap(_mapper).reduceByKey(set.union).collect())
role2toks

{'[NP_SIMPLE]': {'actor',
  'artdirector',
  'character',
  'cinematographer',
  'company',
  'costumedesigner',
  'film',
  'filmdirector',
  'filmdistributor',
  'filmeditor',
  'filmproducer',
  'person',
  'productioncompany',
  'screenwriter'},
 '[entity]': {'M0', 'M1', 'M2', 'M3', 'M4', 'M5', 'M6', 'M7', 'M8', 'M9'},
 '[ADJECTIVE_SIMPLE]': {'American',
  'British',
  'Canadian',
  'Chinese',
  'Dutch',
  'French',
  'German',
  'Italian',
  'Japanese',
  'Mexican',
  'Spanish',
  'Swedish',
  'female',
  'male'},
 '[VP_SIMPLE]': {'acquire',
  'acquired',
  'direct',
  'directed',
  'distribute',
  'distributed',
  'edit',
  'edited',
  'employ',
  'employed',
  'executiveproduce',
  'executiveproduced',
  'found',
  'founded',
  'influence',
  'influenced',
  'married',
  'marry',
  'play',
  'played',
  'produce',
  'produced',
  'star',
  'starred',
  'write',
  'written',
  'wrote'},
 '[ROLE_SIMPLE]': {'actor',
  'artdirector',
  'child',
  'cinematographer',
  'costumedesigne

In [90]:
special_roles = ['[VP_SIMPLE]', '[ROLE_SIMPLE]']

def _mapper(r):
    uniq_toks = set()
    for role in special_roles:
        for tok in role2toks[role]:
            if tok in r['questionPatternModEntities'].split(' '):
                uniq_toks.add(tok)
    _, typs, _ = list(zip(*map(find_rel, r['sparqlPatternModEntities'].split('\n')[1 : -1])))
    uniq_typs = set(typs)
    return list(product(uniq_toks, uniq_typs)), uniq_toks, uniq_typs

occs = df.rdd.map(_mapper).persist()
both, c_tok, c_rel = map(count, (occs.flatMap(at(0)), occs.flatMap(at(1)), occs.flatMap(at(2))))

In [114]:
def _fn(by):
    for key, grp in groupby(sorted(([tok, [typ, c]] if by == 'tok' else [typ, [tok, c]]
                                    for [tok, typ], c in both.items()), key=at(0)), at(0)):
        _, c = zip(*grp)
        print(key, (c_tok if by == 'tok' else c_rel)[key])
        print(*islice(sorted(c, key=at(1), reverse=True), 3), sep='\n')
        
_fn('tok')
print()
_fn('typ')

acquire 1830
['ns:organization.organization.companies_acquired/ns:business.acquisition.company_acquired', 1307]
['a', 1086]
['ns:organization.organization.acquired_by/ns:business.acquisition.acquiring_company', 540]
acquired 2625
['ns:organization.organization.acquired_by/ns:business.acquisition.acquiring_company', 2246]
['a', 1826]
['ns:organization.organization.companies_acquired/ns:business.acquisition.company_acquired', 809]
actor 20287
['a', 15062]
['ns:film.actor.film/ns:film.performance.character', 9219]
['ns:people.person.nationality', 7425]
artdirector 31822
['ns:film.film_art_director.films_art_directed', 22744]
['a', 17117]
['ns:film.editor.film', 9740]
child 14955
['ns:people.person.parents|ns:fictional_universe.fictional_character.parents|ns:organization.organization.parent/ns:organization.organization_relationship.parent', 14747]
['a', 9731]
['ns:people.person.nationality', 4293]
cinematographer 37256
['ns:film.cinematographer.film', 26135]
['a', 20010]
['ns:film.editor.f

In [249]:
SEP, NIL = '{SEP}', '{NIL}'

In [None]:
isvar = lambda tok: tok.startswith('?x')
idx2var = sorted(set(chain(filter(isvar, srcs), filter(isvar, dsts))))
idx2var

In [212]:
idx2tok = [SEP, NIL] + idx2var + unique(df.rdd.flatMap(lambda r: r['questionPatternModEntities'].split(' ')))
tok2idx = {tok : idx for idx, tok in enumerate(idx2tok)}
idx2tok

['{SEP}',
 '{NIL}',
 '?x0',
 '?x1',
 '?x2',
 '?x3',
 '?x4',
 '?x5',
 "'s",
 ',',
 'American',
 'British',
 'Canadian',
 'Chinese',
 'Did',
 'Dutch',
 'French',
 'German',
 'Italian',
 'Japanese',
 'M0',
 'M1',
 'M2',
 'M3',
 'M4',
 'M5',
 'M6',
 'M7',
 'M8',
 'M9',
 'Mexican',
 'Spanish',
 'Swedish',
 'Was',
 'Were',
 'What',
 'Which',
 'Who',
 'a',
 'acquire',
 'acquired',
 'actor',
 'and',
 'artdirector',
 'by',
 'character',
 'child',
 'cinematographer',
 'company',
 'costumedesigner',
 'countryofnationality',
 'did',
 'direct',
 'directed',
 'director',
 'distribute',
 'distributed',
 'distributor',
 'edit',
 'edited',
 'editor',
 'employ',
 'employed',
 'employee',
 'employer',
 'executiveproduce',
 'executiveproduced',
 'executiveproducer',
 'female',
 'film',
 'filmdirector',
 'filmdistributor',
 'filmeditor',
 'filmproducer',
 'found',
 'founded',
 'founder',
 'gender',
 'influence',
 'influenced',
 'male',
 'married',
 'marry',
 'of',
 'parent',
 'person',
 'play',
 'played',


In [210]:
idx2tag = [SEP, NIL] + idx2var + unique(df.rdd.flatMap(lambda r: r['questionTemplate'].split(' ')))
tag2idx = {tag : idx for idx, tag in enumerate(idx2tag)}
idx2tag

['{SEP}',
 '{NIL}',
 '?x0',
 '?x1',
 '?x2',
 '?x3',
 '?x4',
 '?x5',
 "'s",
 ',',
 'Did',
 'Was',
 'Were',
 'What',
 'Which',
 'Who',
 '[ADJECTIVE_SIMPLE]',
 '[NP_SIMPLE]',
 '[ROLE_SIMPLE]',
 '[VP_SIMPLE]',
 '[entity]',
 'a',
 'and',
 'by',
 'did',
 'of',
 'that',
 'was',
 'were',
 'whose']

In [260]:
r = '(?:%s)' % '|'.join(fr'\[{role[1 : -1]}\]' for role in roles)
p = re.compile(fr'{r} and {r}|(?:{r} , )+and {r}')
def grp_by_tag(tags):
    lens = np.array(list(map(len, tags)))
    ends = np.cumsum(lens) + np.arange(len(tags))
    starts = ends - lens

    t = ' '.join(tags)
    if re.search(p, t) is None:
        grps = [[i] for i in range(len(tags))]
        return grps
    
    m_start, m_end = zip(*([m.start(), m.end()] for m in re.finditer(p, t)))
    hit = False
    grps = []
    for idx, [start, end] in enumerate(zip(starts, ends)):
        if start in m_start:
            hit = True
            grps.append([])
        if hit:
            grps[-1].append(idx)
        else:
            grps.append([idx])
        if end in m_end:
            hit = False
    
    for start, end, grp in zip(m_start, m_end, (grp for grp in grps if len(grp) > 1)):
        assert t[start : end] == ' '.join(tags[idx] for idx in grp)

    return grps

def _mapper(r):
    rels = list(map(find_rel, r['sparqlPatternModEntities'].split('\n')[1 : -1]))
    srcs, typs, dsts = zip(*rels)
    ents = sorted({x for x in chain(srcs, dsts) if re.match('M\d', x) or re.match('\?x\d', x)})

    tail = [SEP] + sorted(ent for ent in ents if ent.startswith('?x')) + [NIL]
    toks = r['questionPatternModEntities'].split(' ') + tail
    tags = r['questionTemplate'].split(' ') + tail
    grps = grp_by_tag(tags)

    seq = [tag2idx[tags[idx]] for idx, *_ in grps]
    mem = [[tok2idx[toks[grp[0]]]] if len(grp) == 1 else
           [toks[idx] for idx in grp if tags[idx] in roles] for grp in grps]
    
    ent2grp = {}
    for idx, tok in zip(chain(*(len(grp) * [idx] for idx, grp in enumerate(grps))), toks):
        if tok in ents:
            ent2grp[tok] = idx
    idx2grp = sorted(ent2grp.values())
    ent2idx = {ent: idx2grp.index(ent2grp[ent]) for ent in ents}
    _, idx2ent = zip(*sorted(ent2idx.items(), key=at(1)))
    
    # filters
    filters = [[ent2idx[src], ent2idx[dst]] for src, typ, dst in rels if typ == '!=']
    
    # attributes
    ent2attr = np.zeros([len(idx2grp), len(idx2attr)])
    for src, _, dst in rels:
        if dst in idx2attr:
            ent2attr[ents.index(src), idx2attr.index(dst)] = 1

    # groundable relations
    gr_rels = [[ent2idx[src], ent2idx[dst], idx2typ.index(typ)] for src, typ, dst in rels if typ in idx2typ]

    return filters, ent2attr, gr_rels, seq, mem, idx2grp, idx2ent

dat = {}
collect = lambda rdd: np.array(rdd.collect())

# print(_mapper(df.rdd.take(5)[-1]))
rdd = df.rdd.map(_mapper).cache()
dat['n_filter'] = collect(rdd.map(at(0)).map(len))
dat['filter'] = collect(rdd.flatMap(at(0)))
dat['attr'] = np.vstack(rdd.map(at(1)).collect())
dat['n_rel'] = collect(rdd.map(at(2)).map(len))
dat['src'] = collect(rdd.flatMap(at(2)).map(at(0)))
dat['dst'] = collect(rdd.flatMap(at(2)).map(at(1)))
dat['typ'] = collect(rdd.flatMap(at(2)).map(at(2)))
dat['seq'] = collect(rdd.flatMap(at(3)))
dat['n_grp'] = collect(rdd.map(at(4)).map(len))
dat['n_mem'] = collect(rdd.flatMap(at(4)).map(len))
dat['mem'] = collect(rdd.flatMap(at(4)).flatMap(lambda r: r))
dat['n'] = collect(rdd.map(at(5)).map(len))

In [261]:
pickle.dump([idx2tok, tok2idx], open(f'{output_dir}/tok-vocab.pickle', 'wb'))
pickle.dump([idx2tag, tag2idx], open(f'{output_dir}/tag-vocab.pickle', 'wb'))
pickle.dump([idx2typ, typ2idx], open(f'{output_dir}/typ-vocab.pickle', 'wb'))
np.savez(f'{output_dir}/data', **dat)

## Variable prediction

In [17]:
d = {}
tok_rdd = df.rdd.map(lambda r: [tok2idx[tok] for tok in r['questionPatternModEntities'].split(' ')])
d['seq'] = fcollect(tok_rdd)
d['n_tok'] = collect(tok_rdd.map(len))
d['n_var'] = collect(df.rdd.map(lambda r: len(set(re.findall(r'\?x[0-9]', r['sparqlPatternModEntities'])))))
np.savez(f'{output_dir}/nvar', **d)