In [1]:
try:
    __IPYTHON__
    USING_IPYTHON = True
    %load_ext autoreload
    %autoreload 2
except NameError:
    USING_IPYTHON = False

#### Argparse

In [2]:
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('project_root', help='')
ap.add_argument('--mrp-data-dir', default='data', help='')
ap.add_argument('--graphviz-sub-dir', default='visualization/graphviz', help='')
ap.add_argument('--train-sub-dir', default='training', help='')
ap.add_argument('--companion-sub-dir', default='companion')
ap.add_argument('--jamr-alignment-file', default='jamr.mrp')


ap.add_argument('--mrp-file-extension', default='.mrp')
ap.add_argument('--companion-file-extension', default='.conllu')
ap.add_argument('--graphviz-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png')
arg_string = """
    /data/proj29_ds1/home/slai/mrp2019
"""
arguments = [arg for arg_line in arg_string.split(r'\\n') for arg in arg_line.split()]

In [3]:
if USING_IPYTHON:
    args = ap.parse_args(arguments)
else:
    args = ap.parse_args()

In [4]:
args

Namespace(companion_file_extension='.conllu', companion_sub_dir='companion', graphviz_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png', graphviz_sub_dir='visualization/graphviz', jamr_alignment_file='jamr.mrp', mrp_data_dir='data', mrp_file_extension='.mrp', project_root='/data/proj29_ds1/home/slai/mrp2019', train_sub_dir='training')

#### Library imports

In [533]:
import json
import logging
import os
import pprint
import string
from collections import Counter, defaultdict, deque

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plot_util
from preprocessing import (CompanionParseDataset, MrpDataset, JamrAlignmentDataset, 
                           mrp_json2parser_states)
from tqdm import tqdm

#### ipython notebook specific imports

In [6]:
if USING_IPYTHON:
    # matplotlib config
    %matplotlib inline

In [7]:
sh = logging.StreamHandler()
formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
sh.setFormatter(formatter)
logging.basicConfig(level=logging.INFO, handlers=[sh])
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.setLevel(logging.INFO)

### Constants

In [8]:
UNKWOWN = 'UNKWOWN'

### Load data

In [9]:
train_dir = os.path.join(args.project_root, args.mrp_data_dir, args.train_sub_dir)

In [10]:
mrp_dataset = MrpDataset()

In [11]:
frameworks, framework2dataset2mrp_jsons = mrp_dataset.load_mrp_json_dir(
    train_dir, args.mrp_file_extension)

frameworks:   0%|          | 0/5 [00:00<?, ?it/s]
dataset_name:   0%|          | 0/2 [00:00<?, ?it/s][A
dataset_name:  50%|█████     | 1/2 [00:00<00:00,  2.81it/s][A
frameworks:  20%|██        | 1/5 [00:00<00:02,  1.45it/s]s][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  40%|████      | 2/5 [00:04<00:04,  1.57s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  60%|██████    | 3/5 [00:08<00:04,  2.48s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  80%|████████  | 4/5 [00:13<00:03,  3.25s/it]t][A
dataset_name:   0%|          | 0/14 [00:00<?, ?it/s][A
dataset_name:  36%|███▌      | 5/14 [00:00<00:00, 43.86it/s][A
dataset_name:  50%|█████     | 7/14 [00:00<00:00, 17.95it/s][A
dataset_name:  64%|██████▍   | 9/14 [00:00<00:00, 15.34it/s][A
dataset_name:  79%|███████▊  | 11/14 [00:01<00:00,  5.03it/s][A
frameworks: 100%|██████████| 5/5 [00:15<00:00,  2.77s/it]t/s][A


### Data Preprocessing companion

In [12]:
companion_dir = os.path.join(args.project_root, args.mrp_data_dir, args.companion_sub_dir)

In [13]:
cparse_dataset = CompanionParseDataset()

In [14]:
dataset2cid2parse = cparse_dataset.load_companion_parse_dir(companion_dir, args.companion_file_extension)

preprocessing - INFO - framework amr found
dataset: 100%|██████████| 13/13 [00:03<00:00,  3.38it/s]
preprocessing - INFO - framework dm found
dataset: 100%|██████████| 5/5 [00:01<00:00,  4.33it/s]
preprocessing - INFO - framework ucca found
dataset: 100%|██████████| 6/6 [00:00<00:00, 25.47it/s]


In [15]:
dataset2cid2parse.keys()

dict_keys(['amr-guidelines', 'bolt', 'cctv', 'dfa', 'dfb', 'fables', 'lorelei', 'mt09sdl', 'proxy', 'rte', 'wb', 'wiki', 'xinhua', 'wsj', 'ewt'])

In [16]:
# Some data is missing
'20003001' in dataset2cid2parse['wsj']

False

### Load JAMR alignment data

In [17]:
jalignment_dataset = JamrAlignmentDataset()source_id

In [18]:
cid2alignment = jalignment_dataset.load_jamr_alignment_file(os.path.join(
    args.project_root,
    args.mrp_data_dir,
    args.companion_sub_dir,
    args.jamr_alignment_file
))

### Define the state at each step

In [19]:
for framework in framework2dataset2mrp_jsons:
    logger.info(framework)
    logger.info(list(framework2dataset2mrp_jsons[framework].keys()))

__main__ - INFO - ucca
__main__ - INFO - ['wiki', 'ewt']
__main__ - INFO - psd
__main__ - INFO - ['wsj']
__main__ - INFO - eds
__main__ - INFO - ['wsj']
__main__ - INFO - dm
__main__ - INFO - ['wsj']
__main__ - INFO - amr
__main__ - INFO - ['xinhua', 'wsj', 'wiki', 'wb', 'rte', 'proxy', 'mt09sdl', 'lorelei', 'fables', 'dfb', 'dfa', 'cctv', 'bolt', 'amr-guidelines']


In [500]:
framework, dataset = [
    ('dm', 'wsj'),
    ('psd', 'wsj'),
    ('eds', 'wsj'),
    ('ucca', 'wiki'),
    ('amr', 'wsj'),
    ('amr', 'wiki'),
][4]

mrp_jsons = framework2dataset2mrp_jsons[framework][dataset]
framework, dataset

('amr', 'wsj')

In [501]:
mrp_json = mrp_jsons[1]

In [502]:
mrp_json['input']

'Mr. Vinken is chairman of Elsevier N.V., the Dutch publishing group.'

In [503]:
logger.info(args.graphviz_file_template.format(
    framework, dataset, mrp_json.get('id')))

__main__ - INFO - http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/amr/wsj.mrp/20001002.png


In [504]:
nodes = mrp_json['nodes']
edges = mrp_json['edges']

node_id2node = {node.get('id', -1):node for node in nodes}
edge_id2edge = {edge_id: edge for edge_id, edge in enumerate(edges)}

By observation, 

edges of **PSD, UCCA and AMR** are from parent to leaf

edges of **EDS and DM** are from leaf to parent

In [505]:
node_id2indegree = Counter()
node_id2neig_id_set = defaultdict(set)
node_id2edge_id_set = defaultdict(set)

parent_to_leaf_framework_set = {'psd', 'ucca', 'amr'}
leaf_to_parent_framework_set = {'eds', 'dm'}

for edge_id, edge in enumerate(edges):
#     if framework in parent_to_leaf_framework_set:
#         source = edge.get('source')
#         target = edge.get('target')
#     else:
#         source = edge.get('target')
#         target = edge.get('source')
    source_id = edge.get('source')
    target_id = edge.get('target')
    node_id2neig_id_set[source_id].add(target_id)
    node_id2edge_id_set[source_id].add(edge_id)
    node_id2edge_id_set[target_id].add(edge_id)
    node_id2indegree[target_id] += 1

In [512]:
# Handle diminished node

# For **eds**, nodes connected only with one 'BV' edge 
# is considered diminished

if framework == 'eds':
    # Split nodes to diminished and anchored set
    diminished_node_id_set = set()
    anchored_nodes = []
    for node in nodes:
        node_id = node.get('id', -1)
        indegree = node_id2indegree[node_id]
        if not indegree and all(
            edges[edge_id].get('label') == 'BV' 
            for edge_id in node_id2edge_id_set[node_id]
        ):
            diminished_node_id_set.add(node_id)
        else:
            anchored_nodes.append(node)
            
    assert all(len(node.get('anchors', [])) == 1 for node in anchored_nodes)
    assert all('from' in node.get('anchors')[0] for node in anchored_nodes)
    assert all('to' in node.get('anchors')[0] for node in anchored_nodes)
    
    # Sort nodes according to anchor range
    anchored_nodes.sort(key=lambda node: (
        node['anchors'][0]['to'], 
        -node['anchors'][0]['from']
    ))
            
# For **ucca**, nodes without anchors 
# is considered diminished
elif framework == 'ucca':
    diminished_node_id_set = set()
    anchored_nodes = []
    
    for node in nodes:
        if 'anchors' in node:
            anchored_nodes.append(node)
        else:
            diminished_node_id_set.add(node.get('id', -1))
    
    assert all(len(node.get('anchors', [])) == 1 for node in anchored_nodes)
    assert all('from' in node.get('anchors')[0] for node in anchored_nodes)
    assert all('to' in node.get('anchors')[0] for node in anchored_nodes)
    
    # Sort nodes according to anchor range
    anchored_nodes.sort(key=lambda node: (
        node['anchors'][0]['to'], 
        -node['anchors'][0]['from']
    ))
    
# For **amr**, nodes without token anchor in jamr alignment
# is considered diminished

# TODO(Sunny): the jamr alignment sometime miss tokens
elif framework == 'amr':
    diminished_node_id_set = set()
    anchored_nodes = []
    
    cid = mrp_json.get('id', '')
    alignment = cid2alignment[cid]
    node_id2token_poss = {}
    for node in alignment.get('nodes', []):
        node_id = node.get('id', -1)
        token_poss = node.get('label', [])
        node_id2token_poss[node_id] = token_poss
    
    for node in nodes:
        node_id = node.get('id', -1)
        if node_id in node_id2token_poss:
            anchored_nodes.append(node)
        else:
            diminished_node_id_set.add(node_id)
    
    assert all(
        len(node_id2token_poss.get(node.get('id', -1), [])) >= 1 
        for node in anchored_nodes
    )
    
    anchored_nodes.sort(key=lambda node: (
        max(node_id2token_poss[node.get('id', -1)]),
        -min(node_id2token_poss[node.get('id', -1)]),
    ))

In [516]:
alignment

{'id': '20001002',
 'flavor': 2,
 'framework': 'alignment',
 'version': 1.0,
 'time': '2019-06-24',
 'nodes': [{'id': 0, 'label': [3]},
  {'id': 1, 'label': [0, 1]},
  {'id': 2,
   'label': [0, 1],
   'properties': ['op1', 'op2'],
   'values': [[0, 1], [0, 1]]},
  {'id': 3, 'label': [11]},
  {'id': 4,
   'label': [5, 6],
   'properties': ['op1', 'op2'],
   'values': [[5, 6], [5, 6]]},
  {'id': 7, 'label': [10]},
  {'id': 8, 'label': [3]}]}

In [517]:
node_id2token_poss

{0: [3], 1: [0, 1], 2: [0, 1], 3: [11], 4: [5, 6], 7: [10], 8: [3]}

We only add the edge when both nodes are seen

In [518]:
mrp_json['input'].split()

['Mr.',
 'Vinken',
 'is',
 'chairman',
 'of',
 'Elsevier',
 'N.V.,',
 'the',
 'Dutch',
 'publishing',
 'group.']

In [519]:
edges

[{'source': 5, 'target': 6, 'label': 'name'},
 {'source': 3, 'target': 7, 'label': 'ARG0-of', 'normal': 'ARG0'},
 {'source': 0, 'target': 1, 'label': 'ARG0'},
 {'source': 3, 'target': 4, 'label': 'name'},
 {'source': 0, 'target': 3, 'label': 'ARG1'},
 {'source': 3, 'target': 5, 'label': 'mod', 'normal': 'domain'},
 {'source': 0, 'target': 8, 'label': 'ARG2'},
 {'source': 1, 'target': 2, 'label': 'name'}]

In [520]:
alignment

{'id': '20001002',
 'flavor': 2,
 'framework': 'alignment',
 'version': 1.0,
 'time': '2019-06-24',
 'nodes': [{'id': 0, 'label': [3]},
  {'id': 1, 'label': [0, 1]},
  {'id': 2,
   'label': [0, 1],
   'properties': ['op1', 'op2'],
   'values': [[0, 1], [0, 1]]},
  {'id': 3, 'label': [11]},
  {'id': 4,
   'label': [5, 6],
   'properties': ['op1', 'op2'],
   'values': [[5, 6], [5, 6]]},
  {'id': 7, 'label': [10]},
  {'id': 8, 'label': [3]}]}

In [521]:
diminished_node_id_set

{5, 6}

In [522]:
seen_node_id_set = set()
seen_edge_id_set = set()
parser_states = []

In [523]:
node_queue = deque(anchored_nodes)

In [524]:
while node_queue:
    node = node_queue.popleft()
    edge_state = []
    diminished_node_state = []
    node_id = node.get('id', -1)
    seen_node_id_set.add(node_id)
    
    for edge_id in node_id2edge_id_set[node_id]:
        edge = edges[edge_id]
        
        source_id = edge.get('source', -1)
        target_id = edge.get('target', -1)
        
        is_remote_edge = 'properties' in edge and 'remote' in edge['properties']
        
        # Handle diminished nodes if not remote edge
        # TODO(Sunny): This cannot handle the case of a node having
        #                All childs as diminished_node
        if not is_remote_edge and node_id not in diminished_node_id_set:
            if source_id in diminished_node_id_set and source_id not in seen_node_id_set:
                diminished_node_state.append(source_id)
                seen_node_id_set.add(source_id)
                node_queue.appendleft(node_id2node[source_id])
            if target_id in diminished_node_id_set and target_id not in seen_node_id_set:
                diminished_node_state.append(target_id)
                seen_node_id_set.add(target_id)
                node_queue.appendleft(node_id2node[target_id])

        edge_not_seen = edge_id not in seen_edge_id_set
        edge_nodes_seen = all([
            source_id in seen_node_id_set,
            target_id in seen_node_id_set,
        ])
        
        # add edge if edge not seen and both ends seen
        if edge_not_seen and edge_nodes_seen:
            edge_state.append(edge_id)
            seen_edge_id_set.add(edge_id)
    parser_states.append((node_id, edge_state, diminished_node_state))

In [525]:
mrp_json['input']

'Mr. Vinken is chairman of Elsevier N.V., the Dutch publishing group.'

In [530]:
alignment

{'id': '20001002',
 'flavor': 2,
 'framework': 'alignment',
 'version': 1.0,
 'time': '2019-06-24',
 'nodes': [{'id': 0, 'label': [3]},
  {'id': 1, 'label': [0, 1]},
  {'id': 2,
   'label': [0, 1],
   'properties': ['op1', 'op2'],
   'values': [[0, 1], [0, 1]]},
  {'id': 3, 'label': [11]},
  {'id': 4,
   'label': [5, 6],
   'properties': ['op1', 'op2'],
   'values': [[5, 6], [5, 6]]},
  {'id': 7, 'label': [10]},
  {'id': 8, 'label': [3]}]}

In [529]:
nodes

[{'id': 0, 'label': 'have-org-role-91'},
 {'id': 1, 'label': 'person'},
 {'id': 2,
  'label': 'name',
  'properties': ['op1', 'op2'],
  'values': ['Mr.', 'Vinken']},
 {'id': 3, 'label': 'group'},
 {'id': 4,
  'label': 'name',
  'properties': ['op1', 'op2'],
  'values': ['Elsevier', 'N.V.']},
 {'id': 5, 'label': 'country'},
 {'id': 6, 'label': 'name', 'properties': ['op1'], 'values': ['Netherlands']},
 {'id': 7, 'label': 'publish-01'},
 {'id': 8, 'label': 'chairman'}]

In [526]:
parser_node_id_set = set()
parser_edge_id_set = set()
for node_id, edge_state, diminished_node_state in parser_states:
    parser_node_id_set.add(node_id)
    for edge_id in edge_state:
        parser_edge_id_set.add(edge_id)
    
    node = node_id2node[node_id]
    node_edges = [edge_id2edge[edge_id] for edge_id in edge_state]
    print(
        node.get('id'),
        node.get('label'), 
        [edge.get('label') for edge in node_edges], 
        diminished_node_state,
    )
    
assert len(parser_node_id_set) == len(nodes)
assert len(parser_edge_id_set) == len(edges)

1 person [] []
2 name ['name'] []
0 have-org-role-91 ['ARG0'] []
8 chairman ['ARG2'] []
4 name [] []
7 publish-01 [] []
3 group ['ARG0-of', 'name', 'ARG1', 'mod'] [5]
5 country [] []


AssertionError: 

### Test module

In [591]:
from preprocessing import mrp_json2parser_states

In [592]:
framework, dataset = [
    ('dm', 'wsj'),
    ('psd', 'wsj'),
    ('eds', 'wsj'),
    ('ucca', 'wiki'),
    ('amr', 'wsj'),
    ('amr', 'wiki'),
][3]

mrp_jsons = framework2dataset2mrp_jsons[framework][dataset]
framework, dataset

('ucca', 'wiki')

In [593]:
mrp_json = mrp_jsons[-1]

In [594]:
alignment = {}
if framework == 'amr':
    cid = mrp_json.get('id', '')
    alignment = cid2alignment[cid]

In [595]:
mrp_json

{'id': '496020',
 'flavor': 1,
 'framework': 'ucca',
 'version': 0.9,
 'time': '2019-04-11 (22:05)',
 'input': 'Not long after, in a move that resulted in years of litigation, at the conclusion of which Bowie was forced to pay Pitt compensation, the singer fired his manager, replacing him with Tony Defries.',
 'tops': [36],
 'nodes': [{'id': 0,
   'anchors': [{'from': 0, 'to': 3},
    {'from': 4, 'to': 8},
    {'from': 9, 'to': 14}],
   'label': 'Notlongafter'},
  {'id': 1, 'anchors': [{'from': 14, 'to': 15}], 'label': ','},
  {'id': 2, 'anchors': [{'from': 16, 'to': 18}], 'label': 'in'},
  {'id': 3, 'anchors': [{'from': 19, 'to': 20}], 'label': 'a'},
  {'id': 4, 'anchors': [{'from': 21, 'to': 25}], 'label': 'move'},
  {'id': 5, 'anchors': [{'from': 26, 'to': 30}], 'label': 'that'},
  {'id': 6, 'anchors': [{'from': 31, 'to': 39}], 'label': 'resulted'},
  {'id': 7, 'anchors': [{'from': 40, 'to': 42}], 'label': 'in'},
  {'id': 8, 'anchors': [{'from': 43, 'to': 48}], 'label': 'years'},
  

In [596]:
logger.info(args.graphviz_file_template.format(
    framework, dataset, mrp_json.get('id')))

__main__ - INFO - http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/ucca/wiki.mrp/496020.png


In [597]:
parser_states, meta_data = mrp_json2parser_states(mrp_json, framework, alignment)

In [598]:
(node_id2node, edge_id2edge) = meta_data

In [610]:
set([node.get('id') for node in mrp_json.get('nodes')]) - parser_node_id_set

{44}

In [606]:
parser_node_id_set

49

In [603]:
parser_node_id_set = set()
parser_edge_id_set = set()
for node_id, edge_state, diminished_node_state in parser_states:
    parser_node_id_set.add(node_id)
    for edge_id in edge_state:
        parser_edge_id_set.add(edge_id)
    
    node = node_id2node[node_id]
    node_edges = [edge_id2edge[edge_id] for edge_id in edge_state]
    print(
        node.get('id'),
        node.get('label'), 
        [edge.get('label') for edge in node_edges], 
        diminished_node_state,
    )
    
assert len(parser_node_id_set) == len(mrp_json.get('nodes'))
assert len(parser_edge_id_set) == len(mrp_json.get('edges'))

0 Notlongafter ['L'] [36]
36 None [] []
1 , ['U'] []
2 in ['L'] []
3 a ['E'] [37]
37 None [] []
4 move ['C'] []
5 that ['F'] [38]
38 None ['A', 'H'] []
6 resulted ['P'] []
7 in ['R'] [39]
39 None ['A'] []
8 years ['C'] [40]
40 None [] []
9 of ['R'] []
10 litigation ['P'] [41]
41 None ['H', 'T'] []
11 , ['U'] []
12 at ['R'] [42]
42 None ['L'] []
13 the ['E'] []
14 conclusion ['C'] []
15 of ['F'] []
16 which ['F'] []
17 Bowie ['A'] [43]
43 None ['H'] []
18 was ['F'] []
19 forced ['D'] []
20 to ['F'] []
21 pay ['P'] []
22 Pitt ['A'] []
23 compensation ['A'] []
24 , ['U'] []
25 the ['E'] [45]
45 None [] []
26 singer ['C'] []
27 fired ['P'] [46]
46 None ['H'] []
28 his ['A'] [47]
47 None ['A'] []
29 manager ['S', 'A'] []
30 , ['U'] []
31 replacing ['P'] [48]
48 None ['H'] []
32 him ['A'] []
33 with ['R'] [49]
49 None ['A'] []
34 TonyDefries ['C'] []
35 . ['U'] []


AssertionError: 

In [552]:
framework, dataset = [
    ('dm', 'wsj'),
    ('psd', 'wsj'),
    ('eds', 'wsj'),
    ('ucca', 'wiki'),
    ('amr', 'wsj'),
    ('amr', 'wiki'),
][4]

mrp_jsons = framework2dataset2mrp_jsons[framework][dataset]
framework, dataset

('amr', 'wsj')

In [553]:
mrp_json = mrp_jsons[1]

In [554]:
alignment = {}
if framework == 'amr':
    cid = mrp_json.get('id', '')
    alignment = cid2alignment[cid]

In [555]:
parser_states = mrp_json2parser_states(mrp_json, framework, alignment)

In [556]:
parser_states

[(1, [], []),
 (2, [7], []),
 (0, [2], []),
 (8, [6], []),
 (4, [], []),
 (7, [], []),
 (3, [1, 3, 4, 5], [5]),
 (5, [], [])]

In [557]:
parser_node_id_set = set()
parser_edge_id_set = set()
for node_id, edge_state, diminished_node_state in parser_states:
    parser_node_id_set.add(node_id)
    for edge_id in edge_state:
        parser_edge_id_set.add(edge_id)
    
    node = node_id2node[node_id]
    node_edges = [edge_id2edge[edge_id] for edge_id in edge_state]
    print(
        node.get('id'),
        node.get('label'), 
        [edge.get('label') for edge in node_edges], 
        diminished_node_state,
    )
    
assert len(parser_node_id_set) == len(nodes)
assert len(parser_edge_id_set) == len(edges)

1 person [] []
2 name ['name'] []
0 have-org-role-91 ['ARG0'] []
8 chairman ['ARG2'] []
4 name [] []
7 publish-01 [] []
3 group ['ARG0-of', 'name', 'ARG1', 'mod'] [5]
5 country [] []


AssertionError: 

### Problems to be solved:

1. missing alignment id in jamr alignment tools
  - e.g. some node values have no anchor
  
  
2. Order of producing high level node
  - i.e. when should we produce high level node 
    - when **one** of its children is reached or 
    - when **all** of its children are reached (For this case, how we train the parser to determine when to finish)