# RNN Model with crf

In [1]:
import rnn_dataset
import rnn_classifier
from crf import CRF
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from collections import Counter
import itertools
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os
from torch.nn.utils.rnn import pad_sequence



In [2]:
#define the hyperparameters
batch_size    = 16
window_size   =  6#left context and right context
lr            = 1e-3
device        = "cpu"
epochs        = 20
emb_size      = 64
hidden_size   = 64
nb_layers     = 2
drop_out      = 0.1

In [3]:
trainset     = rnn_dataset.MweRnnDataset("corpus/train.conllu",  isTrain = True)
testset      = rnn_dataset.MweRnnDataset("corpus/test.conllu")
print(len(trainset.tags_vocab))

token Vocab size 35693
token Vocab size 35693
43


In [4]:
model = rnn_classifier.MweRNN(
    name         = "rnn",
    toks_vocab   = trainset.toks_vocab,
    tags_vocab   = trainset.tags_vocab, 
    deprel_vocab = trainset.deprel_vocab,
    emb_size     = emb_size, 
    hidden_size  = hidden_size, 
    drop_out     = 0.)

In [None]:
model.train_model(trainset,testset, epochs= epochs, lr=lr, batch_size = batch_size, split_train=0.8)

100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 17.70it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 60.76it/s]


Epoch 0 | Mean train loss  46.4963 |  Mean dev loss  29.1421 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.17it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 50.56it/s]


Epoch 1 | Mean train loss  24.7811 |  Mean dev loss  21.4905 



100%|█████████████████████████████████████████| 581/581 [00:33<00:00, 17.60it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 55.16it/s]


Epoch 2 | Mean train loss  18.6540 |  Mean dev loss  17.4223 



 21%|████████▋                                | 123/581 [00:06<00:24, 18.96it/s]

In [None]:
TP, FP, FN, average_precision, average_recall, average_f1_score, weighted_f1_score, weighted_recall, weighted_precision = model.evaluate()

In [30]:
num_tags = len(model.tags_vocab)
# print(num_tags)
TP = torch.zeros(num_tags)
FP = torch.zeros(num_tags)
FN = torch.zeros(num_tags)
class_counts = torch.zeros(num_tags)
with torch.no_grad():
            
    for X_toks, deprel, Y_golds in tqdm(testset.get_loader(batch_size = 500)):
    # Forward pass
        logprobs, masks = model.forward(X_toks)
        best_score, best_paths = model.crf(logprobs, masks) #viterbi
        #best_paths = pad_sequence(best_paths, padding_value= testset.tags_vocab["<pad>"]).T
        #print(best_paths.shape)
        # Mask out the padding positions
        #print(best_paths)
        
        for path, gold, x in zip(best_paths, Y_golds, X_toks):
            print(f"gold {list(model.toks_vocab.rev_lookup(int(i)) for i in x if i!= 1)}")
            print(f"prediction {list(model.tags_vocab.rev_lookup(i) for i in path)}")
            print(f"gold {list(model.tags_vocab.rev_lookup(int(i)) for i in gold if i!= 1)}")
        print(Y_golds.shape)
        # Update confusion matrix
        for tag in range(num_tags):
            TP[tag] += ((best_paths == tag) & (Y_golds == tag)).sum()
            FP[tag] += ((best_paths == tag) & (Y_golds != tag)).sum()
            FN[tag] += ((best_paths != tag) & (Y_golds == tag)).sum()
            class_counts[tag] += (Y_golds == tag).sum()
            

  0%|                                                     | 0/4 [00:00<?, ?it/s]

