<a href="https://colab.research.google.com/github/NULabTMN/ps3-Connor-Frazier/blob/dev/ner_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementing a Viterbi Decoder and Evaluation for Sequence Labeling

In this assignment, you will build a Viterbi decoder for an LSTM named-entity recognition model. As we mentioned in class, recurrent and bidirectional recurrent neural networks, of which LSTMs are the most common examples, can be used to perform sequence labeling. Although these models encode information from the surrounding words in order to make predictions, there are no "hard" constraints on what tags can appear where.

There hard constraints are particularly important for tasks that label spans of more than one token. The most common example of a span-labeling task is named-entity recognition (NER). As described in Eisenstein, Jurafksy & Martin, and other texts, the goal of NER is to label spans of one or more words as _mentions_ of an _entity_, such as a person, location, organization, etc.

The most common approach to NER is to reduce it to a sequence-labeling task, where each token in the input is labeled either with an `O`, if it is "outside" any named-entity span, or with `B-TYPE`, if it is the first token in an entity of type `TYPE`, or with `I-TYPE`, if it is the second or later token in an entity of type `TYPE`. Distinguishing between the first and later tokens of an entity allow us to identify distinct entity spans even when they are adjacent.

Common values of `TYPE` include `PER` for person, `LOC` for location, `DATE` for date, and so on. In the dataset we load below, there are 17 distinct types.

The span-labeling scheme just described implies that the labels on tokens must obey certain constraints: the tag `I-PER` must follow either `B-PER` or another `I-PER`. I cannot follow `O`, `B-LOC`, or `I-LOC`, i.e., a tag for a different entity type. By themselves, LSTMs or bidirectional LSTMs cannot directly enforce these constraints. This is one reason why conditional random fields (CRFs), which _can_ enforce these constraints, are often layered on top of these recurrent models.

In this assignment, you will implement the simplest possible CRF: a CRF so simple that it does not require any training. Rather, it will assign weight 1 to any sequence of tags that obeys the constraints and weight 0 to any sequence of tags that violates them. The inputs to the CRF, which are analogous to the emission probabilities in an HMM, will come from an LSTM.

But first, in order to test your decoder, you will also implement some functions to evaluate the output of an NER system according to two metrics:
1. You will count the number of _violations_ of the NER label constraints, i.e., how many times `I-TYPE` follows `O` or a tag of a different type. This number will be greater than 0 in the raw LSTM output, but should be 0 for your CRF output.
1. You will compute the _span-level_ precision, recall, and F1 of NER output. Although the baseline LSTM was trained to achieve high _token-level_ accuracy, this metric can be misleadingly high, since so many tokens are correctly labeled `O`. In other words, what proportion of spans predicted by the model line up exactly with spans in the gold standard, and what proportion of spans in the gold standard were predicted by the model? For more, see the original task definition: https://www.aclweb.org/anthology/W03-0419/.

We start with loading some code and data and the describe your tasks in more detail.

## Set Up Dependencies and Definitions

In [1]:
!pip install --upgrade spacy allennlp
import spacy
print(spacy.__version__)

