## CS310 Natural Language Processing
## Assignment 4. Dependency Parsing

**Total points**: 50

In this assignment, you will train feed-forward neural network-based dependency parser and evaluate its performance on the provided treebank dataset.

### 0. Import Necessary Libraries

In [17]:
import torch
import torch.nn.functional as F
import torch.nn as nn

from dep_utils import conll_reader, DependencyTree
import copy
from pprint import pprint
from collections import Counter, defaultdict
from typing import List, Dict, Tuple

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import numpy as np

### 1. Read Data and Generate Training Instances

In [18]:
print('In train.conll:')
with open('valid_data/valid_train.conll') as f:
    train_trees = list(conll_reader(f))
print(f'{len(train_trees)} trees read.')

print('In dev.conll:')
with open('valid_data/valid_dev.conll') as f:
    dev_trees = list(conll_reader(f))
print(f'{len(dev_trees)} trees read.')

print('In test.conll:')
with open('valid_data/valid_test.conll') as f:
    test_trees = list(conll_reader(f))
print(f'{len(test_trees)} trees read.')


In train.conll:
39712 trees read.
In dev.conll:
1695 trees read.
In test.conll:
2408 trees read.


In [19]:
class RootDummy(object):
    def __init__(self):
        self.head = None
        self.id = 0
        self.deprel = '<ROOT_rel>'
        self.word='<ROOT>'
        self.pos='<ROOT_POS>'
    def __repr__(self):
        return "<ROOT>"

class NullDummy(object):
    def __init__(self):
        self.head = None
        self.id = -1
        self.deprel = '<NULL_rel>'
        self.word='<NULL>'
        self.pos='<NULL_POS>'
    def __repr__(self):
        return "<NULL>"

In [20]:
class State(object):
    def __init__(self, sentence=[]):
        self.stack = []
        self.buffer = []
        if sentence:
            self.buffer = list(reversed(sentence))
        self.deps = set()

    def shift(self):
        ### START YOUR CODE ###
        assert self.buffer
        self.stack.append(self.buffer.pop())
        ### END YOUR CODE ###

    def left_arc(self, label: str):
        assert len(self.stack) >= 2
        ### START YOUR CODE ###
        dependent = self.stack.pop(-2)
        head = self.stack[-1]
        self.deps.add((head, dependent, label))
        ### END YOUR CODE ###

    def right_arc(self, label: str):
        assert len(self.stack) >= 2
        ### START YOUR CODE ###
        dependent = self.stack.pop()
        head = self.stack[-1]
        self.deps.add((head, dependent, label))
        ### END YOUR CODE ###

    def __repr__(self):
        return "({},{},{})".format(self.stack, self.buffer, self.deps)

In [21]:
def build_vocab():
    word_vocab = {}
    pos_vocab = {}
    rel_vocab={}
    for trees in [train_trees, dev_trees, test_trees]:
        for tree in trees:
            # Assuming tokens is a list of words in the DependencyTree object
            for deprel in tree.deprels.values():
                word= deprel.word
                pos= deprel.pos
                rel= deprel.deprel
                if word not in word_vocab:
                    word_vocab[word] = len(word_vocab) + 1  # Assign a unique ID to each word
                if pos not in pos_vocab:
                    pos_vocab[pos] = len(pos_vocab) + 1  # Assign a unique ID to each word
                if rel not in rel_vocab:
                    rel_vocab[rel] = len(rel_vocab) + 1  # Assign a unique ID to each word
        # Add special tokens <ROOT> and <NULL>
    print(len(word_vocab))
    word_vocab["<NULL>"] = len(word_vocab) + 1
    pos_vocab['<NULL_POS>']=len(pos_vocab) + 1
    rel_vocab['<NULL_rel>']=len(rel_vocab) + 1
    word_vocab['<ROOT>']=0
    pos_vocab['<ROOT_POS>']=0
    rel_vocab['<ROOT_rel>']=0
    action_vocab={}
    ### START YOUR CODE ###
    for trees in [train_trees, dev_trees, test_trees]:
        for deps in trees:
            for deprel in deps.deprels.values():
                if deprel.deprel !='root':
                    key1='L<'+deprel.deprel+'>'
                    key2='R<'+deprel.deprel+'>'
                    if key1 not in action_vocab:
                        action_vocab[key1]=len(action_vocab)+1
                    if key2 not in action_vocab:
                        action_vocab[key2]=len(action_vocab)+1
    action_vocab['shift']=len(action_vocab)+1
    action_vocab["R<root>"]=0
    return word_vocab,pos_vocab,rel_vocab ,action_vocab

