# RNN Model with crf

In [1]:
import rnn_dataset
import rnn_classifier
from crf import CRF
import torch
from data_utils import Vocabulary
from torch.utils.data import Dataset, DataLoader, random_split
from collections import Counter
import itertools
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os
from torch.nn.utils.rnn import pad_sequence
import pyconll


In [2]:
#define the hyperparameters
batch_size    = 16
lr            = 1e-2
device        = "cpu"
epochs        = 20
emb_size      = 64
hidden_size   = 64
nb_layers     = 2
drop_out      = 0.1

In [3]:
trainset     = rnn_dataset.MweRnnDataset("corpus/train.conllu",  isTrain = True)
testset      = rnn_dataset.MweRnnDataset("corpus/test.conllu")


token Vocab size 35693
token Vocab size 35693


In [4]:
for x, tag in trainset.get_loader(shuffle = True):
    print(list(trainset.tags_vocab.rev_lookup(int(t)) for t in tag.squeeze(0)))
    break
    

['B_V', 'B_P', 'B_N', 'B_P', 'B_V', 'B_D', 'B_N', 'B_A', 'B_PONCT', 'B_D', 'B_N', 'B_C', 'B_CL', 'B_V', 'B_V', 'B_ADV', 'B_V', 'B__', 'B_P', 'B_D', 'B_N', 'B_PONCT', '<pad>']


In [5]:
model = rnn_classifier.MweRNN(
    name         = "RNN",
    toks_vocab   = trainset.toks_vocab,
    tags_vocab   = trainset.tags_vocab, 
    emb_size     = emb_size, 
    hidden_size  = hidden_size, 
    drop_out     = drop_out)

In [None]:
model.train_model(trainset,testset, epochs= epochs, lr=lr, batch_size = batch_size, split_train=0.8)

100%|█████████████████████████████████████████| 581/581 [00:28<00:00, 20.22it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 77.71it/s]


Epoch 0 | Mean train loss  12.2721 |  Mean dev loss  6.5046 



100%|█████████████████████████████████████████| 581/581 [00:29<00:00, 20.03it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 71.39it/s]


Epoch 1 | Mean train loss  4.1072 |  Mean dev loss  5.3728 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 21.24it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 77.06it/s]


Epoch 2 | Mean train loss  2.2919 |  Mean dev loss  5.2992 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 18.82it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 69.19it/s]


Epoch 3 | Mean train loss  1.6149 |  Mean dev loss  5.4731 



100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 18.15it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 56.30it/s]


Epoch 4 | Mean train loss  1.2643 |  Mean dev loss  5.6245 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 18.76it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 71.70it/s]


Epoch 5 | Mean train loss  1.0487 |  Mean dev loss  5.8031 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.31it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 50.35it/s]


Epoch 6 | Mean train loss  0.8808 |  Mean dev loss  6.0774 



  8%|███▍                                      | 47/581 [00:02<00:30, 17.35it/s]

In [None]:
TP, FP, FN, average_precision, average_recall, average_f1_score, weighted_f1_score, weighted_recall, weighted_precision = model.evaluate()

In [64]:
num_tags = len(model.tags_vocab)
# print(num_tags)
TP = torch.zeros(num_tags)
FP = torch.zeros(num_tags)
FN = torch.zeros(num_tags)
class_counts = torch.zeros(num_tags)
with torch.no_grad():
            
    for X_toks, Y_golds in tqdm(testset.get_loader(batch_size = 500)):
    # Forward pass
        logprobs, masks = model.forward(X_toks)
        best_score, best_paths = model.crf(logprobs, masks) #viterbi
        #best_paths = pad_sequence(best_paths, padding_value= testset.tags_vocab["<pad>"]).T
        #print(best_paths.shape)
        # Mask out the padding positions
        
        for i in range(len(best_paths)):
            str = list(model.tags_vocab.rev_lookup(int(i))for i in gold if i!= model.padidx)
            path = best_paths[i]
            gold = torch.tensor([j for j in Y_golds[i] if j != model.padidx])
            for tag in path:
                TP[tag] += ((path == tag) & (gold == tag)).sum()
                FP[tag] += ((path == tag) & (gold != tag)).sum()
                FN[tag] += ((path != tag) & (gold == tag)).sum()
                class_counts[tag] += (gold == tag).sum()
                
            

