In [2]:
try:
    __IPYTHON__
    USING_IPYTHON = True
    %load_ext autoreload
    %autoreload 2
except NameError:
    USING_IPYTHON = False

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Argparse

In [333]:
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('project_root', help='')
ap.add_argument('--mrp-data-dir', default='data', help='')
ap.add_argument('--mrp-test-dir', default='src/tests', help='')
ap.add_argument('--tests-fixtures-file-template', default='fixtures/{}-test.jsonl', help='')

ap.add_argument('--graphviz-sub-dir', default='visualization/graphviz', help='')
ap.add_argument('--train-sub-dir', default='training', help='')
ap.add_argument('--companion-sub-dir', default='companion')
ap.add_argument('--jamr-alignment-file', default='jamr.mrp')

ap.add_argument('--evaluation-allennlp-mrp-json-file', default='evaluation/allennlp-evaluation.mrp')
ap.add_argument('--test-input-file', default='evaluation/input.mrp', help='')
ap.add_argument('--test-companion-file', default='evaluation/udpipe.mrp', help='')
ap.add_argument('--allennlp-mrp-json-file-template', default='allennlp-mrp-json-big-{}-{}.jsonl', help='')
ap.add_argument('--data-size-limit', type=int, default=10000, help='')

ap.add_argument('--mrp-file-extension', default='.mrp')
ap.add_argument('--companion-file-extension', default='.conllu')
ap.add_argument('--graphviz-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png')
ap.add_argument('--parse-plot-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.png')

ap.add_argument('--cuda-device', type=int, default=0)

arg_string = """
    /data/proj29_ds1/home/slai/mrp2019
"""
arguments = [arg for arg_line in arg_string.split(r'\\n') for arg in arg_line.split()]

In [334]:
if USING_IPYTHON:
    args = ap.parse_args(arguments)
else:
    args = ap.parse_args()

In [335]:
args

Namespace(allennlp_mrp_json_file_template='allennlp-mrp-json-big-{}-{}.jsonl', companion_file_extension='.conllu', companion_sub_dir='companion', cuda_device=0, data_size_limit=10000, evaluation_allennlp_mrp_json_file='evaluation/allennlp-evaluation.mrp', graphviz_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png', graphviz_sub_dir='visualization/graphviz', jamr_alignment_file='jamr.mrp', mrp_data_dir='data', mrp_file_extension='.mrp', mrp_test_dir='src/tests', parse_plot_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.png', project_root='/data/proj29_ds1/home/slai/mrp2019', test_companion_file='evaluation/udpipe.mrp', test_input_file='evaluation/input.mrp', tests_fixtures_file_template='fixtures/{}-test.jsonl', train_sub_dir='training')

#### Library imports

In [6]:
import json
import logging
import os
import pprint
import re
import string
from collections import Counter, defaultdict, deque

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plot_util
import torch
from action_state import mrp_json2parser_states, _generate_parser_action_states
from action_state import ERROR, APPEND, RESOLVE, IGNORE
from preprocessing import (CompanionParseDataset, MrpDataset, JamrAlignmentDataset,
                           read_companion_parse_json_file, read_mrp_json_file, parse2parse_json)            
from torch import nn
from tqdm import tqdm

#### ipython notebook specific imports

In [7]:
if USING_IPYTHON:
    # matplotlib config
    %matplotlib inline

In [8]:
sh = logging.StreamHandler()
formatter = logging.Formatter('%(levelname)-8s [%(name)s:%(lineno)d] %(message)s')
sh.setFormatter(formatter)
logging.basicConfig(
    level=logging.DEBUG, 
    handlers=[sh]
)
mute_logger_names = ['allennlp.data.iterators.data_iterator']
for logger_name in mute_logger_names:
    logger = logging.getLogger(logger_name)  # pylint: disable=invalid-name
    logger.setLevel(logging.INFO)

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.setLevel(logging.INFO)

### Constants

In [9]:
UNKWOWN = 'UNKWOWN'

### Load data

In [10]:
train_dir = os.path.join(args.project_root, args.mrp_data_dir, args.train_sub_dir)

In [11]:
mrp_dataset = MrpDataset()

In [12]:
frameworks, framework2dataset2mrp_jsons = mrp_dataset.load_mrp_json_dir(
    train_dir, args.mrp_file_extension)

frameworks:   0%|          | 0/5 [00:00<?, ?it/s]
dataset_name:   0%|          | 0/2 [00:00<?, ?it/s][A
dataset_name:  50%|█████     | 1/2 [00:00<00:00,  2.76it/s][A
frameworks:  20%|██        | 1/5 [00:00<00:02,  1.44it/s]s][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  40%|████      | 2/5 [00:05<00:05,  1.80s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  60%|██████    | 3/5 [00:10<00:05,  2.74s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  80%|████████  | 4/5 [00:15<00:03,  3.45s/it]t][A
dataset_name:   0%|          | 0/14 [00:00<?, ?it/s][A
dataset_name:  43%|████▎     | 6/14 [00:00<00:00, 21.08it/s][A
dataset_name:  57%|█████▋    | 8/14 [00:00<00:00, 17.83it/s][A
dataset_name:  71%|███████▏  | 10/14 [00:01<00:00,  6.22it/s][A
dataset_name:  79%|███████▊  | 11/14 [00:01<00:00,  5.57it/s][A
frameworks: 100%|██████████| 5/5 [00:16<00:00,  2.88s/it]t/s][A


In [13]:
framework2dataset2mrp_jsons.keys()

dict_keys(['ucca', 'psd', 'eds', 'dm', 'amr'])

### Data Preprocessing companion

In [14]:
companion_dir = os.path.join(args.project_root, args.mrp_data_dir, args.companion_sub_dir)

In [15]:
cparse_dataset = CompanionParseDataset()

In [16]:
dataset2cid2parse = cparse_dataset.load_companion_parse_dir(companion_dir, args.companion_file_extension)

INFO     [preprocessing:179] framework amr found
dataset: 100%|██████████| 13/13 [00:01<00:00,  9.64it/s]
INFO     [preprocessing:179] framework dm found
dataset: 100%|██████████| 5/5 [00:03<00:00,  1.30it/s]
INFO     [preprocessing:179] framework ucca found
dataset: 100%|██████████| 6/6 [00:00<00:00, 28.48it/s]


In [17]:
dataset2cid2parse_json = cparse_dataset.convert_parse2parse_json()

In [18]:
dataset2cid2parse.keys()

dict_keys(['amr-guidelines', 'bolt', 'cctv', 'dfa', 'dfb', 'fables', 'lorelei', 'mt09sdl', 'proxy', 'rte', 'wb', 'wiki', 'xinhua', 'wsj', 'ewt'])

In [19]:
# Some data is missing
'20003001' in dataset2cid2parse['wsj']

False

### Load JAMR alignment data

In [20]:
jalignment_dataset = JamrAlignmentDataset()

In [21]:
cid2alignment = jalignment_dataset.load_jamr_alignment_file(os.path.join(
    args.project_root,
    args.mrp_data_dir,
    args.companion_sub_dir,
    args.jamr_alignment_file
))

### Load testing data

In [22]:
test_input_filename = os.path.join(args.project_root, args.mrp_data_dir, args.test_input_file)
test_companion_filename = os.path.join(args.project_root, args.mrp_data_dir, args.test_companion_file)

In [23]:
test_mrp_jsons = read_mrp_json_file(test_input_filename)
test_parse_jsons = read_companion_parse_json_file(test_companion_filename)

In [24]:
parse_json = test_parse_jsons['102990']

In [25]:
mrp_json = framework2dataset2mrp_jsons['psd']['wsj'][1]

In [26]:
test_configs = [
    ('ucca', 'wiki', 70),
]
framework, dataset, idx = test_configs[0]

In [27]:
mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
cid = mrp_json.get('id')

In [28]:
parse_json = dataset2cid2parse_json[dataset][cid]

In [29]:
doc = mrp_json['input']

In [30]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [31]:
token_pos = 0
anchors = []
char_pos2tokenized_parse_node_id = []

for node_id, node in enumerate(parse_json.get('nodes')):
    label = node.get('label')
    label_size = len(label)
    while doc[token_pos] == ' ':
        token_pos += 1
        char_pos2tokenized_parse_node_id.append(node_id)
    anchors.append((token_pos, token_pos + label_size))
    char_pos2tokenized_parse_node_id.extend([node_id] * (label_size))
    print(node_id, doc[token_pos: token_pos + label_size], anchors[-1], len(char_pos2tokenized_parse_node_id))
    token_pos += label_size

0 In (0, 2) 2
1 the (3, 6) 6
2 final (7, 12) 12
3 minute (13, 19) 19
4 of (20, 22) 22
5 the (23, 26) 26
6 game (27, 31) 31
7 , (31, 32) 32
8 Johnson (33, 40) 40
9 had (41, 44) 44
10 the (45, 48) 48
11 ball (49, 53) 53
12 stolen (54, 60) 60
13 by (61, 63) 63
14 Celtics (64, 71) 71
15 center (72, 78) 78
16 Robert (79, 85) 85
17 Parish (86, 92) 92
18 , (92, 93) 93
19 and (94, 97) 97
20 then (98, 102) 102
21 missed (103, 109) 109
22 two (110, 113) 113
23 free (114, 118) 118
24 throws (119, 125) 125
25 that (126, 130) 130
26 could (131, 136) 136
27 have (137, 141) 141
28 won (142, 145) 145
29 the (146, 149) 149
30 game (150, 154) 154
31 . (154, 155) 155


In [32]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [33]:
len(char_pos2tokenized_parse_node_id)

155

In [34]:
doc = mrp_json['input']

In [35]:
mrp_json['tops']

[34]

In [36]:
mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
    mrp_json, 
    tokenized_parse_nodes=parse_json['nodes'],
)

In [37]:
(
    doc,
    nodes,
    node_id2node,
    edge_id2edge,
    top_oriented_edges,
    token_nodes,
    # abstract_node_id_set,
    parent_id2indegree,
    # parent_id2child_id_set,
    # child_id2parent_id_set,
    # child_id2edge_id_set,
    # parent_id2edge_id_set,
    # parent_child_id2edge_id_set,
    parse_nodes_anchors,
    char_pos2tokenized_node_id,
    curr_node_ids,
    token_states,
    actions,
) = mrp_meta_data

In [38]:
curr_node_ids = mrp_meta_data[-3]
token_states = mrp_meta_data[-2]
actions = mrp_meta_data[-1]

In [39]:
*_, curr_node_ids, token_states, actions = mrp_meta_data

In [40]:
actions[1]

(1,
 (1,
  0,
  {'id': 0,
   'anchors': [{'from': 0, 'to': 2}],
   'label': 'In',
   'propagate_label': 'R'},
  [[]]))

In [41]:
actions[-4:]

[(1,
  (3,
   2,
   {'id': 41, 'propagate_label': 'A'},
   [[{'source': 41,
      'target': 28,
      'label': 'E',
      'id': 15,
      'parent': 41,
      'child': 28}],
    [{'source': 41,
      'target': 29,
      'label': 'C',
      'id': 24,
      'parent': 41,
      'child': 29}],
    [{'source': 41,
      'target': 30,
      'label': 'U',
      'id': 17,
      'parent': 41,
      'child': 30}]])),
 (1,
  (4,
   3,
   {'id': 40, 'propagate_label': 'H'},
   [[{'source': 40,
      'target': 25,
      'label': 'D',
      'id': 40,
      'parent': 40,
      'child': 25}],
    [{'source': 40,
      'target': 26,
      'label': 'F',
      'id': 32,
      'parent': 40,
      'child': 26}],
    [{'source': 40,
      'target': 27,
      'label': 'P',
      'id': 2,
      'parent': 40,
      'child': 27}],
    [{'source': 40,
      'target': 41,
      'label': 'A',
      'id': 16,
      'parent': 40,
      'child': 41}]])),
 (1,
  (3,
   2,
   {'id': 39, 'propagate_label': 'H'},
   [[{'s

In [42]:
for curr_node_id, action, token_state in zip(curr_node_ids, actions, token_states):
    action_type, params = action
#     pprint.pprint((curr_node_id, action[0]))
#     pprint.pprint(([token_group[:4] for token_group in token_state]))
    pprint.pprint((curr_node_id, action[0], [token_group[:5] for token_group in token_state]))

(0, 0, [(0, False, 'In', [])])
(1, 1, [(0, True, 'R', [(0, False, 'In', [])])])
(1, 0, [(0, True, 'R', [(0, False, 'In', [])]), (1, False, 'the', [])])
(2,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])])])
(2,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, False, 'final', [])])
(3,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])])])
(3,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, False, 'minute', [])])
(4,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, True, 'C', [(3, False, 'minute', [])])])
(4,
 1,
 [(32,
   True,
   'E',
   [(0, True, 'R', [(0, False, 'In', [])]),
    (1, True, 'E', [(1, False, 'the', [])]),
    (2, True, 'E', [(

     True,
     'A',
     [(13, True, 'R', [(13, False, 'by', [])]),
      (35,
       True,
       'E',
       [(14, True, 'A', [(14, False, 'Celtics', [])]),
        (15, True, 'S', [(15, False, 'center', [])])]),
      (16, True, 'C', [(16, False, 'Robert', [])])])]),
  (18, True, 'U', [(18, False, ',', [])]),
  (19, True, 'L', [(19, False, 'and', [])]),
  (20, True, 'L', [(20, False, 'then', [])]),
  (21, True, 'D', [(21, False, 'missed', [])]),
  (22, True, 'D', [(22, False, 'two', [])]),
  (23, True, 'D', [(23, False, 'free', [])]),
  (24, False, 'throws', [])])
(25,
 1,
 [(37,
   True,
   'H',
   [(33,
     True,
     'T',
     [(32,
       True,
       'E',
       [(0, True, 'R', [(0, False, 'In', [])]),
        (1, True, 'E', [(1, False, 'the', [])]),
        (2, True, 'E', [(2, False, 'final', [])]),
        (3, True, 'C', [(3, False, 'minute', [])])]),
      (4, True, 'R', [(4, False, 'of', [])]),
      (5, True, 'E', [(5, False, 'the', [])]),
      (6, True, 'C', [(6, False

In [43]:
for curr_node_id, action, token_state in zip(curr_node_ids, actions, [[]] + token_states):
    action_type, params = action
    pprint.pprint((curr_node_id, action[0], (action[1] or [None])[:2], [token_group[:4] for token_group in token_state]))
    

(0, 0, [None], [])
(1, 1, (1, 0), [(0, False, 'In', [])])
(1, 0, [None], [(0, True, 'R', [(0, False, 'In', [])])])
(2, 1, (1, 0), [(0, True, 'R', [(0, False, 'In', [])]), (1, False, 'the', [])])
(2,
 0,
 [None],
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])])])
(3,
 1,
 (1, 0),
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, False, 'final', [])])
(3,
 0,
 [None],
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])])])
(4,
 1,
 (1, 0),
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, False, 'minute', [])])
(4,
 1,
 (4, 3),
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, True, 'C', [(3, False, 'minute', [])])])
(4,
 0,
 [None],
 [(32,
   True,
   'E',
   [(0, True,

     [(0, True, 'R', [(0, False, 'In', [])]),
      (1, True, 'E', [(1, False, 'the', [])]),
      (2, True, 'E', [(2, False, 'final', [])]),
      (3, True, 'C', [(3, False, 'minute', [])])]),
    (4, True, 'R', [(4, False, 'of', [])]),
    (5, True, 'E', [(5, False, 'the', [])]),
    (6, True, 'C', [(6, False, 'game', [])])]),
  (7, True, 'U', [(7, False, ',', [])]),
  (8, True, 'A', [(8, False, 'Johnson', [])]),
  (9, True, 'F', [(9, False, 'had', [])]),
  (34,
   True,
   'A',
   [(10, True, 'E', [(10, False, 'the', [])]),
    (11, True, 'C', [(11, False, 'ball', [])])]),
  (12, True, 'P', [(12, False, 'stolen', [])]),
  (13, False, 'by', [])])
