# RNN Model with crf

In [1]:
import rnn_dataset
import rnn_classifier
from crf import CRF
import torch
from data_utils import Vocabulary
from torch.utils.data import Dataset, DataLoader, random_split
from collections import Counter
import itertools
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import os
from torch.nn.utils.rnn import pad_sequence
imort pyconll



In [36]:
def readfile(filename, update=False, toks_vocab=Vocabulary([ "<unk>", "<bos>", "<eos>", "<pad>"]), tags_vocab=Vocabulary([ "<unk>", "<bos>", "<eos>", "<pad>"]), deprel_vocab=Vocabulary([ "<unk>", "<bos>", "<eos>", "<pad>"])):
    """
    function to read the corpus at one pass
    signature for train corpus : X_toks, Y_tags = readfile("corpus/train.conllu", update=True)
    signature for test corpus/ dev corpus:  X_test, Y_test = readfile("corpus/train.conllu", update=True, vocabtoks_train, vocabtags_train)
    """

    istream = open(filename, encoding="utf-8")
    corpus = []
    toks, mwes, deprels = [], [], []
    for line in istream:
        line = line.strip()
        if line and line[0] != "#":
            try:
                tokidx, token, lemma, upos, pos, features, headidx, deprel, extended, _ = line.split()
            except ValueError:
                print(line)
                print(corpus[-1])
            #print(upos)
            """
            if tokidx == "1":
                # beginning of sentence, add false toks
                toks.append("<bos>")
                mwes.append("<bos>")
                deprels.append("<bos>")
            """

            # extract tagging information
            # extract simple mwe tags
            extr_mwe = lambda x: "I" if features.startswith("component") else "B"

            mwe = extr_mwe(features)+"_"+upos
            toks.append(token)
            mwes.append(mwe)
            deprels.append(deprel)

            if update:
                toks_vocab.update(token)
                tags_vocab.update(mwe)
                deprel_vocab.update(deprel)

        elif toks:
            # end of sentence, add  false tokens
            #corpus.append({"tok": toks+ ["<eos>"], "mwe":mwes+["<eos>"], "deprel": deprels+ ["<eos>"] })
            corpus.append({"tok": toks , "mwe": mwes , "deprel": deprels })
            toks, mwes, deprels = [], [], []

    istream.close()
    # return the encoded data in list of list, the nested list represents the sentences
    return corpus, toks_vocab, tags_vocab, deprel_vocab



In [27]:
corpus, toks_vocab, tags_vocab, deprel_vocab = readfile("sequoia/sequoia.deep_and_surf.conll")

In [28]:
len(corpus)

3099

In [37]:
corpus, toks_vocab, tags_vocab, deprel_vocab = readfile("corpus/train.conllu")

In [38]:
len(corpus)

11617

In [2]:
#define the hyperparameters
batch_size    = 16
window_size   =  6#left context and right context
lr            = 1e-3
device        = "cpu"
epochs        = 20
emb_size      = 64
hidden_size   = 64
nb_layers     = 2
drop_out      = 0.1

In [3]:
trainset     = rnn_dataset.MweRnnDataset("corpus/train.conllu",  isTrain = True)
testset      = rnn_dataset.MweRnnDataset("corpus/test.conllu")


token Vocab size 35693
token Vocab size 35693


In [5]:
for x, d, tag in testset.get_loader(shuffle = True):
    print(list(trainset.tags_vocab.rev_lookup(int(t)) for t in tag.squeeze(0)))
    break
    

['B_CL', 'B_V', 'B_V', 'B_N', 'B_ADV', 'B__', 'I_P', 'B_D', 'B_N', 'B_V', 'B_P', 'I_N', 'I_P', 'B_V', 'B_D', 'B_N', 'B_A', 'B_P', 'B_D', 'B_N', 'B_P', 'B_V', 'B_D', 'B_N', 'B_PONCT', '<pad>']


In [4]:
model = rnn_classifier.MweRNN(
    name         = "rnn",
    toks_vocab   = trainset.toks_vocab,
    tags_vocab   = trainset.tags_vocab, 
    deprel_vocab = trainset.deprel_vocab,
    emb_size     = emb_size, 
    hidden_size  = hidden_size, 
    drop_out     = 0.)

In [5]:
model.train_model(trainset,testset, epochs= epochs, lr=lr, batch_size = batch_size, split_train=0.8)

100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 17.70it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 60.76it/s]


Epoch 0 | Mean train loss  46.4963 |  Mean dev loss  29.1421 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.17it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 50.56it/s]


Epoch 1 | Mean train loss  24.7811 |  Mean dev loss  21.4905 



100%|█████████████████████████████████████████| 581/581 [00:33<00:00, 17.60it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 55.16it/s]


Epoch 2 | Mean train loss  18.6540 |  Mean dev loss  17.4223 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.20it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 59.49it/s]


