### Bonus 2: Bi-LSTM-based Encoder


### 0. Import Necessary Libraries

In [3]:
import torch.nn as nn
import torch
from dep_utils import conll_reader, DependencyTree, DependencyEdge
import copy
from pprint import pprint
from collections import Counter, defaultdict
from typing import List, Dict, Tuple
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn.functional as F
import os
import numpy as np
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 1. Read Data and Generate Training Instances

In [5]:
print('In train.conll:')
with open('data/train.conll') as f:
    train_trees = list(conll_reader(f))
print(f'{len(train_trees)} trees read.')

print('In dev.conll:')
with open('data/dev.conll') as f:
    dev_trees = list(conll_reader(f))
print(f'{len(dev_trees)} trees read.')

print('In test.conll:')
with open('data/test.conll') as f:
    test_trees = list(conll_reader(f))
print(f'{len(test_trees)} trees read.')

In train.conll:
39832 trees read.
In dev.conll:
1700 trees read.
In test.conll:
2416 trees read.


#### State Class

- The top of stack is `stack[-1]`
- The front of buffer is `buffer[-1]`
- `deps` represents the currently found dependencies
  - It is a list of `(parent, child, relation)` triples, where `parent` and `child` are integer IDs and `relation` is a string (the dependency label).
- The `shift` methods moves the front of the buffer to the top of the stack

In [6]:
class State(object):
    def __init__(self, sentence=[]):
        self.stack = []
        self.buffer = []
        if sentence:
            self.buffer = list(reversed(sentence))
        self.deps = set()

    def shift(self):
        assert len(self.buffer) > 0
        self.stack.append(self.buffer.pop())

    def left_arc(self, label):
        assert len(self.stack) >= 2
        self.deps.add((self.stack[-1], self.stack[-2], label))
        self.stack.pop(-2)

    def right_arc(self, label):
        assert len(self.stack) >= 2
        self.deps.add((self.stack[-2], self.stack[-1], label))
        self.stack.pop(-1)

    def __repr__(self):
        return "({},{},{})".format(self.stack, self.buffer, self.deps)

#### Get training data from a dependency tree

In [7]:
class RootDummy(object):
    def __init__(self):
        self.head = None
        self.id = 0
        self.deprel = None
    def __repr__(self):
        return "<ROOT>"


def get_training_instances(dep_tree: DependencyTree) -> List[Tuple[State, Tuple[str, str]]]:
    deprels = dep_tree.deprels

    word_ids = list(deprels.keys())
    state = State(word_ids)
    state.stack.append(0) # ROOT

    childcount = defaultdict(int)
    for _, rel in deprels.items():
        childcount[rel.head] += 1

    seq = []
    while len(state.buffer) > 0 or len(state.stack) > 1:
        if state.stack[-1] == 0:
            seq.append((copy.deepcopy(state), ("shift", None)))
            state.shift()
            continue
        
        stack_top1 = deprels[state.stack[-1]]
        if state.stack[-2] == 0:
            stack_top2 = RootDummy()
        else:
            stack_top2 = deprels[state.stack[-2]]

        # Decide transition action
        ### START YOUR CODE ###
        try:
            if stack_top2.head == stack_top1.id : # Left-Arc, top1 -> top2
                childcount[stack_top1.id] -= 1
                seq.append((copy.deepcopy(state), ("left_arc", stack_top2.deprel)))
                state.left_arc(stack_top2.deprel)
            elif stack_top1.head == stack_top2.id and childcount[stack_top1.id] == 0: # Right-Arc, top2 -> top1
                childcount[stack_top2.id] -= 1
                seq.append((copy.deepcopy(state), ("right_arc", stack_top1.deprel)))
                state.right_arc(stack_top1.deprel)
            else: # Shift
                seq.append((copy.deepcopy(state), ("shift", None)))
                state.shift()
        except:
            return seq
        ### END YOUR CODE ###
    
    return seq

#### Build vocabulary


