# Tree

In [117]:
from typing import List
from pprint import pprint
from operator import add
from functools import reduce
from collections import Counter

import pandas as pd

from new_semantic_parsing import TopSchemaTokenizer

In [44]:
LBR = '['
RBR = ']'
IN = 'IN:'
SL = 'SL:'

In [165]:
class Tree:
    def __init__(self, entity, subtrees: List = None):
        self.entity = entity
        self.subtrees = subtrees
        if subtrees is None:
            self.subtrees = []

        # for per-class metrics
        self._counts = Counter([entity])
        self._len = 1

        if len(self.subtrees) > 0:
            self._len += sum(map(len, self.subtrees))
            self._counts += reduce(add, (s._counts for s in self.subtrees))

        self._dict_repr = {self.entity: [s._dict_repr for s in self.subtrees]}

    def __repr__(self):
        return repr(self._dict_repr)

    def __eq__(self, other):
        if isinstance(other, dict):
            return self._dict_repr == other
        if isinstance(other, Tree):
            return self._dict_repr == other._dict_repr
        raise ValueError(type(other))
    
    def __len__(self):
        return self._len
    
    @property
    def counts(self):
        return self._counts

    @classmethod
    def from_tokens(cls, tokens, return_index=False):
        """Builds a parsing tree for labeled bracketing score computation.

        Args:
            tokens: list of tokens
            return_index: used in recursion to provide toke index

        Returns:
            tuple of size two: Tree, last_index
        """
        # every tree should start with
        # [ ENTITY_TYPE: ENTITY
        if len(tokens) < 3 or tokens[0] != LBR:
            raise ValueError(f'Tree starts with {tokens[:4]}')

        entity_type = tokens[1]

        # ignore invalid subtrees
        if entity_type not in [IN, SL]:
            return None

        entity = entity_type + tokens[2]  # e.g. IN:INTENT

        subtrees = []
        slot_value_tokens = []

        i = 3
        while i < len(tokens):
            token = tokens[i]

            if entity_type == IN and token not in [LBR, RBR]:
                i += 1
                continue

            if token == LBR:
                subtree, j = cls.from_tokens(tokens[i:], return_index=True)
                subtrees.append(subtree)
                i += j
                continue

            if token == RBR:
                if slot_value_tokens:
                    subtrees = [Tree(' '.join(slot_value_tokens))]
                    slot_value_tokens = []
                i += 1
                break

            if entity_type == SL:
                slot_value_tokens.append(token)
                i += 1
                continue

        tree = Tree(entity, subtrees)
                
        if return_index:
            return tree, i

        return tree
    


In [166]:
test_case_1 = {
    'input': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, RBR],
    'output': Tree(IN + 'INTENT1', [Tree(SL + 'SLOT1', [Tree('slot value')])])
}

test_case_2  = {
    'input': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, 'more', 'text', LBR, SL, 'SLOT2', 'slot2', 'value', RBR, RBR],
    'output': {IN + 'INTENT1': [{SL + 'SLOT1': [Tree('slot value')]}, {SL + 'SLOT2': [Tree('slot2 value')]}]}
}

test_case_3  = {
    'input': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, 'more', 'text', LBR, SL, 'SLOT1', 'slot2', 'value', RBR, RBR],
    'output': {IN + 'INTENT1': [{SL + 'SLOT1': [Tree('slot value')]}, {SL + 'SLOT1': [Tree('slot2 value')]}]}  # this is why you should use lists and not sets/dicts
}

test_case_4  = {
    'input': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, 'more', 'text', LBR, SL, 'SLOT1'],
    'output': {IN + 'INTENT1': [{SL + 'SLOT1': [Tree('slot value')]}, {SL + 'SLOT1': [Tree('slot2 value')]}]}  # this is why you should use lists and not sets/dicts
}

In [167]:
tree = Tree.from_tokens(test_case_1['input'])
print(tree)
print(len(tree))
print(tree.counts)

assert tree == test_case_1['output']

