In [None]:
import numpy as np
from tqdm.auto import tqdm, trange

from dependency_utils import *

In [None]:
TRAIN_PATH = r"./data/train.txt"
TEST_PATH = r"./data/test.txt"

train_loader = Loader(TRAIN_PATH)
test_loader = Loader(TEST_PATH)
set_all(train_loader)
print(f'{all_pos_tags = }')
print(f'{all_dependency_labels = }')

Given stack, buffer, and action sequence get the features for the current state of the parser.

```text
1. TOP:        top(S) token 
2. TOP.POS:    POS-Tag of top(S) 
3. TOP.DEP:    DEP-Tag for top(S) -> head(top(S)) or head(top(S)) <- top(S) € A (only if this has been seen so far)
4: TOP.LDEP:   DEP-Tag for the left-most  w <- top(S) € A (only if this has been seen so far) 
5. TOP.RDEP:   DEP-Tag for the right-most top(S) -> w € A (only if this has been seen so far)
6. FIRST:      first(B) token
7. FIRST.POS:  POS-Tag of first(B)
8. FIRST.LDEP: DEP-Tag for the left-most  w <- first(B) € A (only if this has been seen so far)
9. LOOK.POS:   POS-Tag of first(B — {first(B)})
```


In [None]:
actions = {
    'LEFT-ARC': 0,
    'RIGHT-ARC': 1,
    'REDUCE': 2,
    'SHIFT': 3,
}
actions_list = ['LEFT-ARC', 'RIGHT-ARC', 'REDUCE', 'SHIFT']

In [None]:
def to_1_hot(
    item: Optional[str],
    all_items: List[str]
):
    if item is None:
        return np.zeros(len(all_items), dtype=np.int8)
    return np.array([item == x for x in all_items], dtype=np.int8)


def get_test_dep(
    from_word: Word,
    to_word: Word,
) -> Optional[str]:
    from_pos = from_word.pos_tag
    to_pos = to_word.pos_tag

    return test_pos_pair_dep_relations.get((from_pos, to_pos))


def get_config(
    stack: List[Word],
    buffer: List[Word],
    seen_list: List[Word],
    arcs: List[Tuple[int, int]],
    test: bool = False
):
    """
    Returns the configuration of the stack and buffer

    Args:
        stack: The stack
        buffer: The buffer
        seen_list: All words in the sentence that have been seen so far [sentence - buffer]
        action_sequence: List of tuples (from_idx, to_idx)

    Returns:
    (2 * V + 3 * P + 4 * R) dimensional vector
        - where P is the set of all POS Tags.
        - V is the set of all words.
        - R is the set of all dependency relations.

    All configurations are represented as 1 hot vectors concatenated together
    """

    top = stack[-1] if stack else None
    top_pos = top.pos_tag if top else None

    first = buffer[0] if buffer else None
    first_pos = first.pos_tag if first else None

    look_pos = buffer[1].pos_tag if len(buffer) > 1 else None

    top_dep: Optional[str] = None
    top_ldep: Optional[str] = None
    top_rdep: Optional[str] = None

    first_ldep: Optional[str] = None

    if top:
        # set top_dep
        for word in seen_list:
            if (word.word_index, top.word_index) in arcs:
                if test:
                    top_dep = get_test_dep(word, top)
                else:
                    top_dep = top.dependency_label
                break

        # set top_ldep
        for word in seen_list:
            if (top.word_index, word.word_index) in arcs:
                if test:
                    top_ldep = get_test_dep(top, word)
                else:
                    top_ldep = word.dependency_label
                break

        # set top_rdep
        for word in reversed(seen_list):
            if (top.word_index, word.word_index) in arcs:
                if test:
                    top_rdep = get_test_dep(top, word)
                else:
                    top_rdep = word.dependency_label
                break

    if first:
        # set first_ldep
        for word in seen_list:
            if (first.word_index, word.word_index) in arcs:
                if test:
                    first_ldep = get_test_dep(first, word)
                else:
                    first_ldep = word.dependency_label
                break

    top = top.norm_word if top else None
    first = first.norm_word if first else None

    return np.concatenate([
        to_1_hot(top, all_words),
        to_1_hot(top_pos, all_pos_tags),
        to_1_hot(top_dep, all_dependency_labels),
        to_1_hot(top_ldep, all_dependency_labels),
        to_1_hot(top_rdep, all_dependency_labels),
        to_1_hot(first, all_words),
        to_1_hot(first_pos, all_pos_tags),
        to_1_hot(first_ldep, all_dependency_labels),
        to_1_hot(look_pos, all_pos_tags),
    ])

