In [1]:
try:
    __IPYTHON__
    USING_IPYTHON = True
    %load_ext autoreload
    %autoreload 2
except NameError:
    USING_IPYTHON = False

#### Argparse

In [61]:
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('project_root', help='')
ap.add_argument('--mrp-data-dir', default='data', help='')
ap.add_argument('--mrp-test-dir', default='src/tests', help='')
ap.add_argument('--tests-fixtures-file', default='fixtures/test.jsonl', help='')

ap.add_argument('--graphviz-sub-dir', default='visualization/graphviz', help='')
ap.add_argument('--train-sub-dir', default='training', help='')
ap.add_argument('--companion-sub-dir', default='companion')
ap.add_argument('--jamr-alignment-file', default='jamr.mrp')

ap.add_argument('--test-input-file', default='evaluation/input.mrp', help='')
ap.add_argument('--test-companion-file', default='evaluation/udpipe.mrp', help='')
ap.add_argument('--allennlp-mrp-json-file-template', default='allennlp-mrp-json-small-{}.jsonl', help='')
ap.add_argument('--data-size-limit', type=int, default=100, help='')

ap.add_argument('--mrp-file-extension', default='.mrp')
ap.add_argument('--companion-file-extension', default='.conllu')
ap.add_argument('--graphviz-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png')
ap.add_argument('--parse-plot-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.png')

arg_string = """
    /data/proj29_ds1/home/slai/mrp2019
"""
arguments = [arg for arg_line in arg_string.split(r'\\n') for arg in arg_line.split()]

In [62]:
if USING_IPYTHON:
    args = ap.parse_args(arguments)
else:
    args = ap.parse_args()

In [63]:
args

Namespace(allennlp_mrp_json_file_template='allennlp-mrp-json-small-{}.jsonl', companion_file_extension='.conllu', companion_sub_dir='companion', data_size_limit=100, graphviz_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png', graphviz_sub_dir='visualization/graphviz', jamr_alignment_file='jamr.mrp', mrp_data_dir='data', mrp_file_extension='.mrp', mrp_test_dir='src/tests', parse_plot_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.png', project_root='/data/proj29_ds1/home/slai/mrp2019', test_companion_file='evaluation/udpipe.mrp', test_input_file='evaluation/input.mrp', tests_fixtures_file='fixtures/test.jsonl', train_sub_dir='training')

#### Library imports

In [64]:
import json
import logging
import os
import pprint
import re
import string
from collections import Counter, defaultdict, deque

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plot_util
import torch
from action_state import mrp_json2parser_states, _generate_parser_action_states
from action_state import ERROR, APPEND, RESOLVE, IGNORE
from preprocessing import (CompanionParseDataset, MrpDataset, JamrAlignmentDataset,
                           read_companion_parse_json_file, read_mrp_json_file, parse2parse_json)            
from torch import nn
from tqdm import tqdm

#### ipython notebook specific imports

In [65]:
if USING_IPYTHON:
    # matplotlib config
    %matplotlib inline

DEBUG    [matplotlib.pyplot:219] Loaded backend module://ipykernel.pylab.backend_inline version unknown.


In [66]:
sh = logging.StreamHandler()
formatter = logging.Formatter('%(levelname)-8s [%(name)s:%(lineno)d] %(message)s')
sh.setFormatter(formatter)
logging.basicConfig(level=logging.DEBUG, handlers=[sh])
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.setLevel(logging.INFO)

### Constants

In [67]:
UNKWOWN = 'UNKWOWN'

### Load data

In [9]:
train_dir = os.path.join(args.project_root, args.mrp_data_dir, args.train_sub_dir)

In [10]:
mrp_dataset = MrpDataset()

In [11]:
frameworks, framework2dataset2mrp_jsons = mrp_dataset.load_mrp_json_dir(
    train_dir, args.mrp_file_extension)

