In [1]:
try:
    __IPYTHON__
    USING_IPYTHON = True
    %load_ext autoreload
    %autoreload 2
except NameError:
    USING_IPYTHON = False

#### Argparse

In [2]:
import argparse
ap = argparse.ArgumentParser()
ap.add_argument('project_root', help='')
ap.add_argument('--mrp-data-dir', default='data', help='')
ap.add_argument('--mrp-test-dir', default='src/tests', help='')
ap.add_argument('--tests-fixtures-file-template', default='fixtures/{}-test.jsonl', help='')

ap.add_argument('--graphviz-sub-dir', default='visualization/graphviz', help='')
ap.add_argument('--train-sub-dir', default='training', help='')
ap.add_argument('--companion-sub-dir', default='companion')
ap.add_argument('--jamr-alignment-file', default='jamr.mrp')

ap.add_argument('--test-input-file', default='evaluation/input.mrp', help='')
ap.add_argument('--test-companion-file', default='evaluation/udpipe.mrp', help='')
ap.add_argument('--allennlp-mrp-json-file-template', default='allennlp-mrp-json-small-{}-{}.jsonl', help='')
ap.add_argument('--data-size-limit', type=int, default=1000, help='')

ap.add_argument('--mrp-file-extension', default='.mrp')
ap.add_argument('--companion-file-extension', default='.conllu')
ap.add_argument('--graphviz-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png')
ap.add_argument('--parse-plot-file-template', default='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.png')

ap.add_argument('--cuda-device', type=int, default=0)


arg_string = """
    /data/proj29_ds1/home/slai/mrp2019
"""
arguments = [arg for arg_line in arg_string.split(r'\\n') for arg in arg_line.split()]

In [3]:
if USING_IPYTHON:
    args = ap.parse_args(arguments)
else:
    args = ap.parse_args()

In [4]:
args

Namespace(allennlp_mrp_json_file_template='allennlp-mrp-json-small-{}-{}.jsonl', companion_file_extension='.conllu', companion_sub_dir='companion', cuda_device=0, data_size_limit=1000, graphviz_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.mrp/{}.png', graphviz_sub_dir='visualization/graphviz', jamr_alignment_file='jamr.mrp', mrp_data_dir='data', mrp_file_extension='.mrp', mrp_test_dir='src/tests', parse_plot_file_template='http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/{}/{}.png', project_root='/data/proj29_ds1/home/slai/mrp2019', test_companion_file='evaluation/udpipe.mrp', test_input_file='evaluation/input.mrp', tests_fixtures_file_template='fixtures/{}-test.jsonl', train_sub_dir='training')

#### Library imports

In [5]:
import json
import logging
import os
import pprint
import re
import string
from collections import Counter, defaultdict, deque

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plot_util
import torch
from action_state import mrp_json2parser_states, _generate_parser_action_states
from action_state import ERROR, APPEND, RESOLVE, IGNORE
from preprocessing import (CompanionParseDataset, MrpDataset, JamrAlignmentDataset,
                           read_companion_parse_json_file, read_mrp_json_file, parse2parse_json)            
from torch import nn
from tqdm import tqdm

#### ipython notebook specific imports

In [6]:
if USING_IPYTHON:
    # matplotlib config
    %matplotlib inline

In [7]:
sh = logging.StreamHandler()
formatter = logging.Formatter('%(levelname)-8s [%(name)s:%(lineno)d] %(message)s')
sh.setFormatter(formatter)
logging.basicConfig(
    level=logging.DEBUG, 
    handlers=[sh]
)
mute_logger_names = ['allennlp.data.iterators.data_iterator']
for logger_name in mute_logger_names:
    logger = logging.getLogger(logger_name)  # pylint: disable=invalid-name
    logger.setLevel(logging.INFO)

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
logger.setLevel(logging.INFO)

### Constants

In [8]:
UNKWOWN = 'UNKWOWN'

### Load data

In [9]:
train_dir = os.path.join(args.project_root, args.mrp_data_dir, args.train_sub_dir)

In [10]:
mrp_dataset = MrpDataset()

In [11]:
frameworks, framework2dataset2mrp_jsons = mrp_dataset.load_mrp_json_dir(
    train_dir, args.mrp_file_extension)

frameworks:   0%|          | 0/5 [00:00<?, ?it/s]
dataset_name:   0%|          | 0/2 [00:00<?, ?it/s][A
dataset_name:  50%|█████     | 1/2 [00:00<00:00,  2.26it/s][A
frameworks:  20%|██        | 1/5 [00:00<00:03,  1.30it/s]s][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  40%|████      | 2/5 [00:04<00:05,  1.80s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  60%|██████    | 3/5 [00:10<00:06,  3.02s/it]t][A
dataset_name:   0%|          | 0/1 [00:00<?, ?it/s][A
frameworks:  80%|████████  | 4/5 [00:16<00:03,  3.87s/it]t][A
dataset_name:   0%|          | 0/14 [00:00<?, ?it/s][A
dataset_name:  36%|███▌      | 5/14 [00:00<00:00, 48.30it/s][A
dataset_name:  50%|█████     | 7/14 [00:00<00:00, 19.18it/s][A
dataset_name:  64%|██████▍   | 9/14 [00:00<00:00, 16.33it/s][A
dataset_name:  79%|███████▊  | 11/14 [00:01<00:00,  5.05it/s][A
frameworks: 100%|██████████| 5/5 [00:18<00:00,  3.20s/it]t/s][A


In [12]:
framework2dataset2mrp_jsons.keys()

dict_keys(['ucca', 'psd', 'eds', 'dm', 'amr'])

### Data Preprocessing companion

In [13]:
companion_dir = os.path.join(args.project_root, args.mrp_data_dir, args.companion_sub_dir)

In [14]:
cparse_dataset = CompanionParseDataset()

In [15]:
dataset2cid2parse = cparse_dataset.load_companion_parse_dir(companion_dir, args.companion_file_extension)

INFO     [preprocessing:179] framework amr found
dataset: 100%|██████████| 13/13 [00:01<00:00,  9.58it/s]
INFO     [preprocessing:179] framework dm found
dataset: 100%|██████████| 5/5 [00:04<00:00,  1.06it/s]
INFO     [preprocessing:179] framework ucca found
dataset: 100%|██████████| 6/6 [00:00<00:00, 25.39it/s]


In [16]:
dataset2cid2parse_json = cparse_dataset.convert_parse2parse_json()

In [17]:
dataset2cid2parse.keys()

dict_keys(['amr-guidelines', 'bolt', 'cctv', 'dfa', 'dfb', 'fables', 'lorelei', 'mt09sdl', 'proxy', 'rte', 'wb', 'wiki', 'xinhua', 'wsj', 'ewt'])

In [18]:
# Some data is missing
'20003001' in dataset2cid2parse['wsj']

False

### Load JAMR alignment data

In [19]:
jalignment_dataset = JamrAlignmentDataset()

In [20]:
cid2alignment = jalignment_dataset.load_jamr_alignment_file(os.path.join(
    args.project_root,
    args.mrp_data_dir,
    args.companion_sub_dir,
    args.jamr_alignment_file
))

### Load testing data

In [21]:
test_input_filename = os.path.join(args.project_root, args.mrp_data_dir, args.test_input_file)
test_companion_filename = os.path.join(args.project_root, args.mrp_data_dir, args.test_companion_file)

In [22]:
test_mrp_jsons = read_mrp_json_file(test_input_filename)
test_parse_jsons = read_companion_parse_json_file(test_companion_filename)

In [23]:
parse_json = test_parse_jsons['102990']

In [24]:
mrp_json = framework2dataset2mrp_jsons['psd']['wsj'][1]

In [25]:
test_configs = [
    ('ucca', 'wiki', 70),
]
framework, dataset, idx = test_configs[0]

In [26]:
mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
cid = mrp_json.get('id')

In [27]:
parse_json = dataset2cid2parse_json[dataset][cid]

In [28]:
doc = mrp_json['input']

In [29]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [30]:
token_pos = 0
anchors = []
char_pos2tokenized_parse_node_id = []

for node_id, node in enumerate(parse_json.get('nodes')):
    label = node.get('label')
    label_size = len(label)
    while doc[token_pos] == ' ':
        token_pos += 1
        char_pos2tokenized_parse_node_id.append(node_id)
    anchors.append((token_pos, token_pos + label_size))
    char_pos2tokenized_parse_node_id.extend([node_id] * (label_size))
    print(node_id, doc[token_pos: token_pos + label_size], anchors[-1], len(char_pos2tokenized_parse_node_id))
    token_pos += label_size

0 In (0, 2) 2
1 the (3, 6) 6
2 final (7, 12) 12
3 minute (13, 19) 19
4 of (20, 22) 22
5 the (23, 26) 26
6 game (27, 31) 31
7 , (31, 32) 32
8 Johnson (33, 40) 40
9 had (41, 44) 44
10 the (45, 48) 48
11 ball (49, 53) 53
12 stolen (54, 60) 60
13 by (61, 63) 63
14 Celtics (64, 71) 71
15 center (72, 78) 78
16 Robert (79, 85) 85
17 Parish (86, 92) 92
18 , (92, 93) 93
19 and (94, 97) 97
20 then (98, 102) 102
21 missed (103, 109) 109
22 two (110, 113) 113
23 free (114, 118) 118
24 throws (119, 125) 125
25 that (126, 130) 130
26 could (131, 136) 136
27 have (137, 141) 141
28 won (142, 145) 145
29 the (146, 149) 149
30 game (150, 154) 154
31 . (154, 155) 155


In [31]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [32]:
len(char_pos2tokenized_parse_node_id)

155

In [33]:
doc = mrp_json['input']

In [34]:
mrp_json['tops']

[34]

In [35]:
mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
    mrp_json, 
    tokenized_parse_nodes=parse_json['nodes'],
)

In [36]:
(
    doc,
    nodes,
    node_id2node,
    edge_id2edge,
    top_oriented_edges,
    token_nodes,
    # abstract_node_id_set,
    parent_id2indegree,
    # parent_id2child_id_set,
    # child_id2parent_id_set,
    # child_id2edge_id_set,
    # parent_id2edge_id_set,
    # parent_child_id2edge_id_set,
    parse_nodes_anchors,
    char_pos2tokenized_node_id,
    curr_node_ids,
    token_states,
    actions,
) = mrp_meta_data

In [37]:
curr_node_ids = mrp_meta_data[-3]
token_states = mrp_meta_data[-2]
actions = mrp_meta_data[-1]

In [38]:
*_, curr_node_ids, token_states, actions = mrp_meta_data

In [39]:
actions[:4]

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'anchors': [{'from': 0, 'to': 2}],
    'label': 'In',
    'propagate_label': 'R'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'anchors': [{'from': 3, 'to': 6}],
    'label': 'the',
    'propagate_label': 'E'},
   [[]]))]

In [40]:
for curr_node_id, action, token_state in zip(curr_node_ids, actions, token_states):
    action_type, params = action
#     pprint.pprint((curr_node_id, action[0]))
#     pprint.pprint(([token_group[:4] for token_group in token_state]))
    pprint.pprint((curr_node_id, action[0], [token_group[:4] for token_group in token_state]))

(0, 0, [(0, False, 'In', [])])
(1, 1, [(0, True, 'R', [(0, False, 'In', [])])])
(1, 0, [(0, True, 'R', [(0, False, 'In', [])]), (1, False, 'the', [])])
(2,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])])])
(2,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, False, 'final', [])])
(3,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])])])
(3,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, False, 'minute', [])])
(4,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, True, 'C', [(3, False, 'minute', [])])])
(4,
 1,
 [(32,
   True,
   'E',
   [(0, True, 'R', [(0, False, 'In', [])]),
    (1, True, 'E', [(1, False, 'the', [])]),
    (2, True, 'E', [(

       True,
       'E',
       [(0, True, 'R', [(0, False, 'In', [])]),
        (1, True, 'E', [(1, False, 'the', [])]),
        (2, True, 'E', [(2, False, 'final', [])]),
        (3, True, 'C', [(3, False, 'minute', [])])]),
      (4, True, 'R', [(4, False, 'of', [])]),
      (5, True, 'E', [(5, False, 'the', [])]),
      (6, True, 'C', [(6, False, 'game', [])])]),
    (7, True, 'U', [(7, False, ',', [])]),
    (8, True, 'A', [(8, False, 'Johnson', [])]),
    (9, True, 'F', [(9, False, 'had', [])]),
    (34,
     True,
     'A',
     [(10, True, 'E', [(10, False, 'the', [])]),
      (11, True, 'C', [(11, False, 'ball', [])])]),
    (12, True, 'P', [(12, False, 'stolen', [])]),
    (36,
     True,
     'A',
     [(13, True, 'R', [(13, False, 'by', [])]),
      (35,
       True,
       'E',
       [(14, True, 'A', [(14, False, 'Celtics', [])]),
        (15, True, 'S', [(15, False, 'center', [])])]),
      (16, True, 'C', [(16, False, 'Robert', [])])])]),
  (18, True, 'U', [(18, False, 

    (24, True, 'P', [(24, False, 'throws', [])])]),
  (25, True, 'L', [(25, False, 'that', [])]),
  (26, True, 'D', [(26, False, 'could', [])]),
  (27, True, 'F', [(27, False, 'have', [])]),
  (28, True, 'P', [(28, False, 'won', [])]),
  (39,
   True,
   'A',
   [(29, True, 'E', [(29, False, 'the', [])]),
    (30, True, 'C', [(30, False, 'game', [])]),
    (31, True, 'U', [(31, False, '.', [])])])])
(32,
 1,
 [(37,
   True,
   'H',
   [(33,
     True,
     'T',
     [(32,
       True,
       'E',
       [(0, True, 'R', [(0, False, 'In', [])]),
        (1, True, 'E', [(1, False, 'the', [])]),
        (2, True, 'E', [(2, False, 'final', [])]),
        (3, True, 'C', [(3, False, 'minute', [])])]),
      (4, True, 'R', [(4, False, 'of', [])]),
      (5, True, 'E', [(5, False, 'the', [])]),
      (6, True, 'C', [(6, False, 'game', [])])]),
    (7, True, 'U', [(7, False, ',', [])]),
    (8, True, 'A', [(8, False, 'Johnson', [])]),
    (9, True, 'F', [(9, False, 'had', [])]),
    (34,
     Tr

In [41]:
for curr_node_id, action, token_state in zip(curr_node_ids, actions, [[]] + token_states):
    action_type, params = action
    pprint.pprint((curr_node_id, action[0], [token_group[:4] for token_group in token_state]))

(0, 0, [])
(1, 1, [(0, False, 'In', [])])
(1, 0, [(0, True, 'R', [(0, False, 'In', [])])])
(2, 1, [(0, True, 'R', [(0, False, 'In', [])]), (1, False, 'the', [])])
(2,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])])])
(3,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, False, 'final', [])])
(3,
 0,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])])])
(4,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, False, 'minute', [])])
(4,
 1,
 [(0, True, 'R', [(0, False, 'In', [])]),
  (1, True, 'E', [(1, False, 'the', [])]),
  (2, True, 'E', [(2, False, 'final', [])]),
  (3, True, 'C', [(3, False, 'minute', [])])])
