# RNN Model with crf

In [1]:
import rnn_dataset
import rnn_classifier
from crf import CRF
import torch
from data_utils import Vocabulary
from torch.utils.data import Dataset, DataLoader, random_split
from collections import Counter
import itertools
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os
from torch.nn.utils.rnn import pad_sequence
import pyconll


In [2]:
#define the hyperparameters
batch_size    = 16
window_size   =  6#left context and right context
lr            = 1e-3
device        = "cpu"
epochs        = 20
emb_size      = 64
hidden_size   = 64
nb_layers     = 2
drop_out      = 0.1

In [3]:
trainset     = rnn_dataset.MweRnnDataset("corpus/train.conllu",  isTrain = True)
testset      = rnn_dataset.MweRnnDataset("corpus/test.conllu")


token Vocab size 35693
token Vocab size 35693


In [4]:
for x, d, tag in trainset.get_loader(shuffle = True):
    print(list(trainset.tags_vocab.rev_lookup(int(t)) for t in tag.squeeze(0)))
    break
    

['B_N', 'B_D', 'B_N', 'B_PONCT', 'B_CL', 'B_V', 'B_ADV', 'B_V', 'B_PONCT', 'B_P', 'B_D', 'B_N', 'B_CL', 'B_CL', 'B_V', 'B_V', 'B_PONCT', 'B_P', 'B_D', 'B_N', 'B_P', 'B_D', 'B_N', 'B__', 'B_P', 'B_D', 'B_N', 'B_P', 'B_N', 'B_CL', 'B_V', 'B_P', 'B_CL', 'B_V', 'B_PONCT', 'B_C', 'B_CL', 'B_V', 'B_C', 'B_CL', 'B_V', 'B_ADV', 'B_ADV', 'B_A', 'B_P', 'B_V', 'B_D', 'B_N', 'B_P', 'B_CL', 'B_C', 'B_P', 'B_D', 'B_N', 'B_P', 'B_D', 'B_N', 'B_CL', 'B_V', 'B_ADV', 'I_A', 'B_V', 'B_V', 'B_P', 'B_D', 'B_N', 'B_PONCT', '<pad>']


In [5]:
model = rnn_classifier.MweRNN(
    name         = "rnn",
    toks_vocab   = trainset.toks_vocab,
    tags_vocab   = trainset.tags_vocab, 
    deprel_vocab = trainset.deprel_vocab,
    emb_size     = emb_size, 
    hidden_size  = hidden_size, 
    drop_out     = 0.)

In [6]:
model.train_model(trainset,testset, epochs= epochs, lr=lr, batch_size = batch_size, split_train=0.8)

100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.69it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 77.64it/s]


Epoch 0 | Mean train loss  30.3678 |  Mean dev loss  17.3760 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.75it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 67.01it/s]


Epoch 1 | Mean train loss  14.1967 |  Mean dev loss  12.4934 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.89it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 69.36it/s]


Epoch 2 | Mean train loss  10.8008 |  Mean dev loss  9.8736 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.85it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 68.63it/s]


Epoch 3 | Mean train loss  8.8186 |  Mean dev loss  8.0413 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.97it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 64.04it/s]


Epoch 4 | Mean train loss  7.3161 |  Mean dev loss  7.1748 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.63it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 72.00it/s]


Epoch 5 | Mean train loss  6.3419 |  Mean dev loss  5.8997 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.70it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 69.27it/s]


Epoch 6 | Mean train loss  5.4658 |  Mean dev loss  5.2650 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.62it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 70.74it/s]


Epoch 7 | Mean train loss  4.8039 |  Mean dev loss  4.4984 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.99it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 68.36it/s]


Epoch 8 | Mean train loss  4.1807 |  Mean dev loss  4.1520 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.80it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 67.55it/s]


Epoch 9 | Mean train loss  3.7217 |  Mean dev loss  3.5725 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.96it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 68.35it/s]


Epoch 10 | Mean train loss  3.2857 |  Mean dev loss  3.2097 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.67it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 61.36it/s]


Epoch 11 | Mean train loss  2.9005 |  Mean dev loss  2.9310 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 19.02it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 67.68it/s]


Epoch 12 | Mean train loss  2.6254 |  Mean dev loss  2.4905 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.53it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 69.40it/s]


Epoch 13 | Mean train loss  2.3028 |  Mean dev loss  2.3677 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 18.98it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 69.37it/s]


Epoch 14 | Mean train loss  2.1020 |  Mean dev loss  2.0235 



100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 17.91it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 59.50it/s]