In [8]:
word2id = {}
pos2id = {}
word2id['<PAD>'] = 0
pos2id['<PAD>'] = 0
def get_vocabs(trees: List[DependencyTree]):
    for tree in trees:
        word = tree.words()
        pos = tree.pos()
        for w in word:
            if w is None:
                continue
            if w not in word2id:
                word2id[w] = len(word2id)
        for p in pos:
            if p is None:
                continue
            if p not in pos2id:
                pos2id[p] = len(pos2id)

In [9]:
get_vocabs(train_trees)
get_vocabs(dev_trees)
get_vocabs(test_trees)

word2id['<NULL>'] = len(word2id)
pos2id['<NULL>'] = len(pos2id)
word2id['<ROOT>'] = len(word2id)
pos2id['<ROOT>'] = len(pos2id)


print(f'word_vocab: {len(word2id)} words')
print(f'pos_vocab: {len(pos2id)} pos tags')

word_vocab: 46351 words
pos_vocab: 48 pos tags


##### Action Vocabulary

In [10]:
rel_vocab = {}

for t in train_trees+dev_trees+test_trees:
    for e in t.deprels.values():
        if e.deprel not in rel_vocab:
            rel_vocab[e.deprel] = len(rel_vocab)

# Test results
print('Total number fo unique relations:', len(rel_vocab))
print(rel_vocab.keys())

# You should expect to see the following output:
# Total number fo unique relations: 39
# {'nummod', 'root', 'nmod:tmod', 'nmod', 'punct', 'expl', 'auxpass', 'neg', 'nsubjpass', 'appos' ...

Total number fo unique relations: 39
dict_keys(['case', 'det', 'compound', 'nummod', 'nmod', 'punct', 'nmod:poss', 'amod', 'nsubj', 'dep', 'dobj', 'cc', 'conj', 'nsubjpass', 'acl', 'auxpass', 'advmod', 'root', 'ccomp', 'mark', 'xcomp', 'nmod:tmod', 'appos', 'nmod:npmod', 'aux', 'cop', 'neg', 'acl:relcl', 'advcl', 'mwe', 'det:predet', 'csubj', 'parataxis', 'compound:prt', 'iobj', 'expl', 'cc:preconj', 'discourse', 'csubjpass'])


In [11]:
# action vocab
action2id = {}
action2id[('shift',None)] = len(action2id)
for rel in rel_vocab.keys():
    if rel != 'root':
        action2id[("left_arc", rel)] = len(action2id)
        action2id[("right_arc", rel)] = len(action2id)
action2id[("right_arc", 'root')] = len(action2id)

In [12]:
len(action2id) # (39-1)*2 + 1(right_arc, root) + 1(shift, none) = 78

78

- For actual training step, you need to post-process the data to convert each relation tuple to an integer index. 
- We have 39 unique dependency relations in the data, including `ROOT`. Considering `ROOT` only appears as the head in a `right_arc` action, we have $(39-1)\times 2 + 1 = 77$ possible actions in total.

#### Reference

- https://aclanthology.org/Q16-1023/
- https://github.com/s-kill/Simple-and-Accurate-Dependency-Parsing-Using-Bidirectional-LSTM-Feature-Representations

#### BiLSTM Feature Extractor
- $x_i = e(w_i)⨁e(t_i)$
- $v_i = BiLSTM(x_i, i)$
- $input = v_{s_2}⨁v_{s_1}⨁v_{s_0}⨁v_{b_0}$


In [13]:
class LSTMFeatureExtractor():
    def __init__(self):
        print('LSTM FeatureExtractor')

    def get_input_representation(self, state):
        features = []
        # Index of vs2, vs1, vs0, bs0 in the sequence
        for s in range(-3, 0): # top 3 words on the stack
            if abs(s) <= len(state.stack):
                sw_id = state.stack[s]
                features.append(sw_id) # 0 is ROOT
            else:
                features.append(-1)

        b = -1 # top 1 word on the buffer
        if abs(b) <= len(state.buffer):
            bw_id = state.buffer[b]
            features.append(bw_id)
        else:
            features.append(-1)

        return torch.LongTensor(features).to(device)

    def get_output_representation(self, action):
        return torch.tensor(action2id[action], dtype=torch.long).to(device)