(14,
 0,
 [None],
 [(33,
   True,
   'T',
   [(32,
     True,
     'E',
     [(0, True, 'R', [(0, False, 'In', [])]),
      (1, True, 'E', [(1, False, 'the', [])]),
      (2, True, 'E', [(2, False, 'final', [])]),
      (3, True, 'C', [(3, False, 'minute', [])])]),
    (4, True, 'R', [(4, False, 'of', [])]),
    (5, True, 'E', [(5, False, 't

 [(37,
   True,
   'H',
   [(33,
     True,
     'T',
     [(32,
       True,
       'E',
       [(0, True, 'R', [(0, False, 'In', [])]),
        (1, True, 'E', [(1, False, 'the', [])]),
        (2, True, 'E', [(2, False, 'final', [])]),
        (3, True, 'C', [(3, False, 'minute', [])])]),
      (4, True, 'R', [(4, False, 'of', [])]),
      (5, True, 'E', [(5, False, 'the', [])]),
      (6, True, 'C', [(6, False, 'game', [])])]),
    (7, True, 'U', [(7, False, ',', [])]),
    (8, True, 'A', [(8, False, 'Johnson', [])]),
    (9, True, 'F', [(9, False, 'had', [])]),
    (34,
     True,
     'A',
     [(10, True, 'E', [(10, False, 'the', [])]),
      (11, True, 'C', [(11, False, 'ball', [])])]),
    (12, True, 'P', [(12, False, 'stolen', [])]),
    (36,
     True,
     'A',
     [(13, True, 'R', [(13, False, 'by', [])]),
      (35,
       True,
       'E',
       [(14, True, 'A', [(14, False, 'Celtics', [])]),
        (15, True, 'S', [(15, False, 'center', [])])]),
      (16, True, 'C', 

    (22, True, 'D', [(22, False, 'two', [])]),
    (23, True, 'D', [(23, False, 'free', [])]),
    (24, True, 'P', [(24, False, 'throws', [])])]),
  (25, True, 'L', [(25, False, 'that', [])]),
  (26, True, 'D', [(26, False, 'could', [])]),
  (27, True, 'F', [(27, False, 'have', [])]),
  (28, True, 'P', [(28, False, 'won', [])]),
  (29, True, 'E', [(29, False, 'the', [])]),
  (30, True, 'C', [(30, False, 'game', [])]),
  (31, False, '.', [])])
(32,
 1,
 (3, 2),
 [(37,
   True,
   'H',
   [(33,
     True,
     'T',
     [(32,
       True,
       'E',
       [(0, True, 'R', [(0, False, 'In', [])]),
        (1, True, 'E', [(1, False, 'the', [])]),
        (2, True, 'E', [(2, False, 'final', [])]),
        (3, True, 'C', [(3, False, 'minute', [])])]),
      (4, True, 'R', [(4, False, 'of', [])]),
      (5, True, 'E', [(5, False, 'the', [])]),
      (6, True, 'C', [(6, False, 'game', [])])]),
    (7, True, 'U', [(7, False, ',', [])]),
    (8, True, 'A', [(8, False, 'Johnson', [])]),
    (9, 

In [44]:
actions

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'anchors': [{'from': 0, 'to': 2}],
    'label': 'In',
    'propagate_label': 'R'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'anchors': [{'from': 3, 'to': 6}],
    'label': 'the',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 2,
    'anchors': [{'from': 7, 'to': 12}],
    'label': 'final',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 3,
    'anchors': [{'from': 13, 'to': 19}],
    'label': 'minute',
    'propagate_label': 'C'},
   [[]])),
 (1,
  (4,
   3,
   {'id': 31, 'propagate_label': 'E'},
   [[{'source': 31,
      'target': 0,
      'label': 'R',
      'id': 31,
      'parent': 31,
      'child': 0}],
    [{'source': 31,
      'target': 1,
      'label': 'E',
      'id': 28,
      'parent': 31,
      'child': 1}],
    [{'source': 31,
      'target': 2,
      'label': 'E',
      'id': 21,
      'parent': 31,
      'child': 2}],
    [{'source': 31,
      't

In [45]:
token_states[1]

[(0, True, 'R', [(0, False, 'In', [])])]

In [46]:
[n['label'] for n in parse_json['nodes']]

['In',
 'the',
 'final',
 'minute',
 'of',
 'the',
 'game',
 ',',
 'Johnson',
 'had',
 'the',
 'ball',
 'stolen',
 'by',
 'Celtics',
 'center',
 'Robert',
 'Parish',
 ',',
 'and',
 'then',
 'missed',
 'two',
 'free',
 'throws',
 'that',
 'could',
 'have',
 'won',
 'the',
 'game',
 '.']

In [47]:
token_states[-1]

[(42,
  True,
  '<UCCA-TOP-NODE>',
  [(37,
    True,
    'H',
    [(33,
      True,
      'T',
      [(32,
        True,
        'E',
        [(0, True, 'R', [(0, False, 'In', [])]),
         (1, True, 'E', [(1, False, 'the', [])]),
         (2, True, 'E', [(2, False, 'final', [])]),
         (3, True, 'C', [(3, False, 'minute', [])])]),
       (4, True, 'R', [(4, False, 'of', [])]),
       (5, True, 'E', [(5, False, 'the', [])]),
       (6, True, 'C', [(6, False, 'game', [])])]),
     (7, True, 'U', [(7, False, ',', [])]),
     (8, True, 'A', [(8, False, 'Johnson', [])]),
     (9, True, 'F', [(9, False, 'had', [])]),
     (34,
      True,
      'A',
      [(10, True, 'E', [(10, False, 'the', [])]),
       (11, True, 'C', [(11, False, 'ball', [])])]),
     (12, True, 'P', [(12, False, 'stolen', [])]),
     (36,
      True,
      'A',
      [(13, True, 'R', [(13, False, 'by', [])]),
       (35,
        True,
        'E',
        [(14, True, 'A', [(14, False, 'Celtics', [])]),
         (

In [48]:
companion_parser_states, companion_meta_data = mrp_json2parser_states(
    parse_json,
    mrp_doc=doc,
    tokenized_parse_nodes=parse_json['nodes'],
)

In [49]:
logger.info(args.graphviz_file_template.format(
    framework, dataset, cid))

INFO     [__main__:2] http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/ucca/wiki.mrp/470004.png


In [50]:
mrp_json['input']

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [51]:
mrp_parser_states

[(0,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 0,
      'anchors': [{'from': 0, 'to': 2}],
      'label': 'In',
      'propagate_label': 'R'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')])]),
 (1,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 1,
      'anchors': [{'from': 3, 'to': 6}],
      'label': 'the',
      'propagate_label': 'E'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)]), (1, 1, [(1, 1, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')]),
   (1, True, 'E', [(1, False, 'the', 'the')])]),
 (2,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 2,
      'anchors': [{'from': 7, 'to': 12}],
      'label': 'final',
      'propagate_label': 'E'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)]), (1, 1, [(1, 1, None)]), (2, 2, [(2, 2, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')]),
   (1, True, 'E', [(1, False, 'the', 'the')]),
   (2, True, 'E', [(2, False, 'final', 'final')])]),
 (3,
  

In [52]:
[(node['id'], node.get('label')) for node in mrp_json['nodes']]

[(0, 'In'),
 (1, 'the'),
 (2, 'final'),
 (3, 'minute'),
 (4, 'of'),
 (5, 'the'),
 (6, 'game'),
 (7, ','),
 (8, 'Johnson'),
 (9, 'had'),
 (10, 'the'),
 (11, 'ball'),
 (12, 'stolen'),
 (13, 'by'),
 (14, 'Celtics'),
 (15, 'center'),
 (16, 'RobertParish'),
 (17, ','),
 (18, 'and'),
 (19, 'then'),
 (20, 'missed'),
 (21, 'two'),
 (22, 'free'),
 (23, 'throws'),
 (24, 'that'),
 (25, 'could'),
 (26, 'have'),
 (27, 'won'),
 (28, 'the'),
 (29, 'game'),
 (30, '.'),
 (31, None),
 (32, None),
 (33, None),
 (34, None),
 (35, None),
 (36, None),
 (37, None),
 (38, None),
 (39, None),
 (40, None),
 (41, None)]

In [53]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [54]:
parse_json['nodes']

[{'id': 0,
  'label': 'In',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['in', 'ADP', 'IN']},
 {'id': 1,
  'label': 'the',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['the', 'DET', 'DT']},
 {'id': 2,
  'label': 'final',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['final', 'ADJ', 'JJ']},
 {'id': 3,
  'label': 'minute',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['minute', 'NOUN', 'NN']},
 {'id': 4,
  'label': 'of',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['of', 'ADP', 'IN']},
 {'id': 5,
  'label': 'the',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['the', 'DET', 'DT']},
 {'id': 6,
  'label': 'game',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['game', 'NOUN', 'NN']},
 {'id': 7,
  'label': ',',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': [',', 'PUNCT', ',']},
 {'id': 8,
  'label': 'Johnson',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['Johnson', 'PROPN', 'NNP']},
 {'id': 9,
  'label

In [55]:
[(node['id'], node['label']) for node in parse_json['nodes']]

[(0, 'In'),
 (1, 'the'),
 (2, 'final'),
 (3, 'minute'),
 (4, 'of'),
 (5, 'the'),
 (6, 'game'),
 (7, ','),
 (8, 'Johnson'),
 (9, 'had'),
 (10, 'the'),
 (11, 'ball'),
 (12, 'stolen'),
 (13, 'by'),
 (14, 'Celtics'),
 (15, 'center'),
 (16, 'Robert'),
 (17, 'Parish'),
 (18, ','),
 (19, 'and'),
 (20, 'then'),
 (21, 'missed'),
 (22, 'two'),
 (23, 'free'),
 (24, 'throws'),
 (25, 'that'),
 (26, 'could'),
 (27, 'have'),
 (28, 'won'),
 (29, 'the'),
 (30, 'game'),
 (31, '.')]

In [56]:
anchors

[(0, 2),
 (3, 6),
 (7, 12),
 (13, 19),
 (20, 22),
 (23, 26),
 (27, 31),
 (31, 32),
 (33, 40),
 (41, 44),
 (45, 48),
 (49, 53),
 (54, 60),
 (61, 63),
 (64, 71),
 (72, 78),
 (79, 85),
 (86, 92),
 (92, 93),
 (94, 97),
 (98, 102),
 (103, 109),
 (110, 113),
 (114, 118),
 (119, 125),
 (126, 130),
 (131, 136),
 (137, 141),
 (142, 145),
 (146, 149),
 (150, 154),
 (154, 155)]

### Create training instance

In [73]:
# framework = 'ucca'
# ignore_framework_set = {'amr', 'dm', 'psd', 'eds'}
# dataset = 'wiki'
# ignore_dataset_set = {}

# framework = 'dm'
# ignore_framework_set = {'amr', 'ucca', 'psd', 'eds'}
# dataset = 'wsj'
# ignore_dataset_set = {}

framework = 'ucca'
ignore_framework_set = {'amr', 'psd', 'eds'}
dataset = 'wiki'
ignore_dataset_set = {}

In [74]:
frameworks

['ucca', 'psd', 'eds', 'dm', 'amr']

In [75]:
framework_names = '-'.join([
    framework 
    for framework in frameworks 
    if framework not in ignore_framework_set
])
framework_names

'ucca-dm'

In [76]:
allennlp_tests_fixtures_output_file = os.path.join(
    args.project_root, args.mrp_test_dir, args.tests_fixtures_file_template.format(framework_names))

allennlp_framework_train_output_file = os.path.join(
    args.project_root, args.allennlp_mrp_json_file_template.format(framework_names, 'train'))

allennlp_framework_test_output_file = os.path.join(
    args.project_root, args.allennlp_mrp_json_file_template.format(framework_names, 'test'))

In [77]:
# Create tests fixture jsonl
fixture_combinations = [
#     ('ucca', 'wiki', 70),
    ('dm', 'wsj', 3)
] * 5

with open(allennlp_tests_fixtures_output_file, 'w') as wf:
    for framework, dataset, idx in fixture_combinations:
        mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
        cid = mrp_json.get('id')
        doc = mrp_json.get('input')
        
        alignment = {}
        if framework == 'amr':
            alignment = cid2alignment[cid]  
        parse_json = dataset2cid2parse_json.get(dataset, {}).get(cid, {})

        if parse_json:
            mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
                mrp_json, 
                tokenized_parse_nodes=parse_json['nodes'],
                alignment=alignment,
            )
            companion_parser_states, companion_meta_data = mrp_json2parser_states(
                parse_json, 
                mrp_doc=doc,
                tokenized_parse_nodes=parse_json['nodes'],
            )

            data_instance = {
                'mrp_json': mrp_json,
                'parse_json': parse_json,
                'mrp_parser_states': mrp_parser_states,
                'mrp_meta_data': mrp_meta_data,
                'companion_parser_states': companion_parser_states,
                'companion_meta_data': companion_meta_data,
            }
            json_encoded_instance = json.dumps(data_instance)
            wf.write(json_encoded_instance + '\n')

In [78]:
for idx in range(20):
    mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
    cid = mrp_json.get('id')
    if cid in dataset2cid2parse[dataset]:
        print(idx)

3
8
13
18


In [79]:
[state[-1] for state in mrp_parser_states]

[[(0, True, 'the', [(0, False, 'the', 'The')])],
 [(0, True, 'the', [(0, False, 'the', 'The')]),
  (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')])],
 [(0, True, 'the', [(0, False, 'the', 'The')]),
  (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')]),
  (2, False, 'fiber', 'fiber')],
 [(0, True, 'the', [(0, False, 'the', 'The')]),
  (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')]),
  (2, False, 'fiber', 'fiber'),
  (4, True, 'crocidolite', [(4, False, 'crocidolite', 'crocidolite')])],
 [(2,
   True,
   'fiber',
   [(0, True, 'the', [(0, False, 'the', 'The')]),
    (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')]),
    (2, False, 'fiber', 'fiber'),
    (4, True, 'crocidolite', [(4, False, 'crocidolite', 'crocidolite')])])],
 [(2,
   True,
   'fiber',
   [(0, True, 'the', [(0, False, 'the', 'The')]),
    (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')]),
    (2, False, 'fiber', 'fiber'),
    (4, True, 'crocidolite', [(4, False, 'crocidol

In [80]:
mrp_meta_data[-1]

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'label': 'the',
    'properties': ['pos', 'frame'],
    'values': ['DT', 'q:i-h-h'],
    'anchors': [{'from': 0, 'to': 3}]},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'label': 'asbestos',
    'properties': ['pos', 'frame'],
    'values': ['NN', 'n:x'],
    'anchors': [{'from': 4, 'to': 12}]},
   [[]])),
 (0, None),
 (2, None),
 (0, None),
 (1,
  (1,
   0,
   {'id': 4,
    'label': 'crocidolite',
    'properties': ['pos', 'frame'],
    'values': ['NN', 'n:x'],
    'anchors': [{'from': 20, 'to': 31}]},
   [[]])),
 (1,
  (4,
   2,
   {'id': 2,
    'label': 'fiber',
    'properties': ['pos', 'frame'],
    'values': ['NN', 'n:x'],
    'anchors': [{'from': 13, 'to': 18}]},
   [[{'source': 0,
      'target': 2,
      'label': 'BV',
      'id': 2,
      'parent': 2,
      'child': 0}],
    [{'source': 1,
      'target': 2,
      'label': 'compound',
      'id': 1,
      'parent': 2,
      'child': 1}],
    [],
    [{'source': 4,
    

In [81]:
doc

'The asbestos fiber, crocidolite, is unusually resilient once it enters the lungs, with even brief exposures to it causing symptoms that show up decades later, researchers said.'

In [82]:
parse_json

{'id': '20003002',
 'tops': [30],
 'nodes': [{'id': 0,
   'label': 'The',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['the', 'DET', 'DT']},
  {'id': 1,
   'label': 'asbestos',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['asbestos', 'NOUN', 'NN']},
  {'id': 2,
   'label': 'fiber',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['fiber', 'NOUN', 'NN']},
  {'id': 3,
   'label': ',',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': [',', 'PUNCT', ',']},
  {'id': 4,
   'label': 'crocidolite',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['crocidolite', 'NOUN', 'NN']},
  {'id': 5,
   'label': ',',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': [',', 'PUNCT', ',']},
  {'id': 6,
   'label': 'is',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['be', 'VERB', 'VBZ']},
  {'id': 7,
   'label': 'unusually',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['unusually', 'ADV', 'RB']},
  {'id': 8,
   'label': 'resil

In [83]:
[n['values'][2] for n in parse_json['nodes']]

['DT',
 'NN',
 'NN',
 ',',
 'NN',
 ',',
 'VBZ',
 'RB',
 'JJ',
 'IN',
 'PRP',
 'VBZ',
 'DT',
 'NNS',
 ',',
 'IN',
 'RB',
 'JJ',
 'NNS',
 'TO',
 'PRP',
 'VBG',
 'NNS',
 'WDT',
 'VBP',
 'RP',
 'NNS',
 'RB',
 ',',
 'NNS',
 'VBD',
 '.']

In [84]:
# Create train jsonl
if os.path.isfile(allennlp_framework_train_output_file) and os.path.isfile(
    allennlp_framework_train_output_file):
    logger.info('allennlp_train_output_file found, stop generation')
else:
#     pass
# if 1==1:
    data_size = 0
    with open(allennlp_framework_train_output_file, 'w') as train_wf:
        with open(allennlp_framework_test_output_file, 'w') as test_wf:
            for _, dataset, idx, mrp_json in tqdm(mrp_dataset.mrp_json_generator(
                ignore_framework_set=ignore_framework_set,
                ignore_dataset_set=ignore_dataset_set,
                data_size_limit=args.data_size_limit * 2
            )):
                cid = mrp_json.get('id')
                doc = mrp_json.get('input')

                framework = mrp_json.get('framework')
                alignment = {}
                if framework == 'amr':
                    alignment = cid2alignment[cid]  
                parse_json = dataset2cid2parse_json.get(dataset, {}).get(cid, {})

                if parse_json:
                    mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
                        mrp_json, 
                        tokenized_parse_nodes=parse_json['nodes'],
                        alignment=alignment,
                    )
                    companion_parser_states, companion_meta_data = mrp_json2parser_states(
                        parse_json, 
                        mrp_doc=doc,
                        tokenized_parse_nodes=parse_json['nodes'],
                    )

                    # Continue if error
                    if not mrp_parser_states:
                        continue

                    data_instance = {
                        'mrp_json': mrp_json,
                        'parse_json': parse_json,
                        'mrp_parser_states': mrp_parser_states,
                        'mrp_meta_data': mrp_meta_data,
                        'companion_parser_states': companion_parser_states,
                        'companion_meta_data': companion_meta_data,
                    }
                    json_encoded_instance = json.dumps(data_instance)
                    if idx <= args.data_size_limit:
                        train_wf.write(json_encoded_instance + '\n')
                    else:
                        test_wf.write(json_encoded_instance + '\n')

INFO     [__main__:4] allennlp_train_output_file found, stop generation


### Generate test mrp json data

In [345]:
allennlp_evaluation_output_file = os.path.join(
    args.project_root, args.mrp_data_dir, args.evaluation_allennlp_mrp_json_file)
allennlp_evaluation_output_file

'/data/proj29_ds1/home/slai/mrp2019/data/evaluation/allennlp-evaluation.mrp'

In [348]:
with open(allennlp_evaluation_output_file, 'w') as wf:
    for mrp_json in tqdm(test_mrp_jsons):
        cid = mrp_json.get('id')
        doc = mrp_json.get('input')
        frameworks = mrp_json.get('targets')
        parse_json = test_parse_jsons[cid]

        for framework in frameworks:
            companion_parser_states, companion_meta_data = mrp_json2parser_states(
                parse_json, 
                mrp_doc=doc,
                tokenized_parse_nodes=parse_json['nodes'],
            )

            data_instance = {
                'cid': cid,
                'doc': doc,
                'framework': framework,
                'mrp_json': mrp_json,
                'parse_json': parse_json,
                'companion_parser_states': companion_parser_states,
                'companion_meta_data': companion_meta_data,
            }
            json_encoded_instance = json.dumps(data_instance)
            wf.write(json_encoded_instance + '\n')



  0%|          | 0/6288 [00:00<?, ?it/s][A[A

  0%|          | 1/6288 [00:00<32:22,  3.24it/s][A[A

  0%|          | 2/6288 [00:00<25:51,  4.05it/s][A[A

  0%|          | 4/6288 [00:00<20:16,  5.17it/s][A[A

  0%|          | 5/6288 [00:00<17:48,  5.88it/s][A[A

  0%|          | 6/6288 [00:01<38:23,  2.73it/s][A[A

  0%|          | 7/6288 [00:01<30:34,  3.42it/s][A[A

  0%|          | 10/6288 [00:01<23:19,  4.49it/s][A[A

  0%|          | 11/6288 [00:01<21:15,  4.92it/s][A[A

  0%|          | 15/6288 [00:02<17:27,  5.99it/s][A[A

  0%|          | 16/6288 [00:02<18:05,  5.78it/s][A[A

  0%|          | 17/6288 [00:02<24:52,  4.20it/s][A[A

  0%|          | 19/6288 [00:02<19:08,  5.46it/s][A[A

  0%|          | 23/6288 [00:03<15:43,  6.64it/s][A[A

  0%|          | 26/6288 [00:03<13:04,  7.98it/s][A[A

  0%|          | 28/6288 [00:03<16:36,  6.28it/s][A[A

  0%|          | 30/6288 [00:04<14:20,  7.28it/s][A[A

  1%|          | 32/6288 [00:04<12:37,  8.26

  4%|▎         | 229/6288 [00:39<36:20,  2.78it/s][A[A

  4%|▎         | 230/6288 [00:39<28:50,  3.50it/s][A[A

  4%|▎         | 232/6288 [00:39<30:56,  3.26it/s][A[A

  4%|▎         | 233/6288 [00:40<31:04,  3.25it/s][A[A

  4%|▎         | 235/6288 [00:40<26:01,  3.88it/s][A[A

  4%|▍         | 236/6288 [00:40<25:10,  4.01it/s][A[A

  4%|▍         | 237/6288 [00:40<20:56,  4.81it/s][A[A

  4%|▍         | 240/6288 [00:41<17:04,  5.90it/s][A[A

  4%|▍         | 241/6288 [00:41<20:25,  4.94it/s][A[A

  4%|▍         | 243/6288 [00:41<23:29,  4.29it/s][A[A

  4%|▍         | 244/6288 [00:42<20:07,  5.01it/s][A[A

  4%|▍         | 245/6288 [00:42<21:10,  4.76it/s][A[A

  4%|▍         | 247/6288 [00:42<22:02,  4.57it/s][A[A

  4%|▍         | 249/6288 [00:43<21:03,  4.78it/s][A[A

  4%|▍         | 251/6288 [00:43<20:40,  4.87it/s][A[A

  4%|▍         | 254/6288 [00:43<16:20,  6.16it/s][A[A

  4%|▍         | 257/6288 [00:43<14:08,  7.11it/s][A[A

  4%|▍        

  7%|▋         | 446/6288 [01:17<15:24,  6.32it/s][A[A

  7%|▋         | 448/6288 [01:17<12:23,  7.85it/s][A[A

  7%|▋         | 450/6288 [01:17<12:38,  7.70it/s][A[A

  7%|▋         | 452/6288 [01:17<11:12,  8.67it/s][A[A

  7%|▋         | 454/6288 [01:17<10:41,  9.09it/s][A[A

  7%|▋         | 456/6288 [01:18<13:00,  7.47it/s][A[A

  7%|▋         | 457/6288 [01:18<13:30,  7.20it/s][A[A

  7%|▋         | 458/6288 [01:18<18:25,  5.27it/s][A[A

  7%|▋         | 459/6288 [01:19<19:52,  4.89it/s][A[A

  7%|▋         | 460/6288 [01:19<26:12,  3.71it/s][A[A

  7%|▋         | 461/6288 [01:19<26:39,  3.64it/s][A[A

  7%|▋         | 463/6288 [01:20<26:24,  3.68it/s][A[A

  7%|▋         | 464/6288 [01:20<31:52,  3.05it/s][A[A

  7%|▋         | 465/6288 [01:21<32:55,  2.95it/s][A[A

  7%|▋         | 466/6288 [01:21<28:57,  3.35it/s][A[A

  7%|▋         | 468/6288 [01:21<22:17,  4.35it/s][A[A

  7%|▋         | 469/6288 [01:21<20:08,  4.82it/s][A[A

  8%|▊        

 11%|█         | 661/6288 [01:55<21:28,  4.37it/s][A[A

 11%|█         | 662/6288 [01:55<18:30,  5.07it/s][A[A

 11%|█         | 663/6288 [01:55<25:01,  3.75it/s][A[A

 11%|█         | 665/6288 [01:56<19:08,  4.90it/s][A[A

 11%|█         | 667/6288 [01:56<16:02,  5.84it/s][A[A

 11%|█         | 668/6288 [01:56<14:46,  6.34it/s][A[A

 11%|█         | 669/6288 [01:56<14:15,  6.57it/s][A[A

 11%|█         | 670/6288 [01:56<14:33,  6.43it/s][A[A

 11%|█         | 673/6288 [01:56<11:45,  7.96it/s][A[A

 11%|█         | 675/6288 [01:57<18:20,  5.10it/s][A[A

 11%|█         | 677/6288 [01:57<14:45,  6.33it/s][A[A

 11%|█         | 680/6288 [01:57<12:51,  7.27it/s][A[A

 11%|█         | 682/6288 [01:58<15:11,  6.15it/s][A[A

 11%|█         | 683/6288 [01:58<18:20,  5.09it/s][A[A

 11%|█         | 684/6288 [01:58<15:41,  5.95it/s][A[A

 11%|█         | 686/6288 [01:59<14:40,  6.36it/s][A[A

 11%|█         | 687/6288 [01:59<18:11,  5.13it/s][A[A

 11%|█        

 14%|█▍        | 877/6288 [02:33<10:14,  8.80it/s][A[A

 14%|█▍        | 879/6288 [02:33<08:44, 10.32it/s][A[A

 14%|█▍        | 881/6288 [02:34<08:14, 10.93it/s][A[A

 14%|█▍        | 883/6288 [02:34<10:23,  8.68it/s][A[A

 14%|█▍        | 885/6288 [02:34<09:48,  9.18it/s][A[A

 14%|█▍        | 887/6288 [02:34<10:15,  8.77it/s][A[A

 14%|█▍        | 888/6288 [02:35<12:28,  7.22it/s][A[A

 14%|█▍        | 889/6288 [02:35<14:36,  6.16it/s][A[A

 14%|█▍        | 890/6288 [02:35<13:35,  6.62it/s][A[A

 14%|█▍        | 891/6288 [02:35<22:31,  3.99it/s][A[A

 14%|█▍        | 892/6288 [02:36<21:43,  4.14it/s][A[A

 14%|█▍        | 893/6288 [02:36<19:35,  4.59it/s][A[A

 14%|█▍        | 894/6288 [02:36<21:59,  4.09it/s][A[A

 14%|█▍        | 896/6288 [02:36<19:18,  4.66it/s][A[A

 14%|█▍        | 897/6288 [02:37<20:46,  4.32it/s][A[A

 14%|█▍        | 898/6288 [02:37<17:55,  5.01it/s][A[A

 14%|█▍        | 901/6288 [02:37<15:11,  5.91it/s][A[A

 14%|█▍       

 18%|█▊        | 1105/6288 [03:14<25:24,  3.40it/s][A[A

 18%|█▊        | 1107/6288 [03:14<19:11,  4.50it/s][A[A

 18%|█▊        | 1108/6288 [03:14<19:38,  4.40it/s][A[A

 18%|█▊        | 1111/6288 [03:15<16:54,  5.10it/s][A[A

 18%|█▊        | 1112/6288 [03:15<22:54,  3.76it/s][A[A

 18%|█▊        | 1113/6288 [03:15<20:33,  4.20it/s][A[A

 18%|█▊        | 1114/6288 [03:15<19:48,  4.35it/s][A[A

 18%|█▊        | 1115/6288 [03:16<30:20,  2.84it/s][A[A

 18%|█▊        | 1118/6288 [03:17<24:47,  3.48it/s][A[A

 18%|█▊        | 1119/6288 [03:17<20:33,  4.19it/s][A[A

 18%|█▊        | 1120/6288 [03:17<19:53,  4.33it/s][A[A

 18%|█▊        | 1122/6288 [03:17<15:27,  5.57it/s][A[A

 18%|█▊        | 1123/6288 [03:17<22:21,  3.85it/s][A[A

 18%|█▊        | 1124/6288 [03:18<19:55,  4.32it/s][A[A

 18%|█▊        | 1125/6288 [03:18<19:43,  4.36it/s][A[A

 18%|█▊        | 1127/6288 [03:18<16:24,  5.24it/s][A[A

 18%|█▊        | 1129/6288 [03:18<13:38,  6.30it/s][A[

 21%|██        | 1298/6288 [03:54<22:44,  3.66it/s][A[A

 21%|██        | 1299/6288 [03:54<20:03,  4.15it/s][A[A

 21%|██        | 1302/6288 [03:54<15:28,  5.37it/s][A[A

 21%|██        | 1304/6288 [03:55<14:16,  5.82it/s][A[A

 21%|██        | 1305/6288 [03:55<14:21,  5.78it/s][A[A

 21%|██        | 1307/6288 [03:55<11:47,  7.04it/s][A[A

 21%|██        | 1309/6288 [03:55<13:33,  6.12it/s][A[A

 21%|██        | 1311/6288 [03:56<12:07,  6.84it/s][A[A

 21%|██        | 1312/6288 [03:56<11:53,  6.98it/s][A[A

 21%|██        | 1314/6288 [03:56<10:07,  8.19it/s][A[A

 21%|██        | 1317/6288 [03:56<11:56,  6.94it/s][A[A

 21%|██        | 1318/6288 [03:57<12:49,  6.46it/s][A[A

 21%|██        | 1320/6288 [03:57<15:59,  5.18it/s][A[A

 21%|██        | 1323/6288 [03:57<12:07,  6.82it/s][A[A

 21%|██        | 1325/6288 [03:58<15:09,  5.46it/s][A[A

 21%|██        | 1326/6288 [03:59<29:00,  2.85it/s][A[A
100%|██████████| 6288/6288 [11:47<00:00,  7.21it/s]  


### Test allennlp dataset reader

In [209]:
import torch.optim as optim

from mrp_library.dataset_readers.mrp_jsons import MRPDatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.modules.feedforward import FeedForward

from allennlp.training.metrics import CategoricalAccuracy

from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer

import json
import logging
from typing import Dict

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import LabelField, TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.models import Model
from overrides import overrides

In [210]:
from mrp_library.dataset_readers.mrp_jsons_actions import MRPDatasetActionReader

In [211]:
reader = MRPDatasetActionReader()

In [212]:
train_dataset = reader.read(cached_path(allennlp_framework_train_output_file))




0it [00:00, ?it/s][A[A[AINFO     [mrp_library.dataset_readers.mrp_jsons_actions:139] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/allennlp-mrp-json-big-ucca-dm-train.jsonl












1412it [00:00, 3497.20it/s][A[A[A














3353it [00:00, 3780.77it/s][A[A[A

















5598it [00:01, 3594.84it/s][A[A[A





6344it [00:01, 3563.47it/s][A[A[A











































































15335it [00:04, 3602.05it/s][A[A[A



































19748it [00:05, 3680.27it/s][A[A[A
















21647it [00:05, 3612.99it/s][A[A[A


22011it [00:06, 3607.09it/s][A[A[A


22373it [00:06, 3502.64it/s][A[A[A






























































38963it [00:23, 3707.24it/s][A











43643it [00:24, 3737.15it/s][A

































54953it [00:27, 3473.85it/s][A

















60950it [00:29, 3789.22it/s][A



62483it [00:29, 3476.07it/s][A

























72852it [00:46, 3443.01it/s][A






74716it [00:46, 3614.88it/s][A
















81272it [00:48, 3354.54it/s][A
81615it [00:48, 3043.90it/s][A
81932it [00:48, 3080.30it/s][A
82248it [00:48, 2953.24it/s][A
82550it [00:48, 2917.30it/s][A
82847it [00:48, 2756.60it/s][A
83128it [00:49, 2749.77it/s][A
83412it [00:49, 2753.11it/s][A
83690it [00:49, 2733.97it/s][A
83966it [00:49, 2677.13it/s][A
84236it [00:49, 2587.18it/s][A
84504it [00:49, 2613.66it/s][A
84767it [00:49, 2510.52it/s][A
85020it [00:49, 2449.06it/s][A
85276it [00:49, 2456.37it/s][A
85535it [00:50, 2494.08it/s][A
85807it [00:50, 2556.94it/s][A
86069it [00:50, 2564.30it/s][A
86371it [00:50, 2670.37it/s][A
86654it [00:50, 2656.06it/s][A
86966it [00:50, 2761.53it/s][A
87253it [00:50, 2792.78it/s][A
87549it [00:50, 2799.81it/s][A
87841it [00:50, 2833.82it/s][A
88126it [00:50, 2659.21it/s][A
88395it [00:51, 2560.55it/s][A
88676it [00:51, 2610.99it/s][A
88940it [00:51, 2608.42it/s][A
89214it [00:51, 2645.97it/s][A
89480it [00:51, 2589.5

92427it [00:52, 2537.96it/s][A
92691it [00:52, 2567.41it/s][A
92957it [00:52, 2593.06it/s][A
93230it [00:52, 2605.08it/s][A
93492it [00:53, 2604.39it/s][A
93753it [00:53, 2358.85it/s][A
94012it [00:53, 2414.09it/s][A
94258it [00:53, 2358.60it/s][A
94530it [00:53, 2444.94it/s][A
94779it [00:53, 2458.18it/s][A
95071it [00:53, 2484.90it/s][A
95363it [00:53, 2600.90it/s][A
95626it [01:06, 66.08it/s]  [A
95989it [01:06, 93.67it/s][A
96384it [01:07, 132.47it/s][A
96754it [01:07, 186.33it/s][A
97121it [01:07, 260.51it/s][A
97481it [01:07, 360.94it/s][A
97830it [01:07, 493.26it/s][A
98249it [01:07, 670.80it/s][A
98596it [01:07, 878.28it/s][A
98934it [01:07, 1128.15it/s][A
99313it [01:07, 1427.72it/s][A
99694it [01:07, 1757.20it/s][A
100064it [01:08, 2085.60it/s][A
100432it [01:08, 2397.07it/s][A
100816it [01:08, 2682.05it/s][A
101184it [01:08, 2919.33it/s][A
101551it [01:08, 3108.47it/s][A
101918it [01:08, 3214.84it/s][A
102280it [01:08, 3321.95it/s][A
102661it [

In [213]:
# test_dataset = reader.read(cached_path(allennlp_train_output_file))
test_dataset = reader.read(cached_path(allennlp_framework_test_output_file))

0it [00:00, ?it/s]INFO     [mrp_library.dataset_readers.mrp_jsons_actions:139] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/allennlp-mrp-json-big-ucca-dm-test.jsonl


















87257it [01:01, 1421.82it/s]


In [214]:
tests_fixtures_dataset = reader.read(cached_path(allennlp_tests_fixtures_output_file))


0it [00:00, ?it/s][AINFO     [mrp_library.dataset_readers.mrp_jsons_actions:139] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/src/tests/fixtures/ucca-dm-test.jsonl

365it [00:00, 3647.01it/s][A
385it [00:00, 3563.89it/s][A

In [215]:
non_padded_namespaces = [
    'action_type',
]
label_fields = [
    'resolved_label_root_label',
    'resolved_label_edge_labels',
]
for framework in frameworks:
    for label_field in label_fields:
        non_padded_namespaces.append('{}_{}'.format(framework, label_field))

In [216]:
non_padded_namespaces

['action_type',
 'ucca_resolved_label_root_label',
 'ucca_resolved_label_edge_labels',
 'psd_resolved_label_root_label',
 'psd_resolved_label_edge_labels',
 'eds_resolved_label_root_label',
 'eds_resolved_label_edge_labels',
 'dm_resolved_label_root_label',
 'dm_resolved_label_edge_labels',
 'amr_resolved_label_root_label',
 'amr_resolved_label_edge_labels']

In [217]:
vocab = Vocabulary.from_instances(
    train_dataset + test_dataset + tests_fixtures_dataset,
    non_padded_namespaces=non_padded_namespaces,
)

INFO     [allennlp.data.vocabulary:396] Fitting token dictionary from dataset.

  0%|          | 0/192606 [00:00<?, ?it/s][A
  0%|          | 849/192606 [00:00<00:22, 8485.36it/s][A
  1%|          | 1621/192606 [00:00<00:23, 8236.56it/s][A
  1%|▏         | 2427/192606 [00:00<00:23, 8180.46it/s][A
  2%|▏         | 3254/192606 [00:00<00:23, 8206.24it/s][A
  2%|▏         | 4054/192606 [00:00<00:23, 8140.28it/s][A
  3%|▎         | 4831/192606 [00:00<00:23, 8025.18it/s][A
  3%|▎         | 5646/192606 [00:00<00:23, 8059.58it/s][A
  3%|▎         | 6433/192606 [00:00<00:23, 8001.74it/s][A
  4%|▍         | 7247/192606 [00:00<00:23, 8040.92it/s][A
  4%|▍         | 8020/192606 [00:01<00:23, 7943.98it/s][A
  5%|▍         | 8791/192606 [00:01<00:23, 7852.65it/s][A
  5%|▍         | 9561/192606 [00:01<00:23, 7761.02it/s][A
  5%|▌         | 10327/192606 [00:01<00:24, 7490.08it/s][A
  6%|▌         | 11111/192606 [00:01<00:23, 7590.42it/s][A
  6%|▌         | 11953/192606 [00:01<00:23, 782

 53%|█████▎    | 102413/192606 [00:13<00:12, 7031.57it/s][A
 54%|█████▎    | 103118/192606 [00:14<00:12, 6951.26it/s][A
 54%|█████▍    | 103815/192606 [00:14<00:12, 6865.71it/s][A
 54%|█████▍    | 104503/192606 [00:14<00:13, 6760.96it/s][A
 55%|█████▍    | 105184/192606 [00:14<00:12, 6773.53it/s][A
 55%|█████▍    | 105863/192606 [00:14<00:13, 6538.29it/s][A
 55%|█████▌    | 106563/192606 [00:14<00:12, 6669.14it/s][A
 56%|█████▌    | 107298/192606 [00:14<00:12, 6858.74it/s][A
 56%|█████▌    | 108014/192606 [00:14<00:12, 6946.09it/s][A
 56%|█████▋    | 108711/192606 [00:14<00:12, 6805.19it/s][A
 57%|█████▋    | 109438/192606 [00:14<00:11, 6938.22it/s][A
 57%|█████▋    | 110135/192606 [00:15<00:11, 6898.95it/s][A
 58%|█████▊    | 110827/192606 [00:15<00:11, 6894.12it/s][A
 58%|█████▊    | 111518/192606 [00:15<00:11, 6828.88it/s][A
 58%|█████▊    | 112207/192606 [00:15<00:11, 6843.84it/s][A
 59%|█████▊    | 112893/192606 [00:15<00:11, 6804.18it/s][A
 59%|█████▉    | 113629/

In [218]:
vocab.print_statistics()

INFO     [allennlp.data.vocabulary:664] Printed vocabulary statistics are only for the part of the vocabulary generated from instances. If vocabulary is constructed by extending saved vocabulary with dataset instances, the directly loaded portion won't be considered here.




----Vocabulary Statistics----


Top 10 most frequent tokens in namespace 'word':
	Token: <START-WORD>		Frequency: 2350846
	Token: ,		Frequency: 774566
	Token: the		Frequency: 710528
	Token: .		Frequency: 507126
	Token: of		Frequency: 343712
	Token: a		Frequency: 334175
	Token: in		Frequency: 267493
	Token: and		Frequency: 246198
	Token: to		Frequency: 234115
	Token: be		Frequency: 189588

Top 10 longest tokens in namespace 'word':
	Token: individual-retirement-account		length: 29	Frequency: 140
	Token: research-and-production		length: 23	Frequency: 156
	Token: electronics-instruments		length: 23	Frequency: 82
	Token: interest-rate-sensitive		length: 23	Frequency: 74
	Token: electronics-instrument		length: 22	Frequency: 82
	Token: weapons-modernization		length: 21	Frequency: 186
	Token: Watergate-beleaguered		length: 21	Frequency: 87
	Token: watergate-beleaguered		length: 21	Frequency: 87
	Token: electronic-publishing		length: 21	Frequency: 84
	Token: Bridgestone/Firestone		length: 21

In [219]:
vocab.get_token_from_index(0, namespace='resolved_label_edge_labels')

'@@PADDING@@'

In [220]:
vocab.get_vocab_size('token_node_label')

13639

In [221]:
vocab.get_vocab_size('ucca_resolved_label_edge_labels')

21

In [222]:
vocab.get_vocab_size('ucca_resolved_label_root_label')

16

In [223]:
vocab.get_vocab_size('dm_resolved_label_edge_labels')

29

In [224]:
vocab.get_vocab_size('dm_resolved_label_root_label')

8198

In [225]:
vocab.get_vocab_size('word')

14523

In [226]:
vocab.get_vocab_size('pos')

66

In [227]:
vocab.get_vocab_size('label')

2

In [228]:
EMBEDDING_DIM = 100
HIDDEN_DIM = 50

### Test model

In [229]:
from mrp_library.models.resolver import Resolver
from mrp_library.iterators.same_instance_type_framework_stack_len_iterator import SameInstanceTypeFrameworkStackLenIterator

from allennlp.nn import InitializerApplicator, RegularizerApplicator, util
from allennlp.nn.activations import Activation
from allennlp.common.params import Params
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.modules.seq2vec_encoders.pytorch_seq2vec_wrapper import PytorchSeq2VecWrapper
from allennlp.modules.seq2seq_encoders.pytorch_seq2seq_wrapper import PytorchSeq2SeqWrapper

In [230]:
if torch.cuda.is_available() and False:
    cuda_device = args.cuda_device
else:
    cuda_device = -1
cuda_device

-1

In [231]:
def _cuda(variable, cuda_device):
    if cuda_device != -1:
        variable = variable.cuda(cuda_device)
    return variable

In [232]:
field_types = ['word', 'pos', 
               'ucca_word', 'ucca_pos',
               'dm_word', 'dm_pos', 
               'resolved', 'token_node_label', 'token_node_prev_action']

field_type2embedder = {}
field_type2seq2vec_encoder = {}
field_type2seq2seq_encoder = {}

embedding_group2field_types = {
    'word': ['word', 'parse_curr_word', 'parse_prev_word', 'parse_next_word',
             'resolve_dm_word', 'resolve_ucca_word'],
    'pos': ['pos', 'parse_curr_pos', 'parse_prev_pos', 'parse_next_pos',
            'resolve_dm_pos', 'resolve_ucca_pos'],
    'resolved': ['resolved', 'token_node_resolved'],
    'token_node_label': ['token_node_label'],
    'token_node_prev_action': ['action', 'token_node_prev_action'],
    'template': ['template'],
}

seq2vec_group2field_types = {
    'parse_word': ['word', 'parse_curr_word', 'parse_prev_word', 'parse_next_word'],
    'parse_pos': ['pos', 'parse_curr_pos', 'parse_prev_pos', 'parse_next_pos'],
    'resolve_dm_word': ['resolve_dm_word'],
    'resolve_dm_pos': ['resolve_dm_pos'],
    'resolve_ucca_word': ['resolve_ucca_word'],
    'resolve_ucca_pos': ['resolve_ucca_pos'],
    'resolved': ['resolved', 'token_node_resolved'],
    'token_node_label': ['token_node_label'],
    'action': ['action', 'token_node_prev_action'],
}

seq2vec_group2field_types = {
    'parse_word': ['word', 'parse_curr_word', 'parse_prev_word', 'parse_next_word'],
    'parse_pos': ['pos', 'parse_curr_pos', 'parse_prev_pos', 'parse_next_pos'],
    'resolve_dm_word': ['resolve_dm_word'],
    'resolve_dm_pos': ['resolve_dm_pos'],
    'resolve_ucca_word': ['resolve_ucca_word'],
    'resolve_ucca_pos': ['resolve_ucca_pos'],
    'resolved': ['resolved', 'token_node_resolved'],
    'token_node_label': ['token_node_label'],
    'action': ['action', 'token_node_prev_action'],
}

for embedding_group, field_types in embedding_group2field_types.items():
    embedding = _cuda(
        Embedding(
            num_embeddings=vocab.get_vocab_size(embedding_group),
            embedding_dim=EMBEDDING_DIM
        ), cuda_device)
    logger.info((embedding_group, vocab.get_vocab_size(embedding_group)))
    embedder = BasicTextFieldEmbedder({embedding_group: embedding})
    for field_type in field_types:
        field_type2embedder[field_type] = embedder
    
for seq2vec_group, field_types in seq2vec_group2field_types.items():
    seq2vec = PytorchSeq2VecWrapper(
        _cuda(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True), cuda_device)
    )
    for field_type in field_types:
        field_type2seq2vec_encoder[field_type] = seq2vec

for seq2seq_group, field_types in seq2vec_group2field_types.items():
    seq2seq = PytorchSeq2SeqWrapper(
        _cuda(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True), cuda_device)
    )
    for field_type in field_types:
        field_type2seq2seq_encoder[field_type] = seq2seq

INFO     [__main__:51] ('word', 14523)
INFO     [__main__:51] ('pos', 66)
INFO     [__main__:51] ('resolved', 5)
INFO     [__main__:51] ('token_node_label', 13639)
INFO     [__main__:51] ('token_node_prev_action', 6)
INFO     [__main__:51] ('template', 2)


In [233]:
# word_embedding = Embedding(num_embeddings=vocab.get_vocab_size('word'),
#                             embedding_dim=EMBEDDING_DIM)
# pos_embedding = Embedding(num_embeddings=vocab.get_vocab_size('pos'),
#                             embedding_dim=EMBEDDING_DIM)

# word_embedder = BasicTextFieldEmbedder({
#     "word": word_embedding,
#     "pos": pos_embedding,
# })
# parse_label = {
#     'word': torch.LongTensor(
#         [
#             [ 1,  0,  3,  7,  2,  9,  4],
#             [ 0,  0,  5,  0,  0,  0,  4]
#         ]
#     ),
#     'pos': torch.LongTensor(
#         [
#             [ 1,  0,  3,  7,  2,  9,  4],
#             [ 0,  0,  5,  0,  0,  0,  4]
#         ]
#     )
# }

In [234]:
# embedded_parse_label = word_embedder(parse_label)

In [235]:
# embedded_parse_label.shape

In [236]:
vocab.get_vocab_size('resolved_label_edge_labels')

2

In [237]:
vocab.get_vocab_size('dm_resolved_label_edge_labels')

29

In [238]:
vocab.get_vocab_size('dm_resolved_label_root_label')

8198

In [285]:
vocab.get_vocab_size('action_type_labels')

5

In [286]:
action_classifier_params = Params({
  "input_dim": HIDDEN_DIM * 15,
  "num_layers": 2,
  "hidden_dims": [50, vocab.get_vocab_size('action_type_labels')],
  "activations": ["sigmoid", "linear"],
  "dropout": [0.0, 0.0]
})

action_num_pop_classifier_params = Params({
  "input_dim": HIDDEN_DIM * 2,
  "num_layers": 2,
  "hidden_dims": [50, 1],
  "activations": ["sigmoid", "linear"],
  "dropout": [0.0, 0.0]
})

framework2field_type2feedforward_params = {
    'dm':{
        'child_edges': Params({
            "input_dim": HIDDEN_DIM * 4,
            "num_layers": 2,
            "hidden_dims": [100, vocab.get_vocab_size('dm_resolved_label_edge_labels')],
            "activations": ["sigmoid", "linear"],
            "dropout": [0.0, 0.0]
        }),
        'root_label': Params({
            "input_dim": EMBEDDING_DIM * 4,
            "num_layers": 2,
            "hidden_dims": [100, vocab.get_vocab_size('dm_resolved_label_root_label')],
            "activations": ["sigmoid", "linear"],
            "dropout": [0.0, 0.0]
        }),
    },
    'ucca':{
        'child_edges': Params({
            "input_dim": HIDDEN_DIM * 4,
            "num_layers": 2,
            "hidden_dims": [100, vocab.get_vocab_size('ucca_resolved_label_edge_labels')],
            "activations": ["sigmoid", "linear"],
            "dropout": [0.0, 0.0]
        }),
        'root_label': Params({
            "input_dim": EMBEDDING_DIM * 4,
            "num_layers": 2,
            "hidden_dims": [100, vocab.get_vocab_size('ucca_resolved_label_root_label')],
            "activations": ["sigmoid", "linear"],
            "dropout": [0.0, 0.0]
        }),
    },
}

In [287]:
action_classifier_feedforward = FeedForward.from_params(action_classifier_params)
action_classifier_feedforward = _cuda(action_classifier_feedforward, cuda_device)

action_num_pop_classifier_feedforward = FeedForward.from_params(action_num_pop_classifier_params)
action_num_pop_classifier_feedforward = _cuda(action_num_pop_classifier_feedforward, cuda_device)

INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.modules.feedforward.FeedForward'> from params {'input_dim': 750, 'num_layers': 2, 'hidden_dims': [50, 5], 'activations': ['sigmoid', 'linear'], 'dropout': [0.0, 0.0]} and extras set()
INFO     [allennlp.common.params:252] input_dim = 750
INFO     [allennlp.common.params:252] num_layers = 2
INFO     [allennlp.common.params:252] hidden_dims = [50, 5]
INFO     [allennlp.common.params:252] hidden_dims = [50, 5]
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params ['sigmoid', 'linear'] and extras set()
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params sigmoid and extras set()
INFO     [allennlp.common.params:252] type = sigmoid
INFO

In [288]:
framework2field_type2feedforward = {}

for framework, field_type2feedforward_params in framework2field_type2feedforward_params.items():
    framework2field_type2feedforward[framework] = {}
    for field_type, feedforward_params in field_type2feedforward_params.items():
        feedforward_classifier = FeedForward.from_params(
            feedforward_params)
        feedforward_classifier = _cuda(feedforward_classifier, cuda_device)
        framework2field_type2feedforward[framework][field_type] = feedforward_classifier
    framework2field_type2feedforward[framework] = torch.nn.ModuleDict(framework2field_type2feedforward[framework])

framework2field_type2feedforward = torch.nn.ModuleDict(framework2field_type2feedforward)

INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.modules.feedforward.FeedForward'> from params {'input_dim': 200, 'num_layers': 2, 'hidden_dims': [100, 29], 'activations': ['sigmoid', 'linear'], 'dropout': [0.0, 0.0]} and extras set()
INFO     [allennlp.common.params:252] input_dim = 200
INFO     [allennlp.common.params:252] num_layers = 2
INFO     [allennlp.common.params:252] hidden_dims = [100, 29]
INFO     [allennlp.common.params:252] hidden_dims = [100, 29]
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params ['sigmoid', 'linear'] and extras set()
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params sigmoid and extras set()
INFO     [allennlp.common.params:252] type = sigmoi

In [289]:
field_type = 'word'

In [290]:
field_type2embedder.keys()

dict_keys(['word', 'parse_curr_word', 'parse_prev_word', 'parse_next_word', 'resolve_dm_word', 'resolve_ucca_word', 'pos', 'parse_curr_pos', 'parse_prev_pos', 'parse_next_pos', 'resolve_dm_pos', 'resolve_ucca_pos', 'resolved', 'token_node_resolved', 'token_node_label', 'action', 'token_node_prev_action', 'template'])

In [291]:
parse_label = {
    field_type: _cuda(
        torch.LongTensor(
            [
                [ 1,  0,  3,  7,  2,  9,  4],
                [ 0,  0,  5,  0,  0,  0,  4]
            ]
        ),
        cuda_device
    )
}
embedded_parse_label = field_type2embedder[field_type](parse_label)

In [292]:
feature_mask = util.get_text_field_mask(parse_label)

In [293]:
seq2vec_encoder = field_type2seq2vec_encoder[field_type]

In [294]:
encoded_feature = seq2vec_encoder(embedded_parse_label, feature_mask)

In [295]:
encoded_features = [encoded_feature] * 15

In [296]:
# torch.cat(encoded_features, dim=-1)

In [297]:
actions

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'anchors': [{'from': 0, 'to': 2}],
    'label': 'In',
    'propagate_label': 'R'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'anchors': [{'from': 3, 'to': 6}],
    'label': 'the',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 2,
    'anchors': [{'from': 7, 'to': 12}],
    'label': 'final',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 3,
    'anchors': [{'from': 13, 'to': 19}],
    'label': 'minute',
    'propagate_label': 'C'},
   [[]])),
 (1,
  (4,
   3,
   {'id': 31, 'propagate_label': 'E'},
   [[{'source': 31,
      'target': 0,
      'label': 'R',
      'id': 31,
      'parent': 31,
      'child': 0}],
    [{'source': 31,
      'target': 1,
      'label': 'E',
      'id': 28,
      'parent': 31,
      'child': 1}],
    [{'source': 31,
      'target': 2,
      'label': 'E',
      'id': 21,
      'parent': 31,
      'child': 2}],
    [{'source': 31,
      't

In [298]:
logits = action_classifier_feedforward(torch.cat(encoded_features, dim=-1))

In [299]:
logits.shape

torch.Size([2, 5])

In [300]:
label = torch.tensor([1, 0])

In [301]:
# loss_func = torch.nn.CrossEntropyLoss()
# loss = loss_func(logits, label)

In [302]:
from mrp_library.models.resolver import Resolver
from mrp_library.iterators.same_instance_type_framework_stack_len_iterator import SameInstanceTypeFrameworkStackLenIterator


In [303]:
if torch.cuda.is_available() and False:
    cuda_device = args.cuda_device
    model = Resolver(
        cuda_device=cuda_device,
        vocab=vocab,
        field_type2embedder=field_type2embedder,
        field_type2seq2vec_encoder=field_type2seq2vec_encoder,
        field_type2seq2seq_encoder=field_type2seq2seq_encoder,
        action_classifier_feedforward=action_classifier_feedforward,
        action_num_pop_classifier_feedforward=action_num_pop_classifier_feedforward,
        framework2field_type2feedforward=framework2field_type2feedforward,
    )
    model = model.cuda(cuda_device)
else:
    cuda_device = -1
    model = Resolver(
        cuda_device=cuda_device,
        vocab=vocab,
        field_type2embedder=field_type2embedder,
        field_type2seq2vec_encoder=field_type2seq2vec_encoder,
        field_type2seq2seq_encoder=field_type2seq2seq_encoder,
        action_classifier_feedforward=action_classifier_feedforward,
        action_num_pop_classifier_feedforward=action_num_pop_classifier_feedforward,
        framework2field_type2feedforward=framework2field_type2feedforward,
    )

# iterator = SameInstanceTypeFrameworkIterator(
#     shuffle=True,
#     batch_size=100, 
#     sorting_keys=[("token_node_resolveds", "num_tokens")],
# )
iterator = SameInstanceTypeFrameworkStackLenIterator(
    shuffle=True,
    batch_size=100, 
    sorting_keys=[("token_node_resolveds", "num_tokens")],
)
iterator.index_with(vocab)

optimizer = optim.SGD(model.parameters(), lr=0.1)

INFO     [allennlp.nn.initializers:293] Initializing parameters
INFO     [allennlp.nn.initializers:309] Done initializing parameters; the following parameters are using their default initialization from their code
INFO     [allennlp.nn.initializers:314]    action_classifier_feedforward._linear_layers.0.bias
INFO     [allennlp.nn.initializers:314]    action_classifier_feedforward._linear_layers.0.weight
INFO     [allennlp.nn.initializers:314]    action_classifier_feedforward._linear_layers.1.bias
INFO     [allennlp.nn.initializers:314]    action_classifier_feedforward._linear_layers.1.weight
INFO     [allennlp.nn.initializers:314]    action_num_pop_classifier_feedforward._linear_layers.0.bias
INFO     [allennlp.nn.initializers:314]    action_num_pop_classifier_feedforward._linear_layers.0.weight
INFO     [allennlp.nn.initializers:314]    action_num_pop_classifier_feedforward._linear_layers.1.bias
INFO     [allennlp.nn.initializers:314]    action_num_pop_classifier_feedforward._linear_la

INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.resolve_ucca_word._module.weight_ih_l0
INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.resolved._module.bias_hh_l0
INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.resolved._module.bias_ih_l0
INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.resolved._module.weight_hh_l0
INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.resolved._module.weight_ih_l0
INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.token_node_label._module.bias_hh_l0
INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.token_node_label._module.bias_ih_l0
INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.token_node_label._module.weight_hh_l0
INFO     [allennlp.nn.initializers:314]    field_type2seq2vec_encoder.token_node_label._module.weight_ih_l0
INFO     [allennlp.nn.initializers:314]    framework2field_type2fee

In [304]:
cuda_device

-1

In [305]:
model.resolve_tensor

tensor(1)

In [306]:
# list(model.named_parameters())

In [307]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_dataset,
    validation_dataset=test_dataset,
#     train_dataset=train_dataset,
#     validation_dataset=train_dataset,
    patience=10,
    num_epochs=20,
    cuda_device=cuda_device
)

In [308]:
action_logits = torch.tensor([[-2.2126,  2.6022, -1.1655],
        [ 4.7340, -1.9992, -3.4521],
        [-1.9665,  2.4100, -1.2047],
        [-2.1353,  2.4847, -1.1260],
        [ 4.7492, -2.0234, -3.4460],
        [-1.4369,  1.9822, -1.2885],
        [ 1.0337,  0.3599, -2.0420],
        [ 5.0974, -2.3380, -3.4647],
        [ 5.4187, -2.4720, -3.6469],
        [-3.4045,  2.4903,  0.0773],
        [ 0.6384,  0.6764, -1.9942],
        [-2.2904,  2.7170, -1.2016],
        [ 4.6333, -1.9474, -3.4113],
        [-2.0811,  2.5367, -1.2174],
        [ 5.1840, -2.7536, -3.1499],
        [ 4.7421, -2.0138, -3.4485],
        [ 5.1121, -2.2290, -3.5999],
        [ 0.1843,  0.8324, -1.6990],
        [ 5.3854, -2.4593, -3.6309],
        [ 0.0324,  0.6715, -1.4173]])
action_logits = _cuda(action_logits, cuda_device)

action_type = torch.tensor([0, 2, 2, 1, 1, 0, 1, 2, 1, 2, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0])
action_type = _cuda(action_type, cuda_device)

In [309]:
action_probs, action_preds = action_logits.max(1)
action_resolve_preds = action_preds.eq_(model.resolve_tensor)

In [310]:
iter([1, 2, 3, 4])

<list_iterator at 0x7f6d6ed1b6a0>

In [311]:
defaultdict(lambda: defaultdict(dict))

defaultdict(<function __main__.<lambda>()>, {})

In [312]:
(action_resolve_preds, action_type, action_resolve_preds.eq(action_type))

(tensor([1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1]),
 tensor([0, 2, 2, 1, 1, 0, 1, 2, 1, 2, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0]),
 tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0],
        dtype=torch.uint8))

In [313]:
action_resolve_preds.unsqueeze(-1).float() * action_logits

tensor([[-2.2126,  2.6022, -1.1655],
        [ 0.0000, -0.0000, -0.0000],
        [-1.9665,  2.4100, -1.2047],
        [-2.1353,  2.4847, -1.1260],
        [ 0.0000, -0.0000, -0.0000],
        [-1.4369,  1.9822, -1.2885],
        [ 0.0000,  0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000],
        [-3.4045,  2.4903,  0.0773],
        [ 0.6384,  0.6764, -1.9942],
        [-2.2904,  2.7170, -1.2016],
        [ 0.0000, -0.0000, -0.0000],
        [-2.0811,  2.5367, -1.2174],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.1843,  0.8324, -1.6990],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.0324,  0.6715, -1.4173]])

In [314]:
embedded_fields = torch.ones(99, 62, 100)

In [315]:
embedded_fields.size()

torch.Size([99, 62, 100])

In [316]:
# embedded_fields

In [317]:
root_position_logits = torch.tensor([[0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1720, 0.1715, 0.1718],
        [0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1716, 0.1719, 0.1708],
        [0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1720, 0.1715, 0.1717],
        [0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1716, 0.1719, 0.1717],
        [0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1720, 0.1715, 0.1709]])

In [318]:
root_position_logits.shape
root_position = torch.tensor([ 2,  0, -1,  1,  0])
mask = root_position.eq(torch.tensor(-1)).unsqueeze(-1).expand_as(root_position_logits)
mask = root_position.gt(-1)
masked_root_position_logits = root_position_logits * mask.float().unsqueeze(-1).expand_as(root_position_logits)
loss = torch.nn.CrossEntropyLoss()
inp = root_position_logits
# crossEntropy = -torch.log(torch.gather(inp, 1, root_position.view(-1, 1)))
root_position
# root_position.eq(torch.tensor(-2)).expand_as(root_position_logits)

tensor([ 2,  0, -1,  1,  0])

In [319]:
a = [1, 2]

In [320]:
d = {'b': a, 'c': a}

In [321]:
a.append(3)

In [322]:
d

{'b': [1, 2, 3], 'c': [1, 2, 3]}

In [323]:
masked_root_position_logits.shape

torch.Size([5, 31])

In [324]:
root_position.shape

torch.Size([5])

In [325]:
torch.ones(2, 3, 47).transpose(1, 2).shape

torch.Size([2, 47, 3])

In [326]:
[dim for dim in torch.tensor([[]]).size()]

[1, 0]

In [327]:
torch.tensor([[]]).shape[1]

0

In [328]:
trainer.train()

INFO     [allennlp.training.trainer:465] Beginning training.
INFO     [allennlp.training.trainer:281] Epoch 0/19
INFO     [allennlp.training.trainer:283] Peak CPU memory usage MB: 12146.244
INFO     [allennlp.training.trainer:287] GPU 0 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 1 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 2 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 3 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 4 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 5 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 6 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 7 memory usage MB: 10
INFO     [allennlp.training.trainer:311] Training


  0%|          | 0/1050 [00:00<?, ?it/s][A[A

action_type_accuracy: 0.0000, action_num_pop_accuracy: 0.0000, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 4.5036 ||:   0%|          | 1/1050 [

action_type_accuracy: 0.5116, action_num_pop_accuracy: 0.1985, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.0451 ||:   3%|▎         | 36/1050 [00:08<03:06,  5.43it/s][A[A

action_type_accuracy: 0.5118, action_num_pop_accuracy: 0.1985, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.0176 ||:   4%|▎         | 37/1050 [00:08<03:01,  5.59it/s][A[A

action_type_accuracy: 0.5132, action_num_pop_accuracy: 0.1985, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9908 ||:   4%|▎         | 38/1050 [00:08<02:58,  5.67it/s][A[A

action_type_accuracy: 0.5125, action_num_pop_accuracy: 0.1985, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9669 ||:   4%|▎         | 39/1050 [00:08<02:50,  5.94it/s][A[A

action_type_accuracy: 0.5122, action_num_pop_accuracy: 0.1985, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9436 ||:   4%|▍         | 40/1050 [0

action_type_accuracy: 0.5100, action_num_pop_accuracy: 0.2017, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.5383 ||:   7%|▋         | 76/1050 [00:15<02:58,  5.44it/s][A[A

action_type_accuracy: 0.5100, action_num_pop_accuracy: 0.1994, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.5547 ||:   7%|▋         | 77/1050 [00:15<03:34,  4.55it/s][A[A

action_type_accuracy: 0.5095, action_num_pop_accuracy: 0.1975, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.5704 ||:   7%|▋         | 78/1050 [00:15<03:45,  4.31it/s][A[A

action_type_accuracy: 0.5062, action_num_pop_accuracy: 0.2032, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.5861 ||:   8%|▊         | 79/1050 [00:16<03:59,  4.06it/s][A[A

action_type_accuracy: 0.5072, action_num_pop_accuracy: 0.2071, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.5998 ||:   8%|▊         | 80/1050 [0

action_type_accuracy: 0.5147, action_num_pop_accuracy: 0.2048, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9283 ||:  11%|█         | 115/1050 [00:26<04:25,  3.52it/s][A[A

action_type_accuracy: 0.5137, action_num_pop_accuracy: 0.2040, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9357 ||:  11%|█         | 116/1050 [00:26<04:20,  3.58it/s][A[A

action_type_accuracy: 0.5127, action_num_pop_accuracy: 0.2047, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9424 ||:  11%|█         | 117/1050 [00:27<04:07,  3.77it/s][A[A

action_type_accuracy: 0.5127, action_num_pop_accuracy: 0.2057, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9487 ||:  11%|█         | 118/1050 [00:27<04:16,  3.63it/s][A[A

action_type_accuracy: 0.5134, action_num_pop_accuracy: 0.2072, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9550 ||:  11%|█▏        | 119/10

action_type_accuracy: 0.5134, action_num_pop_accuracy: 0.2066, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9940 ||:  15%|█▍        | 154/1050 [00:37<03:56,  3.79it/s][A[A

action_type_accuracy: 0.5131, action_num_pop_accuracy: 0.2063, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9975 ||:  15%|█▍        | 155/1050 [00:37<03:50,  3.89it/s][A[A

action_type_accuracy: 0.5120, action_num_pop_accuracy: 0.2057, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.0014 ||:  15%|█▍        | 156/1050 [00:37<03:52,  3.85it/s][A[A

action_type_accuracy: 0.5114, action_num_pop_accuracy: 0.2056, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.0048 ||:  15%|█▍        | 157/1050 [00:38<04:04,  3.65it/s][A[A

action_type_accuracy: 0.5106, action_num_pop_accuracy: 0.2064, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.0083 ||:  15%|█▌        | 158/10

action_type_accuracy: 0.5075, action_num_pop_accuracy: 0.2009, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.1013 ||:  18%|█▊        | 193/1050 [00:47<03:55,  3.64it/s][A[A

action_type_accuracy: 0.5077, action_num_pop_accuracy: 0.2008, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.1036 ||:  18%|█▊        | 194/1050 [00:47<03:52,  3.68it/s][A[A

action_type_accuracy: 0.5072, action_num_pop_accuracy: 0.2005, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.1062 ||:  19%|█▊        | 195/1050 [00:47<03:41,  3.87it/s][A[A

action_type_accuracy: 0.5066, action_num_pop_accuracy: 0.2006, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.1084 ||:  19%|█▊        | 196/1050 [00:47<03:32,  4.02it/s][A[A

action_type_accuracy: 0.5069, action_num_pop_accuracy: 0.2006, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.1103 ||:  19%|█▉        | 197/10

action_type_accuracy: 0.5294, action_num_pop_accuracy: 0.1885, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0556, loss: 2.0209 ||:  22%|██▏       | 233/1050 [00:55<02:49,  4.82it/s][A[A

action_type_accuracy: 0.5294, action_num_pop_accuracy: 0.1885, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0556, loss: 2.0162 ||:  22%|██▏       | 234/1050 [00:55<02:33,  5.32it/s][A[A

action_type_accuracy: 0.5294, action_num_pop_accuracy: 0.1885, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0556, loss: 2.0113 ||:  22%|██▏       | 235/1050 [00:55<03:02,  4.47it/s][A[A

action_type_accuracy: 0.5294, action_num_pop_accuracy: 0.1885, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0556, loss: 2.0072 ||:  22%|██▏       | 236/1050 [00:56<03:28,  3.91it/s][A[A

action_type_accuracy: 0.5294, action_num_pop_accuracy: 0.1885, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0556, loss: 2.0022 ||:  23%|██▎       | 237/10

action_type_accuracy: 0.5295, action_num_pop_accuracy: 0.1884, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0502, loss: 1.8798 ||:  26%|██▌       | 274/1050 [01:04<02:31,  5.11it/s][A[A

action_type_accuracy: 0.5295, action_num_pop_accuracy: 0.1884, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0499, loss: 1.8768 ||:  26%|██▌       | 275/1050 [01:04<02:22,  5.43it/s][A[A

action_type_accuracy: 0.5295, action_num_pop_accuracy: 0.1884, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0500, loss: 1.8736 ||:  26%|██▋       | 276/1050 [01:05<02:30,  5.14it/s][A[A

action_type_accuracy: 0.5295, action_num_pop_accuracy: 0.1884, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0502, loss: 1.8702 ||:  26%|██▋       | 277/1050 [01:05<02:20,  5.50it/s][A[A

action_type_accuracy: 0.5295, action_num_pop_accuracy: 0.1884, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0506, loss: 1.8661 ||:  26%|██▋       | 278/10

action_type_accuracy: 0.5319, action_num_pop_accuracy: 0.1870, root_label_type_accuracy: 0.2674, child_edges_type_accuracy: 0.0531, loss: 1.7890 ||:  30%|██▉       | 314/1050 [01:11<02:10,  5.65it/s][A[A

action_type_accuracy: 0.5343, action_num_pop_accuracy: 0.1857, root_label_type_accuracy: 0.2674, child_edges_type_accuracy: 0.0531, loss: 1.7894 ||:  30%|███       | 315/1050 [01:12<02:30,  4.90it/s][A[A

action_type_accuracy: 0.5367, action_num_pop_accuracy: 0.1843, root_label_type_accuracy: 0.2674, child_edges_type_accuracy: 0.0531, loss: 1.7896 ||:  30%|███       | 316/1050 [01:12<03:11,  3.84it/s][A[A

action_type_accuracy: 0.5390, action_num_pop_accuracy: 0.1830, root_label_type_accuracy: 0.2674, child_edges_type_accuracy: 0.0531, loss: 1.7899 ||:  30%|███       | 317/1050 [01:12<03:18,  3.69it/s][A[A

action_type_accuracy: 0.5413, action_num_pop_accuracy: 0.1816, root_label_type_accuracy: 0.2674, child_edges_type_accuracy: 0.0531, loss: 1.7901 ||:  30%|███       | 318/10

action_type_accuracy: 0.5550, action_num_pop_accuracy: 0.1753, root_label_type_accuracy: 0.2671, child_edges_type_accuracy: 0.1078, loss: 1.8187 ||:  34%|███▍      | 360/1050 [01:21<01:12,  9.55it/s][A[A

action_type_accuracy: 0.5550, action_num_pop_accuracy: 0.1753, root_label_type_accuracy: 0.2671, child_edges_type_accuracy: 0.1136, loss: 1.8149 ||:  34%|███▍      | 362/1050 [01:21<01:08, 10.02it/s][A[A

action_type_accuracy: 0.5550, action_num_pop_accuracy: 0.1753, root_label_type_accuracy: 0.2671, child_edges_type_accuracy: 0.1168, loss: 1.8119 ||:  35%|███▍      | 364/1050 [01:21<00:59, 11.49it/s][A[A

action_type_accuracy: 0.5550, action_num_pop_accuracy: 0.1753, root_label_type_accuracy: 0.2671, child_edges_type_accuracy: 0.1186, loss: 1.8135 ||:  35%|███▍      | 366/1050 [01:22<02:02,  5.59it/s][A[A

action_type_accuracy: 0.5550, action_num_pop_accuracy: 0.1754, root_label_type_accuracy: 0.2671, child_edges_type_accuracy: 0.1186, loss: 1.8214 ||:  35%|███▍      | 367/10

action_type_accuracy: 0.5712, action_num_pop_accuracy: 0.1666, root_label_type_accuracy: 0.2679, child_edges_type_accuracy: 0.0999, loss: 1.7770 ||:  38%|███▊      | 403/1050 [01:33<03:33,  3.03it/s][A[A

action_type_accuracy: 0.5712, action_num_pop_accuracy: 0.1666, root_label_type_accuracy: 0.2679, child_edges_type_accuracy: 0.1023, loss: 1.7747 ||:  38%|███▊      | 404/1050 [01:34<03:14,  3.33it/s][A[A

action_type_accuracy: 0.5712, action_num_pop_accuracy: 0.1666, root_label_type_accuracy: 0.2679, child_edges_type_accuracy: 0.1015, loss: 1.7719 ||:  39%|███▊      | 405/1050 [01:34<02:50,  3.79it/s][A[A

action_type_accuracy: 0.5712, action_num_pop_accuracy: 0.1666, root_label_type_accuracy: 0.2679, child_edges_type_accuracy: 0.1000, loss: 1.7688 ||:  39%|███▊      | 406/1050 [01:34<03:29,  3.08it/s][A[A

action_type_accuracy: 0.5712, action_num_pop_accuracy: 0.1666, root_label_type_accuracy: 0.2679, child_edges_type_accuracy: 0.0990, loss: 1.7660 ||:  39%|███▉      | 407/10

action_type_accuracy: 0.5741, action_num_pop_accuracy: 0.1651, root_label_type_accuracy: 0.2370, child_edges_type_accuracy: 0.0954, loss: 1.8266 ||:  42%|████▏     | 443/1050 [01:45<01:28,  6.90it/s][A[A

action_type_accuracy: 0.5741, action_num_pop_accuracy: 0.1651, root_label_type_accuracy: 0.2394, child_edges_type_accuracy: 0.0954, loss: 1.8391 ||:  42%|████▏     | 444/1050 [01:46<01:22,  7.36it/s][A[A

action_type_accuracy: 0.5741, action_num_pop_accuracy: 0.1651, root_label_type_accuracy: 0.2373, child_edges_type_accuracy: 0.0954, loss: 1.8522 ||:  42%|████▏     | 445/1050 [01:46<01:19,  7.57it/s][A[A

action_type_accuracy: 0.5741, action_num_pop_accuracy: 0.1651, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0954, loss: 1.8649 ||:  42%|████▏     | 446/1050 [01:46<01:14,  8.10it/s][A[A

action_type_accuracy: 0.5741, action_num_pop_accuracy: 0.1651, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0954, loss: 1.8618 ||:  43%|████▎     | 447/10

action_type_accuracy: 0.5932, action_num_pop_accuracy: 0.1550, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0868, loss: 1.8005 ||:  46%|████▌     | 483/1050 [01:54<02:43,  3.47it/s][A[A

action_type_accuracy: 0.5932, action_num_pop_accuracy: 0.1550, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0863, loss: 1.7983 ||:  46%|████▌     | 484/1050 [01:55<02:26,  3.85it/s][A[A

action_type_accuracy: 0.5932, action_num_pop_accuracy: 0.1550, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0859, loss: 1.7961 ||:  46%|████▌     | 485/1050 [01:55<02:26,  3.87it/s][A[A

action_type_accuracy: 0.5932, action_num_pop_accuracy: 0.1550, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0856, loss: 1.7938 ||:  46%|████▋     | 486/1050 [01:55<02:33,  3.66it/s][A[A

action_type_accuracy: 0.5932, action_num_pop_accuracy: 0.1550, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0852, loss: 1.7916 ||:  46%|████▋     | 487/10

action_type_accuracy: 0.5972, action_num_pop_accuracy: 0.1534, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7573 ||:  50%|█████     | 525/1050 [02:04<01:40,  5.21it/s][A[A

action_type_accuracy: 0.5972, action_num_pop_accuracy: 0.1534, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7554 ||:  50%|█████     | 526/1050 [02:04<01:34,  5.57it/s][A[A

action_type_accuracy: 0.5972, action_num_pop_accuracy: 0.1534, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7533 ||:  50%|█████     | 527/1050 [02:04<01:35,  5.46it/s][A[A

action_type_accuracy: 0.5972, action_num_pop_accuracy: 0.1534, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7516 ||:  50%|█████     | 528/1050 [02:05<02:27,  3.55it/s][A[A

action_type_accuracy: 0.5972, action_num_pop_accuracy: 0.1534, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7496 ||:  50%|█████     | 529/10

action_type_accuracy: 0.5950, action_num_pop_accuracy: 0.1567, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7522 ||:  54%|█████▍    | 566/1050 [02:14<02:11,  3.68it/s][A[A

action_type_accuracy: 0.5945, action_num_pop_accuracy: 0.1568, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7545 ||:  54%|█████▍    | 567/1050 [02:14<02:06,  3.81it/s][A[A

action_type_accuracy: 0.5942, action_num_pop_accuracy: 0.1569, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7568 ||:  54%|█████▍    | 568/1050 [02:15<02:00,  4.00it/s][A[A

action_type_accuracy: 0.5939, action_num_pop_accuracy: 0.1573, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7592 ||:  54%|█████▍    | 569/1050 [02:15<02:00,  3.98it/s][A[A

action_type_accuracy: 0.5939, action_num_pop_accuracy: 0.1575, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7614 ||:  54%|█████▍    | 570/10

action_type_accuracy: 0.6067, action_num_pop_accuracy: 0.1606, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7419 ||:  58%|█████▊    | 606/1050 [02:25<02:07,  3.48it/s][A[A

action_type_accuracy: 0.6062, action_num_pop_accuracy: 0.1608, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7426 ||:  58%|█████▊    | 607/1050 [02:25<02:02,  3.62it/s][A[A

action_type_accuracy: 0.6054, action_num_pop_accuracy: 0.1611, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7431 ||:  58%|█████▊    | 608/1050 [02:26<02:10,  3.40it/s][A[A

action_type_accuracy: 0.6050, action_num_pop_accuracy: 0.1613, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7436 ||:  58%|█████▊    | 609/1050 [02:26<02:20,  3.13it/s][A[A

action_type_accuracy: 0.6048, action_num_pop_accuracy: 0.1615, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7441 ||:  58%|█████▊    | 610/10

action_type_accuracy: 0.5901, action_num_pop_accuracy: 0.1689, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7603 ||:  61%|██████▏   | 645/1050 [02:35<01:42,  3.96it/s][A[A

action_type_accuracy: 0.5895, action_num_pop_accuracy: 0.1689, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7607 ||:  62%|██████▏   | 646/1050 [02:36<01:44,  3.88it/s][A[A

action_type_accuracy: 0.5888, action_num_pop_accuracy: 0.1690, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7612 ||:  62%|██████▏   | 647/1050 [02:36<01:55,  3.50it/s][A[A

action_type_accuracy: 0.5887, action_num_pop_accuracy: 0.1691, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7614 ||:  62%|██████▏   | 648/1050 [02:36<01:53,  3.53it/s][A[A

action_type_accuracy: 0.5883, action_num_pop_accuracy: 0.1690, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7619 ||:  62%|██████▏   | 649/10

action_type_accuracy: 0.5769, action_num_pop_accuracy: 0.1726, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7856 ||:  65%|██████▌   | 684/1050 [02:46<01:38,  3.70it/s][A[A

action_type_accuracy: 0.5768, action_num_pop_accuracy: 0.1728, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7871 ||:  65%|██████▌   | 685/1050 [02:46<01:36,  3.78it/s][A[A

action_type_accuracy: 0.5766, action_num_pop_accuracy: 0.1728, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7887 ||:  65%|██████▌   | 686/1050 [02:47<01:35,  3.81it/s][A[A

action_type_accuracy: 0.5765, action_num_pop_accuracy: 0.1730, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7902 ||:  65%|██████▌   | 687/1050 [02:47<01:34,  3.82it/s][A[A

action_type_accuracy: 0.5763, action_num_pop_accuracy: 0.1731, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0835, loss: 1.7917 ||:  66%|██████▌   | 688/10

action_type_accuracy: 0.5727, action_num_pop_accuracy: 0.1767, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0833, loss: 1.8332 ||:  69%|██████▉   | 723/1050 [02:57<01:12,  4.54it/s][A[A

action_type_accuracy: 0.5727, action_num_pop_accuracy: 0.1767, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0833, loss: 1.8313 ||:  69%|██████▉   | 724/1050 [02:57<01:14,  4.37it/s][A[A

action_type_accuracy: 0.5727, action_num_pop_accuracy: 0.1767, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0833, loss: 1.8298 ||:  69%|██████▉   | 725/1050 [02:58<01:22,  3.92it/s][A[A

action_type_accuracy: 0.5727, action_num_pop_accuracy: 0.1767, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0833, loss: 1.8282 ||:  69%|██████▉   | 726/1050 [02:58<01:19,  4.05it/s][A[A

action_type_accuracy: 0.5727, action_num_pop_accuracy: 0.1767, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.0833, loss: 1.8265 ||:  69%|██████▉   | 727/10

action_type_accuracy: 0.5719, action_num_pop_accuracy: 0.1771, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7929 ||:  73%|███████▎  | 763/1050 [03:05<00:50,  5.72it/s][A[A

action_type_accuracy: 0.5718, action_num_pop_accuracy: 0.1773, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7928 ||:  73%|███████▎  | 764/1050 [03:05<00:57,  4.98it/s][A[A

action_type_accuracy: 0.5715, action_num_pop_accuracy: 0.1774, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7928 ||:  73%|███████▎  | 765/1050 [03:05<01:04,  4.43it/s][A[A

action_type_accuracy: 0.5714, action_num_pop_accuracy: 0.1774, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7926 ||:  73%|███████▎  | 766/1050 [03:05<01:09,  4.07it/s][A[A

action_type_accuracy: 0.5710, action_num_pop_accuracy: 0.1777, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7927 ||:  73%|███████▎  | 767/10

action_type_accuracy: 0.5670, action_num_pop_accuracy: 0.1805, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7874 ||:  76%|███████▋  | 802/1050 [03:15<01:06,  3.70it/s][A[A

action_type_accuracy: 0.5668, action_num_pop_accuracy: 0.1807, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7873 ||:  76%|███████▋  | 803/1050 [03:15<01:06,  3.74it/s][A[A

action_type_accuracy: 0.5670, action_num_pop_accuracy: 0.1807, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7870 ||:  77%|███████▋  | 804/1050 [03:15<01:06,  3.67it/s][A[A

action_type_accuracy: 0.5669, action_num_pop_accuracy: 0.1807, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7869 ||:  77%|███████▋  | 805/1050 [03:16<01:05,  3.75it/s][A[A

action_type_accuracy: 0.5667, action_num_pop_accuracy: 0.1809, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1286, loss: 1.7868 ||:  77%|███████▋  | 806/10

action_type_accuracy: 0.5671, action_num_pop_accuracy: 0.1819, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1295, loss: 1.8017 ||:  80%|████████  | 842/1050 [03:26<01:10,  2.93it/s][A[A

action_type_accuracy: 0.5671, action_num_pop_accuracy: 0.1819, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1295, loss: 1.8046 ||:  80%|████████  | 843/1050 [03:26<00:59,  3.47it/s][A[A

action_type_accuracy: 0.5671, action_num_pop_accuracy: 0.1819, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1295, loss: 1.8035 ||:  80%|████████  | 844/1050 [03:26<00:54,  3.81it/s][A[A

action_type_accuracy: 0.5672, action_num_pop_accuracy: 0.1820, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1295, loss: 1.8047 ||:  80%|████████  | 845/1050 [03:26<00:57,  3.59it/s][A[A

action_type_accuracy: 0.5671, action_num_pop_accuracy: 0.1819, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1295, loss: 1.8060 ||:  81%|████████  | 846/10

action_type_accuracy: 0.5659, action_num_pop_accuracy: 0.1839, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1301, loss: 1.8461 ||:  84%|████████▍ | 881/1050 [03:36<00:37,  4.56it/s][A[A

action_type_accuracy: 0.5659, action_num_pop_accuracy: 0.1839, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1301, loss: 1.8444 ||:  84%|████████▍ | 882/1050 [03:36<00:31,  5.29it/s][A[A

action_type_accuracy: 0.5661, action_num_pop_accuracy: 0.1841, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1301, loss: 1.8463 ||:  84%|████████▍ | 883/1050 [03:37<00:35,  4.74it/s][A[A

action_type_accuracy: 0.5661, action_num_pop_accuracy: 0.1841, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1301, loss: 1.8482 ||:  84%|████████▍ | 884/1050 [03:37<00:36,  4.60it/s][A[A

action_type_accuracy: 0.5662, action_num_pop_accuracy: 0.1840, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1301, loss: 1.8500 ||:  84%|████████▍ | 885/10

action_type_accuracy: 0.5664, action_num_pop_accuracy: 0.1856, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1351, loss: 1.8669 ||:  88%|████████▊ | 920/1050 [03:48<00:32,  3.94it/s][A[A

action_type_accuracy: 0.5664, action_num_pop_accuracy: 0.1856, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1345, loss: 1.8657 ||:  88%|████████▊ | 921/1050 [03:48<00:36,  3.56it/s][A[A

action_type_accuracy: 0.5664, action_num_pop_accuracy: 0.1856, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1340, loss: 1.8644 ||:  88%|████████▊ | 922/1050 [03:49<00:44,  2.87it/s][A[A

action_type_accuracy: 0.5664, action_num_pop_accuracy: 0.1856, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1335, loss: 1.8631 ||:  88%|████████▊ | 923/1050 [03:49<00:45,  2.77it/s][A[A

action_type_accuracy: 0.5664, action_num_pop_accuracy: 0.1856, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1341, loss: 1.8618 ||:  88%|████████▊ | 924/10

action_type_accuracy: 0.5666, action_num_pop_accuracy: 0.1855, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8247 ||:  91%|█████████▏| 960/1050 [03:58<00:29,  3.03it/s][A[A

action_type_accuracy: 0.5666, action_num_pop_accuracy: 0.1855, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8233 ||:  92%|█████████▏| 961/1050 [03:59<00:36,  2.44it/s][A[A

action_type_accuracy: 0.5666, action_num_pop_accuracy: 0.1855, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8221 ||:  92%|█████████▏| 962/1050 [03:59<00:35,  2.48it/s][A[A

action_type_accuracy: 0.5666, action_num_pop_accuracy: 0.1855, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8209 ||:  92%|█████████▏| 963/1050 [04:00<00:32,  2.69it/s][A[A

action_type_accuracy: 0.5666, action_num_pop_accuracy: 0.1855, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8195 ||:  92%|█████████▏| 964/10

action_type_accuracy: 0.5686, action_num_pop_accuracy: 0.1840, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8021 ||:  95%|█████████▌| 1001/1050 [04:09<00:13,  3.53it/s][A[A

action_type_accuracy: 0.5685, action_num_pop_accuracy: 0.1840, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8026 ||:  95%|█████████▌| 1002/1050 [04:09<00:13,  3.52it/s][A[A

action_type_accuracy: 0.5683, action_num_pop_accuracy: 0.1840, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8031 ||:  96%|█████████▌| 1003/1050 [04:10<00:12,  3.65it/s][A[A

action_type_accuracy: 0.5682, action_num_pop_accuracy: 0.1841, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8036 ||:  96%|█████████▌| 1004/1050 [04:10<00:12,  3.58it/s][A[A

action_type_accuracy: 0.5682, action_num_pop_accuracy: 0.1841, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8040 ||:  96%|█████████▌| 10

action_type_accuracy: 0.5615, action_num_pop_accuracy: 0.1850, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8212 ||:  99%|█████████▉| 1040/1050 [04:20<00:02,  3.96it/s][A[A

action_type_accuracy: 0.5612, action_num_pop_accuracy: 0.1849, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8217 ||:  99%|█████████▉| 1041/1050 [04:20<00:02,  3.72it/s][A[A

action_type_accuracy: 0.5610, action_num_pop_accuracy: 0.1851, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8221 ||:  99%|█████████▉| 1042/1050 [04:20<00:02,  3.40it/s][A[A

action_type_accuracy: 0.5609, action_num_pop_accuracy: 0.1850, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8225 ||:  99%|█████████▉| 1043/1050 [04:21<00:02,  3.41it/s][A[A

action_type_accuracy: 0.5607, action_num_pop_accuracy: 0.1849, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1353, loss: 1.8230 ||:  99%|█████████▉| 10

action_type_accuracy: 0.5582, action_num_pop_accuracy: 0.1851, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1336, loss: 1.8065 ||: : 1082it [04:31,  3.14it/s][A[A

action_type_accuracy: 0.5585, action_num_pop_accuracy: 0.1849, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1336, loss: 1.8076 ||: : 1084it [04:32,  3.75it/s][A[A

action_type_accuracy: 0.5587, action_num_pop_accuracy: 0.1849, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1336, loss: 1.8088 ||: : 1085it [04:32,  3.48it/s][A[A

action_type_accuracy: 0.5588, action_num_pop_accuracy: 0.1849, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1336, loss: 1.8100 ||: : 1086it [04:32,  3.28it/s][A[A

action_type_accuracy: 0.5589, action_num_pop_accuracy: 0.1850, root_label_type_accuracy: 0.2364, child_edges_type_accuracy: 0.1336, loss: 1.8113 ||: : 1087it [04:32,  3.36it/s][A[A

action_type_accuracy: 0.5589, action_num_pop_accuracy: 0.1850, root_label_type_a

  0%|          | 0/873 [00:00<?, ?it/s][A[A

action_type_accuracy: 1.0000, action_num_pop_accuracy: 0.0000, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.9516 ||:   0%|          | 1/873 [00:00<08:51,  1.64it/s][A[A

action_type_accuracy: 1.0000, action_num_pop_accuracy: 0.0000, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 3.3342 ||:   0%|          | 2/873 [00:00<06:45,  2.15it/s][A[A

action_type_accuracy: 1.0000, action_num_pop_accuracy: 0.0000, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 2.4322 ||:   0%|          | 3/873 [00:00<05:09,  2.81it/s][A[A

action_type_accuracy: 1.0000, action_num_pop_accuracy: 0.0000, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss: 1.9664 ||:   0%|          | 4/873 [00:00<04:06,  3.52it/s][A[A

action_type_accuracy: 0.9865, action_num_pop_accuracy: 0.0135, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0000, loss:

action_type_accuracy: 0.7156, action_num_pop_accuracy: 0.1916, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0403, loss: 1.0487 ||:   5%|▌         | 45/873 [00:05<02:14,  6.16it/s][A[A

action_type_accuracy: 0.7156, action_num_pop_accuracy: 0.1916, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0403, loss: 1.0433 ||:   5%|▌         | 46/873 [00:05<01:59,  6.91it/s][A[A

action_type_accuracy: 0.6981, action_num_pop_accuracy: 0.1941, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0403, loss: 1.1060 ||:   5%|▌         | 48/873 [00:06<01:47,  7.65it/s][A[A

action_type_accuracy: 0.6989, action_num_pop_accuracy: 0.1935, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0403, loss: 1.1596 ||:   6%|▌         | 49/873 [00:06<01:40,  8.19it/s][A[A

action_type_accuracy: 0.6989, action_num_pop_accuracy: 0.1935, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.0479, loss: 1.1617 ||:   6%|▌         | 50/873 [00:06<

action_type_accuracy: 0.7063, action_num_pop_accuracy: 0.1508, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1993, loss: 1.2257 ||:  10%|▉         | 85/873 [00:10<01:48,  7.28it/s][A[A

action_type_accuracy: 0.6925, action_num_pop_accuracy: 0.1548, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1993, loss: 1.2449 ||:  10%|▉         | 86/873 [00:10<01:53,  6.94it/s][A[A

action_type_accuracy: 0.6774, action_num_pop_accuracy: 0.1599, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1993, loss: 1.2646 ||:  10%|▉         | 87/873 [00:10<01:57,  6.66it/s][A[A

action_type_accuracy: 0.6675, action_num_pop_accuracy: 0.1633, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1993, loss: 1.2836 ||:  10%|█         | 88/873 [00:10<01:56,  6.74it/s][A[A

action_type_accuracy: 0.6576, action_num_pop_accuracy: 0.1661, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1993, loss: 1.3017 ||:  10%|█         | 89/873 [00:10<

action_type_accuracy: 0.5889, action_num_pop_accuracy: 0.1940, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1948, loss: 1.6190 ||:  15%|█▍        | 130/873 [00:16<01:38,  7.58it/s][A[A

action_type_accuracy: 0.5889, action_num_pop_accuracy: 0.1953, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1948, loss: 1.6305 ||:  15%|█▌        | 131/873 [00:16<01:32,  8.01it/s][A[A

action_type_accuracy: 0.5889, action_num_pop_accuracy: 0.1961, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1948, loss: 1.6420 ||:  15%|█▌        | 132/873 [00:16<01:31,  8.11it/s][A[A

action_type_accuracy: 0.5910, action_num_pop_accuracy: 0.1971, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1948, loss: 1.6525 ||:  15%|█▌        | 133/873 [00:16<01:35,  7.72it/s][A[A

action_type_accuracy: 0.5903, action_num_pop_accuracy: 0.1971, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.1948, loss: 1.6647 ||:  15%|█▌        | 134/873 [0

action_type_accuracy: 0.5669, action_num_pop_accuracy: 0.1989, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.2376, loss: 1.6497 ||:  22%|██▏       | 189/873 [00:56<01:29,  7.64it/s][A
action_type_accuracy: 0.5639, action_num_pop_accuracy: 0.1989, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.2376, loss: 1.6527 ||:  22%|██▏       | 190/873 [00:56<01:31,  7.48it/s][A
action_type_accuracy: 0.5633, action_num_pop_accuracy: 0.1995, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.2376, loss: 1.6543 ||:  22%|██▏       | 191/873 [00:56<01:30,  7.51it/s][A
action_type_accuracy: 0.5619, action_num_pop_accuracy: 0.1994, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.2376, loss: 1.6562 ||:  22%|██▏       | 192/873 [00:56<01:25,  7.95it/s][A
action_type_accuracy: 0.5602, action_num_pop_accuracy: 0.1993, root_label_type_accuracy: 0.0000, child_edges_type_accuracy: 0.2376, loss: 1.6590 ||:  22%|██▏       | 193/873 [00:56<01:21,  8.3

KeyboardInterrupt: 

In [None]:
torch.ones(98, 29, 5).view(-1, 29).shape

In [None]:
vocab.get_token_index()

In [None]:
vocab.get_token_from_index(2, namespace='action_type')

In [None]:
vocab.get_token_from_index(1, namespace='resolve_label_root_label')

In [None]:
token_state = [[26, True, 'H', [[23, True, 'A', [[0, True, 'E', [[0, False, 'The', []]]], [1, True, 'C', [[1, False, 'Lakers', []]]]]], [2, True, 'P', [[2, False, 'advanced', []]]], [25, True, 'A', [[3, True, 'R', [[3, False, 'through', []]]], [4, True, 'E', [[4, False, 'the', []]]], [24, True, 'P', [[5, True, 'T', [[5, False, '1982', []]]], [6, True, 'C', [[6, False, 'playoffs', []]]]]]]]]], [7, True, 'L', [[7, False, 'and', []]]], [8, True, 'P', [[8, False, 'faced', []]]], [9, True, 'A', [[9, False, 'Philadelphia', []]]], [27, True, 'D', [[10, True, 'R', [[10, False, 'for', []]]], [11, True, 'E', [[11, False, 'the', []]]], [12, True, 'Q', [[12, False, 'second', []]]], [13, True, 'C', [[13, False, 'time', []]]]]], [14, True, 'R', [[14, False, 'in', []]]], [15, True, 'Q', [[15, False, 'three', []]]], [16, False, 'years', []]]

In [None]:
pprint.pprint(token_state)

In [None]:
vocab.get_token_from_index(0, namespace='labels')

In [None]:
vocab.get_token_index('RESOLVE')

In [None]:
vocab.get_token_from_index(3, namespace='resolved')

In [None]:
vocab.get_token_from_index(5, namespace='token_node_prev_action')