Epoch 15 | Mean train loss  1.8626 |  Mean dev loss  1.8770 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.38it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 71.01it/s]


Epoch 16 | Mean train loss  1.6658 |  Mean dev loss  1.7273 



100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 18.01it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 67.91it/s]


Epoch 17 | Mean train loss  1.5288 |  Mean dev loss  1.4665 



100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 17.86it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 61.81it/s]


Epoch 18 | Mean train loss  1.3773 |  Mean dev loss  1.3506 



100%|█████████████████████████████████████████| 581/581 [00:33<00:00, 17.17it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 59.20it/s]


Epoch 19 | Mean train loss  1.2256 |  Mean dev loss  1.2767 



  0%|                                                   | 0/105 [00:00<?, ?it/s]


AttributeError: 'bool' object has no attribute 'sum'

In [None]:
TP, FP, FN, average_precision, average_recall, average_f1_score, weighted_f1_score, weighted_recall, weighted_precision = model.evaluate()

In [64]:
num_tags = len(model.tags_vocab)
# print(num_tags)
TP = torch.zeros(num_tags)
FP = torch.zeros(num_tags)
FN = torch.zeros(num_tags)
class_counts = torch.zeros(num_tags)
with torch.no_grad():
            
    for X_toks, Y_golds in tqdm(testset.get_loader(batch_size = 500)):
    # Forward pass
        logprobs, masks = model.forward(X_toks)
        best_score, best_paths = model.crf(logprobs, masks) #viterbi
        #best_paths = pad_sequence(best_paths, padding_value= testset.tags_vocab["<pad>"]).T
        #print(best_paths.shape)
        # Mask out the padding positions
        
        for i in range(len(best_paths)):
            str = list(model.tags_vocab.rev_lookup(int(i))for i in gold if i!= model.padidx)
            path = best_paths[i]
            gold = torch.tensor([j for j in Y_golds[i] if j != model.padidx])
            for tag in path:
                TP[tag] += ((path == tag) & (gold == tag)).sum()
                FP[tag] += ((path == tag) & (gold != tag)).sum()
                FN[tag] += ((path != tag) & (gold == tag)).sum()
                class_counts[tag] += (gold == tag).sum()
                
            

100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.14s/it]


In [73]:
precision = TP / (TP + FP)
# avoid nan
nan_mask = torch.isnan(precision)
precision[nan_mask] = 0.

recall = TP / (TP + FN)
# avoid nan
nan_mask = torch.isnan(recall)
recall[nan_mask] = 0.

f1_score = 2 * (precision * recall) / (precision + recall)
# avoid nan
nan_mask = torch.isnan(f1_score)
f1_score[nan_mask] = 0.
# Calculate class weights
class_weights = class_counts / class_counts.sum()

In [74]:
for tag in range(num_tags):
    print(model.tags_vocab.rev_lookup(tag), precision[tag])

<unk> tensor(0.)
<pad> tensor(0.)
B_CL tensor(0.9306)
I_V tensor(1.)
I_ADV tensor(0.)
B_V tensor(0.9560)
B_P tensor(0.9091)
B_A tensor(0.7545)
B_D tensor(0.9697)
B_N tensor(0.8596)
B_PONCT tensor(0.9672)
B_C tensor(0.9172)
B__ tensor(0.9938)
B_ADV tensor(0.8123)
I_N tensor(0.8571)
I_C tensor(1.)
I_CL tensor(0.)
I_D tensor(0.8333)
I_P tensor(0.6250)
I_PONCT tensor(0.)
I_A tensor(0.)
B_PREF tensor(1.)
B_I tensor(0.)
B_ET tensor(0.)
I_ET tensor(0.)
B_NC tensor(0.)
B_S tensor(0.)
B_X tensor(0.)


In [69]:
for tag in range(num_tags):
    print(model.tags_vocab.rev_lookup(tag), class_counts[tag])

<unk> tensor(0.)
<pad> tensor(0.)
B_CL tensor(3290.)
I_V tensor(2.)
I_ADV tensor(0.)
B_V tensor(16242.)
B_P tensor(35511.)
B_A tensor(7129.)
B_D tensor(29758.)
B_N tensor(81717.)
B_PONCT tensor(16081.)
B_C tensor(2267.)
B__ tensor(1761.)
B_ADV tensor(2218.)
I_N tensor(14.)
I_C tensor(2.)
I_CL tensor(0.)
I_D tensor(7.)
I_P tensor(9.)
I_PONCT tensor(0.)
I_A tensor(0.)
B_PREF tensor(4.)
B_I tensor(0.)
B_ET tensor(0.)
I_ET tensor(0.)
B_NC tensor(0.)
B_S tensor(0.)
B_X tensor(0.)