In [14]:
def process(dep_trees: List[DependencyTree],extractor: LSTMFeatureExtractor):
    feats = []
    outputs = []
    words = []
    pos = []
    feats = []
    for i, tree in enumerate(dep_trees):
        tree_words = tree.words()
        tree_pos = tree.pos()
        instances = get_training_instances(tree)
        if i % 1000 == 0:
            print(f'{i}/{len(dep_trees)}')
        # words and pos inputs for bilstm
        word_ids = []
        pos_ids = []
        for w in tree_words:
            if w is None: 
                # here the first element None is not in the inputs, when getting the hidden state h_i from BiLSTM
                # given an index i, use (i-1) since i is relative index in the word sequence.
                continue
            word_ids.append(word2id[w])
        for p in tree_pos:
            if p is None:
                continue
            pos_ids.append(pos2id[p])
        word_ids = torch.LongTensor(word_ids).to(device)
        pos_ids = torch.LongTensor(pos_ids).to(device)
        for state, action in instances:
            # convert to torch tensor
            words.append(word_ids) # variable length
            pos.append(pos_ids) # variable length
            feats.append(extractor.get_input_representation(state)) # fixed length
            outputs.append(extractor.get_output_representation(action)) # fixed length

    return words, pos, torch.stack(feats).to(device), torch.stack(outputs).to(device)

In [15]:
# Test the FeatureExtractor
w,p,f,l = process(train_trees[:2],LSTMFeatureExtractor())
w[0], p[0], f[0], l[0]

LSTM FeatureExtractor
0/2


(tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,  7, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 10, 26, 27, 28, 29, 25, 20, 30,  6, 31,
         25, 32, 33, 34, 35, 25, 36, 37, 38, 39, 40, 41, 42], device='cuda:0'),
 tensor([ 1,  2,  3,  4,  5,  1,  6,  2,  5,  7,  1,  3,  8,  3,  3,  9,  6, 10,
         11, 12,  2,  5,  1,  3,  3, 13,  7,  5, 14, 11, 15, 13,  2,  5,  1,  3,
         13, 10,  1,  3,  3, 13, 16, 17, 10, 18,  3,  3, 19], device='cuda:0'),
 tensor([-1, -1,  0,  1], device='cuda:0'),
 tensor(0, device='cuda:0'))

#### BiLSTM Oracle



In [30]:
word_dim = len(word2id)
pos_dim = len(pos2id)
word_emb_dim = 50
pos_emb_dim = 10
feature_len = 4
out_dim = len(action2id)
emb_dim = 50
lstm_hidden_dim = 30
mlp_hidden_dim = 100
word_dim, out_dim
# word_emb + pos_emb = 2 * lstm_hidden_dim

(46351, 78)

