In [1]:
try:
    __IPYTHON__
    USING_IPYTHON = True
    %load_ext autoreload
    %autoreload 2
except NameError:
    USING_IPYTHON = False

#### Argparse

In [159]:
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('project_root', help='')
ap.add_argument('--mrp-data-dir', default='data', help='')
ap.add_argument('--graphviz-sub-dir', default='visualization/graphviz', help='')
ap.add_argument('--train-sub-dir', default='training', help='')
ap.add_argument('--companion-sub-dir', default='companion')
ap.add_argument('--jamr-alignment-file', default='jamr.mrp')


ap.add_argument('--mrp-file-extension', default='.mrp')
ap.add_argument('--companion-file-extension', default='.conllu')
ap.add_argument('--graphviz-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png')
arg_string = """
    /data/proj29_ds1/home/slai/mrp2019
"""
arguments = [arg for arg_line in arg_string.split(r'\\n') for arg in arg_line.split()]

In [160]:
if USING_IPYTHON:
    args = ap.parse_args(arguments)
else:
    args = ap.parse_args()

In [161]:
args

Namespace(companion_file_extension='.conllu', companion_sub_dir='companion', graphviz_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png', graphviz_sub_dir='visualization/graphviz', jamr_alignment_file='jamr.mrp', mrp_data_dir='data', mrp_file_extension='.mrp', project_root='/data/proj29_ds1/home/slai/mrp2019', train_sub_dir='training')

#### Library imports

In [163]:
import json
import logging
import os
import pprint
import string
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plot_util
from preprocessing import CompanionParseDataset, MrpDataset, JamrAlignmentDataset
from tqdm import tqdm

#### ipython notebook specific imports

In [6]:
if USING_IPYTHON:
    # matplotlib config
    %matplotlib inline

In [7]:
sh = logging.StreamHandler()
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
sh.setFormatter(formatter)
logging.basicConfig(level=logging.INFO, handlers=[sh])
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.setLevel(logging.INFO)

### Constants

In [75]:
UNKWOWN = 'UNKWOWN'

### Load data

In [8]:
train_dir = os.path.join(args.project_root, args.mrp_data_dir, args.train_sub_dir)

In [9]:
mrp_dataset = MrpDataset()

In [10]:
frameworks, framework2dataset2mrp_jsons = mrp_dataset.load_mrp_json_dir(
    train_dir, args.mrp_file_extension)

frameworks:   0%|          | 0/5 [00:00<?, ?it/s]
dataset_name:   0%|          | 0/2 [00:00<?, ?it/s][A
dataset_name:  50%|█████     | 1/2 [00:00<00:00,  2.74it/s][A
frameworks:  20%|██        | 1/5 [00:00<00:02,  1.40it/s]s][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  40%|████      | 2/5 [00:04<00:04,  1.64s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  60%|██████    | 3/5 [00:09<00:05,  2.50s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  80%|████████  | 4/5 [00:14<00:03,  3.37s/it]t][A
dataset_name:   0%|          | 0/14 [00:00<?, ?it/s][A
dataset_name:  36%|███▌      | 5/14 [00:00<00:00, 45.96it/s][A
dataset_name:  50%|█████     | 7/14 [00:00<00:00, 18.30it/s][A
dataset_name:  64%|██████▍   | 9/14 [00:00<00:00, 15.51it/s][A
dataset_name:  79%|███████▊  | 11/14 [00:01<00:00,  5.22it/s][A
frameworks: 100%|██████████| 5/5 [00:16<00:00,  2.83s/it]t/s][A


### Data Preprocessing companion

In [12]:
companion_dir = os.path.join(args.project_root, args.mrp_data_dir, args.companion_sub_dir)

In [13]:
cparse_dataset = CompanionParseDataset()

In [14]:
dataset2cid2parse = cparse_dataset.load_companion_parse_dir(companion_dir, args.companion_file_extension)

preprocessing - INFO - framework amr found
dataset: 100%|██████████| 13/13 [00:04<00:00,  3.09it/s]
preprocessing - INFO - framework dm found
dataset: 100%|██████████| 5/5 [00:01<00:00,  4.41it/s]
preprocessing - INFO - framework ucca found
dataset: 100%|██████████| 6/6 [00:00<00:00, 33.56it/s]


In [15]:
dataset2cid2parse.keys()

dict_keys(['amr-guidelines', 'bolt', 'cctv', 'dfa', 'dfb', 'fables', 'lorelei', 'mt09sdl', 'proxy', 'rte', 'wb', 'wiki', 'xinhua', 'wsj', 'ewt'])

In [16]:
# Some data is missing
'20003001' in dataset2cid2parse['wsj']

False

### Load JAMR alignment data

In [164]:
jalignment_dataset = JamrAlignmentDataset()

In [166]:
cid2alignment = jalignment_dataset.load_jamr_alignment_file(os.path.join(
    args.project_root,
    args.mrp_data_dir,
    args.companion_sub_dir,
    args.jamr_alignment_file
))

### Define the state at each step

In [18]:
for framework in framework2dataset2mrp_jsons:
    logger.info(framework)
    logger.info(list(framework2dataset2mrp_jsons[framework].keys()))

__main__ - INFO - ucca
__main__ - INFO - ['wiki', 'ewt']
__main__ - INFO - psd
__main__ - INFO - ['wsj']
__main__ - INFO - eds
__main__ - INFO - ['wsj']
__main__ - INFO - dm
__main__ - INFO - ['wsj']
__main__ - INFO - amr
__main__ - INFO - ['xinhua', 'wsj', 'wiki', 'wb', 'rte', 'proxy', 'mt09sdl', 'lorelei', 'fables', 'dfb', 'dfa', 'cctv', 'bolt', 'amr-guidelines']


In [255]:
framework = 'dm'
# framework = 'psd'
framework = 'eds'
# framework = 'ucca'
# framework = 'amr'

dataset = 'wsj'
# dataset = 'wsj'
# dataset = 'cctv'
mrp_jsons = framework2dataset2mrp_jsons[framework][dataset]

In [256]:
mrp_json = mrp_jsons[0]

In [257]:
mrp_json['input']

'Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov. 29.'

In [258]:
logger.info(args.graphviz_file_template.format(
    framework, dataset, mrp_json.get('id')))

__main__ - INFO - http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/eds/wsj.mrp/20001001.png


In [259]:
nodes = mrp_json['nodes']
edges = mrp_json['edges']

By observation, 

edges of **PSD, UCCA and AMR** are from parent to leaf

edges of **EDS and DM** are from leaf to parent

In [260]:
node_id2indegree = Counter()
node_id2neig_id_set = defaultdict(set)
node_id2edge_id_set = defaultdict(set)

parent_to_leaf_framework_set = {'psd', 'ucca', 'amr'}
leaf_to_parent_framework_set = {'eds', 'dm'}

for edge_id, edge in enumerate(edges):
    if framework in parent_to_leaf_framework_set:
        source = edge.get('source')
        target = edge.get('target')
    else:
        source = edge.get('target')
        target = edge.get('source')
    node_id2neig_id_set[source].add(target)
    node_id2edge_id_set[source].add(edge_id)
    node_id2edge_id_set[target].add(edge_id)
    node_id2indegree[target] += 1

We only add the edge when both nodes are seen

In [261]:
seen_node_id_set = set()
seen_edge_id_set = set()
parser_states = []

In [262]:
for node in nodes:
    edge_state = []
    node_id = node.get('id', -1)
    seen_node_id_set.add(node_id)
    for edge_id in node_id2edge_id_set[node_id]:
        edge = edges[edge_id]
        edge_not_seen = edge_id not in seen_edge_id_set
        edge_nodes_seen = all([
            edge.get('source', -1) in seen_node_id_set,
            edge.get('target', -1) in seen_node_id_set,
        ])
        
        # add edge if edge not seen and both ends seen
        if edge_not_seen and edge_nodes_seen:
            edge_state.append(edge_id)
            seen_edge_id_set.add(edge_id)
    parser_states.append((node_id, edge_state))

In [264]:
nodes

[{'id': 0, 'label': 'proper_q', 'anchors': [{'from': 0, 'to': 28}]},
 {'id': 1, 'label': 'compound', 'anchors': [{'from': 0, 'to': 14}]},
 {'id': 2, 'label': 'proper_q', 'anchors': [{'from': 0, 'to': 6}]},
 {'id': 3,
  'label': 'named',
  'properties': ['carg'],
  'values': ['Pierre'],
  'anchors': [{'from': 0, 'to': 6}]},
 {'id': 4,
  'label': 'named',
  'properties': ['carg'],
  'values': ['Vinken'],
  'anchors': [{'from': 7, 'to': 14}]},
 {'id': 5, 'label': 'measure', 'anchors': [{'from': 15, 'to': 23}]},
 {'id': 6, 'label': 'udef_q', 'anchors': [{'from': 15, 'to': 23}]},
 {'id': 7,
  'label': 'card',
  'properties': ['carg'],
  'values': ['61'],
  'anchors': [{'from': 15, 'to': 17}]},
 {'id': 8, 'label': '_year_n_1', 'anchors': [{'from': 18, 'to': 23}]},
 {'id': 9, 'label': '_old_a_1', 'anchors': [{'from': 24, 'to': 28}]},
 {'id': 10, 'label': '_join_v_1', 'anchors': [{'from': 34, 'to': 38}]},
 {'id': 11, 'label': '_the_q', 'anchors': [{'from': 39, 'to': 42}]},
 {'id': 12, 'label':

In [263]:
for node_id, edge_state in parser_states:
    print(node_id, edge_state)

0 []
1 []
2 []
3 [17, 19]
4 [15, 6]
5 []
6 []
7 []
8 [1, 11, 5]
9 [18, 7]
10 [3]
11 []
12 [20, 14]
13 [12]
14 []
15 []
16 [21, 4, 13]
17 [8]
18 []
19 []
20 [2]
21 [10]
22 [0, 9, 16]