Epoch 3 | Mean train loss  15.2325 |  Mean dev loss  13.6854 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 18.77it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 58.49it/s]


Epoch 4 | Mean train loss  12.5433 |  Mean dev loss  11.9401 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.47it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 54.04it/s]


Epoch 5 | Mean train loss  10.7902 |  Mean dev loss  9.6202 



100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 18.03it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 51.01it/s]


Epoch 6 | Mean train loss  9.1894 |  Mean dev loss  8.6488 



100%|█████████████████████████████████████████| 581/581 [00:35<00:00, 16.15it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 57.88it/s]


Epoch 7 | Mean train loss  7.9519 |  Mean dev loss  7.7272 



100%|█████████████████████████████████████████| 581/581 [00:40<00:00, 14.30it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 52.66it/s]


Epoch 8 | Mean train loss  6.9455 |  Mean dev loss  6.6765 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.32it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 59.61it/s]


Epoch 9 | Mean train loss  6.0916 |  Mean dev loss  5.9257 



100%|█████████████████████████████████████████| 581/581 [00:32<00:00, 18.01it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 55.40it/s]


Epoch 10 | Mean train loss  5.4091 |  Mean dev loss  5.1306 



100%|█████████████████████████████████████████| 581/581 [00:35<00:00, 16.58it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 54.23it/s]


Epoch 11 | Mean train loss  4.7402 |  Mean dev loss  4.6164 



100%|█████████████████████████████████████████| 581/581 [00:34<00:00, 17.07it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 57.99it/s]


Epoch 12 | Mean train loss  4.1751 |  Mean dev loss  4.2324 



100%|█████████████████████████████████████████| 581/581 [00:33<00:00, 17.29it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 57.86it/s]


Epoch 13 | Mean train loss  3.7260 |  Mean dev loss  3.6971 



100%|█████████████████████████████████████████| 581/581 [00:30<00:00, 19.08it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 60.29it/s]


Epoch 14 | Mean train loss  3.2615 |  Mean dev loss  3.4191 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.56it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 56.19it/s]


Epoch 15 | Mean train loss  2.9654 |  Mean dev loss  2.8791 



100%|█████████████████████████████████████████| 581/581 [00:34<00:00, 16.64it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 54.99it/s]


Epoch 16 | Mean train loss  2.6388 |  Mean dev loss  2.6459 



100%|█████████████████████████████████████████| 581/581 [00:33<00:00, 17.60it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 51.53it/s]


Epoch 17 | Mean train loss  2.3498 |  Mean dev loss  2.3470 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.37it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 60.98it/s]


Epoch 18 | Mean train loss  2.1079 |  Mean dev loss  2.1175 



100%|█████████████████████████████████████████| 581/581 [00:31<00:00, 18.63it/s]
100%|█████████████████████████████████████████| 146/146 [00:02<00:00, 62.43it/s]


Epoch 19 | Mean train loss  1.9092 |  Mean dev loss  1.8602 



  0%|                                                   | 0/105 [00:00<?, ?it/s]


AttributeError: 'bool' object has no attribute 'sum'

In [None]:
TP, FP, FN, average_precision, average_recall, average_f1_score, weighted_f1_score, weighted_recall, weighted_precision = model.evaluate()

In [7]:
num_tags = len(model.tags_vocab)
# print(num_tags)
TP = torch.zeros(num_tags)
FP = torch.zeros(num_tags)
FN = torch.zeros(num_tags)
class_counts = torch.zeros(num_tags)
with torch.no_grad():
            
    for X_toks, deprel, Y_golds in tqdm(testset.get_loader(batch_size = 500)):
    # Forward pass
        logprobs, masks = model.forward(X_toks)
        best_score, best_paths = model.crf(logprobs, masks) #viterbi
        #best_paths = pad_sequence(best_paths, padding_value= testset.tags_vocab["<pad>"]).T
        #print(best_paths.shape)
        # Mask out the padding positions
        #print(best_paths)
        
        for path, gold, x in zip(best_paths, Y_golds, X_toks):
            print(f"token {list(model.toks_vocab.rev_lookup(int(i)) for i in x if i!= 1)}")
            print(f"prediction {list(model.tags_vocab.rev_lookup(i) for i in path)}")
            print(f"gold {list(model.tags_vocab.rev_lookup(int(i)) for i in gold if i!= 1)}")
        print(Y_golds.shape)
        # Update confusion matrix
        for tag in range(num_tags):
            TP[tag] += ((best_paths == tag) & (Y_golds == tag)).sum()
            FP[tag] += ((best_paths == tag) & (Y_golds != tag)).sum()
            FN[tag] += ((best_paths != tag) & (Y_golds == tag)).sum()
            class_counts[tag] += (Y_golds == tag).sum()
            

  0%|                                                     | 0/4 [00:00<?, ?it/s]