In [31]:
class BiLSTMOracle(nn.Module):
    def __init__(
            self,
            feature_len: int,
            word_dim: int,
            pos_dim: int,
            word_emb_dim: int,
            pos_emb_dim: int,
            out_dim: int,
            lstm_hidden_dim=50,
            mlp_hidden_dim=100
            # word_emb_dim + pos_emb_dim = 2 * lstm_hidden_dim, 因为feature要拼起来
            ):
        assert word_emb_dim+pos_emb_dim == 2*lstm_hidden_dim
        super(BiLSTMOracle, self).__init__()
        self.word_embedding = nn.Embedding(num_embeddings=word_dim, embedding_dim=word_emb_dim)
        self.pos_embedding = nn.Embedding(num_embeddings=pos_dim, embedding_dim=pos_emb_dim)
        self.bilstm = nn.LSTM(word_emb_dim + pos_emb_dim, lstm_hidden_dim, batch_first=True, bidirectional = True)
        # output: [batch_size, seq_len, lstm_hidden_dim]
        self.mlp = nn.Sequential(
            # nn.Flatten(),
            # bilstm output is: 2 (bidirection) * (word_emb_dim + pos_emb_dim) * 4 (3 on stack, 1 on buffer)
            nn.Linear(feature_len*(word_emb_dim + pos_emb_dim), mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, out_dim),
        )
        # softmax layer is calculated outside
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.word_embedding.weight.data.uniform_(-initrange, initrange)
        self.pos_embedding.weight.data.uniform_(-initrange, initrange)
        for layer in self.mlp:
            if isinstance(layer, nn.Linear):
                layer.weight.data.uniform_(-initrange, initrange)
                layer.bias.data.zero_()

    def forward(self, word_ids, pos_ids, features): # word_ids: [batch_size, seq_len], pos_ids: [batch_size, seq_len]
            word_emb = self.word_embedding(word_ids) # [batch_size, seq_len, 50]
            pos_emb = self.pos_embedding(pos_ids) # [batch_size, seq_len, 10]
            
            emb = torch.cat((word_emb, pos_emb), dim=2) # [batch_size, seq_len, 50 + 10]
            lstm_out, _ = self.bilstm(emb) # [batch_size, seq_len, 2 * lstm_hidden_dim]
            mlp_input = [] # [batch_size, 4, lstm_hidden_dim * 2]
            for i, row in enumerate(features):
                row_input = []
                for j, feat_idx in enumerate(row):
                    if feat_idx == -1:  # NULL
                        row_input.append(torch.cat((self.word_embedding(torch.tensor(word2id['<NULL>']).to(device)), self.pos_embedding(torch.tensor(pos2id['<NULL>']).to(device))), dim=-1))
                    elif feat_idx == 0:  # ROOT
                        row_input.append(torch.cat((self.word_embedding(torch.tensor(word2id['<ROOT>']).to(device)), self.pos_embedding(torch.tensor(pos2id['<ROOT>']).to(device))), dim=-1))
                    else:  # relative index in word sequence
                        # the first None is not in the inputs, there is 1 offet
                        # print(lstm_out[:, feat_idx - 1, :].shape) # [batch, lstm_hidden_dim]
                        # print(lstm_out[:, feat_idx - 1, :][i].shape) # [lstm hidden dim]
                        # hidden state at timestep feat_idx
                        row_input.append(lstm_out[:, feat_idx - 1, :][i])
                        # [batch_size, 2 * lstm_hidden_dim]
                mlp_input.append(torch.stack(row_input))
            
            mlp_input = torch.stack(mlp_input, dim=0) # [batch_size, 4, lstm_hidden_dim * 2]
            mlp_input = mlp_input.squeeze(1)
            mlp_input = mlp_input.view(-1,mlp_input.size(1)*mlp_input.size(2)) # 从二维特征调整为一维 [batch, 4 * 2 * lstm_hidden_dim]
            x = self.mlp(mlp_input) # [batch, num_class]
            return x

#### Process data 

In [17]:
train_words, train_pos, train_feats, train_label = process(train_trees[:1000], LSTMFeatureExtractor())
# this could take a while, using [:1000] as demonstration

LSTM FeatureExtractor
0/1000


### 3. Train and Evaluate