In [77]:
weighted_f1_score = (f1_score * class_weights)
weighted_recall = (recall * class_weights)
weighted_precision = (precision * class_weights)

In [15]:
print(sum(weighted_f1_score))
print(sum(weighted_recall))
print(sum(weighted_recall))

tensor(0.9184)
tensor(0.9386)
tensor(0.9386)


In [80]:
for tag in range(num_tags):
    print(model.tags_vocab.rev_lookup(tag), f1_score[tag])

<unk> tensor(0.)
<pad> tensor(0.)
B_CL tensor(0.9440)
I_V tensor(0.6667)
I_ADV tensor(0.)
B_V tensor(0.9277)
B_P tensor(0.9387)
B_A tensor(0.7683)
B_D tensor(0.9375)
B_N tensor(0.9005)
B_PONCT tensor(0.9833)
B_C tensor(0.9393)
B__ tensor(0.9963)
B_ADV tensor(0.8186)
I_N tensor(0.8571)
I_C tensor(1.)
I_CL tensor(0.)
I_D tensor(0.7692)
I_P tensor(0.5882)
I_PONCT tensor(0.)
I_A tensor(0.)
B_PREF tensor(1.)
B_I tensor(0.)
B_ET tensor(0.)
I_ET tensor(0.)
B_NC tensor(0.)
B_S tensor(0.)
B_X tensor(0.)


# RNN Layer + Attention Layer + CRF decoder

In [9]:
#define the hyperparameters
batch_size    = 16
window_size   =  6#left context and right context
lr            = 1e-2
device        = "cpu"
epochs        = 20
emb_size      = 64
hidden_size   = 64
drop_out      = 0.1

In [10]:
trainset     = rnn_dataset.MweRnnDataset("corpus/train.conllu",  isTrain = True)
testset      = rnn_dataset.MweRnnDataset("corpus/test.conllu")

token Vocab size 35693
token Vocab size 35693


In [11]:
model = rnn_classifier.MweRNN(
    name         = "rnn",
    toks_vocab   = trainset.toks_vocab,
    tags_vocab   = trainset.tags_vocab, 
    emb_size     = emb_size, 
    hidden_size  = hidden_size, 
    drop_out     = 0.)

In [12]:
model.train_model(trainset,testset, epochs= epochs, lr=lr, batch_size = batch_size, split_train=0.8)

100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.33it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 70.27it/s]


Epoch 0 | Mean train loss  12.0193 |  Mean dev loss  6.7987 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.33it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 66.87it/s]


Epoch 1 | Mean train loss  4.5913 |  Mean dev loss  4.2064 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 19.36it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 69.21it/s]


Epoch 2 | Mean train loss  2.7259 |  Mean dev loss  2.5497 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.75it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 65.60it/s]


Epoch 3 | Mean train loss  1.8030 |  Mean dev loss  1.8576 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.90it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 60.51it/s]


Epoch 4 | Mean train loss  1.3155 |  Mean dev loss  1.3191 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.42it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 64.76it/s]


Epoch 5 | Mean train loss  1.0527 |  Mean dev loss  1.2239 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.59it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 66.10it/s]


Epoch 6 | Mean train loss  0.9031 |  Mean dev loss  1.1100 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 19.70it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 67.22it/s]


Epoch 7 | Mean train loss  0.8308 |  Mean dev loss  0.9874 



100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 18.14it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 65.08it/s]


Epoch 8 | Mean train loss  1.1247 |  Mean dev loss  2.5232 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 19.13it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 67.18it/s]


Epoch 9 | Mean train loss  2.1894 |  Mean dev loss  1.8341 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 19.02it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 58.87it/s]


Epoch 10 | Mean train loss  1.2867 |  Mean dev loss  1.2812 



100%|█████████████████████████████████████████| 581/581 [00:33<00:00, 17.46it/s]
100%|█████████████████████████████████████████| 146/146 [00:03<00:00, 47.96it/s]


Epoch 11 | Mean train loss  1.0197 |  Mean dev loss  1.2670 



100%|█████████████████████████████████████████| 581/581 [00:33<00:00, 17.60it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 56.95it/s]


Epoch 12 | Mean train loss  0.9808 |  Mean dev loss  1.6853 



100%|█████████████████████████████████████████| 581/581 [00:34<00:00, 17.04it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 57.79it/s]


Epoch 13 | Mean train loss  1.7437 |  Mean dev loss  2.6243 