{'IN:INTENT1': [{'SL:SLOT1': [{'slot value': []}]}]}
3
Counter({'IN:INTENT1': 1, 'SL:SLOT1': 1, 'slot value': 1})


In [168]:
tree = Tree.from_tokens(test_case_2['input'])
print(tree)
print(len(tree))
print(tree.counts)

assert tree == test_case_2['output']

{'IN:INTENT1': [{'SL:SLOT1': [{'slot value': []}]}, {'SL:SLOT2': [{'slot2 value': []}]}]}
5
Counter({'IN:INTENT1': 1, 'SL:SLOT1': 1, 'slot value': 1, 'SL:SLOT2': 1, 'slot2 value': 1})


In [169]:
tree = Tree.from_tokens(test_case_3['input'])
print(tree)
print(len(tree))
print(tree.counts)

assert tree == test_case_3['output']

{'IN:INTENT1': [{'SL:SLOT1': [{'slot value': []}]}, {'SL:SLOT1': [{'slot2 value': []}]}]}
5
Counter({'SL:SLOT1': 2, 'IN:INTENT1': 1, 'slot value': 1, 'slot2 value': 1})


In [172]:
tree = Tree.from_tokens(test_case_4['input'])
print(tree)
print(len(tree))
print(tree.counts)

{'IN:INTENT1': [{'SL:SLOT1': [{'slot value': []}]}, {'SL:SLOT1': []}]}
4
Counter({'SL:SLOT1': 2, 'IN:INTENT1': 1, 'slot value': 1})


In [173]:
data = pd.read_table('../data/top-dataset-semantic-parsing/eval.tsv', names=['text', 'tokens', 'schema'])

tokenized_schema = [TopSchemaTokenizer.tokenize(t) for t in data.schema]

In [203]:
i = 10

print(tokenized_schema[i])
print(Tree.from_tokens(tokenized_schema[i]))

['[', 'IN:', 'GET_EVENT', 'Anything', '[', 'SL:', 'DATE_TIME', 'this', 'weekend', ']', 'for', '[', 'SL:', 'ATTRIBUTE_EVENT', 'families', 'with', 'small', 'children', ']', ']']
{'IN:GET_EVENT': [{'SL:DATE_TIME': [{'this weekend': []}]}, {'SL:ATTRIBUTE_EVENT': [{'families with small children': []}]}]}


In [175]:
complex_example = (
    '[IN:GET_EVENT Are there any '
        '[SL:CATEGORY_EVENT Concerts ] at '
        '[SL:LOCATION [IN:GET_LOCATION [SL:POINT_ON_MAP Chattaqua Amphitheater ] ] ] '
        '[SL:DATE_TIME this weekend ] with available tickets ]'
)

complex_example_tokens = TopSchemaTokenizer.tokenize(complex_example)
complex_tree = Tree.from_tokens(complex_example_tokens)
pprint(complex_tree._dict_repr)

{'IN:GET_EVENT': [{'SL:CATEGORY_EVENT': [{'Concerts': []}]},
                  {'SL:LOCATION': [{'IN:GET_LOCATION': [{'SL:POINT_ON_MAP': [{'Chattaqua Amphitheater': []}]}]}]},
                  {'SL:DATE_TIME': [{'this weekend': []}]}]}


# Metrics

In [176]:
test_case_1 = {
    'true': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, RBR],
    'pred': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, RBR],
    'f1': 1,
    'precision': 1,
    'recall': 1,
}

test_case_2 = {
    'true': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, RBR],
    'pred': [LBR, IN, 'INTENT2', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, RBR],
    'f1': 0,
    'precision': 0,
    'recall': 0,
}

test_case_3 = {
    'true': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, RBR],
    'pred': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT2', 'slot', 'value', RBR, RBR],
    'f1': 0.5,
    'precision': 0.5,
    'recall': 0.5,
}

test_case_4 = {
    'true': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, RBR],
    'pred': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, LBR, SL, 'SLOT2', 'value', RBR, RBR],
    'f1': 2/3.,
    'precision': 3/4.,
    'recall': 1,
}