word_vocab, pos_vocab,rel_vocab,action_vocab = build_vocab()

46266


In [22]:
def extract_features(state, deprels,word_vocab, pos_vocab,rel_vocab):
    # Initialize feature vectors for words and POS tags
    word_feature = []
    pos_feature = []
    rel_feature= []
    # Get the top three words on stack and buffer
    
    stack_words_id = state.stack[-3:] if len(state.stack) >= 3 else state.stack + [-1] * (3 - len(state.stack))
    buffer_words_id = state.buffer[-3:] if len(state.buffer) >= 3 else state.buffer + [-1] * (3 - len(state.buffer))

    stack_word=[]
    for word_id in stack_words_id:
        if word_id==-1:
            stack_word.append('<NULL>')
        elif word_id==0:
            stack_word.append('<ROOT>')
        else:
            stack_word.append(deprels[word_id].word)

    buffer_word=[]
    for word_id in buffer_words_id:
        if word_id==-1:
            buffer_word.append("<NULL>")
        elif word_id==0:
            buffer_word.append("<ROOT>")
        else:
            buffer_word.append(deprels[word_id].word)

    stack_pos=[]
    for word_id in stack_words_id:
        if word_id==-1:
            stack_pos.append("<NULL_POS>")
        elif word_id==0:
            stack_pos.append("<ROOT_POS>")
        else:
            stack_pos.append(deprels[word_id].pos)

    buffer_pos=[]
    for word_id in buffer_words_id:
        if word_id==-1:
            buffer_pos.append("<NULL_POS>")
        elif word_id==0:
            buffer_pos.append("<ROOT_POS>")
        else:
            buffer_pos.append(deprels[word_id].pos)

    stack_rel=[]
    for word_id in stack_words_id:
        if word_id==-1:
            stack_rel.append("<NULL_rel>")
        elif word_id==0:
            stack_rel.append("<ROOT_rel>")
        else:
            stack_rel.append(deprels[word_id].deprel)

    buffer_rel=[]
    for word_id in buffer_words_id:
        if word_id==-1:
            buffer_rel.append("<NULL_rel>")
        elif word_id==0:
            buffer_rel.append("<ROOT_rel>")
        else:
            buffer_rel.append(deprels[word_id].deprel)

    
    for word in stack_word:
        word_feature.append(word_vocab[word])
    for word in buffer_word:
        word_feature.append(word_vocab[word])
    for pos in stack_pos:
        pos_feature.append(pos_vocab[pos])
    for pos in buffer_pos:
        pos_feature.append(pos_vocab[pos])
    for rel in stack_rel:
        rel_feature.append(rel_vocab[rel])
    for rel in buffer_rel:
        rel_feature.append(rel_vocab[rel])

    return word_feature + pos_feature +rel_feature