100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.14s/it]


In [2]:
precision = TP / (TP + FP)
# avoid nan
nan_mask = torch.isnan(precision)
precision[nan_mask] = 0.

recall = TP / (TP + FN)
# avoid nan
nan_mask = torch.isnan(recall)
recall[nan_mask] = 0.

f1_score = 2 * (precision * recall) / (precision + recall)
# avoid nan
nan_mask = torch.isnan(f1_score)
f1_score[nan_mask] = 0.
# Calculate class weights
class_weights = class_counts / class_counts.sum()

NameError: name 'TP' is not defined

In [None]:
for tag in range(num_tags):
    print(model.tags_vocab.rev_lookup(tag), precision[tag])

In [3]:
for tag in range(num_tags):
    print(model.tags_vocab.rev_lookup(tag), class_counts[tag])

NameError: name 'num_tags' is not defined

In [77]:
weighted_f1_score = (f1_score * class_weights)
weighted_recall = (recall * class_weights)
weighted_precision = (precision * class_weights)

In [15]:
print(sum(weighted_f1_score))
print(sum(weighted_recall))
print(sum(weighted_recall))

tensor(0.9184)
tensor(0.9386)
tensor(0.9386)


In [18]:
for tag in range(num_tags):
    print(model.tags_vocab.rev_lookup(tag), f1_score[tag])


<unk> tensor(0.)
<pad> tensor(0.)
B_CL tensor(0.9270)
I_V tensor(0.8889)
I_ADV tensor(0.2275)
B_V tensor(0.8962)
B_P tensor(0.9381)
B_A tensor(0.7716)
B_D tensor(0.9347)
B_N tensor(0.8759)
B_PONCT tensor(0.9828)
B_C tensor(0.9336)
B__ tensor(0.9997)
B_ADV tensor(0.3598)
I_N tensor(0.4318)
I_C tensor(0.4000)
I_CL tensor(0.9091)
I_D tensor(0.7143)
I_P tensor(0.3692)
I_PONCT tensor(0.8235)
I_A tensor(0.4545)
B_PREF tensor(0.8000)
B_I tensor(0.)
B_ET tensor(0.)
I_ET tensor(0.)
B_NC tensor(0.)
B_S tensor(0.)
B_X tensor(0.)


# LSTM Layer + CRF decoder

In [26]:
#define the hyperparameters
batch_size    = 16
window_size   =  6#left context and right context
lr            = 1e-2
device        = "cpu"
epochs        = 30
emb_size      = 64
hidden_size   = 64
drop_out      = 0.1

In [27]:
trainset     = rnn_dataset.MweRnnDataset("corpus/train.conllu",  isTrain = True)
testset      = rnn_dataset.MweRnnDataset("corpus/test.conllu")

token Vocab size 35693
token Vocab size 35693


In [28]:
model = rnn_classifier.MweRNN(
    name         = "LSTM",
    toks_vocab   = trainset.toks_vocab,
    tags_vocab   = trainset.tags_vocab, 
    emb_size     = emb_size, 
    hidden_size  = hidden_size, 
    drop_out     = 0.)

In [29]:
model.train_model(trainset,testset, epochs= epochs, lr=lr, batch_size = batch_size, split_train=0.8)

100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 22.27it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 88.52it/s]


Epoch 0 | Mean train loss  12.3865 |  Mean dev loss  6.3059 



100%|█████████████████████████████████████████| 581/581 [00:25<00:00, 23.22it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 91.47it/s]


Epoch 1 | Mean train loss  4.2572 |  Mean dev loss  3.8725 



100%|█████████████████████████████████████████| 581/581 [00:24<00:00, 23.68it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 78.67it/s]