Oracle for the parser is as follows:

```text
1. If first(B) -> top(S) € D and * -> top(S) not € A
    then LEFT-ARC
2. Else if top(S) -> first(B) € D
    then RIGHT-ARC
3. Else if * -> top(S) € A and there exists w € S, w != top(S) 
        such that (w -> first(B) € D or first(B) -> w € D)
    then REDUCE
4. Else SHIFT
```


In [None]:
def oracle(
    stack: List[Word],                      # The stack S
    buffer: List[Word],                     # The buffer B
    arcs: List[Tuple[int, int]],            # The arcs A
    gold_arcs: List[Tuple[int, int]],       # The full dependency tree D
):
    """
    Returns the next action to be taken by the parser
    """

    top = stack[-1] if stack else None
    first = buffer[0] if buffer else None

    if top and first:
        # LEFT-ARC
        if ((first.word_index, top.word_index) in gold_arcs and
            not any([to_idx == top.word_index
                     for (from_idx, to_idx) in arcs])):
            return actions['LEFT-ARC']

        # RIGHT-ARC
        if (top.word_index, first.word_index) in gold_arcs:
            return actions['RIGHT-ARC']

        # REDUCE
        if (any([to_idx == top.word_index
                for (from_idx, to_idx) in arcs]) and
            any([(word.word_index, first.word_index) in gold_arcs or
                 (first.word_index, word.word_index) in gold_arcs
                 for word in stack[:-1]])):
            return actions['REDUCE']
    # END if top

    return actions['SHIFT']

Learn and inference the dependency parsing of a sentence

In [None]:
def learn_weights(
    train_loader: Loader,
    test_loader: Loader,
    epochs: int,
):
    weights = np.zeros((2 * len(all_words) +
                        3 * len(all_pos_tags) +
                        4 * len(all_dependency_labels), 4), dtype=np.int8)

    max_uas_score = 0.0

    for _ in trange(epochs, desc='Epochs'):
        total_preds = 0
        correct_preds = 0

        for sentence in tqdm(train_loader, desc='Training Set'):
            # train config
            stack: List[Word] = []
            buffer: List[Word] = sentence.copy().words
            seen_list: List[Word] = []
            arcs: List[Tuple[int, int]] = []

            gold_arcs = [(word.head_index, word.word_index)
                         for word in sentence]

            while buffer:
                config = get_config(stack=stack,
                                    buffer=buffer,
                                    seen_list=seen_list,
                                    arcs=arcs)

                oracle_action = oracle(stack=stack,
                                       buffer=buffer,
                                       arcs=arcs,
                                       gold_arcs=gold_arcs)

                if stack:
                    # find the best action
                    for i in range(weights.shape[1]):
                        cost = config @ weights[:, i]

                        if i == 0 or cost > max_cost:
                            max_cost = cost.item()
                            action = i
                else:
                    action = actions['SHIFT']

                total_preds += 1
                # update weights
                if action != oracle_action:
                    weights[:, action] -= config.T
                    weights[:, oracle_action] += config.T
                else:
                    correct_preds += 1

                # perform oracle action
                if oracle_action == actions['LEFT-ARC']:
                    arcs.append((buffer[0].word_index, stack[-1].word_index))
                    stack.pop()
                elif oracle_action == actions['RIGHT-ARC']:
                    arcs.append((stack[-1].word_index, buffer[0].word_index))
                    stack.append(buffer.pop(0))
                    seen_list.append(stack[-1])
                elif oracle_action == actions['REDUCE']:
                    stack.pop()
                elif oracle_action == actions['SHIFT']:
                    stack.append(buffer.pop(0))
                    seen_list.append(stack[-1])
            # END while buffer
        # END for sentence in tqdm(loader)

        print('Training Set Results:')
        print(f'{correct_preds = }')
        print(f'{total_preds = }')
        print(f'{correct_preds / total_preds = }\n')

        _, _, uas = get_test_preds(test_loader, weights)

        if uas > max_uas_score:
            max_uas_score = uas
            np.save('dependency_model_on.npy', weights)
            print(f'{max_uas_score = }\n\n')

    # END for _ in trange(epochs)

    return weights
