In [1]:
import os
import re
import json
from tqdm import tqdm

In [2]:
raw_train = json.load(open('./preprocess/raw_train.json'))
raw_test = json.load(open('./preprocess/raw_test.json'))

In [3]:
class Preprocess(object):
    def __init__(self):
        
        ent_labels = json.load(open('./preprocess/all_entity.json', 'rb'))
        rel_labels = json.load(open('./preprocess/all_relation.json', 'rb'))
        
        vocab=['"', '(', 'rdfs:label', 'by', 'ask', '>', 'select', 'que', 'limit', 'jai', 'mai', 
        '?sbj', ')', 'lang', 'year', '}', '?value', 'peint', 'desc', 'where', 'ce', 'distinct', 
       'filter', 'lcase', 'order', 'la', '<', 'asc', 'en', 'contains', 'as', ',', 'strstarts', 
       '{', "'", 'j', 'count', '=', '.', '?vr0', '?vr1', '?vr2', '?vr3', '?vr4', '?vr5', '?vr6', 
       '?vr0_label', '?vr1_label', '?vr2_label', '?vr3_label', '?vr4_label', '?vr5_label', '?vr6_label',
       'wd:', 'wdt:', 'ps:', 'p:', 'pq:', '?maskvar1', '[DEF]','null']

        vocab_dict={}
        for i,text in enumerate(vocab):
            vocab_dict[text]='<extra_id_'+str(i)+'>'

        for kk in ent_labels:
            if ent_labels[kk] is None: ent_labels[kk] = vocab_dict['null']

        self.ent_labels = ent_labels
        self.rel_labels = rel_labels
        self.vocab_dict = vocab_dict

    
    def _preprocess(self, data):
        wikisparql = data['sparql_wikidata']
        raw_question = data['question']
        if raw_question is None:
            raw_question = data['NNQT_question']
        raw_question = raw_question.replace('}','').replace('{','')

        sparql = wikisparql.replace('(',' ( ').replace(')',' ) ').replace('{',' { ')\
        .replace('}',' } ').replace(':',': ').replace(',',' , ').replace("'"," ' ")\
        .replace('.',' . ').replace('=',' = ').lower()
        sparql = ' '.join(sparql.split())
        
        _ents = re.findall( r'wd: (?:.*?) ', sparql)
        # _ents_for_labels = re.findall( r'wd: (.*?) ', sparql)
        
        _rels = re.findall( r'wdt: (?:.*?) ',sparql)
        _rels += re.findall( r' p: (?:.*?) ',sparql)
        _rels += re.findall( r' ps: (?:.*?) ',sparql)
        _rels += re.findall( r'pq: (?:.*?) ',sparql)
        
        # _rels_for_labels = re.findall( r'wdt: (.*?) ',sparql)
        # _rels_for_labels += re.findall( r' p: (.*?) ',sparql)
        # _rels_for_labels += re.findall( r' ps: (.*?) ',sparql)
        # _rels_for_labels += re.findall( r'pq: (.*?) ',sparql)

#         for j in range(len(_ents_for_labels)):
#             if '}' in _ents[j]: 
#                 _ents[j]=''
#             _ents[j] = _ents[j] + self.ent_labels[_ents_for_labels[j]]+' '
            
#         for j in range(len(_rels_for_labels)):
#             if _rels_for_labels[j] not in self.rel_labels:
#                 self.rel_labels[_rels_for_labels[j]] = self.vocab_dict['null']
#             _rels[j] = _rels[j] + self.rel_labels[_rels_for_labels[j]]+' '

        # _ents += _rels
    
        newvars = ['?vr0','?vr1','?vr2','?vr3','?vr4','?vr5']
        
        variables = set([x for x in sparql.split() if x[0] == '?'])
        for idx,var in enumerate(sorted(variables)):
            if var == '?maskvar1':
                continue         
            sparql = sparql.replace(var,newvars[idx])
            
        split = sparql.split()
        for idx, item in enumerate(split):
            if item in self.ent_labels:
                split[idx] = self.ent_labels[item]
            elif item in self.rel_labels:
                split[idx] = self.rel_labels[item]

            if item in self.vocab_dict:
                split[idx] = self.vocab_dict[item]
        
        gold_query = ' '.join(split).strip()
        
        question = raw_question
        tail = ''
        
        for ent in _ents:
            ent = ent.replace('wd:',self.vocab_dict['wd:']+' ')

            ent_split = ent.split()
            # index = 2 if bool(re.match('[pq][0-9]+', ent_split[1])) else 1
            tail = tail+' '+self.vocab_dict['[DEF]']+' '+ ' '.join(ent_split)

        for rel in _rels:
            rel=rel.replace('wdt:', self.vocab_dict['wdt:']+' ')
            rel=rel.replace('p:', self.vocab_dict['p:']+' ')
            rel=rel.replace('ps:', self.vocab_dict['ps:']+' ')
            rel=rel.replace('pq:', self.vocab_dict['pq:']+' ')

            rel_split = rel.split()
            # index = 2 if bool(re.match('[pq][0-9]+', rel_split[1])) else 1
            tail = tail+' '+self.vocab_dict['[DEF]']+' '+' '.join(rel_split)
        
        tail_split = tail.split()
        for idx, item in enumerate(tail_split):
            if item in self.ent_labels:
                tail_split[idx] = self.ent_labels[item]
            elif item in self.rel_labels:
                tail_split[idx] = self.rel_labels[item]
                
            
        schema = ' '.join(tail_split).strip()
        question_input = ' '.join(question.split()).strip()+' '+self.vocab_dict['[DEF]']+ ' ' +schema
        
        res = {
                'input': question_input,    
                'target': gold_query,
               }

        return res

In [4]:
pre = Preprocess()

In [5]:
train = [pre._preprocess(item) for item in tqdm(raw_train)]
with open('train.json','w+') as file:
    file.write(json.dumps(train, indent=2))
    
test = [pre._preprocess(item) for item in tqdm(raw_test)]
with open('test.json','w+') as file:
    file.write(json.dumps(test, indent=2))

100%|██████████| 24180/24180 [00:03<00:00, 6328.23it/s]
100%|██████████| 6046/6046 [00:01<00:00, 5553.86it/s]