Epoch 2 | Mean train loss  2.5388 |  Mean dev loss  2.6031 



100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 21.84it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 71.32it/s]


Epoch 3 | Mean train loss  1.8201 |  Mean dev loss  1.7126 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 21.31it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 77.20it/s]


Epoch 4 | Mean train loss  1.3661 |  Mean dev loss  1.3698 



100%|█████████████████████████████████████████| 581/581 [00:25<00:00, 23.01it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 64.88it/s]


Epoch 5 | Mean train loss  1.1158 |  Mean dev loss  1.0955 



100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 21.92it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 78.32it/s]


Epoch 6 | Mean train loss  0.8928 |  Mean dev loss  0.9807 



100%|█████████████████████████████████████████| 581/581 [00:25<00:00, 23.11it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 77.07it/s]


Epoch 7 | Mean train loss  0.7569 |  Mean dev loss  0.7933 



100%|█████████████████████████████████████████| 581/581 [00:24<00:00, 23.62it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 64.17it/s]


Epoch 8 | Mean train loss  0.6448 |  Mean dev loss  0.6704 



100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 21.82it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 70.74it/s]


Epoch 9 | Mean train loss  0.5317 |  Mean dev loss  0.6198 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 20.83it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 61.02it/s]


Epoch 10 | Mean train loss  0.4490 |  Mean dev loss  0.5743 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 21.13it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 74.37it/s]


Epoch 11 | Mean train loss  0.4278 |  Mean dev loss  0.5111 



100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 22.02it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 65.00it/s]


Epoch 12 | Mean train loss  0.6699 |  Mean dev loss  1.5737 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 21.32it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 81.35it/s]


Epoch 13 | Mean train loss  1.7729 |  Mean dev loss  1.5296 



100%|█████████████████████████████████████████| 581/581 [00:25<00:00, 22.35it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 72.80it/s]


Epoch 14 | Mean train loss  1.0089 |  Mean dev loss  0.9141 



100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 22.21it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 78.17it/s]


Epoch 15 | Mean train loss  0.5897 |  Mean dev loss  0.5759 



100%|█████████████████████████████████████████| 581/581 [00:24<00:00, 23.65it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 76.97it/s]


Epoch 16 | Mean train loss  0.3790 |  Mean dev loss  0.3967 



100%|█████████████████████████████████████████| 581/581 [00:25<00:00, 23.16it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 74.83it/s]


Epoch 17 | Mean train loss  0.2670 |  Mean dev loss  0.2926 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 21.48it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 67.23it/s]


Epoch 18 | Mean train loss  0.2062 |  Mean dev loss  0.2066 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 20.96it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 70.23it/s]


Epoch 19 | Mean train loss  0.1852 |  Mean dev loss  0.1735 



100%|█████████████████████████████████████████| 581/581 [00:25<00:00, 22.68it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 71.88it/s]


Epoch 20 | Mean train loss  0.1444 |  Mean dev loss  0.1677 



100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 21.90it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 72.95it/s]


Epoch 21 | Mean train loss  0.1181 |  Mean dev loss  0.1315 



100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 21.88it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 62.64it/s]


Epoch 22 | Mean train loss  0.1112 |  Mean dev loss  0.1291 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 21.21it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 72.97it/s]


Epoch 23 | Mean train loss  0.4967 |  Mean dev loss  2.8051 



100%|█████████████████████████████████████████| 581/581 [00:26<00:00, 22.28it/s]
100%|█████████████████████████████████████████| 146/146 [00:01<00:00, 73.08it/s]


Epoch 24 | Mean train loss  1.8956 |  Mean dev loss  1.5678 



100%|█████████████████████████████████████████| 581/581 [00:27<00:00, 20.93it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 64.87it/s]


Epoch 25 | Mean train loss  0.9912 |  Mean dev loss  0.9728 



100%|█████████████████████████████████████████| 581/581 [00:28<00:00, 20.14it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 61.28it/s]