In [23]:
def get_training_data(dep_tree,word_vocab, pos_vocab,rel_vocab,action_vocab):
    deprels = dep_tree.deprels
    word_ids = list(deprels.keys())
    state = State(word_ids)
    state.stack.append(0) # ROOT

    childcount = defaultdict(int)
    for _, rel in deprels.items():
        childcount[rel.head] += 1

    deprels_list=[]

    features = []
    labels = []
    
    transition_UAS=[]
    transition_LAS=[]

    for _, rel in deprels.items():
        deprels_list.append((rel.head,rel.id))

    while len(state.buffer) > 0 or len(state.stack) > 1:
        
        # start your code here
        feature= extract_features(state, deprels,word_vocab, pos_vocab,rel_vocab)
        # label=action_vocab["R("+ stack_top1.deprel +")"]
        features.append(feature)        
        # end your code here
                
        if state.stack[-1] == 0:
            state.shift()
            key='shift'
            index=action_vocab[key]
            label=[0]*len(action_vocab)
            label[index]=1
            labels.append(label)
            continue
        
        stack_top1 = deprels[state.stack[-1]]
        if state.stack[-2] == 0:
            stack_top2 = RootDummy()
        else:
            stack_top2 = deprels[state.stack[-2]]
        
        if (stack_top1.id, stack_top2.id) in deprels_list:
            deprels_list.remove((stack_top1.id, stack_top2.id))
            state.left_arc(stack_top2.deprel)
            childcount[stack_top1.id] -= 1
            
            key= 'L<'+stack_top2.deprel+'>'
            index=action_vocab[key]

            transition_UAS.append([stack_top1.word,stack_top2.word,'left_arc'])
            transition_LAS.append([stack_top1.word,stack_top2.word,'left_arc',stack_top2.deprel])

        elif (stack_top2.id, stack_top1.id) in deprels_list and childcount[stack_top1.id] == 0:
            deprels_list.remove((stack_top2.id, stack_top1.id))
            state.right_arc(stack_top1.deprel)
            childcount[stack_top2.id] -= 1

            key='R<'+stack_top1.deprel +'>'
            index=action_vocab[key]

            transition_UAS.append([stack_top2.word,stack_top1.word,'right_arc'])
            transition_LAS.append([stack_top2.word,stack_top1.word,'right_arc',stack_top1.deprel])

        else:
            state.shift()
            key='shift'
            index=action_vocab[key]

        label=[0]*len(action_vocab)
        label[index]=1
        labels.append(label)


    return transition_LAS,transition_UAS,features,labels

In [24]:
word_vocab, pos_vocab,rel_vocab,action_vocab=build_vocab()
def generate_train_data(trees):
    features_list=[]
    labels_list=[]
    for i,tree in enumerate(trees):
        _,_,features,labels=get_training_data(tree,word_vocab, pos_vocab,rel_vocab,action_vocab)
        for feature in features:
            features_list.append(feature)
        for label in labels:
            labels_list.append(label)
    return features_list,labels_list

46266


In [25]:
train_data,train_label=generate_train_data(train_trees)


### 2. Build the Model

In [36]:
class Parser(nn.Module):
    def __init__(self, word_size,pos_size,rel_size,input_size,emb_size, hidden_size, output_size):
        super(Parser, self).__init__()
        self.input_size=input_size
        self.emb_layer1 = nn.Embedding(word_size, emb_size)
        self.emb_layer2 = nn.Embedding(pos_size, emb_size)
        self.emb_layer3 = nn.Embedding(rel_size, emb_size) 
        self.hidden_layer = nn.Linear(emb_size*input_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)
    
    def forward(self, X): 
        emb_word_output = self.emb_layer1(X[:,int(0):int(self.input_size/3)])
        emb_pos_output = self.emb_layer2(X[:,int(self.input_size/3):int(self.input_size*2/3)])
        emb_rel_output = self.emb_layer3(X[:,int(self.input_size*2/3):int(self.input_size)])
        emb_output=torch.cat((emb_word_output,emb_pos_output,emb_rel_output),dim=1)
        emb_output=emb_output.view(emb_output.size(0),-1)
        hidden_output = F.relu(self.hidden_layer(emb_output))
        output=self.output_layer(hidden_output)
        # score=F.softmax(output,dim=-1)  后边交叉熵损失集成了，这里就不用计算了
        return output
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
word_size=len(word_vocab)
pos_size=len(pos_vocab)
rel_size=len(rel_vocab)
emb_size=50
hidden_size=200
output_size=len(action_vocab) # 78
batch_size = 1024
learning_rate = 0.001
num_epochs=6
input_size=18
model=Parser(word_size,pos_size,rel_size,input_size,emb_size, hidden_size,output_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(), lr=learning_rate)


### 3. Train and Evaluate

In [37]:
class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = torch.tensor(data, dtype=torch.long)
        self.targets = torch.tensor(targets, dtype=torch.long)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index], self.targets[index]



In [38]:
train_dataset = CustomDataset(train_data, train_label)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

In [39]:
# 推理的时候要检查
def is_legal_transition(action,stack,buffer):
    if(action==-1):
        if(len(stack)<=2):
            return False
    if(action==0):
        if(len(buffer)<=0):
            return False
    if(action==1):
        if(len(stack)<2):
            return False
    return True