test_case_5 = {
    'true': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'value', RBR, RBR],
    'pred': [LBR, IN, 'INTENT1', 'text', LBR, SL, 'SLOT1', 'slot', 'wrong value', RBR, RBR],
    'f1': 2/3.,
    'precision': 2/3.,
    'recall': 2/3.,
}

In [177]:
def f1(p, r):
    return 2 * p * r / (p + r)

In [178]:
tree1 = Tree.from_tokens(test_case_1['true'])
tree2 = Tree.from_tokens(test_case_1['pred'])

print(tree1)
print(tree2)

{'IN:INTENT1': [{'SL:SLOT1': [{'slot value': []}]}]}
{'IN:INTENT1': [{'SL:SLOT1': [{'slot value': []}]}]}


In [192]:
def labeled_bracketing_recall(pred_tokens, true_tokens):
    """Compute recall labeling bracketng score"""
    pred_tree = Tree.from_tokens(pred_tokens)
    true_tree = Tree.from_tokens(true_tokens)

    true_positive, false_negative = 0, 0

    if pred_tree.entity != true_tree.entity:
        false_negative += 1
    else:
        true_positive += 1
        tp, fn = _labeled_bracketing_tp_fn(pred_tree.subtrees, true_tree.subtrees)

        true_positive += tp
        false_negative += fn
    
    recall = true_positive / (true_positive + false_negative)
    return recall


def labeled_bracketing_precision(pred_tokens, true_tokens):
    """Compute precision labeling bracketng score"""
    pred_tree = Tree.from_tokens(pred_tokens)
    true_tree = Tree.from_tokens(true_tokens)

    true_positive, false_positive = 0, 0

    if pred_tree.entity != true_tree.entity:
        false_positive += 1
    else:
        true_positive += 1
        tp, fp = _labeled_bracketing_tp_fp(pred_tree.subtrees, true_tree.subtrees)

        true_positive += tp
        false_positive += fp

    recall = true_positive / (true_positive + false_positive)
    return recall


def _labeled_bracketing_tp_fn(pred_subtrees: List[Tree], true_subtrees: List[Tree]):
    """Compute true positive and false negative labeling bracketng scores"""
    true_positive, false_negative = 0, 0

    for i, true_tree in enumerate(true_subtrees):
        correct_subtree_indices = [i for i, pred_tree in enumerate(pred_subtrees) if pred_tree.entity == true_tree.entity]

        if len(correct_subtree_indices) == 0:
            false_negative += 1
        else:
            true_positive += 1
            
            for pred_subtree_idx in correct_subtree_indices:
                pred_tree = pred_subtrees[pred_subtree_idx]
                tp, fn = _labeled_bracketing_tp_fn(pred_tree.subtrees, true_tree.subtrees)

                true_positive += tp
                false_negative += fn            

    return true_positive, false_negative


def _labeled_bracketing_tp_fp(pred_subtrees: List[Tree], true_subtrees: List[Tree]):
    """Compute true positive and false positive labeling bracketng scores"""
    return _labeled_bracketing_tp_fn(true_subtrees, pred_subtrees)

In [193]:
test_case = test_case_2

for i, test_case in enumerate([test_case_1, test_case_2, test_case_3, test_case_4, test_case_5]):

    recall = labeled_bracketing_recall(test_case['pred'], test_case['true'])

    if recall == test_case['recall']:
        print(f'test_case_{i+1} passed. Computed recall: {recall}')
    else:
        print(f'\t test_case_{i+1} FAILED. Computed recall: {recall}')


test_case_1 passed. Computed recall: 1.0
test_case_2 passed. Computed recall: 0.0
test_case_3 passed. Computed recall: 0.5
test_case_4 passed. Computed recall: 1.0
test_case_5 passed. Computed recall: 0.6666666666666666


In [194]:
for i, test_case in enumerate([test_case_1, test_case_2, test_case_3, test_case_4, test_case_5]):

    precision = labeled_bracketing_precision(test_case['pred'], test_case['true'])

    if precision == test_case['precision']:
        print(f'test_case_{i+1} passed. Computed precision: {precision}')
    else:
        print(f'\t test_case_{i+1} FAILED. Computed precision: {precision}')