Epoch 26 | Mean train loss  0.6217 |  Mean dev loss  0.6855 



100%|█████████████████████████████████████████| 581/581 [00:28<00:00, 20.47it/s]
100%|█████████████████████████████████████████| 146/146 [00:03<00:00, 44.79it/s]


Epoch 27 | Mean train loss  0.4597 |  Mean dev loss  0.4892 



100%|█████████████████████████████████████████| 581/581 [00:28<00:00, 20.16it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 49.07it/s]


Epoch 28 | Mean train loss  0.3572 |  Mean dev loss  0.3485 



 46%|██████████████████▊                      | 266/581 [00:13<00:15, 19.98it/s]


KeyboardInterrupt: 

In [15]:
num_tags = len(model.tags_vocab)
# print(num_tags)
TP = torch.zeros(num_tags)
FP = torch.zeros(num_tags)
FN = torch.zeros(num_tags)
class_counts = torch.zeros(num_tags)
with torch.no_grad():
            
    for X_toks, Y_golds in tqdm(testset.get_loader(batch_size = 500)):
    # Forward pass
        logits, masks = model.forward(X_toks)
        best_score, best_paths = model.crf(logits, masks) #viterbi
        #best_paths = pad_sequence(best_paths, padding_value= testset.tags_vocab["<pad>"]).T
        #print(best_paths.shape)
        # Mask out the padding positions
        
        for i in range(len(best_paths)):
            str = list(model.tags_vocab.rev_lookup(int(i))for i in Y_golds[i] if i!= model.padidx)
            path = best_paths[i]
            gold = torch.tensor([j for j in Y_golds[i] if j != model.padidx])
            for tag in path:
                TP[tag] += ((path == tag) & (gold == tag)).sum()
                FP[tag] += ((path == tag) & (gold != tag)).sum()
                FN[tag] += ((path != tag) & (gold == tag)).sum()
                class_counts[tag] += (gold == tag).sum()
              

100%|█████████████████████████████████████████████| 4/4 [00:05<00:00,  1.46s/it]


In [16]:
precision = TP / (TP + FP)
# avoid nan
nan_mask = torch.isnan(precision)
precision[nan_mask] = 0.

recall = TP / (TP + FN)
# avoid nan
nan_mask = torch.isnan(recall)
recall[nan_mask] = 0.

f1_score = 2 * (precision * recall) / (precision + recall)
# avoid nan
nan_mask = torch.isnan(f1_score)
f1_score[nan_mask] = 0.
# Calculate class weights
class_weights = class_counts / class_counts.sum()

In [17]:
for tag in range(num_tags):
    print(model.tags_vocab.rev_lookup(tag), f1_score[tag])

<unk> tensor(0.)
<pad> tensor(0.)
B_CL tensor(0.9270)
I_V tensor(0.8889)
I_ADV tensor(0.2275)
B_V tensor(0.8962)
B_P tensor(0.9381)
B_A tensor(0.7716)
B_D tensor(0.9347)
B_N tensor(0.8759)
B_PONCT tensor(0.9828)
B_C tensor(0.9336)
B__ tensor(0.9997)
B_ADV tensor(0.3598)
I_N tensor(0.4318)
I_C tensor(0.4000)
I_CL tensor(0.9091)
I_D tensor(0.7143)
I_P tensor(0.3692)
I_PONCT tensor(0.8235)
I_A tensor(0.4545)
B_PREF tensor(0.8000)
B_I tensor(0.)
B_ET tensor(0.)
I_ET tensor(0.)
B_NC tensor(0.)
B_S tensor(0.)
B_X tensor(0.)


In [11]:
weighted_f1_score = (f1_score * class_weights)
weighted_recall = (recall * class_weights)
weighted_precision = (precision * class_weights)

In [12]:
print(sum(weighted_f1_score))
print(sum(weighted_recall))
print(sum(weighted_recall))

tensor(0.9072)
tensor(0.9160)
tensor(0.9160)


tensor(0.5314)
tensor(0.5052)
tensor(0.5134)