In [18]:
class Parser(object): 

    def __init__(self, model: BiLSTMOracle, extractor: LSTMFeatureExtractor):
        self.model = model
        self.extractor = extractor
        self.id2action = {v: k for k, v in action2id.items()}

    def parse_sentence(self, words, pos):
        state = State(range(1, len(words)))
        state.stack.append(0) # ROOT
        word_ids = []
        for w in words:
            if w is None:
                continue
            word_ids.append(word2id[w])
        word_ids = torch.LongTensor(word_ids).to(device)
        pos_ids = []
        for p in pos:
            if p is None:
                continue
            pos_ids.append(pos2id[p])
        pos_ids = torch.LongTensor(pos_ids).to(device)
        word_ids = word_ids.unsqueeze(0)
        pos_ids = pos_ids.unsqueeze(0)
        while len(state.buffer) > 0 or len(state.stack) > 1:
            feats = self.extractor.get_input_representation(state)
            feats = feats.unsqueeze(0)

            model_out = self.model.forward(word_ids, pos_ids, feats)
            probs = torch.softmax(model_out, dim=1)
            sorted_indices = torch.argsort(probs, dim=1, descending=True)
            sorted_indices = torch.squeeze(sorted_indices)
            for i in range(0, len(sorted_indices)): # might have illegal actions
                move, rel = self.id2action[sorted_indices[i].item()]
                if move == 'shift' and len(state.buffer) > 0:
                    state.shift()
                    break
                elif len(state.stack) >= 2:
                    if move == 'left_arc' and state.stack[-2] != 0 and rel != 'root':
                        state.left_arc(rel)
                        break
                    if move == 'right_arc':
                        state.right_arc(rel)
                        break

        result = DependencyTree()
        for h, c, r in state.deps: # head, child(dependent), relation
            result.add_deprel(DependencyEdge(c, words[c], pos[c], h, r))
        return result 
    
    # compare the predicted tree with the reference tree
    def compare_tree(self, ref_tree: DependencyTree, prediction: DependencyTree):
        # unlabeled does not care about the relation
        target_unlabeled = set((d.id,d.head) for d in ref_tree.deprels.values())
        target_labeled = set((d.id,d.head,d.deprel) for d in ref_tree.deprels.values())
        predict_unlabeled = set((d.id,d.head) for d in prediction.deprels.values())
        predict_labeled = set((d.id,d.head,d.deprel) for d in prediction.deprels.values())

        labeled_correct = len(predict_labeled.intersection(target_labeled))
        unlabeled_correct = len(predict_unlabeled.intersection(target_unlabeled))
        num_words = len(predict_labeled)
        return labeled_correct, unlabeled_correct, num_words 
        

In [19]:
def evaluate(dep_trees: List[DependencyTree], parser: Parser):
    total_labeled_correct = 0
    total_unlabeled_correct = 0
    total_words = 0
    count = 0 
    print("Evaluating.")
    for dtree in dep_trees:
        words = dtree.words()
        pos = dtree.pos()
        prediction = parser.parse_sentence(words, pos)
        labeled_correct, unlabeled_correct, num_words = parser.compare_tree(dtree, prediction)
        total_labeled_correct += labeled_correct
        total_unlabeled_correct += unlabeled_correct
        total_words += num_words
        count += 1 
        if count % 200 == 0:
            print(f'{count}/{len(dep_trees)}')

    las = total_labeled_correct / float(total_words)
    uas = total_unlabeled_correct / float(total_words)

    print(f"{len(dep_trees)} sentences.\n")
    print(f"Labeled Attachment Score: {las}\n")
    print(f"Unlabeled Attachment Score: {uas}")

In [21]:
batch_size = 256
epochs = 1
learning_rate = 0.001

In [22]:
def batchify(words, pos, feats, label, batch_size):
    for i in range(0, len(words), batch_size):
        batch_words = words[i:i+batch_size]
        batch_pos = pos[i:i+batch_size]
        batch_words = pad_sequence(batch_words, batch_first=True).to(device)
        batch_pos = pad_sequence(batch_pos, batch_first=True).to(device)
        
        yield batch_words, batch_pos, feats[i:i+batch_size], label[i:i+batch_size]
        

In [23]:
batches = list(batchify(train_words, train_pos, train_feats, train_label, batch_size))

In [24]:
len(batches), len(batches[0][0])

(191, 256)