frameworks:   0%|          | 0/5 [00:00<?, ?it/s]
dataset_name:   0%|          | 0/2 [00:00<?, ?it/s][A
dataset_name:  50%|█████     | 1/2 [00:00<00:00,  2.65it/s][A
frameworks:  20%|██        | 1/5 [00:00<00:02,  1.47it/s]s][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  40%|████      | 2/5 [00:04<00:04,  1.50s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  60%|██████    | 3/5 [00:08<00:04,  2.35s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  80%|████████  | 4/5 [00:13<00:03,  3.03s/it]t][A
dataset_name:   0%|          | 0/14 [00:00<?, ?it/s][A
dataset_name:  43%|████▎     | 6/14 [00:00<00:00, 22.32it/s][A
dataset_name:  57%|█████▋    | 8/14 [00:00<00:00, 17.07it/s][A
dataset_name:  71%|███████▏  | 10/14 [00:01<00:00,  6.51it/s][A
dataset_name:  79%|███████▊  | 11/14 [00:01<00:00,  5.94it/s][A
frameworks: 100%|██████████| 5/5 [00:14<00:00,  2.56s/it]t/s][A


### Data Preprocessing companion

In [12]:
companion_dir = os.path.join(args.project_root, args.mrp_data_dir, args.companion_sub_dir)

In [13]:
cparse_dataset = CompanionParseDataset()

In [14]:
dataset2cid2parse = cparse_dataset.load_companion_parse_dir(companion_dir, args.companion_file_extension)

INFO     [preprocessing:172] framework amr found
dataset: 100%|██████████| 13/13 [00:01<00:00, 10.53it/s]
INFO     [preprocessing:172] framework dm found
dataset: 100%|██████████| 5/5 [00:03<00:00,  1.15it/s]
INFO     [preprocessing:172] framework ucca found
dataset: 100%|██████████| 6/6 [00:00<00:00, 31.03it/s]


In [15]:
dataset2cid2parse_json = cparse_dataset.convert_parse2parse_json()

In [16]:
dataset2cid2parse.keys()

dict_keys(['amr-guidelines', 'bolt', 'cctv', 'dfa', 'dfb', 'fables', 'lorelei', 'mt09sdl', 'proxy', 'rte', 'wb', 'wiki', 'xinhua', 'wsj', 'ewt'])

In [17]:
# Some data is missing
'20003001' in dataset2cid2parse['wsj']

False

### Load JAMR alignment data

In [18]:
jalignment_dataset = JamrAlignmentDataset()

In [19]:
cid2alignment = jalignment_dataset.load_jamr_alignment_file(os.path.join(
    args.project_root,
    args.mrp_data_dir,
    args.companion_sub_dir,
    args.jamr_alignment_file
))

### Load testing data

In [20]:
test_input_filename = os.path.join(args.project_root, args.mrp_data_dir, args.test_input_file)
test_companion_filename = os.path.join(args.project_root, args.mrp_data_dir, args.test_companion_file)

In [21]:
test_mrp_jsons = read_mrp_json_file(test_input_filename)
test_parse_jsons = read_companion_parse_json_file(test_companion_filename)

In [22]:
parse_json = test_parse_jsons['102990']

In [111]:
mrp_json = framework2dataset2mrp_jsons['psd']['wsj'][1]

In [114]:
framework = 'ucca'
dataset = 'wiki'

# framework = 'dm'
# dataset = 'wsj'

In [197]:
mrp_json = framework2dataset2mrp_jsons['ucca']['wiki'][6]

In [198]:
mrp_json

{'id': '495000',
 'flavor': 1,
 'framework': 'ucca',
 'version': 0.9,
 'time': '2019-04-11 (22:04)',
 'input': 'Conn quickly began to promote Bowie.',
 'tops': [7],
 'nodes': [{'id': 0, 'anchors': [{'from': 0, 'to': 4}], 'label': 'Conn'},
  {'id': 1, 'anchors': [{'from': 5, 'to': 12}], 'label': 'quickly'},
  {'id': 2, 'anchors': [{'from': 13, 'to': 18}], 'label': 'began'},
  {'id': 3, 'anchors': [{'from': 19, 'to': 21}], 'label': 'to'},
  {'id': 4, 'anchors': [{'from': 22, 'to': 29}], 'label': 'promote'},
  {'id': 5, 'anchors': [{'from': 30, 'to': 35}], 'label': 'Bowie'},
  {'id': 6, 'anchors': [{'from': 35, 'to': 36}], 'label': '.'},
  {'id': 7},
  {'id': 8}],
 'edges': [{'source': 8, 'target': 4, 'label': 'P', 'id': 0},
  {'source': 8, 'target': 1, 'label': 'D', 'id': 1},
  {'source': 8, 'target': 5, 'label': 'A', 'id': 2},
  {'source': 8, 'target': 3, 'label': 'F', 'id': 3},
  {'source': 8, 'target': 2, 'label': 'D', 'id': 4},
  {'source': 8, 'target': 0, 'label': 'A', 'id': 5},
  {

In [206]:
cid = list(dataset2cid2parse_json[dataset].keys())[5]

In [207]:
idx, mrp_json = [
    (idx, mrp_json)
    for idx, mrp_json in enumerate(framework2dataset2mrp_jsons[framework][dataset])
    if mrp_json.get('id') == cid
][0]
idx

1429

In [208]:
parse_json = dataset2cid2parse_json[dataset][cid]

In [209]:
doc = mrp_json['input']

In [210]:
doc

'When Bowie left the technical school the following year, he informed his parents of his intention to become a pop star.'

In [211]:
token_pos = 0
anchors = []
char_pos2tokenized_parse_node_id = []

for node_id, node in enumerate(parse_json.get('nodes')):
    label = node.get('label')
    label_size = len(label)
    while doc[token_pos] == ' ':
        token_pos += 1
        char_pos2tokenized_parse_node_id.append(node_id)
    anchors.append((token_pos, token_pos + label_size))
    char_pos2tokenized_parse_node_id.extend([node_id] * (label_size))
    print(node_id, doc[token_pos: token_pos + label_size], anchors[-1], len(char_pos2tokenized_parse_node_id))
    token_pos += label_size

0 When (0, 4) 4
1 Bowie (5, 10) 10
2 left (11, 15) 15
3 the (16, 19) 19
4 technical (20, 29) 29
5 school (30, 36) 36
6 the (37, 40) 40
7 following (41, 50) 50
8 year (51, 55) 55
9 , (55, 56) 56
10 he (57, 59) 59
11 informed (60, 68) 68
12 his (69, 72) 72
13 parents (73, 80) 80
14 of (81, 83) 83
15 his (84, 87) 87
16 intention (88, 97) 97
17 to (98, 100) 100
18 become (101, 107) 107
19 a (108, 109) 109
20 pop (110, 113) 113
21 star (114, 118) 118
22 . (118, 119) 119


In [212]:
doc

'When Bowie left the technical school the following year, he informed his parents of his intention to become a pop star.'

In [213]:
len(char_pos2tokenized_parse_node_id)

119

In [214]:
doc = mrp_json['input']

In [215]:
mrp_json['tops']

[23]

In [216]:
mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
    mrp_json, 
    tokenized_parse_nodes=parse_json['nodes'],
)

DEBUG    [action_state:60] ('remote 1', 17)
DEBUG    [action_state:219] {23}
DEBUG    [action_state:219] {0, 9, 27, 24}
DEBUG    [action_state:219] {1, 2, 10, 11, 25, 26, 28, 29}
DEBUG    [action_state:219] {3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 30}
DEBUG    [action_state:219] {17, 18, 31}
DEBUG    [action_state:219] {19, 20, 21, 22}
DEBUG    [action_state:85] ('remote 2', 17)
DEBUG    [action_state:429] ('prev anchors', 0)
DEBUG    [action_state:442] ('anchors', 0, 4, 0, 1)
DEBUG    [action_state:448] ('curr_node_id', 0)
DEBUG    [action_state:471] (0, [], True, True, True, True)
DEBUG    [action_state:495] ('node_state', 0, [(0, 0, None)])
DEBUG    [action_state:506] ('stack_position', 0, 0, 0, None, 0)
DEBUG    [action_state:524] (0, 0, [(0, 0, [(0, 0, None)])])
DEBUG    [action_state:525] [(0, None), (1, (1, 0, {'id': 0, 'anchors': [{'from': 0, 'to': 4}], 'label': 'When'}, [[]]))]
DEBUG    [action_state:538] (0, 18, 23, {9, 27, 24})
DEBUG    [action_state:607] ('token stack', [(0, 

DEBUG    [action_state:495] ('node_state', 6, [(0, 0, [(0, 0, None)]), (1, 1, [(1, 1, None)]), (2, 2, [(2, 2, None)]), (25, 25, [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]), (6, 6, None)])
DEBUG    [action_state:506] ('stack_position', 0, 6, 6, None, 0)
DEBUG    [action_state:524] (6,
 6,
 [(0, 0, [(0, 0, None)]),
  (1, 1, [(1, 1, None)]),
  (2, 2, [(2, 2, None)]),
  (25,
   25,
   [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]),
  (6, 6, [(6, 6, None)])])
DEBUG    [action_state:525] [(0, None), (1, (1, 0, {'id': 6, 'anchors': [{'from': 37, 'to': 40}], 'label': 'the'}, [[]]))]
DEBUG    [action_state:538] (6, 13, 26, {8, 7})
DEBUG    [action_state:607] ('token stack',
 [(0, [(0, 'When', 'When')]),
  (1, [(1, 'Bowie', 'Bowie')]),
  (2, [(2, 'left', 'left')]),
  (25,
   [(3, [(3, 'the', 'the')]),
    (4, [(4, 'technical', 'technical')]),
    (5, [(5, 'school', 'school')])]),
  (6, [(6, 'the', 'the')])])
DEBUG    [action_state:609] ('v

DEBUG    [action_state:538] (24, 23, 23, {9, 27})
DEBUG    [action_state:607] ('token stack',
 [(0, [(0, 'When', 'When')]),
  (24,
   [(1, [(1, 'Bowie', 'Bowie')]),
    (2, [(2, 'left', 'left')]),
    (25,
     [(3, [(3, 'the', 'the')]),
      (4, [(4, 'technical', 'technical')]),
      (5, [(5, 'school', 'school')])]),
    (26,
     [(6, [(6, 'the', 'the')]),
      (7, [(7, 'following', 'following')]),
      (8, [(8, 'year', 'year')])])])])
DEBUG    [action_state:609] ('visited states', {0, 1, 2, 3, 4, 5, 6, 7, 8, 24, 25, 26}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 24, 25, 26}, {24, 25, 26}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 24, 25, 26})
DEBUG    [action_state:429] ('prev anchors', 9)
DEBUG    [action_state:442] ('anchors', 55, 56, 9, 10)
DEBUG    [action_state:448] ('curr_node_id', 9)
DEBUG    [action_state:471] (9,
 [(0, 0, [(0, 0, None)]),
  (24,
   24,
   [(1, 1, [(1, 1, None)]),
    (2, 2, [(2, 2, None)]),
    (25,
     25,
     [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None

DEBUG    [action_state:506] ('stack_position', 0, 12, 12, None, 0)
DEBUG    [action_state:524] (12,
 12,
 [(0, 0, [(0, 0, None)]),
  (24,
   24,
   [(1, 1, [(1, 1, None)]),
    (2, 2, [(2, 2, None)]),
    (25,
     25,
     [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]),
    (26,
     26,
     [(6, 6, [(6, 6, None)]),
      (7, 7, [(7, 7, None)]),
      (8, 8, [(8, 8, None)])])]),
  (9, 9, [(9, 9, None)]),
  (10, 10, [(10, 10, None)]),
  (11, 11, [(11, 11, None)]),
  (12, 12, [(12, 12, None)])])
DEBUG    [action_state:525] [(0, None), (1, (1, 0, {'id': 12, 'anchors': [{'from': 69, 'to': 72}], 'label': 'his'}, [[]]))]
DEBUG    [action_state:538] (12, 19, 28, {13})
DEBUG    [action_state:607] ('token stack',
 [(0, [(0, 'When', 'When')]),
  (24,
   [(1, [(1, 'Bowie', 'Bowie')]),
    (2, [(2, 'left', 'left')]),
    (25,
     [(3, [(3, 'the', 'the')]),
      (4, [(4, 'technical', 'technical')]),
      (5, [(5, 'school', 'school')])]),
    (26,
     [(6, [(6, 'the'

DEBUG    [action_state:609] ('visited states', {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 24, 25, 26, 28}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 24, 25, 26, 28}, {24, 25, 26, 28}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 24, 25, 26, 28})
DEBUG    [action_state:429] ('prev anchors', 14)
DEBUG    [action_state:442] ('anchors', 81, 83, 14, 15)
DEBUG    [action_state:448] ('curr_node_id', 14)
DEBUG    [action_state:471] (14,
 [(0, 0, [(0, 0, None)]),
  (24,
   24,
   [(1, 1, [(1, 1, None)]),
    (2, 2, [(2, 2, None)]),
    (25,
     25,
     [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]),
    (26,
     26,
     [(6, 6, [(6, 6, None)]),
      (7, 7, [(7, 7, None)]),
      (8, 8, [(8, 8, None)])])]),
  (9, 9, [(9, 9, None)]),
  (10, 10, [(10, 10, None)]),
  (11, 11, [(11, 11, None)]),
  (28, 28, [(12, 12, [(12, 12, None)]), (13, 13, [(13, 13, None)])])],
 True,
 True,
 True,
 True)
DEBUG    [action_state:495] ('node_state', 14, [(0, 0, [(0, 0, No

DEBUG    [action_state:609] ('visited states', {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24, 25, 26, 28}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24, 25, 26, 28}, {24, 25, 26, 28}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 24, 25, 26, 28})
DEBUG    [action_state:429] ('prev anchors', 17)
DEBUG    [action_state:442] ('anchors', 98, 100, 17, 18)
DEBUG    [action_state:448] ('curr_node_id', 17)
DEBUG    [action_state:471] (17,
 [(0, 0, [(0, 0, None)]),
  (24,
   24,
   [(1, 1, [(1, 1, None)]),
    (2, 2, [(2, 2, None)]),
    (25,
     25,
     [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]),
    (26,
     26,
     [(6, 6, [(6, 6, None)]),
      (7, 7, [(7, 7, None)]),
      (8, 8, [(8, 8, None)])])]),
  (9, 9, [(9, 9, None)]),
  (10, 10, [(10, 10, None)]),
  (11, 11, [(11, 11, None)]),
  (28, 28, [(12, 12, [(12, 12, None)]), (13, 13, [(13, 13, None)])]),
  (14, 14, [(14, 14, None)]),
  (15, 15, [(15, 15, None)]

DEBUG    [action_state:506] ('stack_position', 0, 19, 19, None, 0)
DEBUG    [action_state:524] (19,
 19,
 [(0, 0, [(0, 0, None)]),
  (24,
   24,
   [(1, 1, [(1, 1, None)]),
    (2, 2, [(2, 2, None)]),
    (25,
     25,
     [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]),
    (26,
     26,
     [(6, 6, [(6, 6, None)]),
      (7, 7, [(7, 7, None)]),
      (8, 8, [(8, 8, None)])])]),
  (9, 9, [(9, 9, None)]),
  (10, 10, [(10, 10, None)]),
  (11, 11, [(11, 11, None)]),
  (28, 28, [(12, 12, [(12, 12, None)]), (13, 13, [(13, 13, None)])]),
  (14, 14, [(14, 14, None)]),
  (15, 15, [(15, 15, None)]),
  (16, 16, [(16, 16, None)]),
  (17, 17, [(17, 17, None)]),
  (18, 18, [(18, 18, None)]),
  (19, 19, [(19, 19, None)])])
DEBUG    [action_state:525] [(0, None), (1, (1, 0, {'id': 19, 'anchors': [{'from': 108, 'to': 109}], 'label': 'a'}, [[]]))]
DEBUG    [action_state:538] (19, 10, 31, {20, 21, 22})
DEBUG    [action_state:607] ('token stack',
 [(0, [(0, 'When', 'When')]),

DEBUG    [action_state:609] ('visited states', {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 24, 25, 26, 28}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 24, 25, 26, 28}, {24, 25, 26, 28}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 24, 25, 26, 28})
DEBUG    [action_state:429] ('prev anchors', 22)
DEBUG    [action_state:442] ('anchors', 118, 119, 22, 23)
DEBUG    [action_state:448] ('curr_node_id', 22)
DEBUG    [action_state:471] (22,
 [(0, 0, [(0, 0, None)]),
  (24,
   24,
   [(1, 1, [(1, 1, None)]),
    (2, 2, [(2, 2, None)]),
    (25,
     25,
     [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]),
    (26,
     26,
     [(6, 6, [(6, 6, None)]),
      (7, 7, [(7, 7, None)]),
      (8, 8, [(8, 8, None)])])]),
  (9, 9, [(9, 9, None)]),
  (10, 10, [(10, 10, None)]),
  (11, 11, [(11, 11, None)]),
  (28, 28, [(12, 12, [(12, 12, None)]), (13, 13, [(13, 13, None)])

DEBUG    [action_state:448] ('curr_node_id', 30)
DEBUG    [action_state:471] (30,
 [(0, 0, [(0, 0, None)]),
  (24,
   24,
   [(1, 1, [(1, 1, None)]),
    (2, 2, [(2, 2, None)]),
    (25,
     25,
     [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]),
    (26,
     26,
     [(6, 6, [(6, 6, None)]),
      (7, 7, [(7, 7, None)]),
      (8, 8, [(8, 8, None)])])]),
  (9, 9, [(9, 9, None)]),
  (10, 10, [(10, 10, None)]),
  (11, 11, [(11, 11, None)]),
  (28, 28, [(12, 12, [(12, 12, None)]), (13, 13, [(13, 13, None)])]),
  (14, 14, [(14, 14, None)]),
  (15, 15, [(15, 15, None)]),
  (16, 16, [(16, 16, None)]),
  (17, 17, [(17, 17, None)]),
  (18, 18, [(18, 18, None)]),
  (31,
   31,
   [(19, 19, [(19, 19, None)]),
    (20, 20, [(20, 20, None)]),
    (21, 21, [(21, 21, None)]),
    (22, 22, [(22, 22, None)])])],
 False,
 True,
 False,
 True)
DEBUG    [action_state:495] ('node_state', 30, [(0, 0, [(0, 0, None)]), (24, 24, [(1, 1, [(1, 1, None)]), (2, 2, [(2, 2, None)]), (

DEBUG    [action_state:609] ('visited states', {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 27, 28, 29, 30, 31}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 28, 29, 30, 31}, {24, 25, 26, 27, 28, 29, 30, 31}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 24, 25, 26, 28, 29, 30, 31})
DEBUG    [action_state:448] ('curr_node_id', 27)
DEBUG    [action_state:471] (27,
 [(0, 0, [(0, 0, None)]),
  (24,
   24,
   [(1, 1, [(1, 1, None)]),
    (2, 2, [(2, 2, None)]),
    (25,
     25,
     [(3, 3, [(3, 3, None)]), (4, 4, [(4, 4, None)]), (5, 5, [(5, 5, None)])]),
    (26,
     26,
     [(6, 6, [(6, 6, None)]),
      (7, 7, [(7, 7, None)]),
      (8, 8, [(8, 8, None)])])]),
  (9, 9, [(9, 9, None)]),
  (10, 10, [(10, 10, None)]),
  (11, 11, [(11, 11, None)]),
  (28, 28, [(12, 12, [(12, 12, None)]), (13, 13, [(13, 13, None)])]),
  (29,
   29,
   [(14, 14, [(14, 14

DEBUG    [action_state:525] [(1, (4, 3, {'id': 23}, [[{'source': 23, 'target': 0, 'label': 'L', 'id': 18, 'parent': 23, 'child': 0}], [{'source': 23, 'target': 24, 'label': 'H', 'id': 23, 'parent': 23, 'child': 24}], [{'source': 23, 'target': 9, 'label': 'U', 'id': 21, 'parent': 23, 'child': 9}], [{'source': 23, 'target': 27, 'label': 'H', 'id': 29, 'parent': 23, 'child': 27}]]))]
DEBUG    [action_state:607] ('token stack',
 [(23,
   [(0, [(0, 'When', 'When')]),
    (24,
     [(1, [(1, 'Bowie', 'Bowie')]),
      (2, [(2, 'left', 'left')]),
      (25,
       [(3, [(3, 'the', 'the')]),
        (4, [(4, 'technical', 'technical')]),
        (5, [(5, 'school', 'school')])]),
      (26,
       [(6, [(6, 'the', 'the')]),
        (7, [(7, 'following', 'following')]),
        (8, [(8, 'year', 'year')])])]),
    (9, [(9, ',', ',')]),
    (27,
     [(10, [(10, 'he', 'he')]),
      (11, [(11, 'informed', 'informed')]),
      (28, [(12, [(12, 'his', 'his')]), (13, [(13, 'parents', 'parents')])]),
 

In [217]:
mrp_meta_data

('When Bowie left the technical school the following year, he informed his parents of his intention to become a pop star.',
 [{'id': 0, 'anchors': [{'from': 0, 'to': 4}], 'label': 'When'},
  {'id': 1, 'anchors': [{'from': 5, 'to': 10}], 'label': 'Bowie'},
  {'id': 2, 'anchors': [{'from': 11, 'to': 15}], 'label': 'left'},
  {'id': 3, 'anchors': [{'from': 16, 'to': 19}], 'label': 'the'},
  {'id': 4, 'anchors': [{'from': 20, 'to': 29}], 'label': 'technical'},
  {'id': 5, 'anchors': [{'from': 30, 'to': 36}], 'label': 'school'},
  {'id': 6, 'anchors': [{'from': 37, 'to': 40}], 'label': 'the'},
  {'id': 7, 'anchors': [{'from': 41, 'to': 50}], 'label': 'following'},
  {'id': 8, 'anchors': [{'from': 51, 'to': 55}], 'label': 'year'},
  {'id': 9, 'anchors': [{'from': 55, 'to': 56}], 'label': ','},
  {'id': 10, 'anchors': [{'from': 57, 'to': 59}], 'label': 'he'},
  {'id': 11, 'anchors': [{'from': 60, 'to': 68}], 'label': 'informed'},
  {'id': 12, 'anchors': [{'from': 69, 'to': 72}], 'label': 'his

In [37]:
curr_node_ids = mrp_meta_data[-3]
token_states = mrp_meta_data[-2]
actions = mrp_meta_data[-1]

In [38]:
*_, curr_node_ids, token_states, actions = mrp_meta_data

In [39]:
for curr_node_id, action, token_state in zip(curr_node_ids, actions, [[]] + token_states):
    action_type, params = action
    print(curr_node_id, action, token_state)

0 (2, None) []
1 (2, None) []
2 (0, None) []
3 (0, None) [(2, 'unclear')]
3 (1, (1, 0, {'id': 3, 'label': 'what', 'properties': ['pos', 'frame'], 'values': ['WP', 'q:i-h-h'], 'anchors': [{'from': 14, 'to': 18}]}, [[]])) [(2, 'unclear'), (3, 'what')]
4 (0, None) [(2, 'unclear'), (3, [(3, 'what')])]
4 (1, (2, 1, {'id': 4, 'label': 'effect', 'properties': ['pos', 'frame'], 'values': ['NN', 'n:x'], 'anchors': [{'from': 19, 'to': 25}]}, [[{'source': 3, 'target': 4, 'label': 'BV', 'id': 3, 'parent': 4, 'child': 3}], []])) [(2, 'unclear'), (3, [(3, 'what')]), (4, 'effect')]
5 (0, None) [(2, 'unclear'), (4, [(3, [(3, 'what')]), (4, 'effect')])]
5 (1, (1, 0, {'id': 5, 'label': 'the', 'properties': ['pos', 'frame'], 'values': ['DT', 'q:i-h-h'], 'anchors': [{'from': 26, 'to': 29}]}, [[]])) [(2, 'unclear'), (4, [(3, [(3, 'what')]), (4, 'effect')]), (5, 'the')]
6 (0, None) [(2, 'unclear'), (4, [(3, [(3, 'what')]), (4, 'effect')]), (5, [(5, 'the')])]
7 (2, None) [(2, 'unclear'), (4, [(3, [(3, 'what'

In [40]:
companion_parser_states, companion_meta_data = mrp_json2parser_states(
    parse_json,
    mrp_doc=doc,
    tokenized_parse_nodes=parse_json['nodes'],
)

In [218]:
logger.info(args.graphviz_file_template.format(
    framework, dataset, cid))

INFO     [__main__:2] http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/ucca/wiki.mrp/494011.png


In [42]:
mrp_json['input']

'It is unclear what effect the sale of the shopping centers will have on earnings.'

In [43]:
mrp_parser_states

[(2,
  [(2, None), (2, None), (0, None)],
  [],
  [],
  [],
  [(2, 2, None)],
  [(2, 'unclear', 'unclear')]),
 (3,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 3,
      'label': 'what',
      'properties': ['pos', 'frame'],
      'values': ['WP', 'q:i-h-h'],
      'anchors': [{'from': 14, 'to': 18}]},
     [[]]))],
  [3],
  [],
  [4],
  [(2, 2, None), (3, 3, [(3, 3, None)])],
  [(2, 'unclear', 'unclear'), (3, [(3, 'what', 'what')])]),
 (4,
  [(0, None),
   (1,
    (2,
     1,
     {'id': 4,
      'label': 'effect',
      'properties': ['pos', 'frame'],
      'values': ['NN', 'n:x'],
      'anchors': [{'from': 19, 'to': 25}]},
     [[{'source': 3,
        'target': 4,
        'label': 'BV',
        'id': 3,
        'parent': 4,
        'child': 3}],
      []]))],
  [],
  [],
  [],
  [(2, 2, None), (4, 4, [(3, 3, [(3, 3, None)]), (4, 4, None)])],
  [(2, 'unclear', 'unclear'),
   (4, [(3, [(3, 'what', 'what')]), (4, 'effect', 'effect')])]),
 (4,
  [],
  [],
  [],
  [],
  [(2, 2, None)

In [44]:
[(node['id'], node.get('label')) for node in mrp_json['nodes']]

[(2, 'unclear'),
 (3, 'what'),
 (4, 'effect'),
 (5, 'the'),
 (6, 'sale'),
 (8, 'the'),
 (9, 'shop'),
 (10, 'center'),
 (12, 'have'),
 (13, 'on'),
 (14, 'earnings')]

In [45]:
doc

'It is unclear what effect the sale of the shopping centers will have on earnings.'

In [46]:
parse_json['nodes']

[{'id': 0,
  'label': 'It',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['it', 'PRON', 'PRP']},
 {'id': 1,
  'label': 'is',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['be', 'VERB', 'VBZ']},
 {'id': 2,
  'label': 'unclear',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['unclear', 'ADJ', 'JJ']},
 {'id': 3,
  'label': 'what',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['what', 'DET', 'WDT']},
 {'id': 4,
  'label': 'effect',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['effect', 'NOUN', 'NN']},
 {'id': 5,
  'label': 'the',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['the', 'DET', 'DT']},
 {'id': 6,
  'label': 'sale',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['sale', 'NOUN', 'NN']},
 {'id': 7,
  'label': 'of',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['of', 'ADP', 'IN']},
 {'id': 8,
  'label': 'the',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['the', 'DET', 'DT']},
 {'id': 9,
  'labe

In [47]:
[(node['id'], node['label']) for node in parse_json['nodes']]

[(0, 'It'),
 (1, 'is'),
 (2, 'unclear'),
 (3, 'what'),
 (4, 'effect'),
 (5, 'the'),
 (6, 'sale'),
 (7, 'of'),
 (8, 'the'),
 (9, 'shopping'),
 (10, 'centers'),
 (11, 'will'),
 (12, 'have'),
 (13, 'on'),
 (14, 'earnings'),
 (15, '.')]

In [48]:
anchors

[(0, 2),
 (3, 5),
 (6, 13),
 (14, 18),
 (19, 25),
 (26, 29),
 (30, 34),
 (35, 37),
 (38, 41),
 (42, 50),
 (51, 58),
 (59, 63),
 (64, 68),
 (69, 71),
 (72, 80),
 (80, 81)]

### Create training instance

In [68]:
total_count = 0
with_parse_count = 0
data_size_limit = args.data_size_limit
ignore_framework_set = {'amr', 'dm', 'psd', 'eds'}
ignore_dataset_set = {}

In [69]:
allennlp_tests_fixtures_output_file = os.path.join(
    args.project_root, args.mrp_test_dir, args.tests_fixtures_file)
allennlp_train_output_file = os.path.join(
    args.project_root, args.allennlp_mrp_json_file_template.format('train'))
allennlp_test_output_file = os.path.join(
    args.project_root, args.allennlp_mrp_json_file_template.format('test'))

In [70]:
# Create tests fixture jsonl
fixture_combinations = [
    ('ucca', 'wiki', 70)
]

with open(allennlp_tests_fixtures_output_file, 'w') as wf:
    for framework, dataset, idx in fixture_combinations:
        mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
        cid = mrp_json.get('id')
        doc = mrp_json.get('input')
        
        alignment = {}
        if framework == 'amr':
            alignment = cid2alignment[cid]  
        parse_json = dataset2cid2parse_json.get(dataset, {}).get(cid, {})

        if parse_json:
            with_parse_count += 1
            mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
                mrp_json, 
                tokenized_parse_nodes=parse_json['nodes'],
                alignment=alignment,
            )
            companion_parser_states, companion_meta_data = mrp_json2parser_states(
                parse_json, 
                mrp_doc=doc,
                tokenized_parse_nodes=parse_json['nodes'],
            )

            data_instance = {
                'mrp_json': mrp_json,
                'parse_json': parse_json,
                'mrp_parser_states': mrp_parser_states,
                'mrp_meta_data': mrp_meta_data,
                'companion_parser_states': companion_parser_states,
                'companion_meta_data': companion_meta_data,
            }
            json_encoded_instance = json.dumps(data_instance)
            wf.write(json_encoded_instance + '\n')

In [71]:
[state[-1] for state in mrp_parser_states]

[[(0, [(0, 'In', 'In')])],
 [(0, [(0, 'In', 'In')]), (1, [(1, 'the', 'the')])],
 [(0, [(0, 'In', 'In')]),
  (1, [(1, 'the', 'the')]),
  (2, [(2, 'final', 'final')])],
 [(0, [(0, 'In', 'In')]),
  (1, [(1, 'the', 'the')]),
  (2, [(2, 'final', 'final')]),
  (3, [(3, 'minute', 'minute')])],
 [(31,
   [(0, [(0, 'In', 'In')]),
    (1, [(1, 'the', 'the')]),
    (2, [(2, 'final', 'final')]),
    (3, [(3, 'minute', 'minute')])])],
 [(31,
   [(0, [(0, 'In', 'In')]),
    (1, [(1, 'the', 'the')]),
    (2, [(2, 'final', 'final')]),
    (3, [(3, 'minute', 'minute')])]),
  (4, [(4, 'of', 'of')])],
 [(31,
   [(0, [(0, 'In', 'In')]),
    (1, [(1, 'the', 'the')]),
    (2, [(2, 'final', 'final')]),
    (3, [(3, 'minute', 'minute')])]),
  (4, [(4, 'of', 'of')]),
  (5, [(5, 'the', 'the')])],
 [(31,
   [(0, [(0, 'In', 'In')]),
    (1, [(1, 'the', 'the')]),
    (2, [(2, 'final', 'final')]),
    (3, [(3, 'minute', 'minute')])]),
  (4, [(4, 'of', 'of')]),
  (5, [(5, 'the', 'the')]),
  (6, [(6, 'game', 'game')]

In [72]:
mrp_meta_data[-1]

[(0, None),
 (1,
  (1, 0, {'id': 0, 'anchors': [{'from': 0, 'to': 2}], 'label': 'In'}, [[]])),
 (0, None),
 (1,
  (1, 0, {'id': 1, 'anchors': [{'from': 3, 'to': 6}], 'label': 'the'}, [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 2, 'anchors': [{'from': 7, 'to': 12}], 'label': 'final'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 3, 'anchors': [{'from': 13, 'to': 19}], 'label': 'minute'},
   [[]])),
 (1,
  (4,
   3,
   {'id': 31},
   [[{'source': 31,
      'target': 0,
      'label': 'R',
      'id': 31,
      'parent': 31,
      'child': 0}],
    [{'source': 31,
      'target': 1,
      'label': 'E',
      'id': 28,
      'parent': 31,
      'child': 1}],
    [{'source': 31,
      'target': 2,
      'label': 'E',
      'id': 21,
      'parent': 31,
      'child': 2}],
    [{'source': 31,
      'target': 3,
      'label': 'C',
      'id': 14,
      'parent': 31,
      'child': 3}]])),
 (0, None),
 (1,
  (1, 0, {'id': 4, 'anchors': [{'from': 20, 'to': 22}], 'label': 'of'}, [[]])),

In [73]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [74]:
parse_json

{'id': '470004',
 'tops': [9],
 'nodes': [{'id': 0,
   'label': 'In',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['in', 'ADP', 'IN']},
  {'id': 1,
   'label': 'the',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['the', 'DET', 'DT']},
  {'id': 2,
   'label': 'final',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['final', 'ADJ', 'JJ']},
  {'id': 3,
   'label': 'minute',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['minute', 'NOUN', 'NN']},
  {'id': 4,
   'label': 'of',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['of', 'ADP', 'IN']},
  {'id': 5,
   'label': 'the',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['the', 'DET', 'DT']},
  {'id': 6,
   'label': 'game',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['game', 'NOUN', 'NN']},
  {'id': 7,
   'label': ',',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': [',', 'PUNCT', ',']},
  {'id': 8,
   'label': 'Johnson',
   'properties': ['lemma', 'up

In [75]:
[n['values'][2] for n in parse_json['nodes']]

['IN',
 'DT',
 'JJ',
 'NN',
 'IN',
 'DT',
 'NN',
 ',',
 'NNP',
 'VBD',
 'DT',
 'NN',
 'VBN',
 'IN',
 'NNPS',
 'NN',
 'NNP',
 'NNP',
 ',',
 'CC',
 'RB',
 'VBD',
 'CD',
 'JJ',
 'NNS',
 'WDT',
 'MD',
 'VB',
 'VBN',
 'DT',
 'NN',
 '.']

In [76]:
# Create train jsonl
if os.path.isfile(allennlp_train_output_file):
    logger.info('allennlp_train_output_file found, stop generation')
else:
    pass
if 1==1:
    data_size = 0
    with open(allennlp_train_output_file, 'w') as wf:
        for _, dataset, mrp_json in tqdm(mrp_dataset.mrp_json_generator(
            ignore_framework_set=ignore_framework_set,
            ignore_dataset_set=ignore_dataset_set,
        )):
            total_count += 1
            if data_size >= data_size_limit:
                break
            cid = mrp_json.get('id')
            doc = mrp_json.get('input')
            
            framework = mrp_json.get('framework')
            alignment = {}
            if framework == 'amr':
                alignment = cid2alignment[cid]  
            parse_json = dataset2cid2parse_json.get(dataset, {}).get(cid, {})
            
            if parse_json:
                mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
                    mrp_json, 
                    tokenized_parse_nodes=parse_json['nodes'],
                    alignment=alignment,
                )
                companion_parser_states, companion_meta_data = mrp_json2parser_states(
                    parse_json, 
                    mrp_doc=doc,
                    tokenized_parse_nodes=parse_json['nodes'],
                )
                
                # Continue if error
                if not mrp_parser_states:
                    continue
                    
                data_size += 1
                logger.info(data_size)
                data_instance = {
                    'mrp_json': mrp_json,
                    'parse_json': parse_json,
                    'mrp_parser_states': mrp_parser_states,
                    'mrp_meta_data': mrp_meta_data,
                    'companion_parser_states': companion_parser_states,
                    'companion_meta_data': companion_meta_data,
                }
                json_encoded_instance = json.dumps(data_instance)
                wf.write(json_encoded_instance + '\n')

                
# Create test jsonl
if os.path.isfile(allennlp_test_output_file):
    logger.info('allennlp_test_output_file found, stop generation')
else:
    pass
if 1==1:
    data_size = 0
    with open(allennlp_test_output_file, 'w') as wf:
        alignment = {}
        for mrp_json in tqdm(test_mrp_jsons):
            
            if data_size >= data_size_limit:
                break
            cid = mrp_json.get('id', '')
            framework = mrp_json.get('framework', '')
            if framework in ignore_framework_set:
                continue
            parse_json = test_parse_jsons[cid]
            doc = parse_json.get('input')
            companion_parser_states, companion_meta_data = mrp_json2parser_states(
                parse_json, 
                mrp_doc=doc,
                tokenized_parse_nodes=parse_json['nodes'],
            )
            if not companion_parser_states:
                continue
            data_size += 1
            logger.info(data_size)
            data_instance = {
                'mrp_json': mrp_json,
                'parse_json': parse_json,
                'companion_parser_states': companion_parser_states,
                'companion_meta_data': companion_meta_data,
            }
            json_encoded_instance = json.dumps(data_instance)
            wf.write(json_encoded_instance + '\n')

INFO     [__main__:3] allennlp_train_output_file found, stop generation
0it [00:00, ?it/s]INFO     [__main__:42] 1
1it [00:00,  7.69it/s]INFO     [__main__:42] 2
2it [00:00,  6.10it/s]INFO     [__main__:42] 3
INFO     [__main__:42] 4
6it [00:01,  4.50it/s]INFO     [__main__:42] 5
INFO     [__main__:42] 6
8it [00:01,  5.66it/s]INFO     [__main__:42] 7
10it [00:01,  5.07it/s]INFO     [__main__:42] 8
12it [00:02,  4.29it/s]INFO     [__main__:42] 9
INFO     [__main__:42] 10
INFO     [__main__:42] 11
17it [00:03,  5.10it/s]INFO     [__main__:42] 12
INFO     [__main__:42] 13
19it [00:03,  6.49it/s]INFO     [__main__:42] 14
21it [00:03,  8.00it/s]INFO     [__main__:42] 15
25it [00:04,  8.26it/s]INFO     [__main__:42] 16
29it [00:04,  5.48it/s]INFO     [__main__:42] 17
INFO     [__main__:42] 18
32it [00:05,  6.77it/s]INFO     [__main__:42] 19
34it [00:05,  7.68it/s]INFO     [__main__:42] 20
INFO     [__main__:42] 21
36it [00:05,  5.97it/s]INFO     [__main__:42] 22
38it [00:05,  7.52it/s]INFO  

INFO     [__main__:82] 54
INFO     [__main__:82] 55

  1%|          | 55/6288 [00:01<02:26, 42.55it/s][AINFO     [__main__:82] 56
INFO     [__main__:82] 57
INFO     [__main__:82] 58
INFO     [__main__:82] 59
INFO     [__main__:82] 60

  1%|          | 60/6288 [00:01<03:03, 34.01it/s][AINFO     [__main__:82] 61
INFO     [__main__:82] 62
INFO     [__main__:82] 63
INFO     [__main__:82] 64
INFO     [__main__:82] 65
INFO     [__main__:82] 66
INFO     [__main__:82] 67

  1%|          | 67/6288 [00:01<02:42, 38.26it/s][AINFO     [__main__:82] 68
INFO     [__main__:82] 69
INFO     [__main__:82] 70
INFO     [__main__:82] 71
INFO     [__main__:82] 72

  1%|          | 72/6288 [00:01<02:44, 37.90it/s][AINFO     [__main__:82] 73
INFO     [__main__:82] 74
INFO     [__main__:82] 75
INFO     [__main__:82] 76
INFO     [__main__:82] 77

  1%|          | 77/6288 [00:02<03:40, 28.22it/s][AINFO     [__main__:82] 78
INFO     [__main__:82] 79
INFO     [__main__:82] 80
INFO     [__main__:82] 81

  1%|▏

### Test allennlp dataset reader

In [77]:
import torch.optim as optim

from mrp_library.dataset_readers.mrp_jsons import MRPDatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.modules.feedforward import FeedForward

from allennlp.training.metrics import CategoricalAccuracy

from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer

import json
import logging
from typing import Dict

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import LabelField, TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.models import Model
from overrides import overrides

INFO     [pytorch_pretrained_bert.modeling:230] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .
DEBUG    [allennlp.common.registrable:56] instantiating registered subclass relu of <class 'allennlp.nn.activations.Activation'>
DEBUG    [allennlp.common.registrable:56] instantiating registered subclass relu of <class 'allennlp.nn.activations.Activation'>
DEBUG    [allennlp.common.registrable:56] instantiating registered subclass relu of <class 'allennlp.nn.activations.Activation'>
DEBUG    [allennlp.common.registrable:56] instantiating registered subclass relu of <class 'allennlp.nn.activations.Activation'>


In [78]:
from mrp_library.dataset_readers.mrp_jsons_actions import MRPDatasetActionReader

In [79]:
reader = MRPDatasetActionReader()

In [80]:
train_dataset = reader.read(cached_path(allennlp_train_output_file))


0it [00:00, ?it/s][AINFO     [mrp_library.dataset_readers.mrp_jsons_actions:58] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/allennlp-mrp-json-small-train.jsonl

892it [00:00, 8860.44it/s][A
1837it [00:00, 9019.99it/s][A
2725it [00:00, 8913.87it/s][A
3728it [00:00, 9213.53it/s][A
4440it [00:04, 530.41it/s] [A
5322it [00:04, 736.99it/s][A
5594it [00:04, 1154.90it/s][A

In [81]:
test_dataset = reader.read(cached_path(allennlp_train_output_file))
#test_dataset = reader.read(cached_path(allennlp_test_output_file))


0it [00:00, ?it/s][AINFO     [mrp_library.dataset_readers.mrp_jsons_actions:58] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/allennlp-mrp-json-small-train.jsonl

631it [00:00, 6301.26it/s][A
1314it [00:00, 6430.49it/s][A
2057it [00:00, 6655.68it/s][A
2710it [00:00, 6617.29it/s][A
3397it [00:00, 6628.10it/s][A
4052it [00:00, 6585.04it/s][A
4699it [00:00, 6549.49it/s][A
5322it [00:00, 6335.26it/s][A
5594it [00:00, 6557.46it/s][A

In [82]:
tests_fixtures_dataset = reader.read(cached_path(allennlp_tests_fixtures_output_file))


0it [00:00, ?it/s][AINFO     [mrp_library.dataset_readers.mrp_jsons_actions:58] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/src/tests/fixtures/test.jsonl

74it [00:00, 4621.89it/s][A

In [83]:
vocab = Vocabulary.from_instances(train_dataset + test_dataset + tests_fixtures_dataset)

INFO     [allennlp.data.vocabulary:396] Fitting token dictionary from dataset.

  0%|          | 0/11262 [00:00<?, ?it/s][A
 25%|██▌       | 2855/11262 [00:00<00:00, 28547.24it/s][A
 51%|█████     | 5769/11262 [00:00<00:00, 28720.70it/s][A
 77%|███████▋  | 8690/11262 [00:00<00:00, 28863.12it/s][A
100%|██████████| 11262/11262 [00:00<00:00, 28887.94it/s][A

In [84]:
vocab.print_statistics()

INFO     [allennlp.data.vocabulary:664] Printed vocabulary statistics are only for the part of the vocabulary generated from instances. If vocabulary is constructed by extending saved vocabulary with dataset instances, the directly loaded portion won't be considered here.




----Vocabulary Statistics----


Top 10 most frequent tokens in namespace 'word':
	Token: the		Frequency: 12994
	Token: ,		Frequency: 12948
	Token: <START-LABEL>		Frequency: 12868
	Token: <END-LABEL>		Frequency: 12868
	Token: in		Frequency: 6368
	Token: and		Frequency: 5264
	Token: to		Frequency: 5090
	Token: a		Frequency: 4472
	Token: ’s		Frequency: 3536
	Token: for		Frequency: 3528

Top 10 longest tokens in namespace 'word':
	Token: confrontational		length: 15	Frequency: 100
	Token: <START-LABEL>		length: 13	Frequency: 12868
	Token: collaboration		length: 13	Frequency: 96
	Token: collaborative		length: 13	Frequency: 96
	Token: contributions		length: 13	Frequency: 40
	Token: circumstances		length: 13	Frequency: 36
	Token: relationship		length: 12	Frequency: 204
	Token: Hillsborough		length: 12	Frequency: 104
	Token: Merseysiders		length: 12	Frequency: 100
	Token: cancellation		length: 12	Frequency: 100

Top 10 shortest tokens in namespace 'word':
	Token: A		length: 1	Frequency: 52
	T

In [85]:
vocab.get_vocab_size('word')

1184

In [86]:
vocab.get_vocab_size('pos')

59

In [87]:
vocab.get_vocab_size('label')

2

In [88]:
EMBEDDING_DIM = 100
HIDDEN_DIM = 50

### Test model

In [89]:
from mrp_library.models.generalizer import Generalizer, ActionGeneralizer
from allennlp.nn import InitializerApplicator, RegularizerApplicator, util
from allennlp.nn.activations import Activation
from allennlp.common.params import Params
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.modules.seq2vec_encoders.pytorch_seq2vec_wrapper import PytorchSeq2VecWrapper

In [90]:
word_embedding = Embedding(num_embeddings=vocab.get_vocab_size('word'),
                            embedding_dim=EMBEDDING_DIM)
word_embedder = BasicTextFieldEmbedder({"word": word_embedding})

In [91]:
pos_embedding = Embedding(num_embeddings=vocab.get_vocab_size('pos'),
                            embedding_dim=EMBEDDING_DIM)
pos_embedder = BasicTextFieldEmbedder({"pos": pos_embedding})

In [92]:
encoder = PytorchSeq2VecWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

In [93]:
classifier_params = Params({
  "input_dim": HIDDEN_DIM * 4 * 3,
  "num_layers": 2,
  "hidden_dims": [50, 3],
  "activations": ["sigmoid", "linear"],
  "dropout": [0.2, 0.0]
})

In [94]:
classifier_feedforward = FeedForward.from_params(classifier_params)

INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.modules.feedforward.FeedForward'> from params {'input_dim': 600, 'num_layers': 2, 'hidden_dims': [50, 3], 'activations': ['sigmoid', 'linear'], 'dropout': [0.2, 0.0]} and extras set()
INFO     [allennlp.common.params:252] input_dim = 600
INFO     [allennlp.common.params:252] num_layers = 2
INFO     [allennlp.common.params:252] hidden_dims = [50, 3]
INFO     [allennlp.common.params:252] hidden_dims = [50, 3]
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params ['sigmoid', 'linear'] and extras set()
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params sigmoid and extras set()
INFO     [allennlp.common.params:252] type = sigmoid
INFO

In [95]:
parse_label = {
    'word': torch.LongTensor(
        [
            [ 1,  0,  3,  7,  2,  9,  4],
            [ 0,  0,  5,  0,  0,  0,  4]
        ]
    )
}
embedded_parse_label = word_embedder(parse_label)

In [96]:
feature_mask = util.get_text_field_mask(parse_label)

In [97]:
encoded_feature = encoder(embedded_parse_label, feature_mask)

In [98]:
encoded_features = [encoded_feature] * 4 * 3

In [99]:
torch.cat(encoded_features, dim=-1).shape

torch.Size([2, 600])

In [100]:
logits = classifier_feedforward(torch.cat(encoded_features, dim=-1))

In [101]:
logits.shape

torch.Size([2, 3])

In [102]:
label = torch.tensor([1, 0])

In [103]:
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits, label)

In [104]:
model = ActionGeneralizer(
    vocab=vocab,
    word_embedder=word_embedder,
    pos_embedder=pos_embedder,
    encoder=encoder,
    classifier_feedforward=classifier_feedforward
)

INFO     [allennlp.nn.initializers:293] Initializing parameters
INFO     [allennlp.nn.initializers:309] Done initializing parameters; the following parameters are using their default initialization from their code
INFO     [allennlp.nn.initializers:314]    classifier_feedforward._linear_layers.0.bias
INFO     [allennlp.nn.initializers:314]    classifier_feedforward._linear_layers.0.weight
INFO     [allennlp.nn.initializers:314]    classifier_feedforward._linear_layers.1.bias
INFO     [allennlp.nn.initializers:314]    classifier_feedforward._linear_layers.1.weight
INFO     [allennlp.nn.initializers:314]    encoder._module.bias_hh_l0
INFO     [allennlp.nn.initializers:314]    encoder._module.bias_ih_l0
INFO     [allennlp.nn.initializers:314]    encoder._module.weight_hh_l0
INFO     [allennlp.nn.initializers:314]    encoder._module.weight_ih_l0
INFO     [allennlp.nn.initializers:314]    pos_embedder.token_embedder_pos.weight
INFO     [allennlp.nn.initializers:314]    word_embedder.token_e

In [105]:
optimizer = optim.SGD(model.parameters(), lr=0.01)
cuda_device = -1

In [106]:
iterator = BucketIterator(batch_size=20, sorting_keys=[("curr_parse_node_label", "num_tokens")])
iterator.index_with(vocab)

In [107]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_dataset,
    validation_dataset=train_dataset,
    patience=10,
    num_epochs=20,
    cuda_device=cuda_device
)

In [109]:
trainer.train()

INFO     [allennlp.training.trainer:465] Beginning training.
INFO     [allennlp.training.trainer:281] Epoch 0/19
INFO     [allennlp.training.trainer:283] Peak CPU memory usage MB: 5694.772
INFO     [allennlp.training.trainer:287] GPU 0 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 1 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 2 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 3 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 4 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 5 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 6 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 7 memory usage MB: 10
INFO     [allennlp.training.trainer:311] Training

  0%|          | 0/280 [00:00<?, ?it/s][ADEBUG    [allennlp.data.iterators.data_iterator:151] Batch padding lengths: {'curr_parse_node_label': {'word_length': 1, 'num_tokens': 1}, 'curr_parse_node_lemma': {'word_length': 1,

DEBUG    [allennlp.data.iterators.data_iterator:151] Batch padding lengths: {'curr_parse_node_label': {'word_length': 1, 'num_tokens': 1}, 'curr_parse_node_lemma': {'word_length': 1, 'num_tokens': 1}, 'curr_parse_node_upos': {'pos_length': 1, 'num_tokens': 1}, 'curr_parse_node_xpos': {'pos_length': 1, 'num_tokens': 1}, 'prev_parse_node_labels': {'word_length': 5, 'num_tokens': 5}, 'prev_parse_node_lemmas': {'word_length': 5, 'num_tokens': 5}, 'prev_parse_node_uposs': {'pos_length': 5, 'num_tokens': 5}, 'prev_parse_node_xposs': {'pos_length': 5, 'num_tokens': 5}, 'next_parse_node_labels': {'word_length': 5, 'num_tokens': 5}, 'next_parse_node_lemmas': {'word_length': 5, 'num_tokens': 5}, 'next_parse_node_uposs': {'pos_length': 5, 'num_tokens': 5}, 'next_parse_node_xposs': {'pos_length': 5, 'num_tokens': 5}}
DEBUG    [allennlp.data.iterators.data_iterator:152] Batch size: 20

accuracy: 0.5333, loss: 0.8102 ||:   3%|▎         | 9/280 [00:01<00:51,  5.30it/s][ADEBUG    [allennlp.data.itera

DEBUG    [allennlp.data.iterators.data_iterator:152] Batch size: 20
DEBUG    [allennlp.data.iterators.data_iterator:151] Batch padding lengths: {'curr_parse_node_label': {'word_length': 1, 'num_tokens': 1}, 'curr_parse_node_lemma': {'word_length': 1, 'num_tokens': 1}, 'curr_parse_node_upos': {'pos_length': 1, 'num_tokens': 1}, 'curr_parse_node_xpos': {'pos_length': 1, 'num_tokens': 1}, 'prev_parse_node_labels': {'word_length': 5, 'num_tokens': 5}, 'prev_parse_node_lemmas': {'word_length': 5, 'num_tokens': 5}, 'prev_parse_node_uposs': {'pos_length': 5, 'num_tokens': 5}, 'prev_parse_node_xposs': {'pos_length': 5, 'num_tokens': 5}, 'next_parse_node_labels': {'word_length': 5, 'num_tokens': 5}, 'next_parse_node_lemmas': {'word_length': 5, 'num_tokens': 5}, 'next_parse_node_uposs': {'pos_length': 5, 'num_tokens': 5}, 'next_parse_node_xposs': {'pos_length': 5, 'num_tokens': 5}}
DEBUG    [allennlp.data.iterators.data_iterator:152] Batch size: 14

accuracy: 0.5508, loss: 0.8177 ||:   6%|▋     

DEBUG    [allennlp.data.iterators.data_iterator:152] Batch size: 20

accuracy: 0.5545, loss: 0.8252 ||:   9%|▉         | 26/280 [00:02<00:18, 13.91it/s][ADEBUG    [allennlp.data.iterators.data_iterator:151] Batch padding lengths: {'curr_parse_node_label': {'word_length': 1, 'num_tokens': 1}, 'curr_parse_node_lemma': {'word_length': 1, 'num_tokens': 1}, 'curr_parse_node_upos': {'pos_length': 1, 'num_tokens': 1}, 'curr_parse_node_xpos': {'pos_length': 1, 'num_tokens': 1}, 'prev_parse_node_labels': {'word_length': 5, 'num_tokens': 5}, 'prev_parse_node_lemmas': {'word_length': 5, 'num_tokens': 5}, 'prev_parse_node_uposs': {'pos_length': 5, 'num_tokens': 5}, 'prev_parse_node_xposs': {'pos_length': 5, 'num_tokens': 5}, 'next_parse_node_labels': {'word_length': 5, 'num_tokens': 5}, 'next_parse_node_lemmas': {'word_length': 5, 'num_tokens': 5}, 'next_parse_node_uposs': {'pos_length': 5, 'num_tokens': 5}, 'next_parse_node_xposs': {'pos_length': 5, 'num_tokens': 5}}
DEBUG    [allennlp.data.iter

KeyboardInterrupt: 