In [1]:
try:
    __IPYTHON__
    USING_IPYTHON = True
    %load_ext autoreload
    %autoreload 2
except NameError:
    USING_IPYTHON = False

#### Argparse

In [566]:
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('project_root', help='')
ap.add_argument('--mrp-data-dir', default='data', help='')
ap.add_argument('--mrp-test-dir', default='src/tests', help='')
ap.add_argument('--tests-fixtures-file', default='fixtures/test.jsonl', help='')

ap.add_argument('--graphviz-sub-dir', default='visualization/graphviz', help='')
ap.add_argument('--train-sub-dir', default='training', help='')
ap.add_argument('--companion-sub-dir', default='companion')
ap.add_argument('--jamr-alignment-file', default='jamr.mrp')

ap.add_argument('--test-input-file', default='evaluation/input.mrp', help='')
ap.add_argument('--test-companion-file', default='evaluation/udpipe.mrp', help='')
ap.add_argument('--allennlp-mrp-json-file-template', default='allennlp-mrp-json-small-{}.jsonl', help='')
ap.add_argument('--data-size-limit', type=int, default=100, help='')

ap.add_argument('--mrp-file-extension', default='.mrp')
ap.add_argument('--companion-file-extension', default='.conllu')
ap.add_argument('--graphviz-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png')
ap.add_argument('--parse-plot-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.png')

arg_string = """
    /data/proj29_ds1/home/slai/mrp2019
"""
arguments = [arg for arg_line in arg_string.split(r'\\n') for arg in arg_line.split()]

In [519]:
if USING_IPYTHON:
    args = ap.parse_args(arguments)
else:
    args = ap.parse_args()

In [520]:
args

Namespace(allennlp_mrp_json_file_template='allennlp-mrp-json-small-{}.jsonl', companion_file_extension='.conllu', companion_sub_dir='companion', data_size_limit=100, graphviz_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png', graphviz_sub_dir='visualization/graphviz', jamr_alignment_file='jamr.mrp', mrp_data_dir='data', mrp_file_extension='.mrp', mrp_test_dir='src/tests', parse_plot_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.png', project_root='/data/proj29_ds1/home/slai/mrp2019', test_companion_file='evaluation/udpipe.mrp', test_input_file='evaluation/input.mrp', tests_fixtures_file='fixtures/test.jsonl', train_sub_dir='training')

#### Library imports

In [510]:
import json
import logging
import os
import pprint
import re
import string
from collections import Counter, defaultdict, deque

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plot_util
import torch
from action_state import mrp_json2parser_states, _generate_parser_action_states
from action_state import ERROR, APPEND, RESOLVE, IGNORE
from preprocessing import (CompanionParseDataset, MrpDataset, JamrAlignmentDataset,
                           read_companion_parse_json_file, read_mrp_json_file, parse2parse_json)            
from torch import nn
from tqdm import tqdm

#### ipython notebook specific imports

In [511]:
if USING_IPYTHON:
    # matplotlib config
    %matplotlib inline

DEBUG    [matplotlib.pyplot:219] Loaded backend module://ipykernel.pylab.backend_inline version unknown.


In [512]:
sh = logging.StreamHandler()
formatter = logging.Formatter('%(levelname)-8s [%(name)s:%(lineno)d] %(message)s')
sh.setFormatter(formatter)
logging.basicConfig(level=logging.DEBUG, handlers=[sh])
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.setLevel(logging.INFO)

### Constants

In [8]:
UNKWOWN = 'UNKWOWN'

### Load data

In [9]:
train_dir = os.path.join(args.project_root, args.mrp_data_dir, args.train_sub_dir)

In [10]:
mrp_dataset = MrpDataset()

In [11]:
frameworks, framework2dataset2mrp_jsons = mrp_dataset.load_mrp_json_dir(
    train_dir, args.mrp_file_extension)