test_case_1 passed. Computed precision: 1.0
test_case_2 passed. Computed precision: 0.0
test_case_3 passed. Computed precision: 0.5
test_case_4 passed. Computed precision: 0.75
test_case_5 passed. Computed precision: 0.6666666666666666


## Compare with the official TOP evaluation tool

In [195]:
data_test = pd.read_table('../data/top-dataset-semantic-parsing/test.tsv', names=['text', 'tokens', 'schema'])
data_pred = pd.read_table('../lightning_out/jul8_20epochs_small/predictions.tsv', names=['schema'])

tokenized_schema_test = [TopSchemaTokenizer.tokenize(t) for t in data_test.schema]
tokenized_schema_pred = [TopSchemaTokenizer.tokenize(t) for t in data_pred.schema]

In [196]:
# TOP script gives the following metrics

{'instance_count': 9042,
 'exact_match': 0.25481088254810885,
 'labeled_bracketing_scores': {
     'precision': 0.6032053706505295,
     'recall': 0.3814007712312797,
     'f1': 0.46731984250526504
 },
 'tree_labeled_bracketing_scores': {
     'precision': 0.3943362329803328,
     'recall': 0.24933488775296686,
     'f1': 0.30550315905136893
 },
 'tree_validity': 0.9382879893828799}

{'instance_count': 9042,
 'exact_match': 0.25481088254810885,
 'labeled_bracketing_scores': {'precision': 0.6032053706505295,
  'recall': 0.3814007712312797,
  'f1': 0.46731984250526504},
 'tree_labeled_bracketing_scores': {'precision': 0.3943362329803328,
  'recall': 0.24933488775296686,
  'f1': 0.30550315905136893},
 'tree_validity': 0.9382879893828799}

In [197]:
precisions = []
recalls = []
exact_match = 0

for pred, true in zip(tokenized_schema_pred, tokenized_schema_test):    
    pred_tree = Tree.from_tokens(pred)
    true_tree = Tree.from_tokens(true)

    if pred_tree == true_tree:
        exact_match += 1
    
    precision = labeled_bracketing_precision(pred, true)
    recall = labeled_bracketing_recall(pred, true)
    
    precisions.append(precision)
    recalls.append(recall)


In [198]:
print(true)
print(true_tree)

['[', 'IN:', 'GET_INFO_TRAFFIC', 'is', 'traffic', 'moving', 'on', '[', 'SL:', 'LOCATION', 'I', '-', '65', ']', ']']
{'IN:GET_INFO_TRAFFIC': [{'SL:LOCATION': [{'I - 65': []}]}]}


In [199]:
mean_precision = sum(precisions) / len(precisions)
mean_recall = sum(recalls) / len(recalls)
exact_match /= len(precisions)

print('Precision: ', mean_precision)
print('Recall   : ', mean_recall)
print('F1       : ', f1(mean_precision, mean_recall))
print('exact_match: ', exact_match)

Precision:  0.640802521534121
Recall   :  0.5737675240412504
F1       :  0.6054351126465458
exact_match:  0.2591240875912409


# New approach

In [200]:
def label_bracketing_scores(pred_trees, true_trees):
    true_positives = 0
    n_predicted = 0
    n_expected = 0
    
    for pred_tree, true_tree in zip(pred_trees, true_trees):
        n_predicted += len(pred_tree)
        n_expected += len(true_tree)

        if pred_tree.entity == true_tree.entity:
            true_positives += 1 + _tree_true_positive(pred_tree.subtrees, true_tree.subtrees)

    precision = true_positives / n_predicted
    recall = true_positives / n_expected

    f1 = 0
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    
    return {'LBS_precision': precision, 'LBS_recall': recall, 'LBS_F1': f1}