100%|█████████████████████████████████████████| 581/581 [00:34<00:00, 16.90it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 57.74it/s]


Epoch 14 | Mean train loss  1.8777 |  Mean dev loss  1.6229 



100%|█████████████████████████████████████████| 581/581 [00:33<00:00, 17.27it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 57.71it/s]


Epoch 15 | Mean train loss  1.2462 |  Mean dev loss  1.4156 



100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 17.62it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 61.70it/s]


Epoch 16 | Mean train loss  1.1232 |  Mean dev loss  1.4860 



100%|█████████████████████████████████████████| 581/581 [00:34<00:00, 17.03it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 59.41it/s]


Epoch 17 | Mean train loss  1.0963 |  Mean dev loss  1.4831 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 18.99it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 64.22it/s]


Epoch 18 | Mean train loss  1.7235 |  Mean dev loss  1.7566 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.27it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 70.84it/s]


Epoch 19 | Mean train loss  1.5869 |  Mean dev loss  1.4515 



100%|█████████████████████████████████████████| 105/105 [00:04<00:00, 21.90it/s]

AVR: Precision 0.5373 | Recall  0.5202 |  F-score  0.5246 
Weighted: Precision 0.9105 | Recall  0.9442 |  F-score  0.8833 





In [6]:
num_tags = len(model.tags_vocab)
# print(num_tags)
TP = torch.zeros(num_tags)
FP = torch.zeros(num_tags)
FN = torch.zeros(num_tags)
class_counts = torch.zeros(num_tags)
with torch.no_grad():
            
    for X_toks, Y_golds in tqdm(testset.get_loader(batch_size = 500)):
    # Forward pass
        logits, masks = model.forward(X_toks)
        best_score, best_paths = model.crf(logits, masks) #viterbi
        #best_paths = pad_sequence(best_paths, padding_value= testset.tags_vocab["<pad>"]).T
        #print(best_paths.shape)
        # Mask out the padding positions
        
        for i in range(len(best_paths)):
            str = list(model.tags_vocab.rev_lookup(int(i))for i in Y_golds[i] if i!= model.padidx)
            path = best_paths[i]
            gold = torch.tensor([j for j in Y_golds[i] if j != model.padidx])
            for tag in path:
                TP[tag] += ((path == tag) & (gold == tag)).sum()
                FP[tag] += ((path == tag) & (gold != tag)).sum()
                FN[tag] += ((path != tag) & (gold == tag)).sum()
                class_counts[tag] += (gold == tag).sum()
              

100%|█████████████████████████████████████████████| 4/4 [00:05<00:00,  1.35s/it]


In [7]:
precision = TP / (TP + FP)
# avoid nan
nan_mask = torch.isnan(precision)
precision[nan_mask] = 0.

recall = TP / (TP + FN)
# avoid nan
nan_mask = torch.isnan(recall)
recall[nan_mask] = 0.

f1_score = 2 * (precision * recall) / (precision + recall)
# avoid nan
nan_mask = torch.isnan(f1_score)
f1_score[nan_mask] = 0.
# Calculate class weights
class_weights = class_counts / class_counts.sum()

In [8]:
for tag in range(num_tags):
    print(model.tags_vocab.rev_lookup(tag), f1_score[tag])

<unk> tensor(0.)
<pad> tensor(0.)
B_CL tensor(0.9182)
I_V tensor(0.6667)
I_ADV tensor(0.)
B_V tensor(0.8877)
B_P tensor(0.9400)
B_A tensor(0.7496)
B_D tensor(0.9318)
B_N tensor(0.8716)
B_PONCT tensor(0.9820)
B_C tensor(0.9292)
B__ tensor(0.9997)
B_ADV tensor(0.7915)
I_N tensor(0.7429)
I_C tensor(0.1429)
I_CL tensor(0.)
I_D tensor(0.8095)
I_P tensor(0.3158)
I_PONCT tensor(0.3158)
I_A tensor(0.5333)
B_PREF tensor(0.8889)
B_I tensor(0.)
B_ET tensor(0.)
I_ET tensor(0.)
B_NC tensor(0.)
B_S tensor(0.)
B_X tensor(0.)


In [11]:
weighted_f1_score = (f1_score * class_weights)
weighted_recall = (recall * class_weights)
weighted_precision = (precision * class_weights)

In [12]:
print(sum(weighted_f1_score))
print(sum(weighted_recall))
print(sum(weighted_recall))

tensor(0.9072)
tensor(0.9160)
tensor(0.9160)


tensor(0.5314)
tensor(0.5052)
tensor(0.5134)