frameworks:   0%|          | 0/5 [00:00<?, ?it/s]
dataset_name:   0%|          | 0/2 [00:00<?, ?it/s][A
dataset_name:  50%|█████     | 1/2 [00:00<00:00,  3.29it/s][A
frameworks:  20%|██        | 1/5 [00:00<00:02,  1.37it/s]s][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  40%|████      | 2/5 [00:04<00:05,  1.77s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  60%|██████    | 3/5 [00:10<00:05,  2.77s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  80%|████████  | 4/5 [00:16<00:03,  3.78s/it]t][A
dataset_name:   0%|          | 0/14 [00:00<?, ?it/s][A
dataset_name:  43%|████▎     | 6/14 [00:00<00:00, 19.23it/s][A
dataset_name:  57%|█████▋    | 8/14 [00:00<00:00, 16.58it/s][A
dataset_name:  71%|███████▏  | 10/14 [00:01<00:00,  6.35it/s][A
dataset_name:  79%|███████▊  | 11/14 [00:01<00:00,  5.84it/s][A
frameworks: 100%|██████████| 5/5 [00:17<00:00,  3.10s/it]t/s][A


In [12]:
framework2dataset2mrp_jsons.keys()

dict_keys(['ucca', 'psd', 'eds', 'dm', 'amr'])

### Data Preprocessing companion

In [13]:
companion_dir = os.path.join(args.project_root, args.mrp_data_dir, args.companion_sub_dir)

In [14]:
cparse_dataset = CompanionParseDataset()

In [15]:
dataset2cid2parse = cparse_dataset.load_companion_parse_dir(companion_dir, args.companion_file_extension)

INFO     [preprocessing:172] framework amr found
dataset: 100%|██████████| 13/13 [00:01<00:00,  9.45it/s]
INFO     [preprocessing:172] framework dm found
dataset: 100%|██████████| 5/5 [00:04<00:00,  1.07it/s]
INFO     [preprocessing:172] framework ucca found
dataset: 100%|██████████| 6/6 [00:00<00:00, 26.43it/s]


In [16]:
dataset2cid2parse_json = cparse_dataset.convert_parse2parse_json()

In [17]:
dataset2cid2parse.keys()

dict_keys(['amr-guidelines', 'bolt', 'cctv', 'dfa', 'dfb', 'fables', 'lorelei', 'mt09sdl', 'proxy', 'rte', 'wb', 'wiki', 'xinhua', 'wsj', 'ewt'])

In [18]:
# Some data is missing
'20003001' in dataset2cid2parse['wsj']

False

### Load JAMR alignment data

In [19]:
jalignment_dataset = JamrAlignmentDataset()

In [20]:
cid2alignment = jalignment_dataset.load_jamr_alignment_file(os.path.join(
    args.project_root,
    args.mrp_data_dir,
    args.companion_sub_dir,
    args.jamr_alignment_file
))

### Load testing data

In [21]:
test_input_filename = os.path.join(args.project_root, args.mrp_data_dir, args.test_input_file)
test_companion_filename = os.path.join(args.project_root, args.mrp_data_dir, args.test_companion_file)

In [22]:
test_mrp_jsons = read_mrp_json_file(test_input_filename)
test_parse_jsons = read_companion_parse_json_file(test_companion_filename)

In [23]:
parse_json = test_parse_jsons['102990']

In [24]:
mrp_json = framework2dataset2mrp_jsons['psd']['wsj'][1]

In [184]:
test_configs = [
    ('ucca', 'wiki', 70),
]
framework, dataset, idx = test_configs[0]

In [185]:
mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
cid = mrp_json.get('id')

In [186]:
parse_json = dataset2cid2parse_json[dataset][cid]

In [188]:
doc = mrp_json['input']

In [189]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [190]:
token_pos = 0
anchors = []
char_pos2tokenized_parse_node_id = []

for node_id, node in enumerate(parse_json.get('nodes')):
    label = node.get('label')
    label_size = len(label)
    while doc[token_pos] == ' ':
        token_pos += 1
        char_pos2tokenized_parse_node_id.append(node_id)
    anchors.append((token_pos, token_pos + label_size))
    char_pos2tokenized_parse_node_id.extend([node_id] * (label_size))
    print(node_id, doc[token_pos: token_pos + label_size], anchors[-1], len(char_pos2tokenized_parse_node_id))
    token_pos += label_size

0 In (0, 2) 2
1 the (3, 6) 6
2 final (7, 12) 12
3 minute (13, 19) 19
4 of (20, 22) 22
5 the (23, 26) 26
6 game (27, 31) 31
7 , (31, 32) 32
8 Johnson (33, 40) 40
9 had (41, 44) 44
10 the (45, 48) 48
11 ball (49, 53) 53
12 stolen (54, 60) 60
13 by (61, 63) 63
14 Celtics (64, 71) 71
15 center (72, 78) 78
16 Robert (79, 85) 85
17 Parish (86, 92) 92
18 , (92, 93) 93
19 and (94, 97) 97
20 then (98, 102) 102
21 missed (103, 109) 109
22 two (110, 113) 113
23 free (114, 118) 118
24 throws (119, 125) 125
25 that (126, 130) 130
26 could (131, 136) 136
27 have (137, 141) 141
28 won (142, 145) 145
29 the (146, 149) 149
30 game (150, 154) 154
31 . (154, 155) 155


In [191]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [192]:
len(char_pos2tokenized_parse_node_id)

155

In [193]:
doc = mrp_json['input']

In [194]:
mrp_json['tops']

[34]

In [195]:
mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
    mrp_json, 
    tokenized_parse_nodes=parse_json['nodes'],
)

In [196]:
(
    doc,
    nodes,
    node_id2node,
    edge_id2edge,
    top_oriented_edges,
    token_nodes,
    # abstract_node_id_set,
    parent_id2indegree,
    # parent_id2child_id_set,
    # child_id2parent_id_set,
    # child_id2edge_id_set,
    # parent_id2edge_id_set,
    # parent_child_id2edge_id_set,
    parse_nodes_anchors,
    char_pos2tokenized_node_id,
    curr_node_ids,
    token_states,
    actions,
) = mrp_meta_data

In [197]:
curr_node_ids = mrp_meta_data[-3]
token_states = mrp_meta_data[-2]
actions = mrp_meta_data[-1]

In [198]:
*_, curr_node_ids, token_states, actions = mrp_meta_data

In [199]:
actions[:4]

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'anchors': [{'from': 0, 'to': 2}],
    'label': 'In',
    'propagate_label': 'R'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'anchors': [{'from': 3, 'to': 6}],
    'label': 'the',
    'propagate_label': 'E'},
   [[]]))]

In [236]:
for curr_node_id, action, token_state in zip(curr_node_ids, actions, token_states):
    action_type, params = action
#     pprint.pprint((curr_node_id, action[0]))
#     pprint.pprint(([token_group[:4] for token_group in token_state]))
    pprint.pprint((curr_node_id, action[0], [token_group[:4] for token_group in token_state]))

(0, 0, [(0, False, 'In', [])])
(1, 1, [(0, True, 'R', [(0, False, 'In', [])])])
(1, 0, [(0, True, 'R', [(0, False, 'In', [])]), (1, False, 'the', [])])
(2,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])])])
(2,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, False, 'final', [])])
(3,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])])])
(3,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, False, 'minute', [])])
(4,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, True, 'C', [(3, False, 'minute', [])])])
(4,
 1,
 [(32,
   True,
   'E',
   [(0, True, 'R', [(0, False, 'In', [])]),
    (1, True, 'E', [(1, False, 'the', [])]),
    (2, True, 'E', [(

       [(14, True, 'A', [(14, False, 'Celtics', [])]),
        (15, True, 'S', [(15, False, 'center', [])])]),
      (16, True, 'C', [(16, False, 'Robert', [])])])]),
  (18, True, 'U', [(18, False, ',', [])]),
  (19, True, 'L', [(19, False, 'and', [])]),
  (20, True, 'L', [(20, False, 'then', [])]),
  (21, True, 'D', [(21, False, 'missed', [])]),
  (22, True, 'D', [(22, False, 'two', [])]),
  (23, True, 'D', [(23, False, 'free', [])])])