(4,
 0,
 [(32,
   True,
   'E',
   [(0, True, 'R', [(0, False, 'In', [])]),
    (1, True, 'E', [(1, False, 'the', [])]),
    (2, Tr

   [(33,
     True,
     'T',
     [(32,
       True,
       'E',
       [(0, True, 'R', [(0, False, 'In', [])]),
        (1, True, 'E', [(1, False, 'the', [])]),
        (2, True, 'E', [(2, False, 'final', [])]),
        (3, True, 'C', [(3, False, 'minute', [])])]),
      (4, True, 'R', [(4, False, 'of', [])]),
      (5, True, 'E', [(5, False, 'the', [])]),
      (6, True, 'C', [(6, False, 'game', [])])]),
    (7, True, 'U', [(7, False, ',', [])]),
    (8, True, 'A', [(8, False, 'Johnson', [])]),
    (9, True, 'F', [(9, False, 'had', [])]),
    (34,
     True,
     'A',
     [(10, True, 'E', [(10, False, 'the', [])]),
      (11, True, 'C', [(11, False, 'ball', [])])]),
    (12, True, 'P', [(12, False, 'stolen', [])]),
    (36,
     True,
     'A',
     [(13, True, 'R', [(13, False, 'by', [])]),
      (35,
       True,
       'E',
       [(14, True, 'A', [(14, False, 'Celtics', [])]),
        (15, True, 'S', [(15, False, 'center', [])])]),
      (16, True, 'C', [(16, False, 'Robert', [

     [(10, True, 'E', [(10, False, 'the', [])]),
      (11, True, 'C', [(11, False, 'ball', [])])]),
    (12, True, 'P', [(12, False, 'stolen', [])]),
    (36,
     True,
     'A',
     [(13, True, 'R', [(13, False, 'by', [])]),
      (35,
       True,
       'E',
       [(14, True, 'A', [(14, False, 'Celtics', [])]),
        (15, True, 'S', [(15, False, 'center', [])])]),
      (16, True, 'C', [(16, False, 'Robert', [])])])]),
  (18, True, 'U', [(18, False, ',', [])]),
  (19, True, 'L', [(19, False, 'and', [])]),
  (20, True, 'L', [(20, False, 'then', [])]),
  (38,
   True,
   'H',
   [(21, True, 'D', [(21, False, 'missed', [])]),
    (22, True, 'D', [(22, False, 'two', [])]),
    (23, True, 'D', [(23, False, 'free', [])]),
    (24, True, 'P', [(24, False, 'throws', [])])]),
  (25, True, 'L', [(25, False, 'that', [])]),
  (26, False, 'could', [])])
(27,
 0,
 [(37,
   True,
   'H',
   [(33,
     True,
     'T',
     [(32,
       True,
       'E',
       [(0, True, 'R', [(0, False, 'In'

In [42]:
actions

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'anchors': [{'from': 0, 'to': 2}],
    'label': 'In',
    'propagate_label': 'R'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'anchors': [{'from': 3, 'to': 6}],
    'label': 'the',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 2,
    'anchors': [{'from': 7, 'to': 12}],
    'label': 'final',
    'propagate_label': 'E'},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 3,
    'anchors': [{'from': 13, 'to': 19}],
    'label': 'minute',
    'propagate_label': 'C'},
   [[]])),
 (1,
  (4,
   3,
   {'id': 31, 'propagate_label': 'E'},
   [[{'source': 31,
      'target': 0,
      'label': 'R',
      'id': 31,
      'parent': 31,
      'child': 0}],
    [{'source': 31,
      'target': 1,
      'label': 'E',
      'id': 28,
      'parent': 31,
      'child': 1}],
    [{'source': 31,
      'target': 2,
      'label': 'E',
      'id': 21,
      'parent': 31,
      'child': 2}],
    [{'source': 31,
      't

In [43]:
token_states[1]

[(0, True, 'R', [(0, False, 'In', [])])]

In [44]:
[n['label'] for n in parse_json['nodes']]

['In',
 'the',
 'final',
 'minute',
 'of',
 'the',
 'game',
 ',',
 'Johnson',
 'had',
 'the',
 'ball',
 'stolen',
 'by',
 'Celtics',
 'center',
 'Robert',
 'Parish',
 ',',
 'and',
 'then',
 'missed',
 'two',
 'free',
 'throws',
 'that',
 'could',
 'have',
 'won',
 'the',
 'game',
 '.']

In [45]:
token_states[-1]

[(42,
  True,
  '<UCCA-TOP-NODE>',
  [(37,
    True,
    'H',
    [(33,
      True,
      'T',
      [(32,
        True,
        'E',
        [(0, True, 'R', [(0, False, 'In', [])]),
         (1, True, 'E', [(1, False, 'the', [])]),
         (2, True, 'E', [(2, False, 'final', [])]),
         (3, True, 'C', [(3, False, 'minute', [])])]),
       (4, True, 'R', [(4, False, 'of', [])]),
       (5, True, 'E', [(5, False, 'the', [])]),
       (6, True, 'C', [(6, False, 'game', [])])]),
     (7, True, 'U', [(7, False, ',', [])]),
     (8, True, 'A', [(8, False, 'Johnson', [])]),
     (9, True, 'F', [(9, False, 'had', [])]),
     (34,
      True,
      'A',
      [(10, True, 'E', [(10, False, 'the', [])]),
       (11, True, 'C', [(11, False, 'ball', [])])]),
     (12, True, 'P', [(12, False, 'stolen', [])]),
     (36,
      True,
      'A',
      [(13, True, 'R', [(13, False, 'by', [])]),
       (35,
        True,
        'E',
        [(14, True, 'A', [(14, False, 'Celtics', [])]),
         (

In [46]:
companion_parser_states, companion_meta_data = mrp_json2parser_states(
    parse_json,
    mrp_doc=doc,
    tokenized_parse_nodes=parse_json['nodes'],
)

In [47]:
logger.info(args.graphviz_file_template.format(
    framework, dataset, cid))

INFO     [__main__:2] http://localhost:8000/files/proj29_ds1/home/slai/mrp2019/visualization/graphviz/ucca/wiki.mrp/470004.png


In [48]:
mrp_json['input']

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [49]:
mrp_parser_states

[(0,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 0,
      'anchors': [{'from': 0, 'to': 2}],
      'label': 'In',
      'propagate_label': 'R'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')])]),
 (1,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 1,
      'anchors': [{'from': 3, 'to': 6}],
      'label': 'the',
      'propagate_label': 'E'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)]), (1, 1, [(1, 1, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')]),
   (1, True, 'E', [(1, False, 'the', 'the')])]),
 (2,
  [(0, None),
   (1,
    (1,
     0,
     {'id': 2,
      'anchors': [{'from': 7, 'to': 12}],
      'label': 'final',
      'propagate_label': 'E'},
     [[]]))],
  [],
  [],
  [],
  [(0, 0, [(0, 0, None)]), (1, 1, [(1, 1, None)]), (2, 2, [(2, 2, None)])],
  [(0, True, 'R', [(0, False, 'In', 'In')]),
   (1, True, 'E', [(1, False, 'the', 'the')]),
   (2, True, 'E', [(2, False, 'final', 'final')])]),
 (3,
  

In [50]:
[(node['id'], node.get('label')) for node in mrp_json['nodes']]

[(0, 'In'),
 (1, 'the'),
 (2, 'final'),
 (3, 'minute'),
 (4, 'of'),
 (5, 'the'),
 (6, 'game'),
 (7, ','),
 (8, 'Johnson'),
 (9, 'had'),
 (10, 'the'),
 (11, 'ball'),
 (12, 'stolen'),
 (13, 'by'),
 (14, 'Celtics'),
 (15, 'center'),
 (16, 'RobertParish'),
 (17, ','),
 (18, 'and'),
 (19, 'then'),
 (20, 'missed'),
 (21, 'two'),
 (22, 'free'),
 (23, 'throws'),
 (24, 'that'),
 (25, 'could'),
 (26, 'have'),
 (27, 'won'),
 (28, 'the'),
 (29, 'game'),
 (30, '.'),
 (31, None),
 (32, None),
 (33, None),
 (34, None),
 (35, None),
 (36, None),
 (37, None),
 (38, None),
 (39, None),
 (40, None),
 (41, None)]

In [51]:
doc

'In the final minute of the game, Johnson had the ball stolen by Celtics center Robert Parish, and then missed two free throws that could have won the game.'

In [52]:
parse_json['nodes']

[{'id': 0,
  'label': 'In',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['in', 'ADP', 'IN']},
 {'id': 1,
  'label': 'the',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['the', 'DET', 'DT']},
 {'id': 2,
  'label': 'final',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['final', 'ADJ', 'JJ']},
 {'id': 3,
  'label': 'minute',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['minute', 'NOUN', 'NN']},
 {'id': 4,
  'label': 'of',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['of', 'ADP', 'IN']},
 {'id': 5,
  'label': 'the',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['the', 'DET', 'DT']},
 {'id': 6,
  'label': 'game',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['game', 'NOUN', 'NN']},
 {'id': 7,
  'label': ',',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': [',', 'PUNCT', ',']},
 {'id': 8,
  'label': 'Johnson',
  'properties': ['lemma', 'upos', 'xpos'],
  'values': ['Johnson', 'PROPN', 'NNP']},
 {'id': 9,
  'label

In [53]:
[(node['id'], node['label']) for node in parse_json['nodes']]

[(0, 'In'),
 (1, 'the'),
 (2, 'final'),
 (3, 'minute'),
 (4, 'of'),
 (5, 'the'),
 (6, 'game'),
 (7, ','),
 (8, 'Johnson'),
 (9, 'had'),
 (10, 'the'),
 (11, 'ball'),
 (12, 'stolen'),
 (13, 'by'),
 (14, 'Celtics'),
 (15, 'center'),
 (16, 'Robert'),
 (17, 'Parish'),
 (18, ','),
 (19, 'and'),
 (20, 'then'),
 (21, 'missed'),
 (22, 'two'),
 (23, 'free'),
 (24, 'throws'),
 (25, 'that'),
 (26, 'could'),
 (27, 'have'),
 (28, 'won'),
 (29, 'the'),
 (30, 'game'),
 (31, '.')]

In [54]:
anchors

[(0, 2),
 (3, 6),
 (7, 12),
 (13, 19),
 (20, 22),
 (23, 26),
 (27, 31),
 (31, 32),
 (33, 40),
 (41, 44),
 (45, 48),
 (49, 53),
 (54, 60),
 (61, 63),
 (64, 71),
 (72, 78),
 (79, 85),
 (86, 92),
 (92, 93),
 (94, 97),
 (98, 102),
 (103, 109),
 (110, 113),
 (114, 118),
 (119, 125),
 (126, 130),
 (131, 136),
 (137, 141),
 (142, 145),
 (146, 149),
 (150, 154),
 (154, 155)]

### Create training instance

In [55]:
# framework = 'ucca'
# ignore_framework_set = {'amr', 'dm', 'psd', 'eds'}
# dataset = 'wiki'
# ignore_dataset_set = {}

# framework = 'dm'
# ignore_framework_set = {'amr', 'ucca', 'psd', 'eds'}
# dataset = 'wsj'
# ignore_dataset_set = {}

framework = 'ucca'
ignore_framework_set = {'amr', 'psd', 'eds'}
dataset = 'wiki'
ignore_dataset_set = {}

In [56]:
frameworks

['ucca', 'psd', 'eds', 'dm', 'amr']

In [57]:
framework_names = '-'.join([
    framework 
    for framework in frameworks 
    if framework not in ignore_framework_set
])
framework_names

'ucca-dm'

In [58]:
allennlp_tests_fixtures_output_file = os.path.join(
    args.project_root, args.mrp_test_dir, args.tests_fixtures_file_template.format(framework_names))

allennlp_framework_train_output_file = os.path.join(
    args.project_root, args.allennlp_mrp_json_file_template.format(framework_names, 'train'))

allennlp_framework_test_output_file = os.path.join(
    args.project_root, args.allennlp_mrp_json_file_template.format(framework_names, 'test'))

In [59]:
# Create tests fixture jsonl
fixture_combinations = [
    ('ucca', 'wiki', 70),
    ('dm', 'wsj', 3)
] * 5

with open(allennlp_tests_fixtures_output_file, 'w') as wf:
    for framework, dataset, idx in fixture_combinations:
        mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
        cid = mrp_json.get('id')
        doc = mrp_json.get('input')
        
        alignment = {}
        if framework == 'amr':
            alignment = cid2alignment[cid]  
        parse_json = dataset2cid2parse_json.get(dataset, {}).get(cid, {})

        if parse_json:
            mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
                mrp_json, 
                tokenized_parse_nodes=parse_json['nodes'],
                alignment=alignment,
            )
            companion_parser_states, companion_meta_data = mrp_json2parser_states(
                parse_json, 
                mrp_doc=doc,
                tokenized_parse_nodes=parse_json['nodes'],
            )

            data_instance = {
                'mrp_json': mrp_json,
                'parse_json': parse_json,
                'mrp_parser_states': mrp_parser_states,
                'mrp_meta_data': mrp_meta_data,
                'companion_parser_states': companion_parser_states,
                'companion_meta_data': companion_meta_data,
            }
            json_encoded_instance = json.dumps(data_instance)
            wf.write(json_encoded_instance + '\n')

In [60]:
for idx in range(20):
    mrp_json = framework2dataset2mrp_jsons[framework][dataset][idx]
    cid = mrp_json.get('id')
    if cid in dataset2cid2parse[dataset]:
        print(idx)

3
8
13
18


In [61]:
[state[-1] for state in mrp_parser_states]

[[(0, True, 'the', [(0, False, 'the', 'The')])],
 [(0, True, 'the', [(0, False, 'the', 'The')]),
  (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')])],
 [(0, True, 'the', [(0, False, 'the', 'The')]),
  (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')]),
  (2, False, 'fiber', 'fiber')],
 [(0, True, 'the', [(0, False, 'the', 'The')]),
  (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')]),
  (2, False, 'fiber', 'fiber'),
  (4, True, 'crocidolite', [(4, False, 'crocidolite', 'crocidolite')])],
 [(2,
   True,
   'fiber',
   [(0, True, 'the', [(0, False, 'the', 'The')]),
    (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')]),
    (2, False, 'fiber', 'fiber'),
    (4, True, 'crocidolite', [(4, False, 'crocidolite', 'crocidolite')])])],
 [(2,
   True,
   'fiber',
   [(0, True, 'the', [(0, False, 'the', 'The')]),
    (1, True, 'asbestos', [(1, False, 'asbestos', 'asbestos')]),
    (2, False, 'fiber', 'fiber'),
    (4, True, 'crocidolite', [(4, False, 'crocidol

In [62]:
mrp_meta_data[-1]

[(0, None),
 (1,
  (1,
   0,
   {'id': 0,
    'label': 'the',
    'properties': ['pos', 'frame'],
    'values': ['DT', 'q:i-h-h'],
    'anchors': [{'from': 0, 'to': 3}]},
   [[]])),
 (0, None),
 (1,
  (1,
   0,
   {'id': 1,
    'label': 'asbestos',
    'properties': ['pos', 'frame'],
    'values': ['NN', 'n:x'],
    'anchors': [{'from': 4, 'to': 12}]},
   [[]])),
 (0, None),
 (2, None),
 (0, None),
 (1,
  (1,
   0,
   {'id': 4,
    'label': 'crocidolite',
    'properties': ['pos', 'frame'],
    'values': ['NN', 'n:x'],
    'anchors': [{'from': 20, 'to': 31}]},
   [[]])),
 (1,
  (4,
   2,
   {'id': 2,
    'label': 'fiber',
    'properties': ['pos', 'frame'],
    'values': ['NN', 'n:x'],
    'anchors': [{'from': 13, 'to': 18}]},
   [[{'source': 0,
      'target': 2,
      'label': 'BV',
      'id': 2,
      'parent': 2,
      'child': 0}],
    [{'source': 1,
      'target': 2,
      'label': 'compound',
      'id': 1,
      'parent': 2,
      'child': 1}],
    [],
    [{'source': 4,
    

In [63]:
doc

'The asbestos fiber, crocidolite, is unusually resilient once it enters the lungs, with even brief exposures to it causing symptoms that show up decades later, researchers said.'

In [64]:
parse_json

{'id': '20003002',
 'tops': [30],
 'nodes': [{'id': 0,
   'label': 'The',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['the', 'DET', 'DT']},
  {'id': 1,
   'label': 'asbestos',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['asbestos', 'NOUN', 'NN']},
  {'id': 2,
   'label': 'fiber',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['fiber', 'NOUN', 'NN']},
  {'id': 3,
   'label': ',',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': [',', 'PUNCT', ',']},
  {'id': 4,
   'label': 'crocidolite',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['crocidolite', 'NOUN', 'NN']},
  {'id': 5,
   'label': ',',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': [',', 'PUNCT', ',']},
  {'id': 6,
   'label': 'is',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['be', 'VERB', 'VBZ']},
  {'id': 7,
   'label': 'unusually',
   'properties': ['lemma', 'upos', 'xpos'],
   'values': ['unusually', 'ADV', 'RB']},
  {'id': 8,
   'label': 'resil

In [65]:
[n['values'][2] for n in parse_json['nodes']]

['DT',
 'NN',
 'NN',
 ',',
 'NN',
 ',',
 'VBZ',
 'RB',
 'JJ',
 'IN',
 'PRP',
 'VBZ',
 'DT',
 'NNS',
 ',',
 'IN',
 'RB',
 'JJ',
 'NNS',
 'TO',
 'PRP',
 'VBG',
 'NNS',
 'WDT',
 'VBP',
 'RP',
 'NNS',
 'RB',
 ',',
 'NNS',
 'VBD',
 '.']

In [66]:
# Create train jsonl
if os.path.isfile(allennlp_framework_train_output_file) and os.path.isfile(
    allennlp_framework_train_output_file):
    logger.info('allennlp_train_output_file found, stop generation')
else:
    pass
if 1==1:
    data_size = 0
    with open(allennlp_framework_train_output_file, 'w') as train_wf:
        with open(allennlp_framework_test_output_file, 'w') as test_wf:
            for _, dataset, idx, mrp_json in tqdm(mrp_dataset.mrp_json_generator(
                ignore_framework_set=ignore_framework_set,
                ignore_dataset_set=ignore_dataset_set,
                data_size_limit=args.data_size_limit * 2
            )):
                cid = mrp_json.get('id')
                doc = mrp_json.get('input')

                framework = mrp_json.get('framework')
                alignment = {}
                if framework == 'amr':
                    alignment = cid2alignment[cid]  
                parse_json = dataset2cid2parse_json.get(dataset, {}).get(cid, {})

                if parse_json:
                    mrp_parser_states, mrp_meta_data = mrp_json2parser_states(
                        mrp_json, 
                        tokenized_parse_nodes=parse_json['nodes'],
                        alignment=alignment,
                    )
                    companion_parser_states, companion_meta_data = mrp_json2parser_states(
                        parse_json, 
                        mrp_doc=doc,
                        tokenized_parse_nodes=parse_json['nodes'],
                    )

                    # Continue if error
                    if not mrp_parser_states:
                        continue

                    data_instance = {
                        'mrp_json': mrp_json,
                        'parse_json': parse_json,
                        'mrp_parser_states': mrp_parser_states,
                        'mrp_meta_data': mrp_meta_data,
                        'companion_parser_states': companion_parser_states,
                        'companion_meta_data': companion_meta_data,
                    }
                    json_encoded_instance = json.dumps(data_instance)
                    if idx <= args.data_size_limit:
                        train_wf.write(json_encoded_instance + '\n')
                    else:
                        test_wf.write(json_encoded_instance + '\n')

INFO     [__main__:4] allennlp_train_output_file found, stop generation




5662it [07:16, 25.98it/s]


### Test allennlp dataset reader

In [67]:
import torch.optim as optim

from mrp_library.dataset_readers.mrp_jsons import MRPDatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.modules.feedforward import FeedForward

from allennlp.training.metrics import CategoricalAccuracy

from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer

import json
import logging
from typing import Dict

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import LabelField, TextField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer
from allennlp.models import Model
from overrides import overrides

INFO     [pytorch_pretrained_bert.modeling:230] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .
DEBUG    [allennlp.common.registrable:56] instantiating registered subclass relu of <class 'allennlp.nn.activations.Activation'>
DEBUG    [allennlp.common.registrable:56] instantiating registered subclass relu of <class 'allennlp.nn.activations.Activation'>
DEBUG    [allennlp.common.registrable:56] instantiating registered subclass relu of <class 'allennlp.nn.activations.Activation'>
DEBUG    [allennlp.common.registrable:56] instantiating registered subclass relu of <class 'allennlp.nn.activations.Activation'>


In [68]:
from mrp_library.dataset_readers.mrp_jsons_actions import MRPDatasetActionReader

In [69]:
reader = MRPDatasetActionReader()

In [70]:
train_dataset = reader.read(cached_path(allennlp_framework_train_output_file))

0it [00:00, ?it/s]INFO     [mrp_library.dataset_readers.mrp_jsons_actions:113] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/allennlp-mrp-json-small-ucca-dm-train.jsonl
49005it [00:19, 2577.62it/s]


In [71]:
# test_dataset = reader.read(cached_path(allennlp_train_output_file))
test_dataset = reader.read(cached_path(allennlp_framework_test_output_file))

0it [00:00, ?it/s]INFO     [mrp_library.dataset_readers.mrp_jsons_actions:113] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/allennlp-mrp-json-small-ucca-dm-test.jsonl
35907it [00:15, 2282.11it/s]


In [72]:
tests_fixtures_dataset = reader.read(cached_path(allennlp_tests_fixtures_output_file))

0it [00:00, ?it/s]INFO     [mrp_library.dataset_readers.mrp_jsons_actions:113] Reading instances from lines in file at: /data/proj29_ds1/home/slai/mrp2019/src/tests/fixtures/ucca-dm-test.jsonl
965it [00:00, 5526.64it/s]


In [73]:
vocab = Vocabulary.from_instances(train_dataset + test_dataset + tests_fixtures_dataset)

INFO     [allennlp.data.vocabulary:396] Fitting token dictionary from dataset.
100%|██████████| 85877/85877 [00:09<00:00, 9301.44it/s] 


In [74]:
vocab.print_statistics()

INFO     [allennlp.data.vocabulary:664] Printed vocabulary statistics are only for the part of the vocabulary generated from instances. If vocabulary is constructed by extending saved vocabulary with dataset instances, the directly loaded portion won't be considered here.




----Vocabulary Statistics----


Top 10 most frequent tokens in namespace 'word':
	Token: <START-WORD>		Frequency: 858770
	Token: <END-WORD>		Frequency: 858770
	Token: ,		Frequency: 297964
	Token: the		Frequency: 238314
	Token: .		Frequency: 171118
	Token: and		Frequency: 117392
	Token: in		Frequency: 111534
	Token: a		Frequency: 108434
	Token: of		Frequency: 105518
	Token: to		Frequency: 82726

Top 10 longest tokens in namespace 'word':
	Token: Bridgestone/Firestone		length: 21	Frequency: 69
	Token: Bridgestone/fiRestone		length: 21	Frequency: 69
	Token: dollar-denominated		length: 18	Frequency: 70
	Token: Corton-Charlemagne		length: 18	Frequency: 42
	Token: Corton-CHARlemagne		length: 18	Frequency: 42
	Token: substance-abusing		length: 17	Frequency: 152
	Token: extraterrestrial		length: 16	Frequency: 188
	Token: sesquicentennial		length: 16	Frequency: 178
	Token: price-depressing		length: 16	Frequency: 178
	Token: interest-bearing		length: 16	Frequency: 174

Top 10 shortest tokens i

In [75]:
vocab.get_vocab_size('token_node_label')

5563

In [76]:
vocab.get_vocab_size('word')

6601

In [77]:
vocab.get_vocab_size('pos')

64

In [78]:
vocab.get_vocab_size('label')

2

In [79]:
EMBEDDING_DIM = 100
HIDDEN_DIM = 50

### Test model

In [80]:
from mrp_library.models.generalizer import ActionGeneralizer
from mrp_library.iterators.same_instance_type_framework_stack_len_iterator import SameInstanceTypeFrameworkStackLenIterator

from allennlp.nn import InitializerApplicator, RegularizerApplicator, util
from allennlp.nn.activations import Activation
from allennlp.common.params import Params
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.modules.seq2vec_encoders.pytorch_seq2vec_wrapper import PytorchSeq2VecWrapper

In [81]:
if torch.cuda.is_available() and False:
    cuda_device = args.cuda_device
else:
    cuda_device = -1
cuda_device

-1

In [82]:
def _cuda(variable, cuda_device):
    if cuda_device != -1:
        variable = variable.cuda(cuda_device)
    return variable

In [83]:
field_types = ['word', 'pos', 'resolved', 'token_node_label', 'token_node_prev_action']
field_type2embedder = {}
field_type2seq2vec_encoder = {}
field_type2seq2seq_encoder = {}

for field_type in field_types:
    embedding = _cuda(
        Embedding(
            num_embeddings=vocab.get_vocab_size(field_type),
            embedding_dim=EMBEDDING_DIM
        ), cuda_device)
    embedder = BasicTextFieldEmbedder({field_type: embedding})
    field_type2embedder[field_type] = embedder
    
    field_type2seq2vec_encoder[field_type] = PytorchSeq2VecWrapper(
        _cuda(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True), cuda_device)
    )
    field_type2seq2seq_encoder[field_type] = PytorchSeq2SeqWrapper(
        _cuda(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True), cuda_device)
    )

In [84]:
# word_embedding = Embedding(num_embeddings=vocab.get_vocab_size('word'),
#                             embedding_dim=EMBEDDING_DIM)
# pos_embedding = Embedding(num_embeddings=vocab.get_vocab_size('pos'),
#                             embedding_dim=EMBEDDING_DIM)

# word_embedder = BasicTextFieldEmbedder({
#     "word": word_embedding,
#     "pos": pos_embedding,
# })
# parse_label = {
#     'word': torch.LongTensor(
#         [
#             [ 1,  0,  3,  7,  2,  9,  4],
#             [ 0,  0,  5,  0,  0,  0,  4]
#         ]
#     ),
#     'pos': torch.LongTensor(
#         [
#             [ 1,  0,  3,  7,  2,  9,  4],
#             [ 0,  0,  5,  0,  0,  0,  4]
#         ]
#     )
# }

In [85]:
# embedded_parse_label = word_embedder(parse_label)

In [86]:
# embedded_parse_label.shape

In [87]:
action_classifier_params = Params({
  "input_dim": HIDDEN_DIM * 3,
  "num_layers": 2,
  "hidden_dims": [50, 3],
  "activations": ["sigmoid", "linear"],
  "dropout": [0.0, 0.0]
})

root_position_classifier_params = Params({
  "input_dim": HIDDEN_DIM * 2,
  "num_layers": 2,
  "hidden_dims": [50, 1],
  "activations": ["sigmoid", "linear"],
  "dropout": [0.0, 0.0]
})

In [88]:
action_classifier_feedforward = FeedForward.from_params(action_classifier_params)
action_classifier_feedforward = _cuda(action_classifier_feedforward, cuda_device)

root_position_classifier_feedforward = FeedForward.from_params(root_position_classifier_params)
root_position_classifier_feedforward = _cuda(root_position_classifier_feedforward, cuda_device)


INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.modules.feedforward.FeedForward'> from params {'input_dim': 150, 'num_layers': 2, 'hidden_dims': [50, 3], 'activations': ['sigmoid', 'linear'], 'dropout': [0.0, 0.0]} and extras set()
INFO     [allennlp.common.params:252] input_dim = 150
INFO     [allennlp.common.params:252] num_layers = 2
INFO     [allennlp.common.params:252] hidden_dims = [50, 3]
INFO     [allennlp.common.params:252] hidden_dims = [50, 3]
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params ['sigmoid', 'linear'] and extras set()
INFO     [allennlp.common.params:252] activations = ['sigmoid', 'linear']
INFO     [allennlp.common.from_params:340] instantiating class <class 'allennlp.nn.activations.Activation'> from params sigmoid and extras set()
INFO     [allennlp.common.params:252] type = sigmoid
INFO

In [89]:
field_type = 'word'

In [90]:
parse_label = {
    field_type: _cuda(
        torch.LongTensor(
            [
                [ 1,  0,  3,  7,  2,  9,  4],
                [ 0,  0,  5,  0,  0,  0,  4]
            ]
        ),
        cuda_device
    )
}
embedded_parse_label = field_type2embedder[field_type](parse_label)

In [91]:
feature_mask = util.get_text_field_mask(parse_label)

In [92]:
seq2vec_encoder = field_type2seq2vec_encoder[field_type]

In [93]:
encoded_feature = seq2vec_encoder(embedded_parse_label, feature_mask)

In [94]:
encoded_features = [encoded_feature] * 3

In [95]:
torch.cat(encoded_features, dim=-1)

tensor([[-0.0528,  0.0801,  0.0226,  0.0625,  0.0575,  0.0207,  0.0299, -0.0183,
         -0.0627,  0.0218, -0.0616, -0.0414, -0.0543,  0.0177, -0.0201, -0.0630,
          0.0038,  0.0363, -0.0058, -0.0553,  0.1409, -0.0112,  0.0362, -0.0500,
          0.0459, -0.0358, -0.0951,  0.0533, -0.0023,  0.0303,  0.0074, -0.0073,
         -0.0301,  0.0257, -0.1059, -0.0625,  0.1057, -0.0673, -0.0661,  0.0288,
         -0.1110, -0.0476, -0.0130,  0.0821, -0.0830, -0.0509, -0.0224,  0.0428,
         -0.0479,  0.0313, -0.0528,  0.0801,  0.0226,  0.0625,  0.0575,  0.0207,
          0.0299, -0.0183, -0.0627,  0.0218, -0.0616, -0.0414, -0.0543,  0.0177,
         -0.0201, -0.0630,  0.0038,  0.0363, -0.0058, -0.0553,  0.1409, -0.0112,
          0.0362, -0.0500,  0.0459, -0.0358, -0.0951,  0.0533, -0.0023,  0.0303,
          0.0074, -0.0073, -0.0301,  0.0257, -0.1059, -0.0625,  0.1057, -0.0673,
         -0.0661,  0.0288, -0.1110, -0.0476, -0.0130,  0.0821, -0.0830, -0.0509,
         -0.0224,  0.0428, -

In [96]:
logits = action_classifier_feedforward(torch.cat(encoded_features, dim=-1))

In [97]:
logits.shape

torch.Size([2, 3])

In [98]:
label = torch.tensor([1, 0])

In [99]:
# loss_func = torch.nn.CrossEntropyLoss()
# loss = loss_func(logits, label)

In [100]:
ActionGeneralizer = None

In [135]:
from mrp_library.models.generalizer import ActionGeneralizer
from mrp_library.iterators.same_instance_type_framework_stack_len_iterator import SameInstanceTypeFrameworkStackLenIterator


In [136]:
ActionGeneralizer

mrp_library.models.generalizer.ActionGeneralizer

In [137]:
if torch.cuda.is_available() and False:
    cuda_device = args.cuda_device
    model = ActionGeneralizer(
        cuda_device=cuda_device,
        vocab=vocab,
        field_type2embedder=field_type2embedder,
        field_type2seq2vec_encoder=field_type2seq2vec_encoder,
        field_type2seq2seq_encoder=field_type2seq2seq_encoder,
        action_classifier_feedforward=action_classifier_feedforward,
        root_position_classifier_feedforward=root_position_classifier_feedforward,
    )
    model = model.cuda(cuda_device)
else:
    cuda_device = -1
    model = ActionGeneralizer(
        cuda_device=cuda_device,
        vocab=vocab,
        field_type2embedder=field_type2embedder,
        field_type2seq2vec_encoder=field_type2seq2vec_encoder,
        field_type2seq2seq_encoder=field_type2seq2seq_encoder,
        action_classifier_feedforward=action_classifier_feedforward,
        root_position_classifier_feedforward=root_position_classifier_feedforward,
    )

# iterator = SameInstanceTypeFrameworkIterator(
#     shuffle=True,
#     batch_size=100, 
#     sorting_keys=[("token_node_resolveds", "num_tokens")],
# )
iterator = SameInstanceTypeFrameworkStackLenIterator(
    shuffle=True,
    batch_size=100, 
    sorting_keys=[("token_node_resolveds", "num_tokens")],
)
iterator.index_with(vocab)

optimizer = optim.SGD(model.parameters(), lr=0.1)

INFO     [allennlp.nn.initializers:293] Initializing parameters
INFO     [allennlp.nn.initializers:309] Done initializing parameters; the following parameters are using their default initialization from their code
INFO     [allennlp.nn.initializers:314]    action_classifier_feedforward._linear_layers.0.bias
INFO     [allennlp.nn.initializers:314]    action_classifier_feedforward._linear_layers.0.weight
INFO     [allennlp.nn.initializers:314]    action_classifier_feedforward._linear_layers.1.bias
INFO     [allennlp.nn.initializers:314]    action_classifier_feedforward._linear_layers.1.weight
INFO     [allennlp.nn.initializers:314]    field_type2embedder.pos.token_embedder_pos.weight
INFO     [allennlp.nn.initializers:314]    field_type2embedder.resolved.token_embedder_resolved.weight
INFO     [allennlp.nn.initializers:314]    field_type2embedder.token_node_label.token_embedder_token_node_label.weight
INFO     [allennlp.nn.initializers:314]    field_type2embedder.token_node_prev_action.t

In [138]:
cuda_device

-1

In [139]:
model.resolve_tensor

tensor(1)

In [140]:
# list(model.named_parameters())

In [141]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_dataset,
    validation_dataset=test_dataset,
#     train_dataset=train_dataset,
#     validation_dataset=train_dataset,
    patience=10,
    num_epochs=20,
    cuda_device=cuda_device
)

In [142]:
action_logits = torch.tensor([[-2.2126,  2.6022, -1.1655],
        [ 4.7340, -1.9992, -3.4521],
        [-1.9665,  2.4100, -1.2047],
        [-2.1353,  2.4847, -1.1260],
        [ 4.7492, -2.0234, -3.4460],
        [-1.4369,  1.9822, -1.2885],
        [ 1.0337,  0.3599, -2.0420],
        [ 5.0974, -2.3380, -3.4647],
        [ 5.4187, -2.4720, -3.6469],
        [-3.4045,  2.4903,  0.0773],
        [ 0.6384,  0.6764, -1.9942],
        [-2.2904,  2.7170, -1.2016],
        [ 4.6333, -1.9474, -3.4113],
        [-2.0811,  2.5367, -1.2174],
        [ 5.1840, -2.7536, -3.1499],
        [ 4.7421, -2.0138, -3.4485],
        [ 5.1121, -2.2290, -3.5999],
        [ 0.1843,  0.8324, -1.6990],
        [ 5.3854, -2.4593, -3.6309],
        [ 0.0324,  0.6715, -1.4173]])
action_logits = _cuda(action_logits, cuda_device)

action_type = torch.tensor([0, 2, 2, 1, 1, 0, 1, 2, 1, 2, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0])
action_type = _cuda(action_type, cuda_device)

In [143]:
action_probs, action_preds = action_logits.max(1)
action_resolve_preds = action_preds.eq_(model.resolve_tensor)

In [144]:
iter([1, 2, 3, 4])

<list_iterator at 0x7f3767129080>

In [145]:
defaultdict(lambda: defaultdict(dict))

defaultdict(<function __main__.<lambda>()>, {})

In [146]:
(action_resolve_preds, action_type, action_resolve_preds.eq(action_type))

(tensor([1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1]),
 tensor([0, 2, 2, 1, 1, 0, 1, 2, 1, 2, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0]),
 tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0],
        dtype=torch.uint8))

In [147]:
action_resolve_preds.unsqueeze(-1).float() * action_logits

tensor([[-2.2126,  2.6022, -1.1655],
        [ 0.0000, -0.0000, -0.0000],
        [-1.9665,  2.4100, -1.2047],
        [-2.1353,  2.4847, -1.1260],
        [ 0.0000, -0.0000, -0.0000],
        [-1.4369,  1.9822, -1.2885],
        [ 0.0000,  0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000],
        [-3.4045,  2.4903,  0.0773],
        [ 0.6384,  0.6764, -1.9942],
        [-2.2904,  2.7170, -1.2016],
        [ 0.0000, -0.0000, -0.0000],
        [-2.0811,  2.5367, -1.2174],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.1843,  0.8324, -1.6990],
        [ 0.0000, -0.0000, -0.0000],
        [ 0.0324,  0.6715, -1.4173]])

In [148]:
embedded_fields = torch.ones(99, 62, 100)

In [149]:
embedded_fields.size()

torch.Size([99, 62, 100])

In [150]:
embedded_fields

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        ...,

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1., 

In [151]:
root_position_logits = torch.tensor([[0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1720, 0.1715, 0.1718],
        [0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1716, 0.1719, 0.1708],
        [0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1720, 0.1715, 0.1717],
        [0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1716, 0.1719, 0.1717],
        [0.1722, 0.1723, 0.1720, 0.1721, 0.1721, 0.1719, 0.1717, 0.1719, 0.1717,
         0.1716, 0.1719, 0.1717, 0.1716, 0.1718, 0.1719, 0.1718, 0.1719, 0.1717,
         0.1720, 0.1720, 0.1722, 0.1722, 0.1719, 0.1721, 0.1718, 0.1717, 0.1719,
         0.1717, 0.1720, 0.1715, 0.1709]])

In [152]:
root_position_logits.shape

torch.Size([5, 31])

In [153]:
root_position = torch.tensor([ 2,  0, -1,  1,  0])

In [154]:
mask = root_position.eq(torch.tensor(-1)).unsqueeze(-1).expand_as(root_position_logits)

In [155]:
mask = root_position.gt(-1)

In [156]:
masked_root_position_logits = root_position_logits * mask.float().unsqueeze(-1).expand_as(root_position_logits)


In [157]:
loss = torch.nn.CrossEntropyLoss()

In [158]:
inp = root_position_logits

In [159]:
# crossEntropy = -torch.log(torch.gather(inp, 1, root_position.view(-1, 1)))

In [160]:
root_position

tensor([ 2,  0, -1,  1,  0])

In [161]:
# root_position.eq(torch.tensor(-2)).expand_as(root_position_logits)

In [162]:
trainer.train()

INFO     [allennlp.training.trainer:465] Beginning training.
INFO     [allennlp.training.trainer:281] Epoch 0/19
INFO     [allennlp.training.trainer:283] Peak CPU memory usage MB: 7224.06
INFO     [allennlp.training.trainer:287] GPU 0 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 1 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 2 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 3 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 4 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 5 memory usage MB: 10
INFO     [allennlp.training.trainer:287] GPU 6 memory usage MB: 11
INFO     [allennlp.training.trainer:287] GPU 7 memory usage MB: 10
INFO     [allennlp.training.trainer:311] Training



  0%|          | 0/491 [00:00<?, ?it/s][A[A[A


action_type_accuracy: 0.0000, root_position_accuracy: 0.0000, loss: 1.9952 ||:   0%|          | 1/491 [00:00<02:08,  3.81it/s][A[A[A


action_type_accuracy: 0.3952, root

action_type_accuracy: 0.8007, root_position_accuracy: 0.5066, loss: 1.7009 ||:  15%|█▍        | 72/491 [00:07<00:40, 10.23it/s][A[A[A


action_type_accuracy: 0.8068, root_position_accuracy: 0.5130, loss: 1.6847 ||:  15%|█▌        | 74/491 [00:08<00:41, 10.07it/s][A[A[A


action_type_accuracy: 0.8126, root_position_accuracy: 0.5189, loss: 1.6686 ||:  15%|█▌        | 76/491 [00:08<00:41, 10.02it/s][A[A[A


action_type_accuracy: 0.8100, root_position_accuracy: 0.5180, loss: 1.6738 ||:  16%|█▌        | 78/491 [00:08<00:38, 10.69it/s][A[A[A


action_type_accuracy: 0.8018, root_position_accuracy: 0.5135, loss: 1.6823 ||:  16%|█▋        | 80/491 [00:08<00:37, 10.88it/s][A[A[A


action_type_accuracy: 0.8019, root_position_accuracy: 0.5108, loss: 1.6811 ||:  17%|█▋        | 82/491 [00:08<00:37, 10.81it/s][A[A[A


action_type_accuracy: 0.8027, root_position_accuracy: 0.5088, loss: 1.6785 ||:  17%|█▋        | 84/491 [00:09<00:36, 11.06it/s][A[A[A


action_type_accuracy: 0.803

action_type_accuracy: 0.8125, root_position_accuracy: 0.5053, loss: 1.6833 ||:  33%|███▎      | 160/491 [00:17<00:41,  7.96it/s][A[A[A


action_type_accuracy: 0.8123, root_position_accuracy: 0.5049, loss: 1.6890 ||:  33%|███▎      | 161/491 [00:17<00:43,  7.58it/s][A[A[A


action_type_accuracy: 0.8115, root_position_accuracy: 0.5038, loss: 1.6959 ||:  33%|███▎      | 162/491 [00:17<00:46,  7.03it/s][A[A[A


action_type_accuracy: 0.8115, root_position_accuracy: 0.5038, loss: 1.6991 ||:  33%|███▎      | 163/491 [00:17<00:49,  6.59it/s][A[A[A


action_type_accuracy: 0.8098, root_position_accuracy: 0.5018, loss: 1.7063 ||:  33%|███▎      | 164/491 [00:18<00:46,  6.97it/s][A[A[A


action_type_accuracy: 0.8075, root_position_accuracy: 0.4994, loss: 1.7136 ||:  34%|███▎      | 165/491 [00:18<00:46,  7.03it/s][A[A[A


action_type_accuracy: 0.8083, root_position_accuracy: 0.4987, loss: 1.7154 ||:  34%|███▍      | 167/491 [00:18<00:40,  8.01it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.7979, root_position_accuracy: 0.5012, loss: 1.6683 ||:  51%|█████     | 251/491 [00:27<00:33,  7.23it/s][A[A[A


action_type_accuracy: 0.7987, root_position_accuracy: 0.5003, loss: 1.6686 ||:  51%|█████▏    | 252/491 [00:27<00:32,  7.39it/s][A[A[A


action_type_accuracy: 0.7986, root_position_accuracy: 0.5001, loss: 1.6737 ||:  52%|█████▏    | 253/491 [00:27<00:31,  7.64it/s][A[A[A


action_type_accuracy: 0.7968, root_position_accuracy: 0.5000, loss: 1.6729 ||:  52%|█████▏    | 255/491 [00:27<00:27,  8.56it/s][A[A[A


action_type_accuracy: 0.7949, root_position_accuracy: 0.5000, loss: 1.6690 ||:  52%|█████▏    | 257/491 [00:27<00:23, 10.10it/s][A[A[A


action_type_accuracy: 0.7929, root_position_accuracy: 0.5000, loss: 1.6638 ||:  53%|█████▎    | 259/491 [00:27<00:19, 11.78it/s][A[A[A


action_type_accuracy: 0.7916, root_position_accuracy: 0.5006, loss: 1.6638 ||:  53%|█████▎    | 261/491 [00:28<00:21, 10.93it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.7953, root_position_accuracy: 0.4846, loss: 1.6656 ||:  71%|███████   | 348/491 [00:36<00:17,  8.34it/s][A[A[A


action_type_accuracy: 0.7954, root_position_accuracy: 0.4842, loss: 1.6658 ||:  71%|███████▏  | 350/491 [00:37<00:15,  8.87it/s][A[A[A


action_type_accuracy: 0.7951, root_position_accuracy: 0.4850, loss: 1.6651 ||:  72%|███████▏  | 352/491 [00:37<00:16,  8.66it/s][A[A[A


action_type_accuracy: 0.7952, root_position_accuracy: 0.4852, loss: 1.6647 ||:  72%|███████▏  | 353/491 [00:37<00:18,  7.53it/s][A[A[A


action_type_accuracy: 0.7959, root_position_accuracy: 0.4846, loss: 1.6649 ||:  72%|███████▏  | 355/491 [00:37<00:17,  7.75it/s][A[A[A


action_type_accuracy: 0.7966, root_position_accuracy: 0.4846, loss: 1.6641 ||:  73%|███████▎  | 356/491 [00:37<00:16,  8.05it/s][A[A[A


action_type_accuracy: 0.7974, root_position_accuracy: 0.4840, loss: 1.6632 ||:  73%|███████▎  | 358/491 [00:38<00:14,  9.24it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.7914, root_position_accuracy: 0.4746, loss: 1.7204 ||:  86%|████████▌ | 422/491 [00:45<00:07,  8.76it/s][A[A[A


action_type_accuracy: 0.7920, root_position_accuracy: 0.4755, loss: 1.7186 ||:  86%|████████▌ | 423/491 [00:45<00:07,  8.56it/s][A[A[A


action_type_accuracy: 0.7926, root_position_accuracy: 0.4761, loss: 1.7174 ||:  86%|████████▋ | 424/491 [00:46<00:07,  8.67it/s][A[A[A


action_type_accuracy: 0.7931, root_position_accuracy: 0.4767, loss: 1.7160 ||:  87%|████████▋ | 425/491 [00:46<00:07,  8.62it/s][A[A[A


action_type_accuracy: 0.7937, root_position_accuracy: 0.4772, loss: 1.7149 ||:  87%|████████▋ | 426/491 [00:46<00:07,  8.75it/s][A[A[A


action_type_accuracy: 0.7942, root_position_accuracy: 0.4777, loss: 1.7137 ||:  87%|████████▋ | 427/491 [00:46<00:07,  8.90it/s][A[A[A


action_type_accuracy: 0.7951, root_position_accuracy: 0.4786, loss: 1.7112 ||:  87%|████████▋ | 429/491 [00:46<00:06,  9.20it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8089, root_position_accuracy: 0.4924, loss: 1.6593 ||: : 502it [00:55,  4.94it/s][A[A[A


action_type_accuracy: 0.8088, root_position_accuracy: 0.4922, loss: 1.6608 ||: : 503it [00:55,  5.18it/s][A[A[A


action_type_accuracy: 0.8087, root_position_accuracy: 0.4923, loss: 1.6614 ||: : 504it [00:55,  5.09it/s][A[A[A


action_type_accuracy: 0.8087, root_position_accuracy: 0.4920, loss: 1.6627 ||: : 505it [00:55,  4.86it/s][A[A[A


action_type_accuracy: 0.8087, root_position_accuracy: 0.4919, loss: 1.6636 ||: : 506it [00:55,  5.17it/s][A[A[A


action_type_accuracy: 0.8085, root_position_accuracy: 0.4917, loss: 1.6648 ||: : 507it [00:56,  5.59it/s][A[A[A


action_type_accuracy: 0.8085, root_position_accuracy: 0.4917, loss: 1.6657 ||: : 508it [00:56,  5.98it/s][A[A[A


action_type_accuracy: 0.8085, root_position_accuracy: 0.4917, loss: 1.6657 ||: : 509it [00:56,  5.55it/s][A[A[A


action_type_accuracy: 0.8082, root_position_accuracy: 0.4922, lo

action_type_accuracy: 0.6932, root_position_accuracy: 0.6444, loss: 2.1598 ||:  11%|█▏        | 41/360 [00:02<00:16, 19.17it/s][A[A[A


action_type_accuracy: 0.7082, root_position_accuracy: 0.6444, loss: 2.0375 ||:  12%|█▎        | 45/360 [00:02<00:14, 22.04it/s][A[A[A


action_type_accuracy: 0.7148, root_position_accuracy: 0.6444, loss: 1.9495 ||:  14%|█▎        | 49/360 [00:02<00:12, 25.36it/s][A[A[A


action_type_accuracy: 0.7132, root_position_accuracy: 0.6099, loss: 2.0024 ||:  15%|█▍        | 53/360 [00:02<00:13, 22.87it/s][A[A[A


action_type_accuracy: 0.7159, root_position_accuracy: 0.5878, loss: 2.0314 ||:  16%|█▌        | 56/360 [00:02<00:13, 22.73it/s][A[A[A


action_type_accuracy: 0.7175, root_position_accuracy: 0.5712, loss: 2.0621 ||:  16%|█▋        | 59/360 [00:02<00:13, 22.79it/s][A[A[A


action_type_accuracy: 0.7178, root_position_accuracy: 0.5584, loss: 2.0934 ||:  17%|█▋        | 62/360 [00:02<00:13, 21.70it/s][A[A[A


action_type_accuracy: 0.707

action_type_accuracy: 0.7202, root_position_accuracy: 0.4501, loss: 2.2020 ||:  52%|█████▎    | 189/360 [00:10<00:10, 16.00it/s][A[A[A


action_type_accuracy: 0.7179, root_position_accuracy: 0.4467, loss: 2.2078 ||:  53%|█████▎    | 191/360 [00:10<00:10, 16.51it/s][A[A[A


action_type_accuracy: 0.7179, root_position_accuracy: 0.4460, loss: 2.2155 ||:  54%|█████▍    | 194/360 [00:10<00:09, 18.38it/s][A[A[A


action_type_accuracy: 0.7179, root_position_accuracy: 0.4455, loss: 2.2223 ||:  54%|█████▍    | 196/360 [00:10<00:09, 16.94it/s][A[A[A


action_type_accuracy: 0.7184, root_position_accuracy: 0.4453, loss: 2.2293 ||:  55%|█████▌    | 199/360 [00:10<00:09, 17.32it/s][A[A[A


action_type_accuracy: 0.7187, root_position_accuracy: 0.4449, loss: 2.2333 ||:  56%|█████▌    | 201/360 [00:11<00:09, 16.96it/s][A[A[A


action_type_accuracy: 0.7193, root_position_accuracy: 0.4448, loss: 2.2357 ||:  56%|█████▋    | 203/360 [00:11<00:09, 16.89it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.7321, root_position_accuracy: 0.4639, loss: 2.2611 ||:  93%|█████████▎| 335/360 [00:18<00:01, 18.52it/s][A[A[A


action_type_accuracy: 0.7309, root_position_accuracy: 0.4605, loss: 2.2586 ||:  94%|█████████▍| 338/360 [00:18<00:01, 16.99it/s][A[A[A


action_type_accuracy: 0.7299, root_position_accuracy: 0.4584, loss: 2.2569 ||:  94%|█████████▍| 340/360 [00:18<00:01, 16.38it/s][A[A[A


action_type_accuracy: 0.7288, root_position_accuracy: 0.4566, loss: 2.2588 ||:  96%|█████████▌| 344/360 [00:19<00:00, 19.14it/s][A[A[A


action_type_accuracy: 0.7278, root_position_accuracy: 0.4559, loss: 2.2634 ||:  96%|█████████▋| 347/360 [00:19<00:00, 17.53it/s][A[A[A


action_type_accuracy: 0.7273, root_position_accuracy: 0.4584, loss: 2.2653 ||:  97%|█████████▋| 350/360 [00:19<00:00, 18.67it/s][A[A[A


action_type_accuracy: 0.7268, root_position_accuracy: 0.4600, loss: 2.2691 ||:  98%|█████████▊| 353/360 [00:19<00:00, 17.37it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.7954, root_position_accuracy: 0.4341, loss: 1.8314 ||:   3%|▎         | 13/491 [00:01<00:54,  8.76it/s][A[A[A


action_type_accuracy: 0.8158, root_position_accuracy: 0.4452, loss: 1.6878 ||:   3%|▎         | 15/491 [00:02<00:47, 10.05it/s][A[A[A


action_type_accuracy: 0.8233, root_position_accuracy: 0.4420, loss: 1.5993 ||:   3%|▎         | 17/491 [00:02<00:43, 10.81it/s][A[A[A


action_type_accuracy: 0.8320, root_position_accuracy: 0.4434, loss: 1.5212 ||:   4%|▍         | 19/491 [00:02<00:45, 10.34it/s][A[A[A


action_type_accuracy: 0.8359, root_position_accuracy: 0.4477, loss: 1.4593 ||:   4%|▍         | 21/491 [00:02<00:41, 11.24it/s][A[A[A


action_type_accuracy: 0.8413, root_position_accuracy: 0.4488, loss: 1.4050 ||:   5%|▍         | 23/491 [00:02<00:38, 12.10it/s][A[A[A


action_type_accuracy: 0.8450, root_position_accuracy: 0.4468, loss: 1.3636 ||:   5%|▌         | 25/491 [00:02<00:35, 12.96it/s][A[A[A


action_type_accuracy: 0.850

action_type_accuracy: 0.8804, root_position_accuracy: 0.5618, loss: 1.3814 ||:  21%|██        | 104/491 [00:11<00:49,  7.87it/s][A[A[A


action_type_accuracy: 0.8791, root_position_accuracy: 0.5584, loss: 1.3853 ||:  22%|██▏       | 106/491 [00:11<00:44,  8.70it/s][A[A[A


action_type_accuracy: 0.8784, root_position_accuracy: 0.5559, loss: 1.3881 ||:  22%|██▏       | 107/491 [00:11<00:47,  8.03it/s][A[A[A


action_type_accuracy: 0.8778, root_position_accuracy: 0.5539, loss: 1.3902 ||:  22%|██▏       | 108/491 [00:11<00:45,  8.48it/s][A[A[A


action_type_accuracy: 0.8776, root_position_accuracy: 0.5517, loss: 1.3919 ||:  22%|██▏       | 109/491 [00:12<00:44,  8.54it/s][A[A[A


action_type_accuracy: 0.8765, root_position_accuracy: 0.5505, loss: 1.3941 ||:  22%|██▏       | 110/491 [00:12<00:45,  8.39it/s][A[A[A


action_type_accuracy: 0.8761, root_position_accuracy: 0.5492, loss: 1.3955 ||:  23%|██▎       | 111/491 [00:12<00:46,  8.18it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8471, root_position_accuracy: 0.5222, loss: 1.5664 ||:  35%|███▌      | 173/491 [00:20<00:51,  6.23it/s][A[A[A


action_type_accuracy: 0.8470, root_position_accuracy: 0.5220, loss: 1.5700 ||:  35%|███▌      | 174/491 [00:20<00:46,  6.79it/s][A[A[A


action_type_accuracy: 0.8483, root_position_accuracy: 0.5217, loss: 1.5605 ||:  36%|███▌      | 176/491 [00:20<00:39,  8.00it/s][A[A[A


action_type_accuracy: 0.8483, root_position_accuracy: 0.5217, loss: 1.5627 ||:  36%|███▋      | 178/491 [00:21<00:34,  8.97it/s][A[A[A


action_type_accuracy: 0.8503, root_position_accuracy: 0.5239, loss: 1.5605 ||:  37%|███▋      | 180/491 [00:21<00:36,  8.59it/s][A[A[A


action_type_accuracy: 0.8513, root_position_accuracy: 0.5246, loss: 1.5598 ||:  37%|███▋      | 181/491 [00:21<00:37,  8.33it/s][A[A[A


action_type_accuracy: 0.8522, root_position_accuracy: 0.5259, loss: 1.5582 ||:  37%|███▋      | 182/491 [00:21<00:37,  8.17it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8190, root_position_accuracy: 0.4826, loss: 1.6333 ||:  54%|█████▎    | 263/491 [00:29<00:18, 12.27it/s][A[A[A


action_type_accuracy: 0.8166, root_position_accuracy: 0.4802, loss: 1.6427 ||:  54%|█████▍    | 265/491 [00:29<00:19, 11.59it/s][A[A[A


action_type_accuracy: 0.8146, root_position_accuracy: 0.4781, loss: 1.6515 ||:  54%|█████▍    | 267/491 [00:30<00:22, 10.10it/s][A[A[A


action_type_accuracy: 0.8146, root_position_accuracy: 0.4781, loss: 1.6524 ||:  55%|█████▍    | 269/491 [00:30<00:24,  9.11it/s][A[A[A


action_type_accuracy: 0.8143, root_position_accuracy: 0.4779, loss: 1.6532 ||:  55%|█████▌    | 271/491 [00:30<00:23,  9.55it/s][A[A[A


action_type_accuracy: 0.8135, root_position_accuracy: 0.4771, loss: 1.6552 ||:  56%|█████▌    | 273/491 [00:30<00:21, 10.11it/s][A[A[A


action_type_accuracy: 0.8134, root_position_accuracy: 0.4764, loss: 1.6564 ||:  56%|█████▌    | 275/491 [00:30<00:20, 10.67it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8113, root_position_accuracy: 0.4740, loss: 1.6157 ||:  75%|███████▌  | 369/491 [00:40<00:12, 10.14it/s][A[A[A


action_type_accuracy: 0.8112, root_position_accuracy: 0.4736, loss: 1.6171 ||:  76%|███████▌  | 371/491 [00:40<00:12, 10.00it/s][A[A[A


action_type_accuracy: 0.8108, root_position_accuracy: 0.4733, loss: 1.6185 ||:  76%|███████▌  | 373/491 [00:40<00:11, 10.16it/s][A[A[A


action_type_accuracy: 0.8106, root_position_accuracy: 0.4730, loss: 1.6198 ||:  76%|███████▋  | 375/491 [00:41<00:12,  9.53it/s][A[A[A


action_type_accuracy: 0.8106, root_position_accuracy: 0.4730, loss: 1.6199 ||:  77%|███████▋  | 376/491 [00:41<00:12,  9.45it/s][A[A[A


action_type_accuracy: 0.8104, root_position_accuracy: 0.4726, loss: 1.6209 ||:  77%|███████▋  | 377/491 [00:41<00:12,  9.08it/s][A[A[A


action_type_accuracy: 0.8102, root_position_accuracy: 0.4735, loss: 1.6200 ||:  77%|███████▋  | 379/491 [00:41<00:11,  9.97it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8217, root_position_accuracy: 0.4889, loss: 1.6066 ||:  94%|█████████▍| 461/491 [00:50<00:04,  6.91it/s][A[A[A


action_type_accuracy: 0.8213, root_position_accuracy: 0.4886, loss: 1.6084 ||:  94%|█████████▍| 462/491 [00:50<00:04,  7.00it/s][A[A[A


action_type_accuracy: 0.8213, root_position_accuracy: 0.4887, loss: 1.6091 ||:  94%|█████████▍| 463/491 [00:50<00:03,  7.06it/s][A[A[A


action_type_accuracy: 0.8213, root_position_accuracy: 0.4887, loss: 1.6098 ||:  95%|█████████▍| 464/491 [00:50<00:03,  7.44it/s][A[A[A


action_type_accuracy: 0.8213, root_position_accuracy: 0.4887, loss: 1.6116 ||:  95%|█████████▍| 465/491 [00:50<00:03,  7.62it/s][A[A[A


action_type_accuracy: 0.8210, root_position_accuracy: 0.4892, loss: 1.6115 ||:  95%|█████████▍| 466/491 [00:50<00:03,  7.45it/s][A[A[A


action_type_accuracy: 0.8210, root_position_accuracy: 0.4892, loss: 1.6090 ||:  95%|█████████▌| 467/491 [00:51<00:05,  4.67it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8283, root_position_accuracy: 0.5010, loss: 1.6148 ||: : 537it [00:59,  9.13it/s][A[A[A


action_type_accuracy: 0.8283, root_position_accuracy: 0.5007, loss: 1.6154 ||: : 538it [01:00,  9.31it/s][A[A[A


action_type_accuracy: 0.8281, root_position_accuracy: 0.5004, loss: 1.6166 ||: : 540it [01:00,  9.54it/s][A[A[A


action_type_accuracy: 0.8278, root_position_accuracy: 0.5001, loss: 1.6176 ||: : 541it [01:00,  9.15it/s][A[A[A


action_type_accuracy: 0.8276, root_position_accuracy: 0.4998, loss: 1.6183 ||: : 542it [01:00,  8.81it/s][A[A[A


action_type_accuracy: 0.8275, root_position_accuracy: 0.4997, loss: 1.6187 ||: : 543it [01:00,  8.91it/s][A[A[A


action_type_accuracy: 0.8274, root_position_accuracy: 0.4993, loss: 1.6196 ||: : 544it [01:00,  9.05it/s][A[A[A


action_type_accuracy: 0.8273, root_position_accuracy: 0.4991, loss: 1.6202 ||: : 545it [01:00,  9.02it/s][A[A[A


action_type_accuracy: 0.8271, root_position_accuracy: 0.4990, lo

action_type_accuracy: 0.7834, root_position_accuracy: 0.4032, loss: 1.6530 ||:  31%|███       | 112/360 [00:06<00:15, 16.32it/s][A[A[A


action_type_accuracy: 0.7842, root_position_accuracy: 0.4032, loss: 1.6497 ||:  32%|███▏      | 114/360 [00:06<00:15, 16.25it/s][A[A[A


action_type_accuracy: 0.7852, root_position_accuracy: 0.4034, loss: 1.6453 ||:  32%|███▎      | 117/360 [00:06<00:14, 16.82it/s][A[A[A


action_type_accuracy: 0.7842, root_position_accuracy: 0.4022, loss: 1.6611 ||:  33%|███▎      | 120/360 [00:06<00:13, 17.85it/s][A[A[A


action_type_accuracy: 0.7802, root_position_accuracy: 0.3981, loss: 1.6766 ||:  34%|███▍      | 122/360 [00:06<00:13, 17.47it/s][A[A[A


action_type_accuracy: 0.7790, root_position_accuracy: 0.3976, loss: 1.6895 ||:  34%|███▍      | 124/360 [00:06<00:13, 17.22it/s][A[A[A


action_type_accuracy: 0.7762, root_position_accuracy: 0.3983, loss: 1.6945 ||:  35%|███▌      | 126/360 [00:07<00:14, 15.62it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.7234, root_position_accuracy: 0.4456, loss: 1.8796 ||:  72%|███████▏  | 259/360 [00:14<00:06, 16.31it/s][A[A[A


action_type_accuracy: 0.7231, root_position_accuracy: 0.4467, loss: 1.8833 ||:  72%|███████▎  | 261/360 [00:14<00:06, 16.10it/s][A[A[A


action_type_accuracy: 0.7231, root_position_accuracy: 0.4469, loss: 1.8874 ||:  73%|███████▎  | 263/360 [00:14<00:06, 16.12it/s][A[A[A


action_type_accuracy: 0.7228, root_position_accuracy: 0.4480, loss: 1.8905 ||:  74%|███████▎  | 265/360 [00:14<00:05, 16.70it/s][A[A[A


action_type_accuracy: 0.7226, root_position_accuracy: 0.4509, loss: 1.8890 ||:  74%|███████▍  | 267/360 [00:14<00:05, 16.09it/s][A[A[A


action_type_accuracy: 0.7223, root_position_accuracy: 0.4534, loss: 1.8881 ||:  75%|███████▍  | 269/360 [00:15<00:06, 14.52it/s][A[A[A


action_type_accuracy: 0.7218, root_position_accuracy: 0.4557, loss: 1.8878 ||:  75%|███████▌  | 271/360 [00:15<00:06, 14.37it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.7186, root_position_accuracy: 0.4781, loss: 1.9214 ||: : 401it [00:22, 17.26it/s][A[A[A


action_type_accuracy: 0.7177, root_position_accuracy: 0.4775, loss: 1.9226 ||: : 403it [00:22, 16.11it/s][A[A[A


action_type_accuracy: 0.7176, root_position_accuracy: 0.4774, loss: 1.9250 ||: : 405it [00:22, 15.96it/s][A[A[A


action_type_accuracy: 0.7185, root_position_accuracy: 0.4774, loss: 1.9177 ||: : 408it [00:23, 17.90it/s][A[A[A


action_type_accuracy: 0.7177, root_position_accuracy: 0.4760, loss: 1.9219 ||: : 410it [00:23, 17.14it/s][A[A[A


action_type_accuracy: 0.7179, root_position_accuracy: 0.4757, loss: 1.9239 ||: : 413it [00:23, 17.98it/s][A[A[A


action_type_accuracy: 0.7183, root_position_accuracy: 0.4754, loss: 1.9254 ||: : 415it [00:23, 17.74it/s][A[A[A


action_type_accuracy: 0.7188, root_position_accuracy: 0.4751, loss: 1.9275 ||: : 418it [00:23, 19.29it/s][A[A[A


action_type_accuracy: 0.7191, root_position_accuracy: 0.4749, lo

action_type_accuracy: 0.8028, root_position_accuracy: 0.4517, loss: 1.8394 ||:   9%|▉         | 46/491 [00:05<00:52,  8.44it/s][A[A[A


action_type_accuracy: 0.8020, root_position_accuracy: 0.4507, loss: 1.8433 ||:  10%|▉         | 47/491 [00:05<00:51,  8.61it/s][A[A[A


action_type_accuracy: 0.7990, root_position_accuracy: 0.4457, loss: 1.8583 ||:  10%|▉         | 49/491 [00:05<00:49,  8.84it/s][A[A[A


action_type_accuracy: 0.7988, root_position_accuracy: 0.4424, loss: 1.8646 ||:  10%|█         | 50/491 [00:06<00:50,  8.68it/s][A[A[A


action_type_accuracy: 0.7979, root_position_accuracy: 0.4423, loss: 1.8670 ||:  10%|█         | 51/491 [00:06<00:50,  8.76it/s][A[A[A


action_type_accuracy: 0.7981, root_position_accuracy: 0.4415, loss: 1.8707 ||:  11%|█         | 52/491 [00:06<00:50,  8.72it/s][A[A[A


action_type_accuracy: 0.7988, root_position_accuracy: 0.4417, loss: 1.8703 ||:  11%|█         | 53/491 [00:06<00:51,  8.46it/s][A[A[A


action_type_accuracy: 0.798

action_type_accuracy: 0.8043, root_position_accuracy: 0.4918, loss: 1.7783 ||:  24%|██▍       | 119/491 [00:14<00:42,  8.85it/s][A[A[A


action_type_accuracy: 0.8062, root_position_accuracy: 0.4940, loss: 1.7728 ||:  24%|██▍       | 120/491 [00:14<00:45,  8.23it/s][A[A[A


action_type_accuracy: 0.8098, root_position_accuracy: 0.4995, loss: 1.7577 ||:  25%|██▍       | 122/491 [00:14<00:41,  8.97it/s][A[A[A


action_type_accuracy: 0.8133, root_position_accuracy: 0.5029, loss: 1.7463 ||:  25%|██▌       | 124/491 [00:14<00:37,  9.87it/s][A[A[A


action_type_accuracy: 0.8166, root_position_accuracy: 0.5068, loss: 1.7338 ||:  26%|██▌       | 126/491 [00:14<00:38,  9.39it/s][A[A[A


action_type_accuracy: 0.8199, root_position_accuracy: 0.5124, loss: 1.7186 ||:  26%|██▌       | 128/491 [00:14<00:35, 10.09it/s][A[A[A


action_type_accuracy: 0.8211, root_position_accuracy: 0.5128, loss: 1.7281 ||:  26%|██▋       | 130/491 [00:15<00:33, 10.85it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8127, root_position_accuracy: 0.4827, loss: 1.7971 ||:  41%|████      | 200/491 [00:23<00:36,  8.01it/s][A[A[A


action_type_accuracy: 0.8119, root_position_accuracy: 0.4814, loss: 1.8017 ||:  41%|████      | 202/491 [00:23<00:35,  8.15it/s][A[A[A


action_type_accuracy: 0.8129, root_position_accuracy: 0.4830, loss: 1.7988 ||:  41%|████▏     | 203/491 [00:23<00:34,  8.39it/s][A[A[A


action_type_accuracy: 0.8150, root_position_accuracy: 0.4856, loss: 1.7933 ||:  42%|████▏     | 205/491 [00:23<00:32,  8.83it/s][A[A[A


action_type_accuracy: 0.8161, root_position_accuracy: 0.4872, loss: 1.7899 ||:  42%|████▏     | 206/491 [00:24<00:31,  9.03it/s][A[A[A


action_type_accuracy: 0.8171, root_position_accuracy: 0.4886, loss: 1.7867 ||:  42%|████▏     | 207/491 [00:24<00:31,  8.91it/s][A[A[A


action_type_accuracy: 0.8181, root_position_accuracy: 0.4899, loss: 1.7835 ||:  42%|████▏     | 208/491 [00:24<00:32,  8.61it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8459, root_position_accuracy: 0.5169, loss: 1.6483 ||:  57%|█████▋    | 279/491 [00:32<00:26,  8.04it/s][A[A[A


action_type_accuracy: 0.8472, root_position_accuracy: 0.5169, loss: 1.6374 ||:  57%|█████▋    | 281/491 [00:32<00:22,  9.48it/s][A[A[A


action_type_accuracy: 0.8467, root_position_accuracy: 0.5153, loss: 1.6369 ||:  58%|█████▊    | 283/491 [00:32<00:21,  9.74it/s][A[A[A


action_type_accuracy: 0.8456, root_position_accuracy: 0.5140, loss: 1.6457 ||:  58%|█████▊    | 285/491 [00:32<00:22,  9.30it/s][A[A[A


action_type_accuracy: 0.8445, root_position_accuracy: 0.5133, loss: 1.6481 ||:  58%|█████▊    | 286/491 [00:32<00:22,  9.16it/s][A[A[A


action_type_accuracy: 0.8440, root_position_accuracy: 0.5120, loss: 1.6511 ||:  59%|█████▊    | 288/491 [00:33<00:20, 10.00it/s][A[A[A


action_type_accuracy: 0.8425, root_position_accuracy: 0.5095, loss: 1.6559 ||:  59%|█████▉    | 290/491 [00:33<00:18, 10.82it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8309, root_position_accuracy: 0.4926, loss: 1.6439 ||:  75%|███████▍  | 368/491 [00:41<00:13,  8.83it/s][A[A[A


action_type_accuracy: 0.8308, root_position_accuracy: 0.4922, loss: 1.6438 ||:  75%|███████▌  | 369/491 [00:41<00:13,  8.74it/s][A[A[A


action_type_accuracy: 0.8305, root_position_accuracy: 0.4916, loss: 1.6437 ||:  76%|███████▌  | 371/491 [00:41<00:12,  9.36it/s][A[A[A


action_type_accuracy: 0.8303, root_position_accuracy: 0.4908, loss: 1.6437 ||:  76%|███████▌  | 373/491 [00:41<00:11,  9.97it/s][A[A[A


action_type_accuracy: 0.8302, root_position_accuracy: 0.4902, loss: 1.6432 ||:  76%|███████▋  | 375/491 [00:42<00:10, 10.68it/s][A[A[A


action_type_accuracy: 0.8301, root_position_accuracy: 0.4894, loss: 1.6428 ||:  77%|███████▋  | 377/491 [00:42<00:10, 10.99it/s][A[A[A


action_type_accuracy: 0.8300, root_position_accuracy: 0.4891, loss: 1.6422 ||:  77%|███████▋  | 379/491 [00:42<00:09, 11.29it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8343, root_position_accuracy: 0.4920, loss: 1.6347 ||:  94%|█████████▍| 462/491 [00:51<00:03,  8.91it/s][A[A[A


action_type_accuracy: 0.8347, root_position_accuracy: 0.4926, loss: 1.6333 ||:  94%|█████████▍| 463/491 [00:51<00:03,  8.48it/s][A[A[A


action_type_accuracy: 0.8351, root_position_accuracy: 0.4931, loss: 1.6325 ||:  95%|█████████▍| 464/491 [00:51<00:03,  8.23it/s][A[A[A


action_type_accuracy: 0.8355, root_position_accuracy: 0.4937, loss: 1.6311 ||:  95%|█████████▍| 465/491 [00:51<00:03,  7.53it/s][A[A[A


action_type_accuracy: 0.8359, root_position_accuracy: 0.4939, loss: 1.6305 ||:  95%|█████████▍| 466/491 [00:51<00:03,  7.73it/s][A[A[A


action_type_accuracy: 0.8366, root_position_accuracy: 0.4948, loss: 1.6287 ||:  95%|█████████▌| 468/491 [00:52<00:02,  7.82it/s][A[A[A


action_type_accuracy: 0.8367, root_position_accuracy: 0.4949, loss: 1.6281 ||:  96%|█████████▌| 469/491 [00:52<00:03,  6.94it/s][A[A[A


action_type_accuracy




  0%|          | 0/360 [00:00<?, ?it/s][A[A[A


action_type_accuracy: 1.0000, root_position_accuracy: 0.3409, loss: 1.6987 ||:   0%|          | 1/360 [00:00<01:03,  5.66it/s][A[A[A


action_type_accuracy: 0.6639, root_position_accuracy: 0.4057, loss: 1.7164 ||:   1%|          | 3/360 [00:00<00:53,  6.72it/s][A[A[A


action_type_accuracy: 0.6261, root_position_accuracy: 0.4189, loss: 1.6990 ||:   1%|▏         | 5/360 [00:00<00:44,  7.91it/s][A[A[A


action_type_accuracy: 0.6118, root_position_accuracy: 0.4255, loss: 1.6871 ||:   2%|▏         | 7/360 [00:00<00:38,  9.28it/s][A[A[A


action_type_accuracy: 0.6055, root_position_accuracy: 0.4254, loss: 1.6965 ||:   2%|▎         | 9/360 [00:00<00:33, 10.63it/s][A[A[A


action_type_accuracy: 0.5929, root_position_accuracy: 0.4243, loss: 1.7100 ||:   3%|▎         | 11/360 [00:00<00:31, 11.14it/s][A[A[A


action_type_accuracy: 0.5884, root_position_accuracy: 0.4212, loss: 1.7247 ||:   4%|▎         | 13/360 [00:01<00:31, 

action_type_accuracy: 0.6948, root_position_accuracy: 0.4194, loss: 2.0439 ||:  41%|████      | 148/360 [00:08<00:12, 16.42it/s][A[A[A


action_type_accuracy: 0.6938, root_position_accuracy: 0.4194, loss: 2.0509 ||:  42%|████▏     | 150/360 [00:08<00:13, 16.03it/s][A[A[A


action_type_accuracy: 0.6926, root_position_accuracy: 0.4195, loss: 2.0582 ||:  42%|████▏     | 152/360 [00:08<00:12, 16.41it/s][A[A[A


action_type_accuracy: 0.6913, root_position_accuracy: 0.4193, loss: 2.0659 ||:  43%|████▎     | 154/360 [00:08<00:12, 16.30it/s][A[A[A


action_type_accuracy: 0.6900, root_position_accuracy: 0.4194, loss: 2.0748 ||:  44%|████▎     | 157/360 [00:08<00:11, 17.65it/s][A[A[A


action_type_accuracy: 0.6898, root_position_accuracy: 0.4199, loss: 2.0809 ||:  44%|████▍     | 160/360 [00:09<00:10, 18.53it/s][A[A[A


action_type_accuracy: 0.6889, root_position_accuracy: 0.4200, loss: 2.0858 ||:  45%|████▌     | 162/360 [00:09<00:10, 18.93it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.6622, root_position_accuracy: 0.4232, loss: 2.1308 ||:  86%|████████▌ | 308/360 [00:16<00:02, 19.37it/s][A[A[A


action_type_accuracy: 0.6659, root_position_accuracy: 0.4269, loss: 2.1191 ||:  86%|████████▋ | 311/360 [00:16<00:02, 19.19it/s][A[A[A


action_type_accuracy: 0.6696, root_position_accuracy: 0.4304, loss: 2.1079 ||:  87%|████████▋ | 314/360 [00:16<00:02, 20.38it/s][A[A[A


action_type_accuracy: 0.6717, root_position_accuracy: 0.4324, loss: 2.0966 ||:  89%|████████▊ | 319/360 [00:17<00:01, 24.42it/s][A[A[A


action_type_accuracy: 0.6754, root_position_accuracy: 0.4322, loss: 2.0777 ||:  90%|█████████ | 324/360 [00:17<00:01, 28.24it/s][A[A[A


action_type_accuracy: 0.6800, root_position_accuracy: 0.4372, loss: 2.0615 ||:  91%|█████████ | 328/360 [00:17<00:01, 25.33it/s][A[A[A


action_type_accuracy: 0.6833, root_position_accuracy: 0.4410, loss: 2.0494 ||:  92%|█████████▏| 331/360 [00:17<00:01, 26.21it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.9347, root_position_accuracy: 0.4422, loss: 1.9577 ||:   1%|          | 5/491 [00:01<01:56,  4.18it/s][A[A[A


action_type_accuracy: 0.9565, root_position_accuracy: 0.4422, loss: 1.6646 ||:   1%|          | 6/491 [00:01<01:35,  5.06it/s][A[A[A


action_type_accuracy: 0.9739, root_position_accuracy: 0.4422, loss: 1.2788 ||:   2%|▏         | 8/491 [00:01<01:16,  6.33it/s][A[A[A


action_type_accuracy: 0.9807, root_position_accuracy: 0.4422, loss: 1.0387 ||:   2%|▏         | 10/491 [00:01<01:00,  7.89it/s][A[A[A


action_type_accuracy: 0.9747, root_position_accuracy: 0.4170, loss: 1.2351 ||:   2%|▏         | 12/491 [00:01<01:08,  6.98it/s][A[A[A


action_type_accuracy: 0.9517, root_position_accuracy: 0.3724, loss: 1.3836 ||:   3%|▎         | 13/491 [00:01<01:10,  6.77it/s][A[A[A


action_type_accuracy: 0.9520, root_position_accuracy: 0.3754, loss: 1.4024 ||:   3%|▎         | 14/491 [00:02<01:09,  6.90it/s][A[A[A


action_type_accuracy: 0.9410, 

action_type_accuracy: 0.8703, root_position_accuracy: 0.4712, loss: 1.5161 ||:  16%|█▌        | 79/491 [00:10<00:40, 10.26it/s][A[A[A


action_type_accuracy: 0.8741, root_position_accuracy: 0.4712, loss: 1.4795 ||:  16%|█▋        | 81/491 [00:10<00:35, 11.54it/s][A[A[A


action_type_accuracy: 0.8776, root_position_accuracy: 0.4712, loss: 1.4444 ||:  17%|█▋        | 83/491 [00:10<00:33, 12.14it/s][A[A[A


action_type_accuracy: 0.8678, root_position_accuracy: 0.4626, loss: 1.4913 ||:  17%|█▋        | 85/491 [00:10<00:36, 11.12it/s][A[A[A


action_type_accuracy: 0.8636, root_position_accuracy: 0.4555, loss: 1.5153 ||:  18%|█▊        | 87/491 [00:10<00:40,  9.97it/s][A[A[A


action_type_accuracy: 0.8605, root_position_accuracy: 0.4526, loss: 1.5326 ||:  18%|█▊        | 89/491 [00:11<00:47,  8.50it/s][A[A[A


action_type_accuracy: 0.8598, root_position_accuracy: 0.4529, loss: 1.5388 ||:  18%|█▊        | 90/491 [00:11<01:01,  6.55it/s][A[A[A


action_type_accuracy: 0.858

action_type_accuracy: 0.8296, root_position_accuracy: 0.4307, loss: 1.7380 ||:  31%|███       | 151/491 [00:19<00:44,  7.65it/s][A[A[A


action_type_accuracy: 0.8294, root_position_accuracy: 0.4311, loss: 1.7407 ||:  31%|███       | 153/491 [00:19<00:41,  8.23it/s][A[A[A


action_type_accuracy: 0.8295, root_position_accuracy: 0.4303, loss: 1.7432 ||:  31%|███▏      | 154/491 [00:19<00:46,  7.32it/s][A[A[A


action_type_accuracy: 0.8291, root_position_accuracy: 0.4303, loss: 1.7451 ||:  32%|███▏      | 155/491 [00:19<00:43,  7.66it/s][A[A[A


action_type_accuracy: 0.8290, root_position_accuracy: 0.4300, loss: 1.7470 ||:  32%|███▏      | 156/491 [00:19<00:41,  8.01it/s][A[A[A


action_type_accuracy: 0.8286, root_position_accuracy: 0.4298, loss: 1.7486 ||:  32%|███▏      | 157/491 [00:20<00:42,  7.94it/s][A[A[A


action_type_accuracy: 0.8283, root_position_accuracy: 0.4302, loss: 1.7490 ||:  32%|███▏      | 158/491 [00:20<00:40,  8.18it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8490, root_position_accuracy: 0.4663, loss: 1.6104 ||:  48%|████▊     | 238/491 [00:28<00:35,  7.09it/s][A[A[A


action_type_accuracy: 0.8498, root_position_accuracy: 0.4671, loss: 1.6094 ||:  49%|████▊     | 239/491 [00:28<00:34,  7.21it/s][A[A[A


action_type_accuracy: 0.8505, root_position_accuracy: 0.4684, loss: 1.6079 ||:  49%|████▉     | 240/491 [00:28<00:34,  7.19it/s][A[A[A


action_type_accuracy: 0.8512, root_position_accuracy: 0.4692, loss: 1.6069 ||:  49%|████▉     | 241/491 [00:28<00:34,  7.28it/s][A[A[A


action_type_accuracy: 0.8519, root_position_accuracy: 0.4703, loss: 1.6053 ||:  49%|████▉     | 242/491 [00:28<00:34,  7.23it/s][A[A[A


action_type_accuracy: 0.8526, root_position_accuracy: 0.4711, loss: 1.6042 ||:  49%|████▉     | 243/491 [00:28<00:32,  7.62it/s][A[A[A


action_type_accuracy: 0.8533, root_position_accuracy: 0.4720, loss: 1.6030 ||:  50%|████▉     | 244/491 [00:29<00:31,  7.81it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8476, root_position_accuracy: 0.4797, loss: 1.6330 ||:  63%|██████▎   | 311/491 [00:37<00:19,  9.01it/s][A[A[A


action_type_accuracy: 0.8478, root_position_accuracy: 0.4799, loss: 1.6324 ||:  64%|██████▎   | 312/491 [00:37<00:19,  9.27it/s][A[A[A


action_type_accuracy: 0.8477, root_position_accuracy: 0.4797, loss: 1.6331 ||:  64%|██████▎   | 313/491 [00:37<00:18,  9.42it/s][A[A[A


action_type_accuracy: 0.8473, root_position_accuracy: 0.4794, loss: 1.6342 ||:  64%|██████▍   | 314/491 [00:37<00:18,  9.56it/s][A[A[A


action_type_accuracy: 0.8467, root_position_accuracy: 0.4786, loss: 1.6363 ||:  64%|██████▍   | 316/491 [00:37<00:18,  9.72it/s][A[A[A


action_type_accuracy: 0.8469, root_position_accuracy: 0.4787, loss: 1.6360 ||:  65%|██████▍   | 317/491 [00:37<00:20,  8.39it/s][A[A[A


action_type_accuracy: 0.8467, root_position_accuracy: 0.4781, loss: 1.6373 ||:  65%|██████▍   | 318/491 [00:37<00:21,  8.06it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8407, root_position_accuracy: 0.4680, loss: 1.6242 ||:  80%|███████▉  | 391/491 [00:45<00:07, 13.58it/s][A[A[A


action_type_accuracy: 0.8405, root_position_accuracy: 0.4680, loss: 1.6187 ||:  80%|████████  | 393/491 [00:46<00:07, 13.96it/s][A[A[A


action_type_accuracy: 0.8403, root_position_accuracy: 0.4680, loss: 1.6134 ||:  80%|████████  | 395/491 [00:46<00:06, 13.78it/s][A[A[A


action_type_accuracy: 0.8403, root_position_accuracy: 0.4680, loss: 1.6077 ||:  81%|████████  | 397/491 [00:46<00:06, 14.53it/s][A[A[A


action_type_accuracy: 0.8404, root_position_accuracy: 0.4680, loss: 1.6019 ||:  81%|████████▏ | 399/491 [00:46<00:06, 14.99it/s][A[A[A


action_type_accuracy: 0.8400, root_position_accuracy: 0.4677, loss: 1.6032 ||:  82%|████████▏ | 401/491 [00:46<00:06, 13.67it/s][A[A[A


action_type_accuracy: 0.8395, root_position_accuracy: 0.4674, loss: 1.6068 ||:  82%|████████▏ | 403/491 [00:46<00:07, 11.01it/s][A[A[A


action_type_accuracy

action_type_accuracy: 0.8424, root_position_accuracy: 0.4824, loss: 1.5939 ||:  99%|█████████▉| 488/491 [00:56<00:00, 10.75it/s][A[A[A


action_type_accuracy: 0.8425, root_position_accuracy: 0.4822, loss: 1.5937 ||: 100%|█████████▉| 490/491 [00:56<00:00, 10.58it/s][A[A[A


action_type_accuracy: 0.8423, root_position_accuracy: 0.4817, loss: 1.5945 ||: : 492it [00:56, 10.32it/s]                       [A[A[A


action_type_accuracy: 0.8424, root_position_accuracy: 0.4813, loss: 1.5945 ||: : 494it [00:56, 10.43it/s][A[A[A


action_type_accuracy: 0.8423, root_position_accuracy: 0.4812, loss: 1.5944 ||: : 496it [00:57,  6.55it/s][A[A[A


action_type_accuracy: 0.8423, root_position_accuracy: 0.4812, loss: 1.5957 ||: : 497it [00:57,  7.04it/s][A[A[A


action_type_accuracy: 0.8424, root_position_accuracy: 0.4831, loss: 1.5919 ||: : 499it [00:57,  8.14it/s][A[A[A


action_type_accuracy: 0.8425, root_position_accuracy: 0.4849, loss: 1.5878 ||: : 501it [00:57,  9.40it/s][A[A

action_type_accuracy: 0.7647, root_position_accuracy: 0.4979, loss: 2.0665 ||:  12%|█▏        | 44/360 [00:02<00:20, 15.78it/s][A[A[A


action_type_accuracy: 0.7568, root_position_accuracy: 0.4942, loss: 2.1018 ||:  13%|█▎        | 46/360 [00:02<00:19, 16.49it/s][A[A[A


action_type_accuracy: 0.7492, root_position_accuracy: 0.4902, loss: 2.1447 ||:  14%|█▎        | 49/360 [00:02<00:16, 18.50it/s][A[A[A


action_type_accuracy: 0.7596, root_position_accuracy: 0.4856, loss: 2.1160 ||:  14%|█▍        | 51/360 [00:02<00:18, 16.67it/s][A[A[A


action_type_accuracy: 0.7625, root_position_accuracy: 0.4848, loss: 2.0871 ||:  15%|█▌        | 55/360 [00:03<00:15, 19.28it/s][A[A[A


action_type_accuracy: 0.7706, root_position_accuracy: 0.4946, loss: 2.0586 ||:  16%|█▌        | 58/360 [00:03<00:16, 18.82it/s][A[A[A


action_type_accuracy: 0.7829, root_position_accuracy: 0.5108, loss: 1.9907 ||:  17%|█▋        | 61/360 [00:03<00:17, 17.37it/s][A[A[A


action_type_accuracy: 0.790

action_type_accuracy: 0.7584, root_position_accuracy: 0.4904, loss: 2.2826 ||:  49%|████▊     | 175/360 [00:17<02:21,  1.31it/s][A
action_type_accuracy: 0.7563, root_position_accuracy: 0.4891, loss: 2.2866 ||:  49%|████▉     | 177/360 [00:17<01:41,  1.80it/s][A
action_type_accuracy: 0.7540, root_position_accuracy: 0.4880, loss: 2.2909 ||:  50%|████▉     | 179/360 [00:17<01:14,  2.43it/s][A
action_type_accuracy: 0.7515, root_position_accuracy: 0.4867, loss: 2.2955 ||:  50%|█████     | 181/360 [00:17<00:55,  3.22it/s][A
action_type_accuracy: 0.7494, root_position_accuracy: 0.4859, loss: 2.2987 ||:  51%|█████     | 183/360 [00:17<00:42,  4.16it/s][A
action_type_accuracy: 0.7480, root_position_accuracy: 0.4853, loss: 2.2997 ||:  51%|█████▏    | 185/360 [00:17<00:33,  5.25it/s][A
action_type_accuracy: 0.7463, root_position_accuracy: 0.4846, loss: 2.3017 ||:  52%|█████▏    | 187/360 [00:18<00:26,  6.43it/s][A
action_type_accuracy: 0.7441, root_position_accuracy: 0.4836, loss: 2.3041 |

action_type_accuracy: 0.7302, root_position_accuracy: 0.4689, loss: 2.2431 ||:  93%|█████████▎| 335/360 [00:26<00:01, 14.90it/s][A
action_type_accuracy: 0.7285, root_position_accuracy: 0.4688, loss: 2.2397 ||:  94%|█████████▎| 337/360 [00:26<00:01, 14.58it/s][A
action_type_accuracy: 0.7271, root_position_accuracy: 0.4687, loss: 2.2359 ||:  94%|█████████▍| 339/360 [00:26<00:01, 15.42it/s][A
action_type_accuracy: 0.7259, root_position_accuracy: 0.4685, loss: 2.2321 ||:  95%|█████████▍| 341/360 [00:26<00:01, 14.67it/s][A
action_type_accuracy: 0.7247, root_position_accuracy: 0.4684, loss: 2.2285 ||:  95%|█████████▌| 343/360 [00:26<00:01, 12.97it/s][A
action_type_accuracy: 0.7235, root_position_accuracy: 0.4684, loss: 2.2244 ||:  96%|█████████▌| 345/360 [00:26<00:01, 13.55it/s][A
action_type_accuracy: 0.7254, root_position_accuracy: 0.4679, loss: 2.2199 ||:  96%|█████████▋| 347/360 [00:26<00:00, 14.74it/s][A
action_type_accuracy: 0.7265, root_position_accuracy: 0.4679, loss: 2.2092 |

action_type_accuracy: 0.7231, root_position_accuracy: 0.3988, loss: 1.8537 ||:   9%|▉         | 33/360 [00:01<00:18, 18.09it/s][A
action_type_accuracy: 0.7177, root_position_accuracy: 0.3721, loss: 1.8771 ||:  10%|█         | 36/360 [00:02<00:20, 15.85it/s][A
action_type_accuracy: 0.7122, root_position_accuracy: 0.3775, loss: 1.8889 ||:  11%|█         | 38/360 [00:02<00:19, 16.24it/s][A
action_type_accuracy: 0.7070, root_position_accuracy: 0.3653, loss: 1.9195 ||:  11%|█         | 40/360 [00:02<00:20, 15.46it/s][A
action_type_accuracy: 0.7045, root_position_accuracy: 0.3556, loss: 1.9413 ||:  12%|█▏        | 42/360 [00:02<00:20, 15.31it/s][A
action_type_accuracy: 0.7038, root_position_accuracy: 0.3512, loss: 1.9832 ||:  12%|█▎        | 45/360 [00:02<00:19, 16.45it/s][A
action_type_accuracy: 0.7052, root_position_accuracy: 0.3551, loss: 1.9958 ||:  13%|█▎        | 47/360 [00:02<00:18, 17.20it/s][A
action_type_accuracy: 0.7090, root_position_accuracy: 0.3611, loss: 2.0110 ||:  14%

action_type_accuracy: 0.7399, root_position_accuracy: 0.4900, loss: 1.8755 ||:  53%|█████▎    | 191/360 [00:10<00:10, 15.44it/s][A
action_type_accuracy: 0.7399, root_position_accuracy: 0.4897, loss: 1.8750 ||:  54%|█████▍    | 195/360 [00:10<00:08, 18.56it/s][A
action_type_accuracy: 0.7393, root_position_accuracy: 0.4863, loss: 1.8845 ||:  55%|█████▌    | 198/360 [00:10<00:09, 17.54it/s][A
action_type_accuracy: 0.7389, root_position_accuracy: 0.4836, loss: 1.8911 ||:  56%|█████▌    | 200/360 [00:10<00:09, 16.53it/s][A
action_type_accuracy: 0.7384, root_position_accuracy: 0.4860, loss: 1.8892 ||:  56%|█████▌    | 202/360 [00:11<00:10, 15.63it/s][A
action_type_accuracy: 0.7375, root_position_accuracy: 0.4881, loss: 1.8889 ||:  57%|█████▋    | 204/360 [00:11<00:09, 16.58it/s][A
action_type_accuracy: 0.7378, root_position_accuracy: 0.4890, loss: 1.8889 ||:  57%|█████▊    | 207/360 [00:11<00:09, 16.94it/s][A
action_type_accuracy: 0.7381, root_position_accuracy: 0.4880, loss: 1.8885 |

action_type_accuracy: 0.7380, root_position_accuracy: 0.4855, loss: 1.8182 ||:  97%|█████████▋| 349/360 [00:19<00:00, 16.59it/s][A
action_type_accuracy: 0.7390, root_position_accuracy: 0.4853, loss: 1.8130 ||:  98%|█████████▊| 351/360 [00:19<00:00, 16.62it/s][A
action_type_accuracy: 0.7389, root_position_accuracy: 0.4845, loss: 1.8229 ||:  98%|█████████▊| 354/360 [00:19<00:00, 17.45it/s][A
action_type_accuracy: 0.7389, root_position_accuracy: 0.4842, loss: 1.8241 ||:  99%|█████████▉| 356/360 [00:19<00:00, 17.19it/s][A
action_type_accuracy: 0.7390, root_position_accuracy: 0.4840, loss: 1.8257 ||:  99%|█████████▉| 358/360 [00:19<00:00, 15.61it/s][A
action_type_accuracy: 0.7391, root_position_accuracy: 0.4836, loss: 1.8283 ||: 100%|██████████| 360/360 [00:19<00:00, 15.47it/s][A
action_type_accuracy: 0.7392, root_position_accuracy: 0.4832, loss: 1.8300 ||: : 362it [00:19, 15.51it/s]                       [A
action_type_accuracy: 0.7394, root_position_accuracy: 0.4828, loss: 1.8312 |

action_type_accuracy: 0.8371, root_position_accuracy: 0.5162, loss: 1.8951 ||:   4%|▍         | 22/491 [00:03<00:57,  8.12it/s][A
action_type_accuracy: 0.8390, root_position_accuracy: 0.5113, loss: 1.8894 ||:   5%|▍         | 23/491 [00:03<00:56,  8.27it/s][A
action_type_accuracy: 0.8407, root_position_accuracy: 0.5102, loss: 1.8816 ||:   5%|▍         | 24/491 [00:04<00:54,  8.54it/s][A
action_type_accuracy: 0.8396, root_position_accuracy: 0.5056, loss: 1.8808 ||:   5%|▌         | 25/491 [00:04<00:53,  8.74it/s][A
action_type_accuracy: 0.8382, root_position_accuracy: 0.5010, loss: 1.8813 ||:   5%|▌         | 26/491 [00:04<00:51,  8.99it/s][A
action_type_accuracy: 0.8396, root_position_accuracy: 0.5000, loss: 1.8729 ||:   5%|▌         | 27/491 [00:04<00:55,  8.43it/s][A
action_type_accuracy: 0.8359, root_position_accuracy: 0.4928, loss: 1.8762 ||:   6%|▌         | 29/491 [00:04<00:53,  8.70it/s][A
action_type_accuracy: 0.8348, root_position_accuracy: 0.4907, loss: 1.8746 ||:   6%

action_type_accuracy: 0.8483, root_position_accuracy: 0.4871, loss: 1.6071 ||:  23%|██▎       | 113/491 [00:14<00:52,  7.14it/s][A
action_type_accuracy: 0.8498, root_position_accuracy: 0.4895, loss: 1.6041 ||:  23%|██▎       | 114/491 [00:14<00:53,  7.11it/s][A
action_type_accuracy: 0.8499, root_position_accuracy: 0.4894, loss: 1.6020 ||:  23%|██▎       | 115/491 [00:14<00:55,  6.79it/s][A
action_type_accuracy: 0.8506, root_position_accuracy: 0.4891, loss: 1.5997 ||:  24%|██▎       | 116/491 [00:14<00:51,  7.25it/s][A
action_type_accuracy: 0.8486, root_position_accuracy: 0.4869, loss: 1.6095 ||:  24%|██▍       | 117/491 [00:14<00:59,  6.33it/s][A
action_type_accuracy: 0.8486, root_position_accuracy: 0.4868, loss: 1.6256 ||:  24%|██▍       | 118/491 [00:14<00:52,  7.09it/s][A
action_type_accuracy: 0.8489, root_position_accuracy: 0.4871, loss: 1.6277 ||:  24%|██▍       | 119/491 [00:15<01:00,  6.11it/s][A
action_type_accuracy: 0.8489, root_position_accuracy: 0.4870, loss: 1.6343 |

action_type_accuracy: 0.8553, root_position_accuracy: 0.4840, loss: 1.5966 ||:  40%|████      | 197/491 [00:23<00:32,  9.13it/s][A
action_type_accuracy: 0.8550, root_position_accuracy: 0.4837, loss: 1.5977 ||:  40%|████      | 198/491 [00:23<00:32,  9.10it/s][A
action_type_accuracy: 0.8549, root_position_accuracy: 0.4834, loss: 1.5991 ||:  41%|████      | 199/491 [00:23<00:31,  9.20it/s][A
action_type_accuracy: 0.8540, root_position_accuracy: 0.4822, loss: 1.6021 ||:  41%|████      | 201/491 [00:24<00:31,  9.27it/s][A
action_type_accuracy: 0.8541, root_position_accuracy: 0.4820, loss: 1.6022 ||:  41%|████      | 202/491 [00:24<00:31,  9.17it/s][A
action_type_accuracy: 0.8539, root_position_accuracy: 0.4818, loss: 1.6024 ||:  41%|████▏     | 203/491 [00:24<00:31,  9.13it/s][A
action_type_accuracy: 0.8538, root_position_accuracy: 0.4814, loss: 1.6032 ||:  42%|████▏     | 204/491 [00:24<00:30,  9.33it/s][A
action_type_accuracy: 0.8531, root_position_accuracy: 0.4802, loss: 1.6058 |

action_type_accuracy: 0.8517, root_position_accuracy: 0.4737, loss: 1.5964 ||:  57%|█████▋    | 282/491 [00:32<00:29,  7.03it/s][A
action_type_accuracy: 0.8512, root_position_accuracy: 0.4733, loss: 1.5985 ||:  58%|█████▊    | 283/491 [00:33<00:31,  6.61it/s][A
action_type_accuracy: 0.8511, root_position_accuracy: 0.4729, loss: 1.6010 ||:  58%|█████▊    | 284/491 [00:33<00:33,  6.20it/s][A
action_type_accuracy: 0.8512, root_position_accuracy: 0.4731, loss: 1.6014 ||:  58%|█████▊    | 285/491 [00:33<00:31,  6.49it/s][A
action_type_accuracy: 0.8513, root_position_accuracy: 0.4728, loss: 1.6027 ||:  58%|█████▊    | 286/491 [00:33<00:34,  5.94it/s][A
action_type_accuracy: 0.8507, root_position_accuracy: 0.4726, loss: 1.6043 ||:  58%|█████▊    | 287/491 [00:33<00:33,  6.09it/s][A
action_type_accuracy: 0.8504, root_position_accuracy: 0.4723, loss: 1.6056 ||:  59%|█████▊    | 288/491 [00:33<00:30,  6.56it/s][A
action_type_accuracy: 0.8503, root_position_accuracy: 0.4720, loss: 1.6071 |

action_type_accuracy: 0.8502, root_position_accuracy: 0.4742, loss: 1.5822 ||:  75%|███████▍  | 366/491 [00:42<00:12, 10.25it/s][A
action_type_accuracy: 0.8501, root_position_accuracy: 0.4733, loss: 1.5842 ||:  75%|███████▍  | 368/491 [00:42<00:11, 11.02it/s][A
action_type_accuracy: 0.8509, root_position_accuracy: 0.4727, loss: 1.5830 ||:  75%|███████▌  | 370/491 [00:42<00:10, 11.32it/s][A
action_type_accuracy: 0.8515, root_position_accuracy: 0.4735, loss: 1.5814 ||:  76%|███████▌  | 372/491 [00:43<00:11, 10.66it/s][A
action_type_accuracy: 0.8523, root_position_accuracy: 0.4747, loss: 1.5805 ||:  76%|███████▌  | 374/491 [00:43<00:12,  9.04it/s][A
action_type_accuracy: 0.8528, root_position_accuracy: 0.4756, loss: 1.5792 ||:  76%|███████▋  | 375/491 [00:43<00:13,  8.33it/s][A
action_type_accuracy: 0.8532, root_position_accuracy: 0.4763, loss: 1.5786 ||:  77%|███████▋  | 376/491 [00:43<00:15,  7.44it/s][A
action_type_accuracy: 0.8535, root_position_accuracy: 0.4766, loss: 1.5776 |

action_type_accuracy: 0.8552, root_position_accuracy: 0.4848, loss: 1.5827 ||:  91%|█████████ | 445/491 [00:52<00:06,  7.38it/s][A
action_type_accuracy: 0.8550, root_position_accuracy: 0.4844, loss: 1.5842 ||:  91%|█████████ | 447/491 [00:52<00:05,  8.09it/s][A
action_type_accuracy: 0.8549, root_position_accuracy: 0.4840, loss: 1.5855 ||:  91%|█████████▏| 449/491 [00:52<00:04,  8.89it/s][A
action_type_accuracy: 0.8547, root_position_accuracy: 0.4837, loss: 1.5866 ||:  92%|█████████▏| 450/491 [00:53<00:04,  8.83it/s][A
action_type_accuracy: 0.8547, root_position_accuracy: 0.4836, loss: 1.5872 ||:  92%|█████████▏| 451/491 [00:53<00:04,  8.13it/s][A
action_type_accuracy: 0.8546, root_position_accuracy: 0.4835, loss: 1.5879 ||:  92%|█████████▏| 452/491 [00:53<00:05,  7.26it/s][A
action_type_accuracy: 0.8546, root_position_accuracy: 0.4833, loss: 1.5888 ||:  92%|█████████▏| 453/491 [00:53<00:05,  7.22it/s][A
action_type_accuracy: 0.8546, root_position_accuracy: 0.4834, loss: 1.5890 |

action_type_accuracy: 0.8632, root_position_accuracy: 0.4953, loss: 1.5429 ||: : 550it [01:03,  9.90it/s][A
action_type_accuracy: 0.8628, root_position_accuracy: 0.4944, loss: 1.5453 ||: : 552it [01:03, 10.04it/s][A
action_type_accuracy: 0.8628, root_position_accuracy: 0.4944, loss: 1.5465 ||: : 554it [01:04, 10.53it/s][A
action_type_accuracy: 0.8631, root_position_accuracy: 0.4941, loss: 1.5457 ||: : 556it [01:04, 10.86it/s][A
action_type_accuracy: 0.8635, root_position_accuracy: 0.4936, loss: 1.5454 ||: : 558it [01:04, 11.11it/s][AINFO     [allennlp.training.trainer:404] Validating

  0%|          | 0/360 [00:00<?, ?it/s][A
action_type_accuracy: 1.0000, root_position_accuracy: 0.6598, loss: 1.3310 ||:   0%|          | 1/360 [00:00<00:56,  6.38it/s][A
action_type_accuracy: 1.0000, root_position_accuracy: 0.6902, loss: 1.2428 ||:   1%|          | 3/360 [00:00<00:47,  7.52it/s][A
action_type_accuracy: 1.0000, root_position_accuracy: 0.6942, loss: 1.2193 ||:   1%|▏         | 5/36

action_type_accuracy: 0.7915, root_position_accuracy: 0.5029, loss: 1.8219 ||:  41%|████▏     | 149/360 [00:08<00:13, 15.80it/s][A
action_type_accuracy: 0.7935, root_position_accuracy: 0.5053, loss: 1.8078 ||:  42%|████▏     | 151/360 [00:08<00:12, 16.61it/s][A
action_type_accuracy: 0.7962, root_position_accuracy: 0.5063, loss: 1.8016 ||:  42%|████▎     | 153/360 [00:08<00:12, 16.71it/s][A
action_type_accuracy: 0.7986, root_position_accuracy: 0.5083, loss: 1.7955 ||:  43%|████▎     | 155/360 [00:08<00:13, 15.32it/s][A
action_type_accuracy: 0.8016, root_position_accuracy: 0.5118, loss: 1.7819 ||:  44%|████▎     | 157/360 [00:08<00:14, 14.35it/s][A
action_type_accuracy: 0.8044, root_position_accuracy: 0.5157, loss: 1.7680 ||:  44%|████▍     | 159/360 [00:09<00:15, 13.06it/s][A
action_type_accuracy: 0.8072, root_position_accuracy: 0.5197, loss: 1.7540 ||:  45%|████▍     | 161/360 [00:09<00:15, 13.23it/s][A
action_type_accuracy: 0.8100, root_position_accuracy: 0.5229, loss: 1.7414 |

action_type_accuracy: 0.7729, root_position_accuracy: 0.4981, loss: 1.8072 ||:  85%|████████▍ | 305/360 [00:17<00:03, 15.35it/s][A
action_type_accuracy: 0.7714, root_position_accuracy: 0.4973, loss: 1.8068 ||:  85%|████████▌ | 307/360 [00:17<00:03, 13.66it/s][A
action_type_accuracy: 0.7701, root_position_accuracy: 0.4967, loss: 1.8063 ||:  86%|████████▌ | 309/360 [00:17<00:04, 12.04it/s][A
action_type_accuracy: 0.7686, root_position_accuracy: 0.4960, loss: 1.8063 ||:  86%|████████▋ | 311/360 [00:17<00:03, 12.45it/s][A
action_type_accuracy: 0.7691, root_position_accuracy: 0.4957, loss: 1.8076 ||:  87%|████████▋ | 313/360 [00:17<00:03, 13.58it/s][A
action_type_accuracy: 0.7705, root_position_accuracy: 0.4960, loss: 1.8043 ||:  88%|████████▊ | 315/360 [00:18<00:03, 12.73it/s][A
action_type_accuracy: 0.7708, root_position_accuracy: 0.4951, loss: 1.8038 ||:  88%|████████▊ | 318/360 [00:18<00:02, 14.45it/s][A
action_type_accuracy: 0.7693, root_position_accuracy: 0.4943, loss: 1.8057 |

action_type_accuracy: 0.7461, root_position_accuracy: 0.4249, loss: 2.3124 ||:   1%|          | 4/491 [00:00<02:05,  3.89it/s][A
action_type_accuracy: 0.7162, root_position_accuracy: 0.3446, loss: 2.1920 ||:   1%|          | 6/491 [00:01<01:40,  4.82it/s][A
action_type_accuracy: 0.7533, root_position_accuracy: 0.3622, loss: 2.2047 ||:   2%|▏         | 8/491 [00:01<01:26,  5.56it/s][A
action_type_accuracy: 0.7618, root_position_accuracy: 0.3600, loss: 2.2360 ||:   2%|▏         | 9/491 [00:01<01:22,  5.83it/s][A
action_type_accuracy: 0.7805, root_position_accuracy: 0.3597, loss: 2.0774 ||:   2%|▏         | 11/491 [00:01<01:08,  6.98it/s][A
action_type_accuracy: 0.8129, root_position_accuracy: 0.3597, loss: 1.8019 ||:   3%|▎         | 13/491 [00:01<00:57,  8.32it/s][A
action_type_accuracy: 0.8475, root_position_accuracy: 0.3597, loss: 1.5771 ||:   3%|▎         | 15/491 [00:02<00:58,  8.20it/s][A
action_type_accuracy: 0.8593, root_position_accuracy: 0.3597, loss: 1.4844 ||:   3%|▎  

action_type_accuracy: 0.8564, root_position_accuracy: 0.5049, loss: 1.6387 ||:  19%|█▉        | 94/491 [00:11<00:59,  6.65it/s][A
action_type_accuracy: 0.8582, root_position_accuracy: 0.5085, loss: 1.6321 ||:  19%|█▉        | 95/491 [00:11<00:54,  7.28it/s][A
action_type_accuracy: 0.8599, root_position_accuracy: 0.5121, loss: 1.6255 ||:  20%|█▉        | 96/491 [00:11<00:50,  7.77it/s][A
action_type_accuracy: 0.8615, root_position_accuracy: 0.5145, loss: 1.6207 ||:  20%|█▉        | 97/491 [00:11<00:49,  8.04it/s][A
action_type_accuracy: 0.8632, root_position_accuracy: 0.5168, loss: 1.6161 ||:  20%|█▉        | 98/491 [00:11<00:47,  8.34it/s][A
action_type_accuracy: 0.8647, root_position_accuracy: 0.5197, loss: 1.6100 ||:  20%|██        | 99/491 [00:11<00:46,  8.50it/s][A
action_type_accuracy: 0.8663, root_position_accuracy: 0.5212, loss: 1.6068 ||:  20%|██        | 100/491 [00:11<00:45,  8.58it/s][A
action_type_accuracy: 0.8678, root_position_accuracy: 0.5232, loss: 1.6021 ||:  21

action_type_accuracy: 0.8824, root_position_accuracy: 0.5274, loss: 1.5941 ||:  34%|███▍      | 168/491 [00:20<00:39,  8.08it/s][A
action_type_accuracy: 0.8832, root_position_accuracy: 0.5287, loss: 1.5918 ||:  34%|███▍      | 169/491 [00:20<00:51,  6.24it/s][A
action_type_accuracy: 0.8836, root_position_accuracy: 0.5291, loss: 1.5901 ||:  35%|███▍      | 170/491 [00:20<00:52,  6.10it/s][A
action_type_accuracy: 0.8818, root_position_accuracy: 0.5279, loss: 1.5964 ||:  35%|███▍      | 171/491 [00:20<00:51,  6.26it/s][A
action_type_accuracy: 0.8796, root_position_accuracy: 0.5270, loss: 1.6002 ||:  35%|███▌      | 172/491 [00:20<00:49,  6.40it/s][A
action_type_accuracy: 0.8794, root_position_accuracy: 0.5259, loss: 1.6015 ||:  35%|███▌      | 173/491 [00:20<00:52,  6.08it/s][A
action_type_accuracy: 0.8790, root_position_accuracy: 0.5250, loss: 1.6028 ||:  35%|███▌      | 174/491 [00:21<00:49,  6.36it/s][A
action_type_accuracy: 0.8789, root_position_accuracy: 0.5242, loss: 1.6037 |

action_type_accuracy: 0.8717, root_position_accuracy: 0.4980, loss: 1.6104 ||:  50%|█████     | 247/491 [00:29<00:22, 10.71it/s][A
action_type_accuracy: 0.8714, root_position_accuracy: 0.4977, loss: 1.6048 ||:  51%|█████     | 249/491 [00:29<00:20, 11.57it/s][A
action_type_accuracy: 0.8713, root_position_accuracy: 0.4972, loss: 1.5993 ||:  51%|█████     | 251/491 [00:30<00:20, 11.66it/s][A
action_type_accuracy: 0.8715, root_position_accuracy: 0.4970, loss: 1.5926 ||:  52%|█████▏    | 253/491 [00:30<00:21, 10.90it/s][A
action_type_accuracy: 0.8720, root_position_accuracy: 0.4971, loss: 1.5855 ||:  52%|█████▏    | 255/491 [00:30<00:21, 10.96it/s][A
action_type_accuracy: 0.8723, root_position_accuracy: 0.4964, loss: 1.5793 ||:  52%|█████▏    | 257/491 [00:30<00:19, 12.13it/s][A
action_type_accuracy: 0.8726, root_position_accuracy: 0.4956, loss: 1.5730 ||:  53%|█████▎    | 259/491 [00:30<00:18, 12.70it/s][A
action_type_accuracy: 0.8730, root_position_accuracy: 0.4952, loss: 1.5666 |

action_type_accuracy: 0.8727, root_position_accuracy: 0.4883, loss: 1.5608 ||:  71%|███████   | 347/491 [00:40<00:21,  6.69it/s][A
action_type_accuracy: 0.8727, root_position_accuracy: 0.4889, loss: 1.5604 ||:  71%|███████   | 348/491 [00:40<00:21,  6.77it/s][A
action_type_accuracy: 0.8726, root_position_accuracy: 0.4883, loss: 1.5605 ||:  71%|███████   | 349/491 [00:40<00:19,  7.36it/s][A
action_type_accuracy: 0.8726, root_position_accuracy: 0.4882, loss: 1.5598 ||:  71%|███████▏  | 350/491 [00:40<00:17,  7.85it/s][A
action_type_accuracy: 0.8726, root_position_accuracy: 0.4875, loss: 1.5588 ||:  72%|███████▏  | 352/491 [00:40<00:18,  7.33it/s][A
action_type_accuracy: 0.8725, root_position_accuracy: 0.4872, loss: 1.5586 ||:  72%|███████▏  | 353/491 [00:41<00:19,  7.08it/s][A
action_type_accuracy: 0.8723, root_position_accuracy: 0.4867, loss: 1.5587 ||:  72%|███████▏  | 354/491 [00:41<00:18,  7.52it/s][A
action_type_accuracy: 0.8719, root_position_accuracy: 0.4862, loss: 1.5581 |

action_type_accuracy: 0.8679, root_position_accuracy: 0.4788, loss: 1.5451 ||:  88%|████████▊ | 433/491 [00:50<00:04, 12.92it/s][A
action_type_accuracy: 0.8677, root_position_accuracy: 0.4788, loss: 1.5403 ||:  89%|████████▊ | 435/491 [00:50<00:04, 13.81it/s][A
action_type_accuracy: 0.8674, root_position_accuracy: 0.4788, loss: 1.5357 ||:  89%|████████▉ | 437/491 [00:50<00:03, 14.19it/s][A
action_type_accuracy: 0.8672, root_position_accuracy: 0.4788, loss: 1.5309 ||:  89%|████████▉ | 439/491 [00:50<00:03, 13.97it/s][A
action_type_accuracy: 0.8670, root_position_accuracy: 0.4788, loss: 1.5263 ||:  90%|████████▉ | 441/491 [00:50<00:03, 14.97it/s][A
action_type_accuracy: 0.8673, root_position_accuracy: 0.4784, loss: 1.5275 ||:  90%|█████████ | 443/491 [00:50<00:04, 11.49it/s][A
action_type_accuracy: 0.8678, root_position_accuracy: 0.4780, loss: 1.5271 ||:  91%|█████████ | 445/491 [00:50<00:04, 10.87it/s][A
action_type_accuracy: 0.8685, root_position_accuracy: 0.4792, loss: 1.5255 |

action_type_accuracy: 0.8668, root_position_accuracy: 0.4860, loss: 1.5301 ||: : 533it [01:00,  8.47it/s][A
action_type_accuracy: 0.8667, root_position_accuracy: 0.4860, loss: 1.5306 ||: : 534it [01:00,  7.44it/s][A
action_type_accuracy: 0.8665, root_position_accuracy: 0.4858, loss: 1.5314 ||: : 535it [01:00,  7.48it/s][A
action_type_accuracy: 0.8664, root_position_accuracy: 0.4856, loss: 1.5325 ||: : 536it [01:01,  7.18it/s][A
action_type_accuracy: 0.8663, root_position_accuracy: 0.4855, loss: 1.5333 ||: : 537it [01:01,  7.02it/s][A
action_type_accuracy: 0.8663, root_position_accuracy: 0.4855, loss: 1.5337 ||: : 538it [01:01,  7.14it/s][A
action_type_accuracy: 0.8664, root_position_accuracy: 0.4856, loss: 1.5340 ||: : 539it [01:01,  7.27it/s][A
action_type_accuracy: 0.8663, root_position_accuracy: 0.4854, loss: 1.5350 ||: : 540it [01:01,  7.23it/s][A
action_type_accuracy: 0.8662, root_position_accuracy: 0.4851, loss: 1.5359 ||: : 541it [01:01,  7.46it/s][A
action_type_accurac

action_type_accuracy: 0.8658, root_position_accuracy: 0.5317, loss: 1.5591 ||:  35%|███▌      | 126/360 [00:06<00:13, 17.05it/s][A
action_type_accuracy: 0.8636, root_position_accuracy: 0.5291, loss: 1.5609 ||:  36%|███▌      | 128/360 [00:06<00:14, 15.52it/s][A
action_type_accuracy: 0.8608, root_position_accuracy: 0.5258, loss: 1.5650 ||:  36%|███▌      | 130/360 [00:06<00:14, 15.54it/s][A
action_type_accuracy: 0.8589, root_position_accuracy: 0.5237, loss: 1.5667 ||:  37%|███▋      | 132/360 [00:07<00:14, 16.02it/s][A
action_type_accuracy: 0.8559, root_position_accuracy: 0.5209, loss: 1.5703 ||:  37%|███▋      | 134/360 [00:07<00:13, 16.36it/s][A
action_type_accuracy: 0.8540, root_position_accuracy: 0.5193, loss: 1.5739 ||:  38%|███▊      | 136/360 [00:07<00:14, 15.99it/s][A
action_type_accuracy: 0.8523, root_position_accuracy: 0.5179, loss: 1.5773 ||:  38%|███▊      | 138/360 [00:07<00:13, 15.88it/s][A
action_type_accuracy: 0.8498, root_position_accuracy: 0.5161, loss: 1.5848 |

action_type_accuracy: 0.8155, root_position_accuracy: 0.4794, loss: 1.7516 ||:  72%|███████▏  | 260/360 [00:15<00:05, 18.45it/s][A
action_type_accuracy: 0.8171, root_position_accuracy: 0.4815, loss: 1.7479 ||:  73%|███████▎  | 262/360 [00:15<00:05, 16.64it/s][A
action_type_accuracy: 0.8188, root_position_accuracy: 0.4840, loss: 1.7430 ||:  73%|███████▎  | 264/360 [00:15<00:06, 15.27it/s][A
action_type_accuracy: 0.8181, root_position_accuracy: 0.4829, loss: 1.7451 ||:  74%|███████▍  | 266/360 [00:15<00:06, 14.32it/s][A
action_type_accuracy: 0.8167, root_position_accuracy: 0.4805, loss: 1.7514 ||:  74%|███████▍  | 268/360 [00:15<00:06, 14.68it/s][A
action_type_accuracy: 0.8150, root_position_accuracy: 0.4781, loss: 1.7585 ||:  75%|███████▌  | 270/360 [00:15<00:06, 14.69it/s][A
action_type_accuracy: 0.8148, root_position_accuracy: 0.4775, loss: 1.7629 ||:  76%|███████▌  | 273/360 [00:15<00:05, 16.32it/s][A
action_type_accuracy: 0.8139, root_position_accuracy: 0.4767, loss: 1.7643 |

INFO     [allennlp.training.tensorboard_writer:174] loss                   |     1.524  |     1.736
INFO     [allennlp.training.tensorboard_writer:178] cpu_memory_MB          |  7231.580  |       N/A
INFO     [allennlp.training.tensorboard_writer:178] gpu_7_memory_MB        |    10.000  |       N/A
INFO     [allennlp.training.tensorboard_writer:178] gpu_2_memory_MB        |    11.000  |       N/A
INFO     [allennlp.training.tensorboard_writer:178] gpu_1_memory_MB        |    11.000  |       N/A
INFO     [allennlp.training.tensorboard_writer:174] root_position_accuracy |     0.494  |     0.475
INFO     [allennlp.training.tensorboard_writer:178] gpu_0_memory_MB        |    10.000  |       N/A
INFO     [allennlp.training.tensorboard_writer:178] gpu_3_memory_MB        |    10.000  |       N/A
INFO     [allennlp.training.tensorboard_writer:178] gpu_4_memory_MB        |    11.000  |       N/A
INFO     [allennlp.training.tensorboard_writer:178] gpu_6_memory_MB        |    11.000  |       N/A


action_type_accuracy: 0.8537, root_position_accuracy: 0.5016, loss: 1.2503 ||:  13%|█▎        | 62/491 [00:07<00:32, 13.30it/s][A
action_type_accuracy: 0.8558, root_position_accuracy: 0.5016, loss: 1.2297 ||:  13%|█▎        | 64/491 [00:07<00:31, 13.56it/s][A
action_type_accuracy: 0.8553, root_position_accuracy: 0.5008, loss: 1.2947 ||:  13%|█▎        | 66/491 [00:07<00:31, 13.34it/s][A
action_type_accuracy: 0.8485, root_position_accuracy: 0.4907, loss: 1.3350 ||:  14%|█▍        | 68/491 [00:07<00:32, 13.19it/s][A
action_type_accuracy: 0.8431, root_position_accuracy: 0.4804, loss: 1.3607 ||:  14%|█▍        | 70/491 [00:08<00:33, 12.68it/s][A
action_type_accuracy: 0.8400, root_position_accuracy: 0.4745, loss: 1.3856 ||:  15%|█▍        | 72/491 [00:08<00:32, 12.86it/s][A
action_type_accuracy: 0.8387, root_position_accuracy: 0.4722, loss: 1.4157 ||:  15%|█▌        | 74/491 [00:08<00:40, 10.22it/s][A
action_type_accuracy: 0.8383, root_position_accuracy: 0.4711, loss: 1.4351 ||:  15%

action_type_accuracy: 0.8472, root_position_accuracy: 0.4709, loss: 1.6288 ||:  28%|██▊       | 139/491 [00:16<00:57,  6.15it/s][A
action_type_accuracy: 0.8485, root_position_accuracy: 0.4728, loss: 1.6255 ||:  29%|██▊       | 140/491 [00:16<00:58,  5.95it/s][A
action_type_accuracy: 0.8498, root_position_accuracy: 0.4750, loss: 1.6218 ||:  29%|██▊       | 141/491 [00:16<00:56,  6.23it/s][A
action_type_accuracy: 0.8500, root_position_accuracy: 0.4756, loss: 1.6122 ||:  29%|██▉       | 143/491 [00:17<00:46,  7.49it/s][A
action_type_accuracy: 0.8474, root_position_accuracy: 0.4742, loss: 1.6129 ||:  30%|██▉       | 145/491 [00:17<00:41,  8.40it/s][A
action_type_accuracy: 0.8477, root_position_accuracy: 0.4736, loss: 1.6069 ||:  30%|██▉       | 147/491 [00:17<00:37,  9.15it/s][A
action_type_accuracy: 0.8478, root_position_accuracy: 0.4723, loss: 1.6022 ||:  30%|███       | 149/491 [00:17<00:34,  9.83it/s][A
action_type_accuracy: 0.8478, root_position_accuracy: 0.4712, loss: 1.5980 |

action_type_accuracy: 0.8635, root_position_accuracy: 0.4945, loss: 1.4816 ||:  51%|█████     | 249/491 [00:27<00:33,  7.25it/s][A
action_type_accuracy: 0.8634, root_position_accuracy: 0.4946, loss: 1.4830 ||:  51%|█████     | 250/491 [00:27<00:32,  7.31it/s][A
action_type_accuracy: 0.8634, root_position_accuracy: 0.4943, loss: 1.4857 ||:  51%|█████     | 251/491 [00:28<00:33,  7.26it/s][A
action_type_accuracy: 0.8635, root_position_accuracy: 0.4954, loss: 1.4853 ||:  51%|█████▏    | 252/491 [00:28<00:32,  7.35it/s][A
action_type_accuracy: 0.8640, root_position_accuracy: 0.4962, loss: 1.4852 ||:  52%|█████▏    | 253/491 [00:28<00:32,  7.35it/s][A
action_type_accuracy: 0.8645, root_position_accuracy: 0.4971, loss: 1.4844 ||:  52%|█████▏    | 254/491 [00:28<00:31,  7.48it/s][A
action_type_accuracy: 0.8652, root_position_accuracy: 0.4981, loss: 1.4835 ||:  52%|█████▏    | 255/491 [00:28<00:31,  7.55it/s][A
action_type_accuracy: 0.8658, root_position_accuracy: 0.4988, loss: 1.4831 |

KeyboardInterrupt: 

In [None]:
token_state = [[26, True, 'H', [[23, True, 'A', [[0, True, 'E', [[0, False, 'The', []]]], [1, True, 'C', [[1, False, 'Lakers', []]]]]], [2, True, 'P', [[2, False, 'advanced', []]]], [25, True, 'A', [[3, True, 'R', [[3, False, 'through', []]]], [4, True, 'E', [[4, False, 'the', []]]], [24, True, 'P', [[5, True, 'T', [[5, False, '1982', []]]], [6, True, 'C', [[6, False, 'playoffs', []]]]]]]]]], [7, True, 'L', [[7, False, 'and', []]]], [8, True, 'P', [[8, False, 'faced', []]]], [9, True, 'A', [[9, False, 'Philadelphia', []]]], [27, True, 'D', [[10, True, 'R', [[10, False, 'for', []]]], [11, True, 'E', [[11, False, 'the', []]]], [12, True, 'Q', [[12, False, 'second', []]]], [13, True, 'C', [[13, False, 'time', []]]]]], [14, True, 'R', [[14, False, 'in', []]]], [15, True, 'Q', [[15, False, 'three', []]]], [16, False, 'years', []]]

In [None]:
pprint.pprint(token_state)

In [None]:
vocab.get_token_from_index(0, namespace='labels')

In [None]:
vocab.get_token_index('RESOLVE')

In [None]:
vocab.get_token_from_index(3, namespace='resolved')

In [None]:
vocab.get_token_from_index(5, namespace='token_node_prev_action')