In [40]:
def parse_sentence(model, tree):
    deprels = tree.deprels
    words = list(reversed(deprels.values()))
    
    # poses=[word.pos for word in sentence]

    # word_ids = [word_vocab[key] for key in words]
    # pos_ids = [pos_vocab[key] for key in poses]
    
    stack = [RootDummy()]  # Initialize the stack with ROOT symbol
    buffer = words # Initialize the buffer with input words
    
    transition_LAS = []  # Store transitions applied
    transition_UAS = []  # Store transitions applied

    stack_top3=[]
    buffer_top3=[]
    while len(buffer) > 0 or len(stack) > 1:
        stack_top3 = stack[-3:] if len(stack) >= 3 else stack + [NullDummy()] * (3 - len(stack))
        buffer_top3 = buffer[-3:] if len(buffer) >= 3 else buffer + [NullDummy()] * (3 - len(buffer))
        
        stack_word_id=[word_vocab[key.word] for key in stack_top3]
        buffer_word_id=[word_vocab[key.word] for key in buffer_top3]    
        stack_pos_id=[pos_vocab[key.pos] for key in stack_top3]
        buffer_pos_id=[pos_vocab[key.pos] for key in buffer_top3]
        stack_rel_id=[rel_vocab[key.deprel] for key in stack_top3]
        buffer_rel_id=[rel_vocab[key.deprel] for key in buffer_top3]
                
        input_feature=stack_word_id+buffer_word_id+stack_pos_id+buffer_pos_id+stack_rel_id+buffer_rel_id
        input_feature=torch.Tensor(input_feature).to(device).to(torch.long)
        input_feature=torch.unsqueeze(input_feature, 0)
        
        output= model(input_feature)
        output=F.softmax(output,dim=-1)[0]
        has_legal=False
        reversed_action_vocab = {value: key for key, value in action_vocab.items()}


        for action in action_vocab: 
            if action[0]=='L':
                if(is_legal_transition(-1,stack,buffer)):
                    has_legal=True
                else:
                    output[action_vocab[action]]=-1

            if action[0]=='R':
                if(is_legal_transition(1,stack,buffer)):
                    has_legal=True
                else:
                    output[action_vocab[action]]=-1

            if action=='shift':
                if(is_legal_transition(0,stack,buffer)):
                    has_legal=True
                else:
                    output[action_vocab[action]]=-1

        if not has_legal:
            break
        index=np.argmax(np.array(output.tolist()))
        action= reversed_action_vocab[index]
        
        action_direction=0
        rel=None
        if action[0]=='L':
            action_direction=-1
            rel=action[2:-1]
        if action[0]=='R':
            action_direction=1
            rel=action[2:-1]

        
        if(action_direction==-1):
            stack_top1=stack[-1]
            stack_top2=stack[-2]
            transition_LAS.append([stack_top1.word,stack_top2.word ,'left_arc',rel])
            transition_UAS.append([stack_top1.word,stack_top2.word ,'left_arc'])
            stack.pop(-2)
        elif(action_direction==0):
            stack.append(buffer.pop())
        elif(action_direction==1):
            stack_top1=stack[-1]
            stack_top2=stack[-2]
            transition_LAS.append([stack_top2.word,stack_top1.word ,"right_arc",rel])
            transition_UAS.append([stack_top2.word,stack_top1.word ,"right_arc"])
            stack.pop()     
    return transition_LAS,transition_UAS

In [41]:
def evaluate(predict_transition,true_transition):
    # print(predict_transition)
    # print(true_transition)
    predict_transition_tuples = [tuple(item) for item in predict_transition]
    true_transition_tuples = [tuple(item) for item in true_transition]
    set_predict_transition = set(predict_transition_tuples)
    set_true_transition = set(true_transition_tuples)

    # 统计集合的交集大小
    right = len(set_predict_transition.intersection(set_true_transition))
    return right