(24,
 0,
 [(37,
   True,
   'H',
   [(33,
     True,
     'T',
     [(32,
       True,
       'E',
       [(0, True, 'R', [(0, False, 'In', [])]),
        (1, True, 'E', [(1, False, 'the', [])]),
        (2, True, 'E', [(2, False, 'final', [])]),
        (3, True, 'C', [(3, False, 'minute', [])])]),
      (4, True, 'R', [(4, False, 'of', [])]),
      (5, True, 'E', [(5, False, 'the', [])]),
      (6, True, 'C', [(6, False, 'game', [])])]),
    (7, True, 'U', [(7, False, ',', [])]),
    (8, True, 'A', [(8, False, 'Johnson', [])]),
    (9, True, 'F', [(9,

In [642]:
for curr_node_id, action, token_state in zip(curr_node_ids, actions, [[]] + token_states):
    action_type, params = action
    pprint.pprint((curr_node_id, action[0], [token_group[:4] for token_group in token_state]))

(0, 0, [])
(1, 1, [(0, False, 'In', [])])
(1, 0, [(0, True, 'R', [(0, False, 'In', [])])])
(2, 1, [(0, True, 'R', [(0, False, 'In', [])]), (1, False, 'the', [])])
(2,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])])])
(3,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, False, 'final', [])])
(3,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])])])
(4,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, False, 'minute', [])])
(4,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, True, 'C', [(3, False, 'minute', [])])])
(4,
 0,
 [(32,
   True,
   'E',
   [(0, True, 'R', [(0, False, 'In', [])]),
    (1, True, 'E', [(1, False, 'the', [])]),
    (2, Tr

       True,
       'E',
       [(0, True, 'R', [(0, False, 'In', [])]),
        (1, True, 'E', [(1, False, 'the', [])]),
        (2, True, 'E', [(2, False, 'final', [])]),
        (3, True, 'C', [(3, False, 'minute', [])])]),
      (4, True, 'R', [(4, False, 'of', [])]),
      (5, True, 'E', [(5, False, 'the', [])]),
      (6, True, 'C', [(6, False, 'game', [])])]),
    (7, True, 'U', [(7, False, ',', [])]),
    (8, True, 'A', [(8, False, 'Johnson', [])]),
    (9, True, 'F', [(9, False, 'had', [])]),
    (34,
     True,
     'A',
     [(10, True, 'E', [(10, False, 'the', [])]),
      (11, True, 'C', [(11, False, 'ball', [])])]),
    (12, True, 'P', [(12, False, 'stolen', [])]),
    (36,
     True,
     'A',
     [(13, True, 'R', [(13, False, 'by', [])]),
      (35,
       True,
       'E',
       [(14, True, 'A', [(14, False, 'Celtics', [])]),
        (15, True, 'S', [(15, False, 'center', [])])]),
      (16, True, 'C', [(16, False, 'Robert', [])])])]),
  (18, True, 'U', [(18, False, 

In [202]:
actions

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'anchors': [{'from': 0, 'to': 2}],
    'label': 'In',
    'propagate_label': 'R'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'anchors': [{'from': 3, 'to': 6}],
    'label': 'the',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 2,
    'anchors': [{'from': 7, 'to': 12}],
    'label': 'final',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 3,
    'anchors': [{'from': 13, 'to': 19}],
    'label': 'minute',
    'propagate_label': 'C'},
   [[]])),
 (1,
  (4,
   3,
   {'id': 31, 'propagate_label': 'E'},
   [[{'source': 31,
      'target': 0,
      'label': 'R',
      'id': 31,
      'parent': 31,
      'child': 0}],
    [{'source': 31,
      'target': 1,
      'label': 'E',
      'id': 28,
      'parent': 31,
      'child': 1}],
    [{'source': 31,
      'target': 2,
      'label': 'E',
      'id': 21,
      'parent': 31,
      'child': 2}],
    [{'source': 31,
      't

In [203]:
token_states[1]

[(0, True, 'R', [(0, False, 'In', [])])]

In [204]:
[n['label'] for n in parse_json['nodes']]

['In',
 'the',
 'final',
 'minute',
 'of',
 'the',
 'game',
 ',',
 'Johnson',
 'had',
 'the',
 'ball',
 'stolen',
 'by',
 'Celtics',
 'center',
 'Robert',
 'Parish',
 ',',
 'and',
 'then',
 'missed',
 'two',
 'free',
 'throws',
 'that',
 'could',
 'have',
 'won',
 'the',
 'game',
 '.']

In [205]:
token_states[-1]

[(42,
  True,
  '<UCCA-TOP-NODE>',
  [(37,
    True,
    'H',
    [(33,
      True,
      'T',
      [(32,
        True,
        'E',
        [(0, True, 'R', [(0, False, 'In', [])]),
         (1, True, 'E', [(1, False, 'the', [])]),
         (2, True, 'E', [(2, False, 'final', [])]),
         (3, True, 'C', [(3, False, 'minute', [])])]),
       (4, True, 'R', [(4, False, 'of', [])]),
       (5, True, 'E', [(5, False, 'the', [])]),
       (6, True, 'C', [(6, False, 'game', [])])]),
     (7, True, 'U', [(7, False, ',', [])]),
     (8, True, 'A', [(8, False, 'Johnson', [])]),
     (9, True, 'F', [(9, False, 'had', [])]),
     (34,
      True,
      'A',
      [(10, True, 'E', [(10, False, 'the', [])]),
       (11, True, 'C', [(11, False, 'ball', [])])]),
     (12, True, 'P', [(12, False, 'stolen', [])]),
     (36,
      True,
      'A',
      [(13, True, 'R', [(13, False, 'by', [])]),
       (35,
        True,
        'E',
        [(14, True, 'A', [(14, False, 'Celtics', [])]),
         (

In [206]:
companion_parser_states, companion_meta_data = mrp_json2parser_states(
    parse_json,
    mrp_doc=doc,
    tokenized_parse_nodes=parse_json['nodes'],
)

In [207]:
logger.info(args.graphviz_file_template.format(
    framework, dataset, cid))

INFO     [__main__:2] http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/ucca/wiki.mrp/470004.png


In [208]:
mrp_json['input']

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [209]:
mrp_parser_states

[(0,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 0,
      'anchors': [{'from': 0, 'to': 2}],
      'label': 'In',
      'propagate_label': 'R'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')])]),
 (1,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 1,
      'anchors': [{'from': 3, 'to': 6}],
      'label': 'the',
      'propagate_label': 'E'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)]), (1, 1, [(1, 1, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')]),
   (1, True, 'E', [(1, False, 'the', 'the')])]),
 (2,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 2,
      'anchors': [{'from': 7, 'to': 12}],
      'label': 'final',
      'propagate_label': 'E'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)]), (1, 1, [(1, 1, None)]), (2, 2, [(2, 2, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')]),
   (1, True, 'E', [(1, False, 'the', 'the')]),
   (2, True, 'E', [(2, False, 'final', 'final')])]),
 (3,
  

In [210]:
[(node['id'], node.get('label')) for node in mrp_json['nodes']]

[(0, 'In'),
 (1, 'the'),
 (2, 'final'),
 (3, 'minute'),
 (4, 'of'),
 (5, 'the'),
 (6, 'game'),
 (7, ','),
 (8, 'Johnson'),
 (9, 'had'),
 (10, 'the'),
 (11, 'ball'),
 (12, 'stolen'),
 (13, 'by'),
 (14, 'Celtics'),
 (15, 'center'),
 (16, 'RobertParish'),
 (17, ','),
 (18, 'and'),
 (19, 'then'),
 (20, 'missed'),
 (21, 'two'),
 (22, 'free'),
 (23, 'throws'),
 (24, 'that'),
 (25, 'could'),
 (26, 'have'),
 (27, 'won'),
 (28, 'the'),
 (29, 'game'),
 (30, '.'),
 (31, None),
 (32, None),
 (33, None),
 (34, None),
 (35, None),
 (36, None),
 (37, None),
 (38, None),
 (39, None),
 (40, None),
 (41, None)]

In [211]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [212]:
parse_json['nodes']

[{'id': 0,
  'label': 'In',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['in', 'ADP', 'IN']},
 {'id': 1,
  'label': 'the',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['the', 'DET', 'DT']},
 {'id': 2,
  'label': 'final',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['final', 'ADJ', 'JJ']},
 {'id': 3,
  'label': 'minute',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['minute', 'NOUN', 'NN']},
 {'id': 4,
  'label': 'of',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['of', 'ADP', 'IN']},
 {'id': 5,
  'label': 'the',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['the', 'DET', 'DT']},
 {'id': 6,
  'label': 'game',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['game', 'NOUN', 'NN']},
 {'id': 7,
  'label': ',',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': [',', 'PUNCT', ',']},
 {'id': 8,
  'label': 'Johnson',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['Johnson', 'PROPN', 'NNP']},
 {'id': 9,
  'label

In [213]:
[(node['id'], node['label']) for node in parse_json['nodes']]

[(0, 'In'),
 (1, 'the'),
 (2, 'final'),
 (3, 'minute'),
 (4, 'of'),
 (5, 'the'),
 (6, 'game'),
 (7, ','),
 (8, 'Johnson'),
 (9, 'had'),
 (10, 'the'),
 (11, 'ball'),
 (12, 'stolen'),
 (13, 'by'),
 (14, 'Celtics'),
 (15, 'center'),
 (16, 'Robert'),
 (17, 'Parish'),
 (18, ','),
 (19, 'and'),
 (20, 'then'),
 (21, 'missed'),
 (22, 'two'),
 (23, 'free'),
 (24, 'throws'),
 (25, 'that'),
 (26, 'could'),
 (27, 'have'),
 (28, 'won'),
 (29, 'the'),
 (30, 'game'),
 (31, '.')]

In [214]:
anchors

[(0, 2),
 (3, 6),
 (7, 12),
 (13, 19),
 (20, 22),
 (23, 26),
 (27, 31),
 (31, 32),
 (33, 40),
 (41, 44),
 (45, 48),
 (49, 53),
 (54, 60),
 (61, 63),
 (64, 71),
 (72, 78),
 (79, 85),
 (86, 92),
 (92, 93),
 (94, 97),
 (98, 102),
 (103, 109),
 (110, 113),
 (114, 118),
 (119, 125),
 (126, 130),
 (131, 136),
 (137, 141),
 (142, 145),
 (146, 149),
 (150, 154),
 (154, 155)]

### Create training instance

In [567]:
total_count = 0
with_parse_count = 0
data_size_limit = args.data_size_limit
ignore_framework_set = {'amr', 'dm', 'psd', 'eds'}
ignore_dataset_set = {}

In [568]:
allennlp_tests_fixtures_output_file = os.path.join(
    args.project_root, args.mrp_test_dir, args.tests_fixtures_file)
allennlp_train_output_file = os.path.join(
    args.project_root, args.allennlp_mrp_json_file_template.format('train'))
allennlp_test_output_file = os.path.join(
    args.project_root, args.allennlp_mrp_json_file_template.format('test'))

In [569]:
# Create tests fixture jsonl
fixture_combinations = [
    ('ucca', 'wiki', 70),
] * 10

with open(allennlp_tests_fixtures_output_file, 'w') as wf:
    for framework, dataset, idx in fixture_combinations:
        mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
        cid = mrp_json.get('id')
        doc = mrp_json.get('input')
        
        alignment = {}
        if framework == 'amr':
            alignment = cid2alignment[cid]  
        parse_json = dataset2cid2parse_json.get(dataset, {}).get(cid, {})

        if parse_json:
            with_parse_count += 1
            mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
                mrp_json, 
                tokenized_parse_nodes=parse_json['nodes'],
                alignment=alignment,
            )
            companion_parser_states, companion_meta_data = mrp_json2parser_states(
                parse_json, 
                mrp_doc=doc,
                tokenized_parse_nodes=parse_json['nodes'],
            )

            data_instance = {
                'mrp_json': mrp_json,
                'parse_json': parse_json,
                'mrp_parser_states': mrp_parser_states,
                'mrp_meta_data': mrp_meta_data,
                'companion_parser_states': companion_parser_states,
                'companion_meta_data': companion_meta_data,
            }
            json_encoded_instance = json.dumps(data_instance)
            wf.write(json_encoded_instance + '\n')

In [570]:
[state[-1] for state in mrp_parser_states]

[[(0, True, 'R', [(0, False, 'In', 'In')])],
 [(0, True, 'R', [(0, False, 'In', 'In')]),
  (1, True, 'E', [(1, False, 'the', 'the')])],
 [(0, True, 'R', [(0, False, 'In', 'In')]),
  (1, True, 'E', [(1, False, 'the', 'the')]),
  (2, True, 'E', [(2, False, 'final', 'final')])],
 [(0, True, 'R', [(0, False, 'In', 'In')]),
  (1, True, 'E', [(1, False, 'the', 'the')]),
  (2, True, 'E', [(2, False, 'final', 'final')]),
  (3, True, 'C', [(3, False, 'minute', 'minute')])],
 [(31,
   True,
   'E',
   [(0, True, 'R', [(0, False, 'In', 'In')]),
    (1, True, 'E', [(1, False, 'the', 'the')]),
    (2, True, 'E', [(2, False, 'final', 'final')]),
    (3, True, 'C', [(3, False, 'minute', 'minute')])])],
 [(31,
   True,
   'E',
   [(0, True, 'R', [(0, False, 'In', 'In')]),
    (1, True, 'E', [(1, False, 'the', 'the')]),
    (2, True, 'E', [(2, False, 'final', 'final')]),
    (3, True, 'C', [(3, False, 'minute', 'minute')])]),
  (4, True, 'R', [(4, False, 'of', 'of')])],
 [(31,
   True,
   'E',
   [(0, 

In [571]:
mrp_meta_data[-1]

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'anchors': [{'from': 0, 'to': 2}],
    'label': 'In',
    'propagate_label': 'R'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'anchors': [{'from': 3, 'to': 6}],
    'label': 'the',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 2,
    'anchors': [{'from': 7, 'to': 12}],
    'label': 'final',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 3,
    'anchors': [{'from': 13, 'to': 19}],
    'label': 'minute',
    'propagate_label': 'C'},
   [[]])),
 (1,
  (4,
   3,
   {'id': 31, 'propagate_label': 'E'},
   [[{'source': 31,
      'target': 0,
      'label': 'R',
      'id': 31,
      'parent': 31,
      'child': 0}],
    [{'source': 31,
      'target': 1,
      'label': 'E',
      'id': 28,
      'parent': 31,
      'child': 1}],
    [{'source': 31,
      'target': 2,
      'label': 'E',
      'id': 21,
      'parent': 31,
      'child': 2}],
    [{'source': 31,
      't

In [572]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [573]:
parse_json

{'id': '470004',
 'tops': [9],
 'nodes': [{'id': 0,
   'label': 'In',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['in', 'ADP', 'IN']},
  {'id': 1,
   'label': 'the',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['the', 'DET', 'DT']},
  {'id': 2,
   'label': 'final',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['final', 'ADJ', 'JJ']},
  {'id': 3,
   'label': 'minute',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['minute', 'NOUN', 'NN']},
  {'id': 4,
   'label': 'of',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['of', 'ADP', 'IN']},
  {'id': 5,
   'label': 'the',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['the', 'DET', 'DT']},
  {'id': 6,
   'label': 'game',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['game', 'NOUN', 'NN']},
  {'id': 7,
   'label': ',',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': [',', 'PUNCT', ',']},
  {'id': 8,
   'label': 'Johnson',
   'properties': ['lemma', 'up

In [574]:
[n['values'][2] for n in parse_json['nodes']]

['IN',
 'DT',
 'JJ',
 'NN',
 'IN',
 'DT',
 'NN',
 ',',
 'NNP',
 'VBD',
 'DT',
 'NN',
 'VBN',
 'IN',
 'NNPS',
 'NN',
 'NNP',
 'NNP',
 ',',
 'CC',
 'RB',
 'VBD',
 'CD',
 'JJ',
 'NNS',
 'WDT',
 'MD',
 'VB',
 'VBN',
 'DT',
 'NN',
 '.']

In [575]:
# Create train jsonl
if os.path.isfile(allennlp_train_output_file) and os.path.isfile(allennlp_test_output_file):
    logger.info('allennlp_train_output_file found, stop generation')
else:
    pass
if 1==1:
    data_size = 0
    with open(allennlp_train_output_file, 'w') as train_wf:
        with open(allennlp_test_output_file, 'w') as test_wf:
            for _, dataset, mrp_json in tqdm(mrp_dataset.mrp_json_generator(
                ignore_framework_set=ignore_framework_set,
                ignore_dataset_set=ignore_dataset_set,
            )):
                total_count += 1
                if data_size >= data_size_limit * 2:
                    break
                cid = mrp_json.get('id')
                doc = mrp_json.get('input')

                framework = mrp_json.get('framework')
                alignment = {}
                if framework == 'amr':
                    alignment = cid2alignment[cid]  
                parse_json = dataset2cid2parse_json.get(dataset, {}).get(cid, {})

                if parse_json:
                    mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
                        mrp_json, 
                        tokenized_parse_nodes=parse_json['nodes'],
                        alignment=alignment,
                    )
                    companion_parser_states, companion_meta_data = mrp_json2parser_states(
                        parse_json, 
                        mrp_doc=doc,
                        tokenized_parse_nodes=parse_json['nodes'],
                    )

                    # Continue if error
                    if not mrp_parser_states:
                        continue

                    data_size += 1
                    logger.info(data_size)
                    data_instance = {
                        'mrp_json': mrp_json,
                        'parse_json': parse_json,
                        'mrp_parser_states': mrp_parser_states,
                        'mrp_meta_data': mrp_meta_data,
                        'companion_parser_states': companion_parser_states,
                        'companion_meta_data': companion_meta_data,
                    }
                    json_encoded_instance = json.dumps(data_instance)
                    if data_size <= data_size_limit:
                        train_wf.write(json_encoded_instance + '\n')
                    else:
                        test_wf.write(json_encoded_instance + '\n')

INFO     [__main__:3] allennlp_train_output_file found, stop generation

0it [00:00, ?it/s][AINFO     [__main__:43] 1

1it [00:00,  7.08it/s][AINFO     [__main__:43] 2

2it [00:00,  5.59it/s][AINFO     [__main__:43] 3

3it [00:01,  3.19it/s][A
INFO     [__main__:43] 4

6it [00:01,  3.98it/s][AINFO     [__main__:43] 5
INFO     [__main__:43] 6

8it [00:01,  5.06it/s][AINFO     [__main__:43] 7

9it [00:01,  5.54it/s][A
10it [00:02,  4.68it/s][AINFO     [__main__:43] 8

12it [00:02,  4.05it/s][AINFO     [__main__:43] 9

INFO     [__main__:43] 10
INFO     [__main__:43] 11

17it [00:03,  4.70it/s][AINFO     [__main__:43] 12
INFO     [__main__:43] 13

19it [00:03,  5.92it/s][AINFO     [__main__:43] 14

21it [00:03,  7.21it/s][AINFO     [__main__:43] 15

23it [00:04,  5.99it/s][A
25it [00:04,  7.42it/s][AINFO     [__main__:43] 16

27it [00:04,  7.86it/s][A
29it [00:05,  4.95it/s][AINFO     [__main__:43] 17

31it [00:05,  6.34it/s][AINFO     [__main__:43] 18
INFO     [__main__:

### Test allennlp dataset reader

In [576]:
import torch.optim as optim

from mrp_library.dataset_readers.mrp_jsons import MRPDatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.modules.feedforward import FeedForward

from allennlp.training.metrics import CategoricalAccuracy

from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer

import json
import logging
from typing import Dict

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import LabelField, TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.models import Model
from overrides import overrides

In [577]:
from mrp_library.dataset_readers.mrp_jsons_actions import MRPDatasetActionReader

In [578]:
reader = MRPDatasetActionReader()

In [581]:
train_dataset = reader.read(cached_path(allennlp_train_output_file))

0it [00:00, ?it/s]INFO     [mrp_library.dataset_readers.mrp_jsons_actions:113] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/allennlp-mrp-json-small-train.jsonl
5594it [00:00, 6983.68it/s]


In [582]:
# test_dataset = reader.read(cached_path(allennlp_train_output_file))
test_dataset = reader.read(cached_path(allennlp_test_output_file))

0it [00:00, ?it/s]INFO     [mrp_library.dataset_readers.mrp_jsons_actions:113] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/allennlp-mrp-json-small-test.jsonl
5434it [00:05, 925.59it/s] 


In [583]:
tests_fixtures_dataset = reader.read(cached_path(allennlp_tests_fixtures_output_file))

0it [00:00, ?it/s]INFO     [mrp_library.dataset_readers.mrp_jsons_actions:113] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/src/tests/fixtures/test.jsonl
740it [00:00, 6766.56it/s]


In [584]:
vocab = Vocabulary.from_instances(train_dataset + test_dataset + tests_fixtures_dataset)

INFO     [allennlp.data.vocabulary:396] Fitting token dictionary from dataset.
100%|██████████| 11768/11768 [00:01<00:00, 8977.71it/s]


In [585]:
vocab.print_statistics()

INFO     [allennlp.data.vocabulary:664] Printed vocabulary statistics are only for the part of the vocabulary generated from instances. If vocabulary is constructed by extending saved vocabulary with dataset instances, the directly loaded portion won't be considered here.




----Vocabulary Statistics----


Top 10 most frequent tokens in namespace 'word':
	Token: <START-WORD>		Frequency: 117680
	Token: <END-WORD>		Frequency: 117680
	Token: ,		Frequency: 40278
	Token: the		Frequency: 38906
	Token: .		Frequency: 23536
	Token: and		Frequency: 16661
	Token: in		Frequency: 16416
	Token: of		Frequency: 14534
	Token: a		Frequency: 13375
	Token: to		Frequency: 11905

Top 10 longest tokens in namespace 'word':
	Token: confrontational		length: 15	Frequency: 148
	Token: accomplishments		length: 15	Frequency: 65
	Token: accomplishment		length: 14	Frequency: 65
	Token: collaboration		length: 13	Frequency: 136
	Token: collaborative		length: 13	Frequency: 118
	Token: championships		length: 13	Frequency: 106
	Token: philosophical		length: 13	Frequency: 84
	Token: contributions		length: 13	Frequency: 59
	Token: circumstances		length: 13	Frequency: 46
	Token: photographing		length: 13	Frequency: 45

Top 10 shortest tokens in namespace 'word':
	Token: w		length: 1	Frequency

In [586]:
vocab.get_vocab_size('token_node_label')

1496

In [587]:
vocab.get_vocab_size('word')

1941

In [588]:
vocab.get_vocab_size('pos')

59

In [589]:
vocab.get_vocab_size('label')

2

In [590]:
EMBEDDING_DIM = 100
HIDDEN_DIM = 50

### Test model

In [591]:
from mrp_library.models.generalizer import ActionGeneralizer
from allennlp.nn import InitializerApplicator, RegularizerApplicator, util
from allennlp.nn.activations import Activation
from allennlp.common.params import Params
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.modules.seq2vec_encoders.pytorch_seq2vec_wrapper import PytorchSeq2VecWrapper

In [592]:
field_types = ['word', 'pos', 'resolved', 'token_node_label', 'token_node_prev_action']
field_type2embedder = {}
field_type2seq2vec_encoder = {}
field_type2seq2seq_encoder = {}

for field_type in field_types:
    embedding = Embedding(num_embeddings=vocab.get_vocab_size(field_type),
                            embedding_dim=EMBEDDING_DIM)
    embedder = BasicTextFieldEmbedder({field_type: embedding})
    field_type2embedder[field_type] = embedder
    
    field_type2seq2vec_encoder[field_type] = PytorchSeq2VecWrapper(
        torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
    field_type2seq2seq_encoder[field_type] = PytorchSeq2SeqWrapper(
        torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

In [593]:
# word_embedding = Embedding(num_embeddings=vocab.get_vocab_size('word'),
#                             embedding_dim=EMBEDDING_DIM)
# pos_embedding = Embedding(num_embeddings=vocab.get_vocab_size('pos'),
#                             embedding_dim=EMBEDDING_DIM)

# word_embedder = BasicTextFieldEmbedder({
#     "word": word_embedding,
#     "pos": pos_embedding,
# })
# parse_label = {
#     'word': torch.LongTensor(
#         [
#             [ 1,  0,  3,  7,  2,  9,  4],
#             [ 0,  0,  5,  0,  0,  0,  4]
#         ]
#     ),
#     'pos': torch.LongTensor(
#         [
#             [ 1,  0,  3,  7,  2,  9,  4],
#             [ 0,  0,  5,  0,  0,  0,  4]
#         ]
#     )
# }

In [594]:
# embedded_parse_label = word_embedder(parse_label)

In [595]:
# embedded_parse_label.shape

In [596]:
classifier_params = Params({
  "input_dim": HIDDEN_DIM * 3,
  "num_layers": 2,
  "hidden_dims": [50, 3],
  "activations": ["sigmoid", "linear"],
  "dropout": [0.0, 0.0]
})

In [597]:
classifier_feedforward = FeedForward.from_params(classifier_params)

INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.modules.feedforward.FeedForward'> from params {'input_dim': 150, 'num_layers': 2, 'hidden_dims': [50, 3], 'activations': ['sigmoid', 'linear'], 'dropout': [0.0, 0.0]} and extras set()
INFO     [allennlp.common.params:252] input_dim = 150
INFO     [allennlp.common.params:252] num_layers = 2
INFO     [allennlp.common.params:252] hidden_dims = [50, 3]
INFO     [allennlp.common.params:252] hidden_dims = [50, 3]
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params ['sigmoid', 'linear'] and extras set()
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params sigmoid and extras set()
INFO     [allennlp.common.params:252] type = sigmoid
INFO

In [598]:
field_type = 'word'

In [599]:
parse_label = {
    field_type: torch.LongTensor(
        [
            [ 1,  0,  3,  7,  2,  9,  4],
            [ 0,  0,  5,  0,  0,  0,  4]
        ]
    )
}
embedded_parse_label = field_type2embedder[field_type](parse_label)

In [600]:
feature_mask = util.get_text_field_mask(parse_label)

In [601]:
seq2vec_encoder = field_type2seq2vec_encoder[field_type]

In [602]:
encoded_feature = seq2vec_encoder(embedded_parse_label, feature_mask)

In [603]:
encoded_features = [encoded_feature] * 3

In [604]:
torch.cat(encoded_features, dim=-1).shape

torch.Size([2, 150])

In [605]:
logits = classifier_feedforward(torch.cat(encoded_features, dim=-1))

In [606]:
logits.shape

torch.Size([2, 3])

In [607]:
label = torch.tensor([1, 0])

In [608]:
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits, label)

In [619]:
from mrp_library.models.generalizer import ActionGeneralizer

In [620]:
model = ActionGeneralizer(
    vocab=vocab,
    field_type2embedder=field_type2embedder,
    field_type2seq2vec_encoder=field_type2seq2vec_encoder,
    field_type2seq2seq_encoder=field_type2seq2seq_encoder,
    classifier_feedforward=classifier_feedforward
)

iterator = BucketIterator(batch_size=20, sorting_keys=[("token_node_resolveds", "num_tokens")])
iterator.index_with(vocab)

optimizer = optim.SGD(model.parameters(), lr=0.1)
cuda_device = -1

INFO     [allennlp.nn.initializers:293] Initializing parameters
INFO     [allennlp.nn.initializers:309] Done initializing parameters; the following parameters are using their default initialization from their code
INFO     [allennlp.nn.initializers:314]    classifier_feedforward._linear_layers.0.bias
INFO     [allennlp.nn.initializers:314]    classifier_feedforward._linear_layers.0.weight
INFO     [allennlp.nn.initializers:314]    classifier_feedforward._linear_layers.1.bias
INFO     [allennlp.nn.initializers:314]    classifier_feedforward._linear_layers.1.weight
INFO     [allennlp.nn.initializers:314]    field_type2embedder.pos.token_embedder_pos.weight
INFO     [allennlp.nn.initializers:314]    field_type2embedder.resolved.token_embedder_resolved.weight
INFO     [allennlp.nn.initializers:314]    field_type2embedder.token_node_label.token_embedder_token_node_label.weight
INFO     [allennlp.nn.initializers:314]    field_type2embedder.token_node_prev_action.token_embedder_token_node_pre

In [621]:
# list(model.named_parameters())

In [622]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_dataset,
    validation_dataset=test_dataset,
#     train_dataset=train_dataset,
#     validation_dataset=train_dataset,
    patience=10,
    num_epochs=20,
    cuda_device=cuda_device
)

In [639]:
logits = torch.tensor([[-2.2126,  2.6022, -1.1655],
        [ 4.7340, -1.9992, -3.4521],
        [-1.9665,  2.4100, -1.2047],
        [-2.1353,  2.4847, -1.1260],
        [ 4.7492, -2.0234, -3.4460],
        [-1.4369,  1.9822, -1.2885],
        [ 1.0337,  0.3599, -2.0420],
        [ 5.0974, -2.3380, -3.4647],
        [ 5.4187, -2.4720, -3.6469],
        [-3.4045,  2.4903,  0.0773],
        [ 0.6384,  0.6764, -1.9942],
        [-2.2904,  2.7170, -1.2016],
        [ 4.6333, -1.9474, -3.4113],
        [-2.0811,  2.5367, -1.2174],
        [ 5.1840, -2.7536, -3.1499],
        [ 4.7421, -2.0138, -3.4485],
        [ 5.1121, -2.2290, -3.5999],
        [ 0.1843,  0.8324, -1.6990],
        [ 5.3854, -2.4593, -3.6309],
        [ 0.0324,  0.6715, -1.4173]])

In [640]:
values, indices = logits.max(1)

In [641]:
indices.eq_(torch.tensor(2))

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [627]:
trainer.train()

INFO     [allennlp.training.trainer:465] Beginning training.
INFO     [allennlp.training.trainer:281] Epoch 0/19
INFO     [allennlp.training.trainer:283] Peak CPU memory usage MB: 5972.256
INFO     [allennlp.training.trainer:287] GPU 0 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 1 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 2 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 3 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 4 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 5 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 6 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 7 memory usage MB: 10
INFO     [allennlp.training.trainer:311] Training


  0%|          | 0/280 [00:00<?, ?it/s][A[ADEBUG    [allennlp.data.iterators.data_iterator:151] Batch padding lengths: {'parse_node_labels': {'word_length': 62, 'num_tokens': 62}, 'parse_node_lemmas': {'word_length': 62, 

DEBUG    [mrp_library.models.generalizer:98] ('embedded_fields', torch.Size([20, 20, 100]))
DEBUG    [mrp_library.models.generalizer:105] ('field_mask', torch.Size([20, 20]))
DEBUG    [mrp_library.models.generalizer:105] ('field_mask', torch.Size([20, 20]))
DEBUG    [mrp_library.models.generalizer:105] ('field_mask', torch.Size([20, 5]))
DEBUG    [mrp_library.models.generalizer:113] ('output_dict', {'logits': tensor([[ 5.5208, -2.7518, -3.4647],
        [-2.1689,  2.5462, -1.1532],
        [ 4.9926, -2.3420, -3.3720],
        [ 2.2831, -0.5686, -2.3395],
        [ 1.2447, -0.1198, -1.7743],
        [-0.8317,  1.4840, -1.3811],
        [ 5.0900, -2.4140, -3.3918],
        [ 5.2996, -2.5520, -3.4524],
        [ 5.1521, -2.4579, -3.4054],
        [ 5.5662, -2.8184, -3.4478],
        [ 4.9181, -2.2623, -3.3764],
        [ 5.2031, -2.5019, -3.4101],
        [ 0.9622, -0.0400, -1.5898],
        [ 0.4878,  0.5942, -1.7593],
        [ 1.2686, -0.0555, -1.8637],
        [ 5.3026, -2.5578, -3.44

DEBUG    [mrp_library.models.generalizer:79] ('token_states', [[0, True, 'L', [[0, False, 'At', []]]], [1, True, 'F', [[1, False, 'the', []]]], [22, True, 'P', [[21, True, 'A', [[2, True, 'A', [[2, False, 'Country', []]]], [3, True, 'A', [[3, False, 'Disc', []]]]]], [5, True, 'C', [[5, False, 'Convention', []]]]]], [6, True, 'R', [[6, False, 'in', []]]], [7, False, 'early', []]])
DEBUG    [mrp_library.models.generalizer:80] ('token_node_resolveds', {'resolved': tensor([[2, 2, 2, 2, 3],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 3],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 3],
        [2, 2, 2, 2, 3],
        [2, 2, 2, 2, 3],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 3],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 3],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 3],
        [2, 2, 2, 2, 2],
        [2, 2, 2, 2, 3]])})
DEBUG    [mrp_library.models.ge

DEBUG    [mrp_library.models.generalizer:82] ('token_node_prev_actions', {'token_node_prev_action': tensor([[2, 3, 2, 2, 2],
        [2, 3, 2, 3, 2],
        [3, 2, 3, 2, 2],
        [2, 3, 2, 3, 2],
        [2, 3, 2, 3, 2],
        [2, 3, 2, 3, 2],
        [3, 2, 3, 2, 3],
        [2, 3, 2, 3, 2],
        [3, 2, 3, 2, 2],
        [2, 3, 2, 3, 2],
        [3, 2, 3, 2, 2],
        [2, 3, 2, 3, 2],
        [3, 2, 3, 2, 3],
        [3, 2, 3, 2, 2],
        [3, 2, 3, 2, 2],
        [2, 3, 2, 3, 2],
        [2, 3, 2, 3, 2],
        [2, 3, 2, 3, 2],
        [2, 3, 2, 3, 2],
        [2, 2, 3, 2, 3]])})
DEBUG    [mrp_library.models.generalizer:98] ('embedded_fields', torch.Size([20, 8, 100]))
DEBUG    [mrp_library.models.generalizer:105] ('field_mask', torch.Size([20, 8]))
DEBUG    [mrp_library.models.generalizer:105] ('field_mask', torch.Size([20, 8]))
DEBUG    [mrp_library.models.generalizer:105] ('field_mask', torch.Size([20, 5]))
DEBUG    [mrp_library.models.generalizer:113] ('output_dict'

DEBUG    [mrp_library.models.generalizer:117] ('label', tensor([1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0]))


accuracy: 0.8844, loss: 0.2872 ||:   2%|▏         | 5/280 [00:00<01:46,  2.57it/s][A[ADEBUG    [allennlp.data.iterators.data_iterator:151] Batch padding lengths: {'parse_node_labels': {'word_length': 62, 'num_tokens': 62}, 'parse_node_lemmas': {'word_length': 62, 'num_tokens': 62}, 'parse_node_uposs': {'pos_length': 62, 'num_tokens': 62}, 'parse_node_xposs': {'pos_length': 62, 'num_tokens': 62}, 'token_node_resolveds': {'resolved_length': 12, 'num_tokens': 12}, 'token_node_labels': {'token_node_label_length': 12, 'num_tokens': 12}, 'token_node_prev_actions': {'token_node_prev_action_length': 5, 'num_tokens': 5}}
DEBUG    [allennlp.data.iterators.data_iterator:152] Batch size: 20
DEBUG    [mrp_library.models.generalizer:76] ('curr_node_id', 36)
DEBUG    [mrp_library.models.generalizer:77] ('action_types', tensor(0))
DEBUG    [mrp_library.models.generalizer:78

KeyboardInterrupt: 

In [None]:
token_state = [[26, True, 'H', [[23, True, 'A', [[0, True, 'E', [[0, False, 'The', []]]], [1, True, 'C', [[1, False, 'Lakers', []]]]]], [2, True, 'P', [[2, False, 'advanced', []]]], [25, True, 'A', [[3, True, 'R', [[3, False, 'through', []]]], [4, True, 'E', [[4, False, 'the', []]]], [24, True, 'P', [[5, True, 'T', [[5, False, '1982', []]]], [6, True, 'C', [[6, False, 'playoffs', []]]]]]]]]], [7, True, 'L', [[7, False, 'and', []]]], [8, True, 'P', [[8, False, 'faced', []]]], [9, True, 'A', [[9, False, 'Philadelphia', []]]], [27, True, 'D', [[10, True, 'R', [[10, False, 'for', []]]], [11, True, 'E', [[11, False, 'the', []]]], [12, True, 'Q', [[12, False, 'second', []]]], [13, True, 'C', [[13, False, 'time', []]]]]], [14, True, 'R', [[14, False, 'in', []]]], [15, True, 'Q', [[15, False, 'three', []]]], [16, False, 'years', []]]

In [None]:
pprint.pprint(token_state)

In [154]:
vocab.get_token_from_index(0, namespace='labels')

'RESOLVE'

In [155]:
vocab.get_token_from_index(3, namespace='resolved')

'UNRESOLVED'