In [None]:
%reload_ext autoreload
%autoreload 2

csqa_directory="/home/ubuntu/Desktop/CSQA_v9/train" 
wikidata_directory="/home/ubuntu/Desktop/wikidata"

import sys
import os
import time
import argparse
from tqdm import tqdm
import pickle
from collections import defaultdict
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(os.path.abspath(''), os.pardir)))))

from allennlp.common.testing.test_case import AllenNlpTestCase
from allennlp.common import Params
from allennlp.data.dataset_readers.semantic_parsing.csqa.csqa import CSQADatasetReader
from allennlp.state_machines.states.grammar_statelet import GrammarStatelet
from allennlp.state_machines.states.grammar_based_search_state import GrammarBasedSearchState
from allennlp.semparse.domain_languages.csqa_language import CSQALanguage
from allennlp.semparse.domain_languages.csqa_language import Entity, Predicate
from allennlp.semparse.domain_languages import START_SYMBOL

from pathlib import Path

In [None]:
params = {'lazy': True,
          'kg_path':  f'{wikidata_directory}/wikidata_short_1_2_rev.p',
          'kg_type_data_path':  f'{wikidata_directory}/par_child_dict_full.p',
          'entity_id2string_path':  f'{AllenNlpTestCase.FIXTURES_ROOT}/data/csqa/sample_entity_id2string.json',
          'predicate_id2string_path': f'{AllenNlpTestCase.FIXTURES_ROOT}/data/csqa/filtered_property_wikidata4.json'
         }

reader = CSQADatasetReader.from_params(Params(params))
# qa_path = f'{AllenNlpTestCase.FIXTURES_ROOT}/data/csqa/sample_qa.json'
qa_path = f'{AllenNlpTestCase.FIXTURES_ROOT}/data/csqa/sample_train'
dataset = reader.read(qa_path)

In [None]:
def get_initial_state(language):
    language_valid_actions = defaultdict(list)
    for production_rule in language.all_possible_productions():
        lhs, rhs = production_rule.split(' -> ')
        language_valid_actions[lhs].append(rhs)
    return GrammarBasedSearchState(action_history=[],
                                   nonterminal_stack=[START_SYMBOL],
                                   valid_actions=language_valid_actions,
                                   is_nonterminal=language.is_nonterminal)

def search(language, expected_result, verbose=False, max_depth=17, max_time=30, stop_after_n_found=1):
    language.set_search_modus()
    states, correct_action_sequences = [get_initial_state(language)], []
    depth = 0
    finished = False
    start_search_time = time.time()
    
    while depth < max_depth:
        if time.time() - start_search_time > max_time:
            break
        depth += 1
        next_states = []
        if finished: break
        for state in states:
            if state.is_finished():
                if language.execute_action_sequence(state.action_history) == expected_result:
                    correct_action_sequences.append(state.action_history)
                    if len(correct_action_sequences) >= stop_after_n_found:
                        finished = True
                        break
                continue

            state_valid_actions = [v for v in state.get_valid_actions() if
                                   v.count(',') + v.count(':') + 1 <= (max_depth - depth)]            

            if depth < max_depth:
                for valid_action in state_valid_actions:
                    if depth == 1:
                        if (isinstance(expected_result, set) and 'Set' not in valid_action) or \
                           (isinstance(expected_result, int) and 'Number' not in valid_action) or \
                           (isinstance(expected_result, bool) and 'bool' not in valid_action):
                            continue
                                
                    production_rule = state._nonterminal_stack[-1] + " -> " + valid_action
                    next_states.append(state.take_action(production_rule))
        states = next_states
    
    if verbose:
        print("\nfinished at depth {} in {} seconds".format(depth, time.time()-instance_start_time))
        
    return correct_action_sequences

In [None]:
print_instance_info = True
logical_form_result_dict = {}

for i, instance in enumerate(dataset):

    question = instance['question'].tokens
    language = instance['world'].metadata
    expected_result = instance['expected_result'].metadata
    qa_id = instance['qa_id'].metadata

    result_action_sequences = search(language, expected_result)
    
    logical_form_result_dict[qa_id] = result_action_sequences

# pickle.dump(logical_form_result_dict, open(f'{AllenNlpTestCase.FIXTURES_ROOT}/data/csqa/sample_train_action_sequences.p', "wb"))