In [42]:
def train(train_loader):
    for epoch in range(1,num_epochs+1):
        total_loss = 0
        model.train()
        with tqdm(total=len(train_loader), desc=f'Epoch {epoch}/{num_epochs}', unit='batch') as pbar:
            for batch_idx, (data, target) in enumerate(train_loader):
                data=data.to(device).to(torch.long)
                target=target.to(device).to(torch.float)
                optimizer.zero_grad()
                outputs = model(data)
                loss = criterion(outputs, target)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                pbar.set_postfix({'Loss': total_loss / len(train_loader)})
                pbar.update(1)
        if(epoch%3==0):
            model.eval()
            best_score=0
            total_score=0
            LAS_score=0
            UAS_score=0
            for dev_tree in dev_trees:
                predict_transition_LAS,predict_transition_UAS= parse_sentence(model,dev_tree)
                true_transition_LAS,true_transition_UAS,_,_=get_training_data(dev_tree,word_vocab,pos_vocab,rel_vocab,action_vocab)
                UAS_score+=evaluate(predict_transition_UAS,true_transition_UAS)
                LAS_score+=evaluate(predict_transition_LAS,true_transition_LAS)
                total_score+=len(true_transition_LAS)
            dev_LAS_score= LAS_score/total_score*100
            dev_UAS_score= UAS_score/total_score*100
            print('dev_LAS_score: ', dev_LAS_score,"%")
            print('dev_UAS_score: ', dev_UAS_score,"%")
            if dev_LAS_score>best_score:
                best_score=dev_LAS_score
                torch.save(model.state_dict(), 'model.pt')
        print(f'Epoch {epoch}/{num_epochs}, Loss: {total_loss / len(train_loader)}')

In [43]:
train(train_loader)

Epoch 1/9: 100%|██████████| 1848/1848 [00:14<00:00, 128.44batch/s, Loss=0.133] 


Epoch 1/9, Loss: 0.13299296245390138


Epoch 2/9: 100%|██████████| 1848/1848 [00:13<00:00, 139.55batch/s, Loss=0.0631]


Epoch 2/9, Loss: 0.06313275216647234


Epoch 3/9: 100%|██████████| 1848/1848 [00:14<00:00, 127.74batch/s, Loss=0.0528]


dev_LAS_score:  91.6941941941942 %
dev_UAS_score:  91.66916916916918 %
Epoch 3/9, Loss: 0.05284616362214798


Epoch 4/9: 100%|██████████| 1848/1848 [00:13<00:00, 138.73batch/s, Loss=0.0452]


Epoch 4/9, Loss: 0.04520147652274554


Epoch 5/9: 100%|██████████| 1848/1848 [00:14<00:00, 127.88batch/s, Loss=0.0386]


Epoch 5/9, Loss: 0.03864230540403653


Epoch 6/9: 100%|██████████| 1848/1848 [00:14<00:00, 128.22batch/s, Loss=0.0328]


dev_LAS_score:  91.64664664664664 %
dev_UAS_score:  91.62412412412412 %
Epoch 6/9, Loss: 0.0327879844457153


Epoch 7/9: 100%|██████████| 1848/1848 [00:13<00:00, 139.44batch/s, Loss=0.0277]


Epoch 7/9, Loss: 0.027730163716453076


Epoch 8/9: 100%|██████████| 1848/1848 [00:14<00:00, 127.27batch/s, Loss=0.0232]


Epoch 8/9, Loss: 0.02320549653861188


Epoch 9/9: 100%|██████████| 1848/1848 [00:14<00:00, 127.53batch/s, Loss=0.0195]


dev_LAS_score:  91.30630630630631 %
dev_UAS_score:  91.29129129129129 %
Epoch 9/9, Loss: 0.0194670050054712


In [44]:
model=Parser(word_size,pos_size,rel_size,input_size,emb_size, hidden_size,output_size).to(device)
model.load_state_dict(torch.load('model.pt'))
total_score=0
LAS_score=0
UAS_score=0
for test_tree in test_trees:
    predict_transition_LAS,predict_transition_UAS= parse_sentence(model,test_tree)
    true_transition_LAS,true_transition_UAS,_,_=get_training_data(test_tree,word_vocab,pos_vocab,rel_vocab,action_vocab)
    UAS_score+=evaluate(predict_transition_UAS,true_transition_UAS)
    LAS_score+=evaluate(predict_transition_LAS,true_transition_LAS)
    total_score+=len(true_transition_LAS)

test_LAS_score= LAS_score/total_score*100
test_UAS_score= UAS_score/total_score*100
print('test_LAS_score: ', test_LAS_score,"%")
print('test_UAS_score: ', test_UAS_score,"%")

test_LAS_score:  92.09062566428116 %
test_UAS_score:  92.09948274640402 %