prediction ['<bos>', 'B_ADP', '<eos>']
gold ['<bos>', 'B_N', '<eos>']
prediction ['<bos>', 'B_DET', 'B_NOUN', 'B_PRON', 'B_VERB', 'B_VERB', 'B_ADP', 'B_DET', 'B_NOUN', 'B_NOUN', 'B_PUNCT', 'B_ADP', 'B_PROPN', 'B_PUNCT', 'B_ADP', 'B_PROPN', 'B_PROPN', 'B_PUNCT', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_NOUN', 'B_PUNCT', '<eos>']
gold ['<bos>', 'B_D', 'B_N', 'B_CL', 'B_V', 'B_C', 'B_P', 'B_D', 'B_A', 'B_N', 'B_PONCT', 'B_P', 'B_N', 'B_PONCT', 'B_P', 'B_A', 'B_N', 'B_PONCT', 'B_D', 'B_N', 'B_A', 'B_V', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_ADP', 'B_ADV', 'B_VERB', 'B__', 'B_ADP', 'B_DET', 'B_ADJ', 'B_NOUN', 'B_SCONJ', 'B_CONJ', 'B_PRON', 'B_DET', 'B_NOUN', 'B_PUNCT', 'B_ADJ', 'B_NOUN', 'B_ADJ', 'B_ADP', 'B_VERB', 'B_DET', 'B_ADJ', 'B_NOUN', 'B_ADP', 'B_NOUN', 'B_PUNCT', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADJ', 'B__', 'B_ADP', 'B_DET', 'B_NOUN', 'B_PUNCT', '<eos>']
gold ['<bos>', 'B_P', 'I_ADV', 'I_ADV', 'B__', 'B_P', 'B_D', 'B_A', 'B_N', 'B_C', 'B_N', 'B_V', 'B_D', 'B_N', 'B_PONCT', 'B_N', 'B_N',

 25%|███████████▎                                 | 1/4 [00:00<00:01,  2.41it/s]

prediction ['<bos>', 'B_ADP', 'B_NOUN', 'B_ADP', 'B_PROPN', 'B_PUNCT', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_AUX', 'B_VERB', 'B_ADP', 'B_PRON', 'B_VERB', 'B_VERB', 'B_PUNCT', '<eos>']
gold ['<bos>', 'B_P', 'B_N', 'B_P', 'B_N', 'B_PONCT', 'B_D', 'B_A', 'B_N', 'B_V', 'B_V', 'B_C', 'B_CL', 'B_V', 'B_V', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_DET', 'B_NOUN', 'B_ADP', 'B_NOUN', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_CONJ', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_AUX', 'B_VERB', 'B_ADP', 'B_NUM', 'B_NOUN', 'B_VERB', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADP', 'B_NOUN', 'B_ADP', 'B_NOUN', 'B_PUNCT', '<eos>']
gold ['<bos>', 'B_D', 'B_N', 'B_P', 'B_V', 'B_D', 'B_N', 'B_V', 'B_C', 'B_D', 'B_A', 'B_N', 'B_V', 'B_V', 'B_P', 'B_V', 'B_V', 'B_N', 'B_P', 'B_D', 'B_N', 'B_P', 'B_N', 'B_P', 'B_N', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_PRON', 'B_VERB', 'B_ADV', 'B_ADJ', 'B_NOUN', 'B_PUNCT', '<eos>']
gold ['<bos>', 'B_CL', 'B_V', 'B_ADV', 'B_D', 'B_N', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_DET', 'B_NOUN', 'B_ADP', 'B_

 50%|██████████████████████▌                      | 2/4 [00:00<00:00,  2.11it/s]

gold ['<bos>', 'B_N', 'B_A', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_DET', 'B_NOUN', 'B_ADP', 'B_DET', 'B_NOUN', 'B__', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_NOUN', 'B_PUNCT', 'B_ADJ', 'B_NOUN', 'B_ADP', 'B_NOUN', 'B_ADJ', 'B_NOUN', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_PUNCT', 'B_PUNCT', '<eos>']
gold ['<bos>', 'B_N', 'B_A', 'B_P', 'B_D', 'B_N', 'B__', 'B_P', 'B_D', 'B_N', 'B_A', 'B_P', 'B_D', 'B_N', 'B_N', 'B_A', 'B_PONCT', 'B_N', 'B_V', 'B_P', 'B_N', 'B_P', 'B_N', 'B_P', 'B_D', 'B_N', 'B_N', 'B_PONCT', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_DET', 'B_ADJ', 'B_NOUN', 'B_ADP', 'B_DET', 'B_NOUN', '<eos>']
gold ['<bos>', 'B_N', 'B_P', 'B_N', 'B_P', 'B_D', 'B_N', '<eos>']
prediction ['<bos>', 'B_NOUN', 'B_ADJ', '<eos>']
gold ['<bos>', 'B_N', 'B_A', '<eos>']
prediction ['<bos>', 'B_DET', 'B_NOUN', '<eos>']
gold ['<bos>', 'B_N', 'B_N', '<eos>']
prediction ['<bos>', 'B_NOUN', 'B_ADP', 'B_NOUN', '<eos>']
gold ['<bos>', 'B_N', 'B_P', 'B_N', '<e

 75%|█████████████████████████████████▊           | 3/4 [00:01<00:00,  1.79it/s]

gold ['<bos>', 'B_D', 'B_N', 'B_A', 'B_C', 'I_PONCT', 'I_C', 'B_D', 'B_N', 'B_PRO', 'B_ADV', 'B_V', 'B_ADV', 'B__', 'B_P', 'B_D', 'I_N', 'I_P', 'B_D', 'B_N', 'B_C', 'B_P', 'B_D', 'B_N', 'B_V', 'B_V', 'B_V', 'B_P', 'B_D', 'B_N', 'B_V', 'B_P', 'B_D', 'B_N', 'B_P', 'B_D', 'I_A', 'B_N', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_PRON', 'B_DET', 'B_NOUN', 'B_CONJ', 'B_DET', 'B_NOUN', 'B_SCONJ', 'B_PRON', 'B_VERB', 'B_ADP', 'B_PROPN', 'B_PROPN', 'B_PUNCT', '<eos>']
gold ['<bos>', 'B_V', 'B_D', 'B_N', 'B_C', 'B_D', 'B_N', 'B_C', 'B_CL', 'B_V', 'B_D', 'B_A', 'B_N', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_AUX', 'B_AUX', 'B_VERB', 'B_PUNCT', 'B_NOUN', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADP', 'B_PROPN', 'B_PROPN', 'B_PUNCT', 'B_ADP', 'B_NOUN', 'B_CONJ', 'B_ADP', 'B_NOUN', 'B_ADP', 'B_NOUN', 'B_PUNCT', 'B_VERB', 'B_DET', 'B_NOUN', 'B__', 'B_ADP', 'B_DET', 'B_NOUN', 'B_PUNCT', 'B_ADP', 'B_DET', 'B_NOUN', 'B_CONJ', 'B__', 'B_ADP', 'B_DET', 'B_NOUN', 'B_PUNCT', 'B_PUNCT',

100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.20it/s]

gold ['<bos>', 'B_D', 'B_N', 'B_CL', 'B_V', 'B_V', 'B_P', 'B_CL', 'B_V', 'B_N', 'B_P', 'B_D', 'B_A', 'B_N', 'B_PONCT', '<eos>']
prediction ['<bos>', 'B_PUNCT', 'B_PROPN', 'B_PUNCT', 'B_PRON', 'B_AUX', 'B_VERB', 'B_ADP', 'B_NOUN', 'B__', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADP', 'B_NOUN', 'B_ADJ', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_CONJ', 'B_PRON', 'B_VERB', 'B_SCONJ', 'B_PRON', 'B_AUX', 'B_VERB', 'B_ADP', 'B_VERB', 'B_DET', 'B_NOUN', 'B_PRON', 'B_AUX', 'B_AUX', 'B_VERB', 'B_ADP', 'B_DET', 'B_NOUN', 'B_NOUN', 'B_ADJ', 'B_PUNCT', 'B_PROPN', 'B_PUNCT', 'B_ADJ', 'B_CONJ', 'B_ADJ', 'B__', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_CONJ', 'B_ADJ', 'B_PUNCT', '<eos>']
gold ['<bos>', 'B_PONCT', 'B_N', 'B_PONCT', 'B_CL', 'B_V', 'B_V', 'B_P', 'I_N', 'B__', 'I_P', 'B_D', 'B_N', 'B_P', 'B_N', 'B_N', 'B_P', 'B_D', 'B_N', 'B_N', 'B_C', 'B_CL', 'B_V', 'B_C', 'B_CL', 'B_V', 'B_A', 'B_P', 'B_V', 'B_D', 'B_N', 'B_PRO', 'B_V', 'B_V', 'B_V', 'B_P', 'B_D', 'B_A', 'B_N', 'B_N', 'B_PONCT', 'B_N', 'B_PONCT', 'B_




In [32]:
TP

tensor([0.0000e+00, 1.6430e+03, 0.0000e+00, 1.4342e+05, 2.0000e+01, 0.0000e+00,
        0.0000e+00, 5.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.1000e+01,
        3.5000e+01, 3.0000e+00, 0.0000e+00, 0.0000e+00, 7.8000e+01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])