# END learn_weights


def get_test_preds(
    loader: Loader,
    weights: np.ndarray,
    save: bool = False,
):
    correct_arcs = 0
    total_arcs = 0

    for sentence in tqdm(loader, desc='Testing Set'):
        stack: List[Word] = []
        buffer: List[Word] = sentence.copy().words
        seen_list: List[Word] = []
        arcs: List[Tuple[int, int]] = []

        gold_arcs = [(word.head_index, word.word_index)
                     for word in sentence]

        while buffer:
            config = get_config(stack=stack,
                                buffer=buffer,
                                seen_list=seen_list,
                                arcs=arcs,
                                test=True,)

            action: int
            max_cost: int

            # find the best action
            for i in range(weights.shape[1]):
                cost = config @ weights[:, i]

                if i == 0 or cost > max_cost:
                    max_cost = cost.item()
                    action = i

            # perform action
            if stack:
                # if head(top(S)) has been seen then REDUCE
                # else LEFT-ARC
                if action == actions['LEFT-ARC']:
                    if not any([to_idx == stack[-1].word_index
                                for (from_idx, to_idx) in arcs]):
                        arcs.append((buffer[0].word_index,
                                     stack[-1].word_index))
                    stack.pop()
                elif action == actions['RIGHT-ARC']:
                    arcs.append((stack[-1].word_index,
                                 buffer[0].word_index))
                    stack.append(buffer.pop(0))
                    seen_list.append(stack[-1])

                # if head(top(S)) has been seen then REDUCE
                # else LEFT-ARC
                elif action == actions['REDUCE']:
                    if any([to_idx == stack[-1].word_index
                            for (from_idx, to_idx) in arcs]):
                        arcs.append((buffer[0].word_index,
                                     stack[-1].word_index))
                    stack.pop()
                else:
                    stack.append(buffer.pop(0))
                    seen_list.append(stack[-1])
            else:
                stack.append(buffer.pop(0))
                seen_list.append(stack[-1])
        # END while buffer

        total_arcs += len(gold_arcs)
        correct_arcs += sum([arc in arcs for arc in gold_arcs])
        
        if save:
            with open('dependency_predictions_on.tsv',"a",
                      encoding='utf-8') as f:
                for arc in arcs:
                    f.write(f'{sentence.sent_id}\t')
                    f.write(f'{arc[1]+1}\t')
                    f.write(f'{sentence[arc[1]].word}\t')
                    f.write(f'{arc[0]+1}\n')
    # END for sentence in tqdm(loader)

    print('Testing Set Results:')
    print(f'{correct_arcs = }')
    print(f'{total_arcs = }')
    print(f'{correct_arcs / total_arcs = }\n')

    return correct_arcs, total_arcs, correct_arcs / total_arcs
# END get_test_preds

Train the model using the given training data

In [None]:
EPOCHS = 10

weights = learn_weights(train_loader,
                        test_loader,
                        EPOCHS)


In [None]:
# load weights
weights = np.load('dependency_model_on.npy')

correct_preds, total_preds, uas_score = get_test_preds(
    loader=test_loader,
    weights=weights,
    save=True
)