Requirement already up-to-date: spacy in /usr/local/lib/python3.6/dist-packages (2.2.4)
Collecting allennlp
[?25l  Downloading https://files.pythonhosted.org/packages/bb/bb/041115d8bad1447080e5d1e30097c95e4b66e36074277afce8620a61cee3/allennlp-0.9.0-py3-none-any.whl (7.6MB)
[K     |████████████████████████████████| 7.6MB 25.7MB/s 
Collecting tensorboardX>=1.2
[?25l  Downloading https://files.pythonhosted.org/packages/35/f1/5843425495765c8c2dd0784a851a93ef204d314fc87bcc2bbb9f662a3ad1/tensorboardX-2.0-py2.py3-none-any.whl (195kB)
[K     |████████████████████████████████| 204kB 67.1MB/s 
[?25hCollecting numpydoc>=0.8.0
  Downloading https://files.pythonhosted.org/packages/b0/70/4d8c3f9f6783a57ac9cc7a076e5610c0cc4a96af543cafc9247ac307fbfe/numpydoc-0.9.2.tar.gz
Collecting pytorch-transformers==1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/50/89/ad0d6bb932d0a51793eaabcf1617a36ff530dc9ab9e38f765a35dc293306/pytorch_transformers-1.1.0-py3-none-any.whl (158kB)
[K     |██

In [2]:
from typing import Iterator, List, Dict
import torch
import torch.optim as optim
import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor
from allennlp.data.dataset_readers import conll2003

torch.manual_seed(1)

<torch._C.Generator at 0x7fdccd5de430>

In [0]:
class LstmTagger(Model):
  def __init__(self,
               word_embeddings: TextFieldEmbedder,
               encoder: Seq2SeqEncoder,
               vocab: Vocabulary) -> None:
    super().__init__(vocab)
    self.word_embeddings = word_embeddings
    self.encoder = encoder
    self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                      out_features=vocab.get_vocab_size('labels'))
    self.accuracy = CategoricalAccuracy()

  def forward(self,
              tokens: Dict[str, torch.Tensor],
              metadata,
              tags: torch.Tensor = None) -> Dict[str, torch.Tensor]:
    mask = get_text_field_mask(tokens)
    embeddings = self.word_embeddings(tokens)
    encoder_out = self.encoder(embeddings, mask)
    tag_logits = self.hidden2tag(encoder_out)
    output = {"tag_logits": tag_logits}
    if tags is not None:
      self.accuracy(tag_logits, tags, mask)
      output["loss"] = sequence_cross_entropy_with_logits(tag_logits, tags, mask)

    return output

  def get_metrics(self, reset: bool = False) -> Dict[str, float]:
    return {"accuracy": self.accuracy.get_metric(reset)}

## Import Data

In [4]:
reader = conll2003.Conll2003DatasetReader()
train_dataset = reader.read(cached_path('http://www.ccs.neu.edu/home/dasmith/onto.train.ner.sample'))
validation_dataset = reader.read(cached_path('http://www.ccs.neu.edu/home/dasmith/onto.development.ner.sample'))

vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

159377B [00:00, 748906.67B/s]
562it [00:00, 5129.56it/s]
8366B [00:00, 13301572.12B/s]
23it [00:00, 4306.84it/s]
100%|██████████| 585/585 [00:00<00:00, 48176.31it/s]


## Define and Train Model

In [5]:
EMBEDDING_DIM = 6
HIDDEN_DIM = 6
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, bidirectional=False, batch_first=True))
model = LstmTagger(word_embeddings, lstm, vocab)
if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
else:
    cuda_device = -1
# optimizer = optim.AdamW(model.parameters(), lr=1e-4, eps=1e-8)
optimizer = optim.SGD(model.parameters(), lr=0.1)
iterator = BucketIterator(batch_size=2, sorting_keys=[("tokens", "num_tokens")])
iterator.index_with(vocab)
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=100,
                  cuda_device=cuda_device)
trainer.train()

accuracy: 0.8442, loss: 0.9097 ||: 100%|██████████| 281/281 [00:01<00:00, 176.65it/s]
accuracy: 0.7878, loss: 1.1954 ||: 100%|██████████| 12/12 [00:00<00:00, 455.30it/s]
accuracy: 0.8442, loss: 0.7293 ||: 100%|██████████| 281/281 [00:01<00:00, 257.78it/s]
accuracy: 0.7878, loss: 1.2064 ||: 100%|██████████| 12/12 [00:00<00:00, 475.16it/s]
accuracy: 0.8442, loss: 0.7160 ||: 100%|██████████| 281/281 [00:01<00:00, 252.62it/s]
accuracy: 0.7878, loss: 1.1757 ||: 100%|██████████| 12/12 [00:00<00:00, 418.33it/s]
accuracy: 0.8442, loss: 0.7072 ||: 100%|██████████| 281/281 [00:01<00:00, 261.38it/s]
accuracy: 0.7878, loss: 1.1750 ||: 100%|██████████| 12/12 [00:00<00:00, 456.92it/s]
accuracy: 0.8442, loss: 0.6984 ||: 100%|██████████| 281/281 [00:01<00:00, 259.26it/s]
accuracy: 0.7878, loss: 1.1539 ||: 100%|██████████| 12/12 [00:00<00:00, 434.61it/s]
accuracy: 0.8442, loss: 0.6914 ||: 100%|██████████| 281/281 [00:01<00:00, 259.17it/s]
accuracy: 0.7878, loss: 1.1573 ||: 100%|██████████| 12/12 [00:00

{'best_epoch': 98,
 'best_validation_accuracy': 0.8693877551020408,
 'best_validation_loss': 0.4001997278537601,
 'epoch': 99,
 'peak_cpu_memory_MB': 2577.98,
 'peak_gpu_0_memory_MB': 541,
 'training_accuracy': 0.9232072873636268,
 'training_cpu_memory_MB': 2577.98,
 'training_duration': '0:01:58.214610',
 'training_epochs': 99,
 'training_gpu_0_memory_MB': 541,
 'training_loss': 0.2016096050325728,
 'training_start_epoch': 0,
 'validation_accuracy': 0.8673469387755102,
 'validation_loss': 0.41113841835370596}

## Evaluation

The simple code below creators a `predictor` object that applies the model to an input example and then loops over the examples in the validation set, printing out the input token, gold-standard output, and model output. You can see from these methods how to access data and model outputs for evaluation.

In [0]:
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)

def tag_sentence(s):
  tag_ids = np.argmax(predictor.predict_instance(s)['tag_logits'], axis=-1)
  fields = zip(s['tokens'], s['tags'], [model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])
  return list(fields)

baseline_output = [tag_sentence(i) for i in validation_dataset]

Now, you can implement two evaluation functions: `violations` and `span_stats`.

In [0]:
# TODO: count the number of NER label violations,
# such as O followed by I-TYPE or B-TYPE followed by
# I-OTHER_TYPE
# I-type can not be followed by another I-type

# Valid moves:
# From O, can go to O or B-type
# From B-type can go to O or same I-type only or other B-type
# From I-type can go to same I-type or 0

# Invalid Move edge cases
# I-type can not be first tag

# Take tagger output as input
def violations(tagged):
  # Initialize violations count
  count = 0
  # Loop trhugh each sentence
  for sentence in tagged:
    # Loop through each word in the sentence up to the last word
    for i in range(len(sentence) - 1):
      # Compare the current ith word to the next word, add 1 to the violations count if there is one

      # If the first word in the sentence gets an I-TYPE tag, there is a violation
      if i == 0 and sentence[i][2] == 'I':
        count += 1

      # If the the next word after an O tag gets an I-TYPE tag, there is a violation
      if sentence[i][2] == 'O' and sentence[i+1][2][0] == 'I':
        count += 1

      # If the next word after a B-TYPE tag gets an I-TYPE tag of a different type, there is a violation
      if sentence[i][2][0] == 'B' and sentence[i + 1][2][0] == 'I' and sentence[i][2][2:] != sentence[i + 1][2][2:]:
        count += 1

      # If the next word after an I-TYPE tag gets an I-TYPE tag of a different tyoe, there is a violation
      if sentence[i][2][0] == 'I' and sentence[i + 1][2] == 'I' and sentence[i][2][2:] != sentence[i + 1][2][2:]:
        count += 1

  return count

# TODO: return the span-level precision, recall, and F1
# Take tagger output as input
def span_stats(tagged):

  # Definitions
  # true  postive = correct span that matches indexes as well
  # recall = true positive / true positve + false negative  = # of correct matched spans  / true number of spans
  # precision = true positive / true positve + false positive  = # of correct matched spans / number of spans guessed
  # f measure 2rp / (r+ p)

  # Find all the true spans with type and indices
  true_spans = []
  # Loop through all the sentences
  for sentence in tagged:
    # Initialie variables for finding spans
    start_span = False
    start_index = 0
    end_index = 0
    span_type = ""
    # Loop through each word
    for word in sentence:
      # Start the span tracking
      if word[1][0] == 'B' and start_span == False:
        start_span = True
        start_index = sentence.index(word)
        span_type = word[1][2:]
        # End the span tracking
      if (word[1][0] == 'O' or word[1][0] == 'B') and start_span == True:
        start_span = False
        end_index = sentence.index(word) - 1
        true_spans.append((span_type, start_index, end_index))
        # If span ended with a new one, restart
        if word[1][0] == 'B':
          start_span = True
          start_index = sentence.index(word)
          span_type = word[1][2:]

  # Find all the guessed spans with type and indices
  guessed_spans = []
  # Loop through all the sentences
  for sentence in tagged:
    # Initialie variables for finding spans
    start_span = False
    start_index = 0
    end_index = 0
    span_type = ""
    # Loop through each word
    for word in sentence:
      # Start the span tracking
      if word[2][0] == 'B' and start_span == False:
        start_span = True
        start_index = sentence.index(word)
        span_type = word[2][2:]
        # End the span tracking
      if (word[2][0] == 'O' or word[2][0] == 'B') and start_span == True:
        start_span = False
        end_index = sentence.index(word) - 1
        # If span ended with a new one, restart
        guessed_spans.append((span_type, start_index, end_index))
        if word[2][0] == 'B':
          start_span = True
          start_index = sentence.index(word)
          span_type = word[2][2:]
  
  #Find the matching spans between the true spans and the guessed spans
  spans_matched_count = 0  
  for span in true_spans:
    for guessed_span in guessed_spans:
      if span[0] == guessed_span[0] and span[1] == guessed_span[1] and span[2] == guessed_span[2]:
        spans_matched_count += 1

  #Calculate Metrics    
  recall = spans_matched_count/len(true_spans)
  precision = spans_matched_count/len(guessed_spans)
  f1 = 2*recall*precision / (recall + precision)        

  return {'precision': precision,
          'recall': recall,
          'f1': f1}

## You can check how many violations are made by the model output in predictor.
# print(violations(baseline_output))
# print(span_stats(baseline_output))

## Decoding

Now you can finally implement the simple Viterbi decoder. The `predictor` object, when applied to an input sentence, first calculates the scores for each possible output tag for each token. See the line `predictor.predict_instance(i)['tag_logits']` in the code above.

Then, you will construct a transition matrix. You can use the code below to get a list of the tags the model knows about. For a set of K tags, construct a K-by-K matrix with a log(1)=0 when a transition between a given tag pair is valid and a log(0)=-infinity otherwise.

Finally, implement a Viterbi decoder that takes the predictor object and a dataset object and outputs tagged data, just like the `tag_sentence` function above. It should use the Viterbi algorithm with the (max, plus) semiring. You'll be working with sums of log probabilities instead of products of probabilties.

Run your `violations` function on the output of this decoder to make sure that there are no invalid tag transitions. Also, compare the span-level metrics on `baseline_output` and your new output using your `span_stats` function.

In [8]:
# This code show how to map from output vector components to labels
# print(vocab.get_index_to_token_vocabulary('labels'))
#Get tags
tags = vocab.get_index_to_token_vocabulary('labels')

#valid moves:
#From O, can go to O or B-TYPE
#From B-type can go to O or I-type only or to B-type
#From I-type can go to same I-TYPE or 0

#K-by-K matrix with a log(1)=0 when a transition between a given tag pair is valid and a log(0)=-infinity otherwise.

#Create transition matrix, all transitions start as invalid
transitions = np.zeros( (len(tags), len(tags)) )
for r in range(len(tags)):
  for c in range(len(tags)):
    transitions[r][c] = np.log2(0)

#Change matrix values for valid moves
#Oriented so transition is from row tag to column tag
#Rows and columns represents all tags
for r in range(len(tags)):
  for c in range(len(tags)):
    # O tag to other tags
    if tags[r] == 'O':
      if tags[c] == 'O':
        transitions[r][c] = np.log2(1)
      if tags[c][0] == 'B':
          transitions[r][c] = np.log2(1)

    # B-TYPE to other tags
    if tags[r][0] == 'B':
      if tags[c][0] == 'I':
          if tags[r][2:] == tags[c][2:]:
            transitions[r][c] = np.log2(1)        
      if tags[c] == 'O':
        transitions[r][c] = np.log2(1)
      if tags[c][0] == 'B':
        transitions[r][c] = np.log2(1)

    # I-TYPE to other tags
    if tags[r][0] == 'I':
      if tags[c][0] == 'I':
        if tags[r][2:] == tags[c][2:]:
          transitions[r][c] = np.log2(1)
      if tags[c] == 'O':
        transitions[r][c] = np.log2(1)


  


In [9]:
# Get the predictor object
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)

# Tag the sentence with the viterbi decoder
def viterbi_tag_sentence(s):

  # List of the correct tokens in the sentence
  tokens = list(s['tokens'])

  # List of lists that holds the scores of all tags across all tokens
  word_tag_probabilities = predictor.predict_instance(s)['tag_logits'] 

  #Matrices used to hold scores of nodes and associated previous node
  #rows are tags, columns are the words
  sentence_scores = np.zeros((len(tags), len(tokens)))
  sentence_tags = np.zeros((len(tags), len(tokens))) #place to hold preious tags??

  # Set the first word tag nodes to their initial probabiliites
  for i in range(len(tags)):
    sentence_scores[i][0] = word_tag_probabilities[0][i]

  # Loop through each token from 2 to n
  for k in range(len(tokens) - 1):
    # Adjust the index
    k = k + 1
    # Loop trhough tags for the current word
    for j in range(len(tags)):
      # Variables to hold best score for this tag and the node it came from
      max_tag_score = 0
      best_prev_tag = 0
      # Loop through the previous word tags
      for i in range(len(tags)):
        #Calculate the jth tag score for the current word with the previous word's ith tag
        current_score = sentence_scores[i, k - 1] + transitions[i, j] + word_tag_probabilities[k][j]
        # Update the max score and previous node if this score is better
        if current_score > max_tag_score:
          max_tag_score = current_score
          best_prev_tag = i
      
      #filling jth tag at word k with best possible path from the kth-1 word  with ith tag
      sentence_scores[j][k] = max_tag_score   
      #Saving the ith tag as the node this score came from 
      sentence_tags[j][k] = best_prev_tag

  
  # List to hold the best sentence tags
  final_sentence_tags = [0 for i in range(len(tokens))]
  # Variables to hold the last token score and the tag
  last_tag = 0
  last_tag_score = sentence_scores[last_tag][len(tokens) - 1]
  # Find the best tag for the last token
  for i in range(len(tags)):
    if sentence_scores[i][len(tokens) - 1] > last_tag_score:
      last_tag_score = sentence_scores[i][len(tokens) - 1]
      last_tag = i
  # Set the last token's tag as the one with the highest score
  final_sentence_tags[len(tokens) - 1] = last_tag 

  # Backtrack trhough the tags matrix from the second to last word to the first word
  for i in range(len(tokens) - 1):
    # Fix the index value for backtracking
    index = len(tokens) - i - 2
    # This word's tag is the one marked in tags matrix for the word that comes after it (reading the sentence left to right)
    final_sentence_tags[index] = sentence_tags[int(last_tag)][index + 1]
    # Update the last tag to the tag we just set, to be used for the next word
    last_tag = sentence_tags[int(last_tag)][index + 1]

  # Return the togs, correct tags, and viterbi decodert tags in a list
  fields = zip(s['tokens'], s['tags'], [model.vocab.get_token_from_index(i, 'labels') for i in final_sentence_tags])
  return list(fields)

# Using viterbi decoder, get the output
viterbi_output = [viterbi_tag_sentence(i) for i in validation_dataset]


#Calculate violations and statistics of viterbit output and baseline output
#Both are outputs are from the validation dataset
viterbi_output_violations = violations(viterbi_output)
baseline_output_violations = violations(baseline_output)

viterbi_output_stats = span_stats(viterbi_output)
baseline_output_stats = span_stats(baseline_output)


print("Baseline")
print("Violations: " + str(baseline_output_violations))
print(baseline_output_stats)
print("\n")
print("Viterbi")
print("Violations: " + str(viterbi_output_violations))
print(viterbi_output_stats)




Baseline
Violations: 29
{'precision': 0.6111111111111112, 'recall': 0.38823529411764707, 'f1': 0.47482014388489213}


Viterbi
Violations: 0
{'precision': 0.618421052631579, 'recall': 0.5529411764705883, 'f1': 0.5838509316770186}
