In [3]:
try:
    __IPYTHON__
    USING_IPYTHON = True
    %load_ext autoreload
    %autoreload 2
except NameError:
    USING_IPYTHON = False

#### Argparse

In [4]:
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('project_root', help='')
ap.add_argument('--mrp-data-dir', default='data', help='')
ap.add_argument('--graphviz-sub-dir', default='visualization/graphviz', help='')
ap.add_argument('--train-sub-dir', default='training', help='')
ap.add_argument('--companion-sub-dir', default='companion')
ap.add_argument('--jamr-alignment-file', default='jamr.mrp')


ap.add_argument('--mrp-file-extension', default='.mrp')
ap.add_argument('--companion-file-extension', default='.conllu')
ap.add_argument('--graphviz-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png')
arg_string = """
    /data/proj29_ds1/home/slai/mrp2019
"""
arguments = [arg for arg_line in arg_string.split(r'\\n') for arg in arg_line.split()]

In [5]:
if USING_IPYTHON:
    args = ap.parse_args(arguments)
else:
    args = ap.parse_args()

In [6]:
args

Namespace(companion_file_extension='.conllu', companion_sub_dir='companion', graphviz_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png', graphviz_sub_dir='visualization/graphviz', jamr_alignment_file='jamr.mrp', mrp_data_dir='data', mrp_file_extension='.mrp', project_root='/data/proj29_ds1/home/slai/mrp2019', train_sub_dir='training')

#### Library imports

In [464]:
import json
import logging
import os
import pprint
import string
from collections import Counter, defaultdict, deque

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plot_util
from preprocessing import CompanionParseDataset, MrpDataset, JamrAlignmentDataset
from action_state import mrp_json2parser_states
                           
from tqdm import tqdm

#### ipython notebook specific imports

In [8]:
if USING_IPYTHON:
    # matplotlib config
    %matplotlib inline

In [9]:
sh = logging.StreamHandler()
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
sh.setFormatter(formatter)
logging.basicConfig(level=logging.INFO, handlers=[sh])
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.setLevel(logging.INFO)

### Constants

In [10]:
UNKWOWN = 'UNKWOWN'

### Load data

In [11]:
train_dir = os.path.join(args.project_root, args.mrp_data_dir, args.train_sub_dir)

In [12]:
mrp_dataset = MrpDataset()

In [13]:
frameworks, framework2dataset2mrp_jsons = mrp_dataset.load_mrp_json_dir(
    train_dir, args.mrp_file_extension)

frameworks:   0%|          | 0/5 [00:00<?, ?it/s]
dataset_name:   0%|          | 0/2 [00:00<?, ?it/s][A
dataset_name:  50%|█████     | 1/2 [00:00<00:00,  2.31it/s][A
frameworks:  20%|██        | 1/5 [00:00<00:03,  1.26it/s]s][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  40%|████      | 2/5 [00:04<00:05,  1.67s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  60%|██████    | 3/5 [00:09<00:05,  2.75s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  80%|████████  | 4/5 [00:16<00:04,  4.04s/it]t][A
dataset_name:   0%|          | 0/14 [00:00<?, ?it/s][A
dataset_name:  43%|████▎     | 6/14 [00:00<00:00, 21.72it/s][A
dataset_name:  57%|█████▋    | 8/14 [00:00<00:00, 18.69it/s][A
dataset_name:  71%|███████▏  | 10/14 [00:01<00:00,  6.30it/s][A
dataset_name:  79%|███████▊  | 11/14 [00:01<00:00,  5.68it/s][A
frameworks: 100%|██████████| 5/5 [00:18<00:00,  3.29s/it]t/s][A


### Data Preprocessing companion

In [14]:
companion_dir = os.path.join(args.project_root, args.mrp_data_dir, args.companion_sub_dir)

In [15]:
cparse_dataset = CompanionParseDataset()

In [16]:
dataset2cid2parse = cparse_dataset.load_companion_parse_dir(companion_dir, args.companion_file_extension)

preprocessing - INFO - framework amr found
dataset: 100%|██████████| 13/13 [00:04<00:00,  3.12it/s]
preprocessing - INFO - framework dm found
dataset: 100%|██████████| 5/5 [00:01<00:00,  4.35it/s]
preprocessing - INFO - framework ucca found
dataset: 100%|██████████| 6/6 [00:00<00:00, 30.86it/s]


In [17]:
dataset2cid2parse.keys()

dict_keys(['amr-guidelines', 'bolt', 'cctv', 'dfa', 'dfb', 'fables', 'lorelei', 'mt09sdl', 'proxy', 'rte', 'wb', 'wiki', 'xinhua', 'wsj', 'ewt'])

In [18]:
# Some data is missing
'20003001' in dataset2cid2parse['wsj']

False

### Load JAMR alignment data

In [19]:
jalignment_dataset = JamrAlignmentDataset()

In [20]:
cid2alignment = jalignment_dataset.load_jamr_alignment_file(os.path.join(
    args.project_root,
    args.mrp_data_dir,
    args.companion_sub_dir,
    args.jamr_alignment_file
))

### Define the state at each step

In [424]:
for framework in framework2dataset2mrp_jsons:
    logger.info(framework)
    logger.info(list(framework2dataset2mrp_jsons[framework].keys()))

__main__ - INFO - ucca
__main__ - INFO - ['wiki', 'ewt']
__main__ - INFO - psd
__main__ - INFO - ['wsj']
__main__ - INFO - eds
__main__ - INFO - ['wsj']
__main__ - INFO - dm
__main__ - INFO - ['wsj']
__main__ - INFO - amr
__main__ - INFO - ['xinhua', 'wsj', 'wiki', 'wb', 'rte', 'proxy', 'mt09sdl', 'lorelei', 'fables', 'dfb', 'dfa', 'cctv', 'bolt', 'amr-guidelines']


### Test module

In [660]:
from action_state import mrp_json2parser_states

In [661]:
framework, dataset = [
    ('dm', 'wsj'),
    ('psd', 'wsj'),
    ('eds', 'wsj'),
    ('ucca', 'wiki'),
    ('amr', 'wsj'),
    ('amr', 'wiki'),
][2]

mrp_jsons = framework2dataset2mrp_jsons[framework][dataset]
framework, dataset

('eds', 'wsj')

In [662]:
mrp_json = mrp_jsons[0]

In [663]:
alignment = {}
if framework == 'amr':
    cid = mrp_json.get('id', '')
    alignment = cid2alignment[cid]

In [664]:
# mrp_json

In [665]:
logger.info(args.graphviz_file_template.format(
    framework, dataset, mrp_json.get('id')))

__main__ - INFO - http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/eds/wsj.mrp/20001001.png


In [666]:
# for i in range(1000):
#     mrp_json = mrp_jsons[i]
#     parser_states, meta_data = mrp_json2parser_states(mrp_json, framework, alignment)

In [667]:
parser_states, meta_data = mrp_json2parser_states(mrp_json, framework, alignment)

action_state - INFO - ('curr_node_id', 3)
action_state - INFO - (3, [], True, True)
action_state - INFO - (3, 3, [(3, 3, None)])
action_state - INFO - (3, 17, 1, {4})
action_state - INFO - ('curr_node_id', 4)
action_state - INFO - (4, [(3, 3, None)], True, True)
action_state - INFO - (4, 4, [(3, 3, None), (4, 4, None)])
action_state - INFO - (4, 3, 10, {12})
action_state - INFO - (4, 15, 1, set())
action_state - INFO - ('complete', 1)
action_state - INFO - (4, 7, 9, set())
action_state - INFO - ('complete', 9)
action_state - INFO - ('curr_node_id', 1)
action_state - INFO - (1, [(3, 3, None), (4, 4, None)], False, True)
action_state - INFO - (1, 1, [(1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)])])
action_state - INFO - ('curr_node_id', 7)
action_state - INFO - (7, [(1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)])], False, False)
action_state - INFO - (7, 7, [(1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)])])
action_state - INFO - ('curr_node_id', 8)
action_state - INFO - (8, [(1

action_state - INFO - (22, 22, [(13, 13, [(13, 13, None), (15, 15, [(15, 15, None), (16, 16, None)]), (10, 10, [(10, 10, None), (12, 12, None), (5, 5, [(5, 5, None), (9, 9, [(9, 9, None), (7, 7, [(7, 7, None), (8, 8, None)])]), (1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)])])])]), (18, 18, None), (22, 22, None)])
action_state - INFO - (22, 0, 17, set())
action_state - INFO - ('complete', 17)
action_state - INFO - (22, 16, 20, set())
action_state - INFO - ('complete', 20)
action_state - INFO - ('curr_node_id', 17)
action_state - INFO - (17, [(13, 13, [(13, 13, None), (15, 15, [(15, 15, None), (16, 16, None)]), (10, 10, [(10, 10, None), (12, 12, None), (5, 5, [(5, 5, None), (9, 9, [(9, 9, None), (7, 7, [(7, 7, None), (8, 8, None)])]), (1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)])])])]), (18, 18, None), (22, 22, None)], False, True)
action_state - INFO - (17, 17, [(13, 13, [(13, 13, None), (15, 15, [(15, 15, None), (16, 16, None)]), (10, 10, [(10, 10, None), (12, 12, None), (5, 5

In [668]:
mrp_json['input']

'Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov. 29.'

In [669]:
(
    node_id2node,
    edge_id2edge,
    parent_id2child_id_set,
    child_id2parent_id_set,
    child_id2edge_id_set,
    parent_id2edge_id_set,
    parent_id2indegree,
    token_nodes,
    abstract_node_id_set,
) = meta_data

In [643]:
child_id2parent_id_set

defaultdict(set,
            {22: {17, 20},
             8: {5, 7},
             18: {20},
             4: {1, 9, 10},
             10: {13, 17},
             16: {13, 15},
             12: {10},
             3: {1},
             9: {5}})

In [644]:
token_nodes

[{'id': 3,
  'label': 'named',
  'properties': ['carg'],
  'values': ['Pierre'],
  'anchors': [{'from': 0, 'to': 6}]},
 {'id': 4,
  'label': 'named',
  'properties': ['carg'],
  'values': ['Vinken'],
  'anchors': [{'from': 7, 'to': 14}]},
 {'id': 7,
  'label': 'card',
  'properties': ['carg'],
  'values': ['61'],
  'anchors': [{'from': 15, 'to': 17}]},
 {'id': 8, 'label': '_year_n_1', 'anchors': [{'from': 18, 'to': 23}]},
 {'id': 9, 'label': '_old_a_1', 'anchors': [{'from': 24, 'to': 28}]},
 {'id': 10, 'label': '_join_v_1', 'anchors': [{'from': 34, 'to': 38}]},
 {'id': 11, 'label': '_the_q', 'anchors': [{'from': 39, 'to': 42}]},
 {'id': 12, 'label': '_board_n_of', 'anchors': [{'from': 43, 'to': 48}]},
 {'id': 13, 'label': '_as_p', 'anchors': [{'from': 49, 'to': 51}]},
 {'id': 14, 'label': '_a_q', 'anchors': [{'from': 52, 'to': 53}]},
 {'id': 15,
  'label': '_nonexecutive_a_unknown',
  'anchors': [{'from': 54, 'to': 66}]},
 {'id': 16, 'label': '_director_n_of', 'anchors': [{'from': 67, 

In [645]:
edge_id2edge

{0: {'source': 17, 'target': 22, 'label': 'ARG2'},
 1: {'source': 5, 'target': 8, 'label': 'ARG2'},
 2: {'source': 20, 'target': 18, 'label': 'ARG2'},
 3: {'source': 10, 'target': 4, 'label': 'ARG1'},
 4: {'source': 14, 'target': 16, 'label': 'BV'},
 5: {'source': 7, 'target': 8, 'label': 'ARG1'},
 6: {'source': 0, 'target': 4, 'label': 'BV'},
 7: {'source': 9, 'target': 4, 'label': 'ARG1'},
 8: {'source': 17, 'target': 10, 'label': 'ARG1'},
 9: {'source': 19, 'target': 22, 'label': 'BV'},
 10: {'source': 21, 'target': 18, 'label': 'BV'},
 11: {'source': 6, 'target': 8, 'label': 'BV'},
 12: {'source': 13, 'target': 10, 'label': 'ARG1'},
 13: {'source': 13, 'target': 16, 'label': 'ARG2'},
 14: {'source': 10, 'target': 12, 'label': 'ARG2'},
 15: {'source': 1, 'target': 4, 'label': 'ARG1'},
 16: {'source': 20, 'target': 22, 'label': 'ARG1'},
 17: {'source': 1, 'target': 3, 'label': 'ARG2'},
 18: {'source': 5, 'target': 9, 'label': 'ARG1'},
 19: {'source': 2, 'target': 3, 'label': 'BV'},
 

In [646]:
for parent_id, child_id_set in sorted(parent_id2child_id_set.items()):
    print(
        (parent_id, node_id2node[parent_id].get('label')),
        [(child_id, node_id2node[child_id].get('label')) for child_id in child_id_set]
    )

(1, 'compound') [(3, 'named'), (4, 'named')]
(5, 'measure') [(8, '_year_n_1'), (9, '_old_a_1')]
(7, 'card') [(8, '_year_n_1')]
(9, '_old_a_1') [(4, 'named')]
(10, '_join_v_1') [(4, 'named'), (12, '_board_n_of')]
(13, '_as_p') [(16, '_director_n_of'), (10, '_join_v_1')]
(15, '_nonexecutive_a_unknown') [(16, '_director_n_of')]
(17, 'loc_nonsp') [(10, '_join_v_1'), (22, 'dofm')]
(20, 'of_p') [(18, 'mofy'), (22, 'dofm')]


In [647]:
for child_id, parent_id_set in sorted(child_id2parent_id_set.items()):
    print(
        (child_id, node_id2node[child_id].get('label')),
        [(parent_id, node_id2node[parent_id].get('label')) for parent_id in parent_id_set]
    )

(3, 'named') [(1, 'compound')]
(4, 'named') [(9, '_old_a_1'), (10, '_join_v_1'), (1, 'compound')]
(8, '_year_n_1') [(5, 'measure'), (7, 'card')]
(9, '_old_a_1') [(5, 'measure')]
(10, '_join_v_1') [(17, 'loc_nonsp'), (13, '_as_p')]
(12, '_board_n_of') [(10, '_join_v_1')]
(16, '_director_n_of') [(13, '_as_p'), (15, '_nonexecutive_a_unknown')]
(18, 'mofy') [(20, 'of_p')]
(22, 'dofm') [(17, 'loc_nonsp'), (20, 'of_p')]


In [648]:
for parent_id, indegree in sorted(parent_id2indegree.items()):
    print((parent_id, node_id2node[parent_id].get('label'), indegree))

(1, 'compound', 2)
(5, 'measure', 2)
(7, 'card', 1)
(9, '_old_a_1', 1)
(10, '_join_v_1', 2)
(13, '_as_p', 2)
(15, '_nonexecutive_a_unknown', 1)
(17, 'loc_nonsp', 2)
(20, 'of_p', 2)


In [649]:
# mrp_json.get('edges')[17]

In [670]:
parser_node_id_set = set()
parser_edge_id_set = set()
for (node_id, actions, edge_state, abstract_node_state, 
     complete_node_state, node_state, token_stack, pending_token_stack) in parser_states:
    parser_node_id_set.add(node_id)
    for edge_id in edge_state:
        parser_edge_id_set.add(edge_id)
    
    node = node_id2node[node_id]
    node_edges = [edge_id2edge[edge_id] for edge_id in edge_state]
    pprint.pprint((
        node.get('id'),
        actions, 
        node.get('label'), 
        [edge.get('label') for edge in node_edges], 
#         abstract_node_state,
        complete_node_state,
        node_state,
        token_stack,
        pending_token_stack,
    ))
    
print({node.get('id', -1) for node in mrp_json.get('nodes')} - parser_node_id_set)
assert len(parser_node_id_set) == len(mrp_json.get('nodes'))
print({edge_id for edge_id, edge in enumerate(mrp_json.get('edges'))} - parser_edge_id_set)
assert len(parser_edge_id_set) == len(mrp_json.get('edges'))

(3, [(1, None)], 'named', [], [3], [(3, 3, None)], [3], [])
(4,
 [(1, None)],
 'named',
 ['ARG1'],
 [4, 1],
 [(3, 3, None), (4, 4, None)],
 [3, 4],
 [])
(1,
 [(2, 2)],
 'compound',
 [],
 [],
 [(1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)])],
 [1],
 [])
(7,
 [(4, None)],
 'card',
 [],
 [],
 [(1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)])],
 [1],
 [7])
(8,
 [(1, None)],
 '_year_n_1',
 ['ARG1'],
 [8, 7],
 [(1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)]), (8, 8, None)],
 [1, 8],
 [7])
(7,
 [(2, 1)],
 'card',
 [],
 [],
 [(1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)]),
  (7, 7, [(7, 7, None), (8, 8, None)])],
 [1, 7],
 [7])
(9,
 [(2, 1)],
 '_old_a_1',
 ['ARG1'],
 [9, 5],
 [(1, 1, [(1, 1, None), (4, 4, None), (3, 3, None)]),
  (9, 9, [(9, 9, None), (7, 7, [(7, 7, None), (8, 8, None)])])],
 [1, 9],
 [7])
(5,
 [(2, 2)],
 'measure',
 [],
 [],
 [(5,
   5,
   [(5, 5, None),
    (9, 9, [(9, 9, None), (7, 7, [(7, 7, None), (8, 8, None)])]),
    (1, 1, [(1, 1, None), (4, 4, None), (3, 

AssertionError: 

In [671]:
mrp_json

{'id': '20001001',
 'flavor': 1,
 'framework': 'eds',
 'version': 0.9,
 'time': '2019-04-10 (20:21)',
 'input': 'Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov. 29.',
 'tops': [10],
 'nodes': [{'id': 0, 'label': 'proper_q', 'anchors': [{'from': 0, 'to': 28}]},
  {'id': 1, 'label': 'compound', 'anchors': [{'from': 0, 'to': 14}]},
  {'id': 2, 'label': 'proper_q', 'anchors': [{'from': 0, 'to': 6}]},
  {'id': 3,
   'label': 'named',
   'properties': ['carg'],
   'values': ['Pierre'],
   'anchors': [{'from': 0, 'to': 6}]},
  {'id': 4,
   'label': 'named',
   'properties': ['carg'],
   'values': ['Vinken'],
   'anchors': [{'from': 7, 'to': 14}]},
  {'id': 5, 'label': 'measure', 'anchors': [{'from': 15, 'to': 23}]},
  {'id': 6, 'label': 'udef_q', 'anchors': [{'from': 15, 'to': 23}]},
  {'id': 7,
   'label': 'card',
   'properties': ['carg'],
   'values': ['61'],
   'anchors': [{'from': 15, 'to': 17}]},
  {'id': 8, 'label': '_year_n_1', 'anchors': [{'from': 18,

In [502]:
{node.get('id') for node in mrp_json.get('nodes')} - parser_node_id_set

set()

In [690]:
child_id2edge_id_set

defaultdict(set,
            {22: {0, 9, 16},
             8: {1, 5, 11},
             18: {2, 10},
             4: {3, 6, 7, 15},
             16: {4, 13, 21},
             10: {8, 12},
             12: {14, 20},
             3: {17, 19},
             9: {18},
             2: set(),
             1: set(),
             0: set(),
             7: set(),
             6: set(),
             5: set(),
             11: set(),
             13: set(),
             14: set(),
             15: set(),
             21: set(),
             17: set(),
             19: set(),
             20: set()})