def _tree_true_positive(pred_subtrees, true_subtrees):
    true_positive = 0

    for i, true_tree in enumerate(true_subtrees):
        correct_subtree_indices = [i for i, pred_tree in enumerate(pred_subtrees) if pred_tree.entity == true_tree.entity]

        if len(correct_subtree_indices) == 0:
            continue
        
        true_positive += 1
            
        for pred_subtree_idx in correct_subtree_indices:
            pred_tree = pred_subtrees[pred_subtree_idx]
        
            tp = _tree_true_positive(pred_tree.subtrees, true_tree.subtrees)
            true_positive += tp
        
    return true_positive

In [201]:
for i, test_case in enumerate([test_case_1, test_case_2, test_case_3, test_case_4, test_case_5]):

    tree_true = Tree.from_tokens(test_case['true'])
    tree_pred = Tree.from_tokens(test_case['pred'])

    metrics = label_bracketing_scores([tree_pred], [tree_true])

    print(f'test_case_{i+1}:')
    print(metrics)
    print()

test_case_1:
{'LBS_precision': 1.0, 'LBS_recall': 1.0, 'LBS_F1': 1.0}

test_case_2:
{'LBS_precision': 0.0, 'LBS_recall': 0.0, 'LBS_F1': 0}

test_case_3:
{'LBS_precision': 0.3333333333333333, 'LBS_recall': 0.3333333333333333, 'LBS_F1': 0.3333333333333333}

test_case_4:
{'LBS_precision': 0.6, 'LBS_recall': 1.0, 'LBS_F1': 0.7499999999999999}

test_case_5:
{'LBS_precision': 0.6666666666666666, 'LBS_recall': 0.6666666666666666, 'LBS_F1': 0.6666666666666666}



In [202]:
pred_trees = [Tree.from_tokens(t) for t in tokenized_schema_pred]
true_trees = [Tree.from_tokens(t) for t in tokenized_schema_test]

metrics = label_bracketing_scores(pred_trees, true_trees)

print(metrics)

{'LBS_precision': 0.6405084598194851, 'LBS_recall': 0.441435314825186, 'LBS_F1': 0.5226575728511716}


Still a bit higher then the official implementation {'precision': 0.603, 'recall': 0.381, 'f1': 0.467},

# Per-class scores

In [191]:
def label_bracketing_scores_for_classes(pred_trees, true_trees, classes):
    """Compute label bracketing scores only considering slots, intents and values from classes."""
    true_positives = 0
    n_predicted = 0
    n_expected = 0

    for pred_tree, true_tree in zip(pred_trees, true_trees):
        n_predicted += len(pred_tree)
        n_expected += len(true_tree)

        if pred_tree.entity == true_tree.entity:
            true_positives += 1 + _tree_true_positive(pred_tree.subtrees, true_tree.subtrees)

    precision = 0 if n_predicted == 0 else true_positives / n_predicted
    recall = 0 if n_expected == 0 else true_positives / n_expected

    f1 = 0
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)

    return {'cLBS_precision': precision, 'cLBS_recall': recall, 'cLBS_F1': f1}


def _tree_true_positive_for_classes(pred_subtrees, true_subtrees, classes):
    true_positive = 0
    
    for i, true_tree in enumerate(true_subtrees):

        correct_subtree_indices = [i for i, pred_tree in enumerate(pred_subtrees) if pred_tree.entity == true_tree.entity]

        if len(correct_subtree_indices) == 0:
            continue
        
        if true_tree.entity in classes:
            true_positive += 1
            
        for pred_subtree_idx in correct_subtree_indices:
            pred_tree = pred_subtrees[pred_subtree_idx]
        
            tp = _tree_true_positive_for_classes(pred_tree.subtrees, true_tree.subtrees, classes)
            true_positive += tp
        
    return true_positive

In [None]:
true_trees = [Tree.from_tokens(t) for ]

for i, test_case in enumerate([test_case_1, test_case_2, test_case_3, test_case_4, test_case_5]):

    tree_true = Tree.from_tokens(test_case['true'])
    tree_pred = Tree.from_tokens(test_case['pred'])

    metrics = label_bracketing_scores([tree_pred], [tree_true])

    print(f'test_case_{i+1}:')
    print(metrics)
    print()

# Tree path score