token ['Gutenberg']
prediction ['B_ADV']
gold ['B_N']
token ['Cette', 'exposition', 'nous', 'apprend', 'que', 'dès', 'le', 'XIIe', 'siècle', ',', 'à', '<unk>', ',', 'entre', 'autres', 'sites', ',', 'une', 'industrie', '<unk>', 'existait', '.']
prediction ['B_DET', 'B_NOUN', 'B_PRON', 'B_VERB', 'B_ADV', 'B_ADP', 'B_DET', 'B_ADJ', 'B_NOUN', 'B_PUNCT', 'B_ADP', 'B_PROPN', 'B_PUNCT', 'B_ADP', 'B_ADJ', 'B_NOUN', 'B_PUNCT', 'B_DET', 'B_NOUN', 'B_AUX', 'B_VERB', 'B_PUNCT']
gold ['B_D', 'B_N', 'B_CL', 'B_V', 'B_C', 'B_P', 'B_D', 'B_A', 'B_N', 'B_PONCT', 'B_P', 'B_N', 'B_PONCT', 'B_P', 'B_A', 'B_N', 'B_PONCT', 'B_D', 'B_N', 'B_A', 'B_V', 'B_PONCT']
token ['à', 'peu', 'près', 'au', 'à', 'le', 'même', 'moment', 'que', 'Gutenberg', '<unk>', "l'", 'imprimerie', ',', '<unk>', '<unk>', 'créait', 'en', '1450', 'la', 'première', 'forge', 'à', '<unk>', ',', 'à', "l'", 'actuel', 'emplacement', 'du', 'de', 'le', '<unk>', '.']
prediction ['B_ADP', 'B_ADV', 'B_ADV', 'B__', 'B_ADP', 'B_DET', 'B_ADJ', 'B_NOUN

 25%|███████████▎                                 | 1/4 [00:01<00:03,  1.32s/it]

gold ['B_N', 'B_P', 'B_N', 'B_PONCT', 'B_D', 'B_N', 'B_A']
token ['Trois', 'personnes', 'ont', 'été', 'gravement', 'blessées', ',', 'dimanche', 'soir', 'vers', '17', 'h', '40', 'à', 'la', 'suite', "d'", 'un', 'accident', 'spectaculaire', 'qui', "s'", 'est', 'produit', 'à', "l'", 'entrée', 'du', 'de', 'le', 'village', 'de', '<unk>', 'sur', 'la', 'départementale', '32', '(', "L'", 'Est', '<unk>', 'du', 'de', 'le', '5', 'juillet', ')', '.']
prediction ['B_NUM', 'B_NOUN', 'B_AUX', 'B_AUX', 'B_ADV', 'B_VERB', 'B_PUNCT', 'B_NOUN', 'B_NOUN', 'B_ADP', 'B_NUM', 'B_NOUN', 'B_NUM', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_PRON', 'B_PRON', 'B_AUX', 'B_VERB', 'B_ADP', 'B_DET', 'B_NOUN', 'B__', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADP', 'B_PROPN', 'B_ADP', 'B_DET', 'B_NOUN', 'B_NUM', 'B_PUNCT', 'B_PROPN', 'B_PROPN', 'B_PROPN', 'B__', 'B_ADP', 'B_DET', 'B_NUM', 'B_NOUN', 'B_PUNCT', 'B_PUNCT']
gold ['B_D', 'B_N', 'B_V', 'B_V', 'B_ADV', 'B_V', 'B_PONCT', 'B_N', 'B_N', 'B_P', 'B_D',

 50%|██████████████████████▌                      | 2/4 [00:02<00:02,  1.33s/it]

token ['-', 'Chez', 'les', 'patients', 'qui', 'bénéficient', "d'", 'une', 'intervention', '<unk>', 'destinée', 'à', 'traiter', 'des', '<unk>']
prediction ['B_PUNCT', 'B_ADP', 'B_DET', 'B_NOUN', 'B_PRON', 'B_VERB', 'B_ADP', 'B_DET', 'B_NOUN', 'B_AUX', 'B_VERB', 'B_ADP', 'B_VERB', 'B_DET', 'B_PROPN']
gold ['B_PONCT', 'B_P', 'B_D', 'B_N', 'B_PRO', 'B_V', 'B_P', 'B_D', 'B_N', 'B_A', 'B_V', 'B_P', 'B_V', 'B_D', 'B_N']
token ['<unk>', 'SONT', 'LES', '<unk>', 'À', '<unk>', '<unk>', "D'", '<unk>', '<unk>']
prediction ['B_AUX', 'B_VERB', 'B_PROPN', 'B_PROPN', 'B_PROPN', 'B_PROPN', 'B_PROPN', 'B_PROPN', 'B_PROPN', 'B_PROPN']
gold ['B_A', 'B_V', 'B_D', 'B_N', 'B_P', 'B_V', 'B_P', 'I_P', 'B_V', 'B_N']
token ["N'", 'utilisez', 'jamais', '<unk>', ':']
prediction ['B_PART', 'B_VERB', 'B_ADV', 'B_PROPN', 'B_PUNCT']
gold ['B_ADV', 'B_V', 'B_ADV', 'B_N', 'B_PONCT']
token ['-', 'si', 'vous', 'êtes', '<unk>', '(', 'allergique', ')', 'à', 'la', '<unk>', 'ou', 'à', "l'", 'un', 'des', 'de', 'les', 'autres', 

 75%|█████████████████████████████████▊           | 3/4 [00:04<00:01,  1.37s/it]

token ['Avant', "l'", 'administration', ',', 'il', 'faut', 'laisser', 'la', 'solution', '<unk>', 'atteindre', 'la', 'température', 'ambiante', '.']
prediction ['B_P', 'B_D', 'B_N', 'B_PONCT', 'B_CL', 'B_V', 'B_V', 'B_D', 'B_N', 'B_V', 'B_V', 'B_D', 'B_N', 'B_A', 'B_PONCT']
gold ['B_P', 'B_D', 'B_N', 'B_PONCT', 'B_CL', 'B_V', 'B_V', 'B_D', 'B_N', 'B_V', 'B_V', 'B_D', 'B_N', 'B_A', 'B_PONCT']
token ['Que', 'contient', '<unk>']
prediction ['B_PRO', 'B_V', 'B_V']
gold ['B_PRO', 'B_V', 'B_N']
token ['-', 'La', 'substance', 'active', 'est', "l'", 'acide', '<unk>', '.']
prediction ['B_PUNCT', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_VERB', 'B_DET', 'B_NOUN', 'B_VERB', 'B_PUNCT']
gold ['B_PONCT', 'B_D', 'B_N', 'B_A', 'B_V', 'B_D', 'B_N', 'B_A', 'B_PONCT']
token ['Chaque', '<unk>', 'de', '100', 'ml', 'de', 'solution', 'contient', '5', 'mg', "d'", 'acide', '<unk>', '<unk>', ',', 'ce', 'qui', '<unk>', 'à', '<unk>', 'mg', "d'", 'acide', '<unk>', '<unk>', '.']
prediction ['B_DET', 'B_PROPN', 'B_ADP', 'B_NUM'

100%|█████████████████████████████████████████████| 4/4 [00:04<00:00,  1.12s/it]

token ['Il', "n'", 'y', 'a', 'aucun', 'motif', 'pour', "qu'", 'ils', 'y', 'soient', '.']
prediction ['B_PRON', 'B_PART', 'B_CL', 'B_V', 'B_D', 'B_N', 'B_P', 'I_C', 'B_CL', 'B_ADV', 'B_VERB', 'B_PUNCT']
gold ['B_CL', 'B_ADV', 'B_CL', 'B_V', 'B_D', 'B_N', 'B_P', 'I_C', 'B_CL', 'B_CL', 'B_V', 'B_PONCT']
token ['Ils', "n'", 'ont', 'été', 'impliqués', 'dans', 'aucune', 'action', 'terroriste', 'ou', 'militaire', '.']
prediction ['B_PRON', 'B_PART', 'B_AUX', 'B_AUX', 'B_VERB', 'B_ADP', 'B_DET', 'B_NOUN', 'B_ADJ', 'B_CONJ', 'B_ADJ', 'B_PUNCT']
gold ['B_CL', 'B_ADV', 'B_V', 'B_V', 'B_V', 'B_P', 'B_D', 'B_N', 'B_A', 'B_C', 'B_A', 'B_PONCT']
token ['Je', 'crois', 'que', 'nous', 'devrions', 'nous', 'pencher', 'sur', 'ce', 'problème', '.']
prediction ['B_CL', 'B_V', 'B_C', 'B_CL', 'B_V', 'B_CL', 'B_V', 'B_P', 'B_D', 'B_N', 'B_PONCT']
gold ['B_CL', 'B_V', 'B_C', 'B_CL', 'B_V', 'B_CL', 'B_V', 'B_P', 'B_D', 'B_N', 'B_PONCT']
token ['<unk>', '<unk>', ',', 'la', 'docteur', '<unk>', 'condamnée', 'le', '1




In [32]:
TP

tensor([0.0000e+00, 1.6430e+03, 0.0000e+00, 1.4342e+05, 2.0000e+01, 0.0000e+00,
        0.0000e+00, 5.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 8.1000e+01,
        3.5000e+01, 3.0000e+00, 0.0000e+00, 0.0000e+00, 7.8000e+01, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])