In [35]:
# if there is cuNN error, rerun
model = BiLSTMOracle(feature_len=feature_len, word_dim=word_dim, pos_dim=pos_dim, word_emb_dim=word_emb_dim, pos_emb_dim=pos_emb_dim, out_dim=out_dim, lstm_hidden_dim=lstm_hidden_dim, mlp_hidden_dim=mlp_hidden_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [36]:
def train(
        model: BiLSTMOracle,
        optimizer: torch.optim.Optimizer,
        loss_function,
        # train_dataloader: DataLoader,
        batches: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
        log_interval=500,
        epochs: int = 2):
    
    model.train()
    dev_parser = Parser(model, LSTMFeatureExtractor())
    for epoch in range(epochs):
        total_loss = 0
        epoch_loss = 0
        # for id, (w, p ,f, batch_label) in enumerate(tqdm(train_dataloader)):
        for id, (w, p ,f, batch_label) in enumerate(tqdm(batches)):
            optimizer.zero_grad()
            output = model.forward(w,p,f)
            output = output.cpu()
            batch_label = batch_label.cpu()
            # output:[batch_size, num_classes]
            # label[batch_size]
            loss = loss_function(output, batch_label)
            total_loss += loss.item()
            epoch_loss += loss.item()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.2) # 防止梯度爆炸
            optimizer.step()

            if id % log_interval == 0 and id > 0:
                print(
                    "| epoch {:3d} | {:5d}/{:5d} batches "
                    "| loss {:8.4f}".format(
                        epoch, id, len(batches), loss
                    )
                )
                total_loss = 0

        print(f'Epoch {epoch}, loss: {epoch_loss/len(batches)}')
        print('--'*20)
        evaluate(dev_trees[:50], dev_parser)
        print('--'*20)
        epoch_loss = 0

In [37]:
train(model, optimizer, criterion, batches, epochs=epochs, log_interval=50)
torch.save(model, 'dep_model_lstm_demo.pt')
# this is just for training demonstration

LSTM FeatureExtractor


 27%|██▋       | 51/191 [00:14<00:38,  3.60it/s]

| epoch   0 |    50/  191 batches | loss   1.9169


 53%|█████▎    | 101/191 [00:28<00:21,  4.27it/s]

| epoch   0 |   100/  191 batches | loss   1.5207


 79%|███████▉  | 151/191 [00:40<00:11,  3.63it/s]

| epoch   0 |   150/  191 batches | loss   1.0705


100%|██████████| 191/191 [00:51<00:00,  3.70it/s]


Epoch 0, loss: 1.8020762579603344
----------------------------------------
Evaluating.
50 sentences.

Labeled Attachment Score: 0.2927659574468085

Unlabeled Attachment Score: 0.3727659574468085
----------------------------------------


#### Evaluation
The final model is trained using:
- `batch_size` = 512
- `epochs` = 2
- `learning_rate` = 0.001
- `word_emb_dim` = 50
- `pos_emb_dim` = 10
- `lstm_hidden_dim` = 30
- `mlp_hidden_dim` = 100

Note: word_emb + pos_emb = 2 * lstm_hidden_dim

which takes around 1 h:)


In [39]:
test_model = torch.load('dep_model_lstm_final.pt')
parser = Parser(test_model, LSTMFeatureExtractor())
evaluate(test_trees, parser)
# this takes around 3 mins.

LSTM FeatureExtractor
Evaluating.
200/2416
400/2416
600/2416
800/2416
1000/2416
1200/2416
1400/2416
1600/2416
1800/2416
2000/2416
2200/2416
2400/2416
2416 sentences.

Labeled Attachment Score: 0.8461294192364689

Unlabeled Attachment Score: 0.8716921882718227


Note: The LAS for baseline model is `0.733`

In [40]:
test_model = torch.load('dep_model_lstm_final.pt')
parser = Parser(test_model, LSTMFeatureExtractor())
evaluate(test_trees[:100], parser)
# for demonstration

LSTM FeatureExtractor
Evaluating.
100 sentences.

Labeled Attachment Score: 0.8252933507170795

Unlabeled Attachment Score: 0.8609300304215558
