# Pyramid NER

This is an implementation of "*Pyramid: A Layered Model for Nested Named Entity Recognition*" by Jue Wang, Lidan Shou, Ke Chen and Gang Chen.

Paper: https://www.aclweb.org/anthology/2020.acl-main.525

**Note**: Although it is not mandatory, it is almost impossible to train this model in a CPU. You will need to run the following notebook in an environment with a GPU. In addition to this, use a bigger LM model depending on the size of the GPU memory (see variable `LM_NAME`).

## Load libraries & constants

Install additional libraries:
`pip install transformers`

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, BertModel

from sklearn.metrics import precision_score, recall_score

from collections import defaultdict
import xml.etree.ElementTree as ET
import numpy as np
import re
import os
import json
import copy

np.random.seed(42)

In [None]:
device = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

torch.cuda.empty_cache()

In [None]:
LM_NAME = 'google/bert_uncased_L-4_H-512_A-8'
LM_DIM = 512
TOTAL_LAYERS = 8

# The authors implementation is:
# LM_NAME = 'dmis-lab/biobert-large-cased-v1.1'
# LM_DIM = 1024
# TOTAL_LAYERS = 16

# Other Language Models:
# 'google/bert_uncased_L-4_H-512_A-8'
# 'bert-base-uncased'
# 'dmis-lab/biobert-large-cased-v1.1'

In [None]:
# Download and save model
save_path = '../artifacts/%s/' % LM_NAME
if not os.path.exists(save_path):
    os.makedirs(save_path)

model = BertModel.from_pretrained(LM_NAME)
model.save_pretrained(save_path)

slow_tokenizer = BertTokenizer.from_pretrained(LM_NAME)
slow_tokenizer.save_pretrained(save_path)

## Preprocess data

### General functions

The input data is a file which looks like this:

```
{
  "tokens": ["token0", "token1", "token2"],
  "entities": [
    {
      "entity_type": "PER", 
      "span": [0, 1],
    },
    {
      "entity_type": "ORG", 
      "span": [2, 3],
    },
  ]
}
```

Now, it is necessary to encode the entities as outputs of each layer.

In [None]:
train_file = '../data/train.genia.json'
valid_file = '../data/valid.genia.json'
test_file = '../data/test.genia.json'

with open(train_file, 'r') as fp:
    train_dataset = json.load(fp)
with open(valid_file, 'r') as fp:
    dev_dataset = json.load(fp)
with open(test_file, 'r') as fp:
    test_dataset = json.load(fp)

In [None]:
print('Train dataset size: %d' % len(train_dataset))
print('Dev dataset size: %d' % len(dev_dataset))
print('Test dataset size: %d' % len(test_dataset))

In [None]:
def add_layer_outputs(dataset, total_layers=16, entity_dict=None):
    if entity_dict is None:
        init_entity_dict = True
        entity_dict = {'O': 0}
    else:
        init_entity_dict = False
    
    # Create dictionary of entities
    for i, item in enumerate(dataset):
        layer_seq = [0] * len(item['tokens'])
        dataset[i]['layer_outputs'] = [layer_seq[i_layer:] for i_layer in range(total_layers)]
        if init_entity_dict:
            for entity in item['entities']:
                entity_type = entity['entity_type']
                b_entity_type = 'B-%s' % entity_type
                i_entity_type = 'I-%s' % entity_type
                if b_entity_type not in entity_dict:
                    entity_dict[b_entity_type] = len(entity_dict)
                if i_entity_type not in entity_dict:
                    entity_dict[i_entity_type] = len(entity_dict)

    # Generate outputs of each layer
    last_layer = total_layers - 1
    for item in dataset:
        for entity in item['entities']:
            span_start = entity['span'][0]
            span_end = entity['span'][1]
            b_type_id = entity_dict['B-%s' % entity['entity_type']]
            i_type_id = entity_dict['I-%s' % entity['entity_type']]

            length = span_end - span_start
            if length >= total_layers:
                item['layer_outputs'][last_layer][span_start] = b_type_id
                item['layer_outputs'][last_layer][span_start+1:span_end+1] = [i_type_id]*(length)
            else:
                item['layer_outputs'][length][span_start] = b_type_id
    
    return dataset, entity_dict

In [None]:
train_dataset, entity_dict = add_layer_outputs(train_dataset, total_layers=TOTAL_LAYERS)

In [None]:
dev_dataset, _ = add_layer_outputs(dev_dataset, total_layers=TOTAL_LAYERS, entity_dict=entity_dict)

In [None]:
test_dataset, _ = add_layer_outputs(test_dataset, total_layers=TOTAL_LAYERS, entity_dict=entity_dict)

In [None]:
entity_idx = [x for x in list(entity_dict.keys())]

## Build datasets & vocabularies

Once we have preprocessed the data, we can build the datasets (for Pytorch) and the vocabularies of words and characters.

### GloVe embeddings & words vocabulary

Download 100-dimensional GloVe word embeddings trained on 6B tokens. These embeddings contains $400K$ uncased tokens. Place the file `glove.6B.100d.txt` in the folder `./data/embeddings/`. Source: https://nlp.stanford.edu/projects/glove/

In [None]:
def load_embedding_matrix(filepath, dimension, special_tokens):
    id2word = []
    word2id = {}
    
    # Read GloVe vectors
    glove = {}
    with open(filepath, 'rb') as f:
        for line in f:
            values = line.decode().split()
            word = values[0]
            
            word2id[word] = len(id2word)
            id2word.append(word)
            glove[word] = values[1:]

    # Build embedding matrix
    embedding_matrix = np.zeros((len(glove) + len(special_tokens), dimension), dtype=np.float)
    
    for idx, word in enumerate(id2word):
        embedding_matrix[idx] = np.asarray(glove[word], dtype=np.float)
    
    # Add special tokens and randomly initialize them
    for special_token in special_tokens:
        token_id = len(id2word)
        word2id[special_token] = token_id
        id2word.append(special_token)
        embedding_matrix[token_id] = np.random.normal(size=dimension)
    
    return embedding_matrix, id2word, word2id

In [None]:
GLOVE_FILE = '../data/glove.6B.100d.txt'
WORD_DIM = 100
special_tokens = ['[UNK]', '[PAD]', '[CLS]', '[SEP]', '[MASK]']

embedding_matrix, id2word, word2id = load_embedding_matrix(GLOVE_FILE, WORD_DIM, special_tokens)

In [None]:
# Sample embedding of token "the"
token_id = word2id['the']
embedding_matrix[token_id]

### Build char vocabulary

In [None]:
def build_char_vocab(genia_data, lower_case=False, special_tokens=[]):
    id2char = []
    for item in genia_data:
        if lower_case:
            item_chars = [x.lower() for x in ''.join(item['tokens'])]
        else:
            item_chars = [x for x in ''.join(item['tokens'])]
        id2char += item_chars
    
    # Remove duplicates and generate inverse dictionary
    id2char = list(set(id2char))
    char2id = {x:i for i, x in enumerate(id2char)}
    
    # Add special tokens
    for special_token in special_tokens:
        token_id = len(id2char)
        char2id[special_token] = token_id
        id2char.append(special_token)
    
    return id2char, char2id

In [None]:
id2char, char2id = build_char_vocab(train_dataset, special_tokens=special_tokens)

## Tokenizers/encoders

We need three different encoders for the inputs: words, chars and Language Model.

### Word-level & Char-level tokenizers

I use the Spacy tokenizer, which it gives the same tokens as GloVe. This library not only tokenizes texts, but also provides additional information, such as the span position of each token. <u>Note that LM (e.g. BERT) will use a different tokenizer!</u>

In [None]:
def get_tokenizer(artifacts_path='../artifacts/', lm_name='dmis-lab/biobert-large-cased-v1.1'):
    save_path = '%s%s/' % (artifacts_path, lm_name)
    tokenizer = BertWordPieceTokenizer('%svocab.txt' % save_path, lowercase=True)
    return tokenizer

In [None]:
def tokenize_text(tokenizer, text, lower=True):
    if lower:
        text = text.lower()
    
    encoded = tokenizer.encode(text)
    
    tokens = encoded.tokens[1:-1]
    spans = encoded.offsets[1:-1]
    
    spans = [[x[0], x[1]-1] for x in spans]
    
    i = len(tokens)
    while i >= 0:
        i -= 1
        if re.search(r"^##.+", tokens[i]):
            token = tokens[i][2:]
            tokens[i-1] += token
            spans[i-1][1] += len(token)
            del tokens[i]
            del spans[i]

    return tokens, spans

In [None]:
class WordInput:
    def __init__(self, word2id, uncased=False):
        self.uncased = uncased
        self.word2id = word2id
        
    def encode(self, tok_text, padding_length=512, unk='[UNK]', pad='[PAD]'):
        if self.uncased:
            tok_text = [w.lower() for w in tok_text]

        input_ids = [self.word2id[word] if word in self.word2id else self.word2id[unk] for word in tok_text]

        # Pad if necessary
        add_padding = padding_length - len(input_ids)
        if add_padding > 0:
            pad_id = self.word2id[pad]
            input_ids = input_ids + ([pad_id] * add_padding)
        elif add_padding < 0:
            raise Exception('(Words) Text too long (%d / %d):' % (len(input_ids), padding_length), tok_text)

        return input_ids

In [None]:
class CharInput:
    def __init__(self, char2id, uncased=False):
        self.uncased = uncased
        self.char2id = char2id

    def encode(self, tok_text, padding_length=512, char_padding=60, unk='[UNK]', pad='[PAD]'):
        if self.uncased:
            tok_text = [w.lower() for w in tok_text]

        input_ids = []
        
        for token in tok_text:
            char_ids = [self.char2id[char] if char in self.char2id else self.char2id[unk] for char in token]

            # Pad char list if necessary
            add_padding = char_padding - len(char_ids)
            if add_padding > 0:
                pad_id = self.char2id[pad]
                char_ids = char_ids + ([pad_id] * add_padding)
            elif add_padding < 0:
                raise Exception('(Chars-1) Text too long (%d / %d):' % (len(char_ids), char_padding), text)

            input_ids.append(char_ids)

        # Pad token list if necessary
        add_padding = padding_length - len(input_ids)
        if add_padding > 0:
            pad_id = self.char2id[pad]
            input_ids = input_ids + ([[pad_id] * char_padding] * add_padding)
        elif add_padding < 0:
            raise Exception('(Chars-2) Text too long (%d / %d):' % (len(input_ids), padding_length), tok_text)
        
        return input_ids

In [None]:
# Create tokenizer
tokenizer = get_tokenizer(lm_name=LM_NAME)

In [None]:
# Sample of word-level encoding
word_input = WordInput(word2id)

text = 'This is Dimas\' car, it\'s blue.'
text, _ = tokenize_text(tokenizer, text)
word_input.encode(text, padding_length=15)

In [None]:
# Sample of char-level encoding
char_input = CharInput(char2id)

text = 'This is Dimas\' car, it\'s blue.'
text, _ = tokenize_text(tokenizer, text)
char_input.encode(text, padding_length=12, char_padding=7)

#### BERT Tokenizer

BERT expects 3 different inputs:
- Token IDs: the tokens transformed into numbers.
- Attention Mask: sequence of `0` (if there are PAD tokens in that position) and `1` (otherwise).
- Segments or Type IDs: sequence of `0` and `1` to distinguish between the first and the second sentence in NSP tasks. In this notebook, we do not need this input, so it will be always `0`.

For example:
```
Text:       Is this jacksonville? Yes, it is.
---------------------------------------------------------------------------------
Tokens:     [CLS] Is    this  ja    ##cks ##on  ##ville ?   [SEP] Yes   ,    it   is   .    [SEP] [PAD] [PAD] ...
Token IDs:  101   12034 10531 10201 18676 10263 12043   136 102   2160  117  1122 1110 119  102   100   100   ...
Mask:       1     1     1     1     1     1     1       1   1     1     1    1    1    1    1     0     0     ...
Type IDs:   0     0     0     0     0     0     0       0   0     1     1    1    1    1    1     0     0     ...
```

Note: previous token IDs are just an example, so real token IDs might be different.

For further details, read BERT paper: https://arxiv.org/pdf/1810.04805.pdf

In [None]:
class BertInput:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def encode(self, tok_text, padding_length=512):
        tok_text = ['[CLS]'] + tok_text + ['[SEP]']
        
        # Encode context (token IDs, mask and token types)
        token_spans = [1]
        input_ids = []
        type_ids = []
        attention_mask = []
        for token in tok_text:
            encoded_text = self.tokenizer.encode(token)

            # Create inputs
            span = len(encoded_text.ids[1:-1])
            token_spans += [span] + [0] * (span - 1)
            input_ids += encoded_text.ids[1:-1]
            type_ids += encoded_text.type_ids[1:-1]
            attention_mask += encoded_text.attention_mask[1:-1]
        
        if len(input_ids) > padding_length:
            raise Exception('(BERT) Text too long (%d / %d): %s' % (len(input_ids), padding_length, text))

        # Pad if necessary. Note that "100" is the ID of the token "[PAD]" in BERT.
        add_padding = padding_length - len(input_ids)
        if add_padding > 0:
            input_ids = input_ids + ([100] * add_padding)
            attention_mask = attention_mask + ([0] * add_padding)
            type_ids = type_ids + ([0] * add_padding)
        
        token_spans += ([1] * (padding_length - len(token_spans)))

        # BERT inputs must be as follows: input_ids, attention_mask, token_type_ids
        return [input_ids, attention_mask, type_ids], token_spans

In [None]:
bert_input = BertInput(tokenizer)

text = 'Hellowing'
text, _ = tokenize_text(tokenizer, text)
bert_input.encode(text, padding_length=10)

### Dataloader

Pytorch has its own dataloader.

In [None]:
class NestedNamedEntitiesDataset(Dataset):    
    def __init__(self, data, word_input, char_input, lm_input, padding_length=512,
                 total_layers=16, skip_exceptions=True, max_items=-1, has_outputs=True):
        self.total_layers = total_layers
        self.has_outputs = has_outputs
        
        self.masks = []
        self.X_word = []
        self.X_char = []
        self.X_lm_inputs = []
        self.X_lm_attention = []
        self.X_lm_type_ids = []
        self.X_lm_spans = []
        self.Y_entities = [[] for _ in range(total_layers)]
        self.n_skipped = 0
        
        for i, item in enumerate(data):
            if max_items > 0 and i >= max_items:
                break
            
            try:
                x_word = word_input.encode(item['tokens'], padding_length)
                x_char = char_input.encode(item['tokens'], padding_length)
                x_lm, x_lm_span = lm_input.encode(item['tokens'], padding_length)

                mask = [1.] * len(item['tokens']) + [0.] * (padding_length - len(item['tokens']))
                self.masks.append(mask)
                
                self.X_word.append(x_word)
                self.X_char.append(x_char)

                self.X_lm_inputs.append(x_lm[0])
                self.X_lm_attention.append(x_lm[1])
                self.X_lm_type_ids.append(x_lm[2])
                self.X_lm_spans.append(x_lm_span)
                
                if has_outputs:
                    for i_layer, raw_seq in enumerate(item['layer_outputs']):
                        if i_layer >= total_layers:
                            break
                        padded_seq = raw_seq if len(raw_seq) > 0 else []
                        if len(padded_seq) < padding_length - i_layer:
                            padded_seq = padded_seq + [0] * (padding_length - len(padded_seq) - i_layer)
                        self.Y_entities[i_layer].append(padded_seq)
            except Exception as e:
                # Text-too-long exception
                if skip_exceptions:
                    self.n_skipped += 1
                    continue
                else:
                    raise e
        
        self._init_tensors(padding_length, total_layers)

    def _init_tensors(self, padding_length, total_layers):
        self.masks = torch.tensor(self.masks, dtype=torch.float)
        self.X_word = torch.tensor(self.X_word, dtype=torch.long)
        self.X_char = torch.tensor(self.X_char, dtype=torch.long)
        self.X_lm_inputs = torch.tensor(self.X_lm_inputs, dtype=torch.long)
        self.X_lm_attention = torch.tensor(self.X_lm_attention, dtype=torch.long)
        self.X_lm_type_ids = torch.tensor(self.X_lm_type_ids, dtype=torch.long)
        self.X_lm_spans = torch.tensor(self.X_lm_spans, dtype=torch.long)
        if self.has_outputs:
            for i, seqs in enumerate(self.Y_entities):
                self.Y_entities[i] = torch.tensor(seqs, dtype=torch.long)

    def __len__(self):
        return len(self.X_word)

    def __getitem__(self, idx):
        data = {
            'masks': self.masks[idx],
            'x_word': self.X_word[idx],
            'x_char': self.X_char[idx],
            'x_lm_input': self.X_lm_inputs[idx],
            'x_lm_attention': self.X_lm_attention[idx],
            'x_lm_type_ids': self.X_lm_type_ids[idx],
            'x_lm_spans': self.X_lm_spans[idx],
        }
        if self.has_outputs:
            for i_layer in range(self.total_layers):
                data['y_target_%d' % i_layer] = self.Y_entities[i_layer][idx]
        return data

    def get_item(self, idx):
        return self.__getitem__(idx)
    
    def get_n_skipped(self):
        return self.n_skipped

In [None]:
nne_train_dataset = NestedNamedEntitiesDataset(train_dataset, word_input, char_input, bert_input,
                                               total_layers=TOTAL_LAYERS, skip_exceptions=False, max_items=-1)

print('Skipped: %d' % nne_train_dataset.get_n_skipped())

In [None]:
# Sample item
sample_item = nne_train_dataset.get_item(0)

## Define model

As it is described in the paper, here are the hyperparameters to be used:

![../images/hyperparameters.jpg](../images/hyperparameters.jpg)

Note that the authors increased by $0.05$ the dropout rates: "*\[...\] with pre-trained contextualized embeddings, \[...\] we increase the dropout rate by 0.05 for these settings.*"

### Initializers

In [None]:
def init_embeddings(input_embedding):
    bias = np.sqrt(3.0 / input_embedding.size(1))
    nn.init.uniform_(input_embedding, -bias, bias)

In [None]:
def init_linear(input_linear):
    """
    Initialize linear transformation
    """
    bias = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1)))
    nn.init.uniform_(input_linear.weight, -bias, bias)
    if input_linear.bias is not None:
        input_linear.bias.data.zero_()

In [None]:
def init_lstm(input_lstm):
    """
    Initialize lstm
    
    PyTorch weights parameters:
    
        weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
            of shape `(hidden_size * input_size)` for `k = 0`. Otherwise, the shape is
            `(hidden_size * hidden_size)`
            
        weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
            of shape `(hidden_size * hidden_size)`            
    """
    # Weights init for forward layer
    for ind in range(0, input_lstm.num_layers):
        
        ## Gets the weights Tensor from our model, for the input-hidden weights in our current layer
        weight = eval('input_lstm.weight_ih_l' + str(ind))
        
        # Initialize the sampling range
        sampling_range = np.sqrt(6.0 / (weight.size(0) / 4 + weight.size(1)))
        
        # Randomly sample from our samping range using uniform distribution and apply it to our current layer
        nn.init.uniform_(weight, -sampling_range, sampling_range)
        
        # Similar to above but for the hidden-hidden weights of the current layer
        weight = eval('input_lstm.weight_hh_l' + str(ind))
        sampling_range = np.sqrt(6.0 / (weight.size(0) / 4 + weight.size(1)))
        nn.init.uniform_(weight, -sampling_range, sampling_range)
        
        
    # We do the above again, for the backward layer if we are using a bi-directional LSTM (our final model uses this)
    if input_lstm.bidirectional:
        for ind in range(0, input_lstm.num_layers):
            weight = eval('input_lstm.weight_ih_l' + str(ind) + '_reverse')
            sampling_range = np.sqrt(6.0 / (weight.size(0) / 4 + weight.size(1)))
            nn.init.uniform_(weight, -sampling_range, sampling_range)
            weight = eval('input_lstm.weight_hh_l' + str(ind) + '_reverse')
            sampling_range = np.sqrt(6.0 / (weight.size(0) / 4 + weight.size(1)))
            nn.init.uniform_(weight, -sampling_range, sampling_range)

    # Bias initialization steps
    
    # We initialize them to zero except for the forget gate bias, which is initialized to 1
    if input_lstm.bias:
        for ind in range(0, input_lstm.num_layers):
            bias = eval('input_lstm.bias_ih_l' + str(ind))
            
            # Initializing to zero
            bias.data.zero_()
            
            # This is the range of indices for our forget gates for each LSTM cell
            bias.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
            
            #Similar for the hidden-hidden layer
            bias = eval('input_lstm.bias_hh_l' + str(ind))
            bias.data.zero_()
            bias.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
            
        # Similar to above, we do for backward layer if we are using a bi-directional LSTM 
        if input_lstm.bidirectional:
            for ind in range(0, input_lstm.num_layers):
                bias = eval('input_lstm.bias_ih_l' + str(ind) + '_reverse')
                bias.data.zero_()
                bias.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
                bias = eval('input_lstm.bias_hh_l' + str(ind) + '_reverse')
                bias.data.zero_()
                bias.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1

### Encoder

Formally, given an input sentence $x$:

$$
\tilde{x}_{char} = LSTM_{char}(Embed_{char}(x)) \\
\tilde{x}_{word} = Embed_{word}(x) \\
\tilde{x} = LSTM_{enc}([\tilde{x}_{char}; \tilde{x}_{word}])
$$

Then, we concatenate the previous result with the output of the Language Model:
$$\tilde{x} = Linear_{enc}([\tilde{x}; LM(x)])$$

In [None]:
def create_emb_layer(trainable, embedding_matrix=None, shape=None, device=None):
    if embedding_matrix is None:
        emb_layer = nn.Embedding(num_embeddings=shape[0], embedding_dim=shape[1])
        emb_layer.weight.requires_grad = trainable
    else:
        embedding_tensor = torch.FloatTensor(embedding_matrix)
        emb_layer = nn.Embedding.from_pretrained(embedding_tensor, freeze=(not trainable))
    return emb_layer.to(device=device)
    
def create_lm_layer(lm_name, trainable, device=None, artifacts_path='../artifacts'):
    lm_path = '%s/%s/' % (artifacts_path, lm_name)
    bert_model = BertModel.from_pretrained(lm_path)
    if not trainable:
        for param in bert_model.parameters():
            param.requires_grad = False
    return bert_model.to(device=device)

In [None]:
class CharEncoder(nn.Module):
    def __init__(self, char2id, dimension=60, hidden_size=100, device=None):
        super().__init__()
        self.device = device
        
        self.embedding = create_emb_layer(True, shape=(len(char2id), dimension), device=device)
        init_embeddings(self.embedding.weight)
        
        self.lstm = nn.LSTM(input_size=dimension, hidden_size=hidden_size,
                            bidirectional=True, batch_first=True).to(device=device)
        init_lstm(self.lstm)
        
    def forward(self, inputs):
        x = self.embedding(inputs)
        
        outputs = []
        for seq in x:
            x_output, _ = self.lstm(seq)
            outputs.append(x_output[:, -1])

        return torch.stack(outputs).to(device=self.device)

In [None]:
# Sample
char_encoder = CharEncoder(char2id, dimension=60, hidden_size=100, device=device)

sample_item = nne_train_dataset.get_item(0)
sample_x_char = torch.unsqueeze(sample_item['x_char'], 0).to(device=device)
ce_output = char_encoder(sample_x_char)
print(ce_output.shape)

In [None]:
class LMEncoder(nn.Module):
    def __init__(self, lm_name, device=None):
        super().__init__()
        self.device = device
        self.lm_layer = create_lm_layer(lm_name, False, device=device)
    
    def _create_lm_layer(self, lm_name, trainable):
        bert_model = BertModel.from_pretrained(lm_name)
        if not trainable:
            for param in bert_model.parameters():
                param.requires_grad = False
        
        return bert_model.to(device=self.device)

    def forward(self, inputs, attention, type_ids, lm_spans, masks):
        x_lm = self.lm_layer(input_ids=inputs, attention_mask=attention, token_type_ids=type_ids,
                             output_hidden_states=True)
        x_lm = torch.stack(x_lm[2][-4:])
        x_lm = torch.mean(x_lm, dim=0)

        x = torch.zeros(size=x_lm.size(), device=self.device)
        for seq_i, seq_span in enumerate(lm_spans):
            mask_length = masks[seq_i].sum()
            
            for token_i, span in enumerate(seq_span):
                if token_i >= mask_length - 1:
                    # Skips from SEP token
                    break
                elif token_i == 0:
                    # Skips CLS token
                    continue
                
                token_i -= 1
                for k in range(span):
                    x[seq_i, token_i] = x[seq_i, token_i].add(x_lm[seq_i, token_i+k+1])
                x[seq_i, token_i].div(span)
        
        return x

In [None]:
# Sample
encoder_layer = LMEncoder(LM_NAME, device=device)

sample_data = nne_train_dataset.get_item(0)
masks = torch.unsqueeze(sample_data['masks'], 0).to(device=device)
x_lm_inputs = torch.unsqueeze(sample_data['x_lm_input'], 0).to(device=device)
x_lm_attention = torch.unsqueeze(sample_data['x_lm_attention'], 0).to(device=device)
x_lm_type_ids = torch.unsqueeze(sample_data['x_lm_type_ids'], 0).to(device=device)
x_lm_span = torch.unsqueeze(sample_data['x_lm_spans'], 0).to(device=device)

out_encoder = encoder_layer(x_lm_inputs, x_lm_attention, x_lm_type_ids, x_lm_span, masks)
out_encoder.size()

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, lm_name, word_embeddings, char2id, char_dimension=60, word_dimension=200,
                 lm_dimension=1024, hidden_size=100, drop_rate=0.45, device=None):
        super(EncoderLayer, self).__init__()
        
        self.device = device

        self.char_encoder = CharEncoder(char2id, char_dimension, hidden_size, device=device)
        self.emb_word = self._create_emb_layer(False, embedding_matrix=word_embeddings, device=device).to(device=device)
        self.dropout = nn.Dropout(drop_rate).to(device=device)
        self.lm_encoder = LMEncoder(lm_name, device=device)

        self.lstm_char = nn.LSTM(input_size=char_dimension, hidden_size=hidden_size,
                                 bidirectional=True, batch_first=True).to(device=device)
        init_lstm(self.lstm_char)
        
        self.lstm_enc = nn.LSTM(input_size=(hidden_size*2 + word_dimension),
                                hidden_size=hidden_size, bidirectional=True, batch_first=True).to(device=device)
        init_lstm(self.lstm_enc)
        
        self.linear = nn.Linear(lm_dimension + hidden_size*2, hidden_size*2).to(device=device)
        init_linear(self.linear)

    def forward(self, input_word, input_char, input_lm, input_lm_attention, input_lm_type_ids,
                input_lm_spans, input_masks):
        x_char = self.char_encoder(input_char)
        x_word = self.emb_word(input_word)
        
        x_enc = torch.cat((x_char, x_word), dim=-1)
        x_enc = self.dropout(x_enc)
        x_enc, _ = self.lstm_enc(x_enc)
        
        x_lm = self.lm_encoder(input_lm, input_lm_attention, input_lm_type_ids, input_lm_spans, input_masks)
        
        x = torch.cat((x_enc, x_lm), dim=2)
        x = self.linear(x)
        
        return x
    
    def _create_emb_layer(self, trainable, embedding_matrix=None, shape=None, device=None):
        if embedding_matrix is None:
            emb_layer = nn.Embedding(num_embeddings=shape[0], embedding_dim=shape[1])
            emb_layer.weight.requires_grad = trainable
        else:
            embedding_tensor = torch.FloatTensor(embedding_matrix).to(device=self.device)
            emb_layer = nn.Embedding.from_pretrained(embedding_tensor, freeze=(not trainable))
        return emb_layer

In [None]:
# Sample
encoder_layer = EncoderLayer(LM_NAME, embedding_matrix, char2id, char_dimension=60,
                             word_dimension=WORD_DIM, lm_dimension=LM_DIM,
                             hidden_size=100, drop_rate=0.45, device=device)

sample_data = nne_train_dataset.get_item(0)
masks = torch.unsqueeze(sample_data['masks'], 0).to(device=device)
x_word = torch.unsqueeze(sample_data['x_word'], 0).to(device=device)
x_char = torch.unsqueeze(sample_data['x_char'], 0).to(device=device)
x_lm_inputs = torch.unsqueeze(sample_data['x_lm_input'], 0).to(device=device)
x_lm_attention = torch.unsqueeze(sample_data['x_lm_attention'], 0).to(device=device)
x_lm_type_ids = torch.unsqueeze(sample_data['x_lm_type_ids'], 0).to(device=device)
x_lm_span = torch.unsqueeze(sample_data['x_lm_spans'], 0).to(device=device)

out_encoder = encoder_layer(x_word, x_char, x_lm_inputs, x_lm_attention, x_lm_type_ids, x_lm_span, masks)
out_encoder.size()

### Pyramid

The first pyramid is composed of several decoding layers:

![../images/decoding_layer.jpg](../images/decoding_layer.jpg)

The paper does not contain information about the hyperparameters of the LSTM layer, so I assume the same hidden space as previous.

Regarding the convolutional layer, the authors say "*\[...\] CNN with a kernel of two \[...\] and the CNN aggregates two adjacent hidden states \[...\]*". I assume that the authors actually mean $kernel=3$.

In [None]:
class DecodingLayer(nn.Module):
    def __init__(self, drop_rate=0.45, seq_length=512, hidden_size=100, device=None):
        super(DecodingLayer, self).__init__()
        self.device = device
        
        self.norm = nn.LayerNorm(normalized_shape=(seq_length, hidden_size*2)).to(device=device)
        self.dropout_1 = nn.Dropout(drop_rate).to(device=device)
        self.dropout_2 = nn.Dropout(drop_rate).to(device=device)
        
        self.lstm = nn.LSTM(input_size=hidden_size*2, hidden_size=hidden_size, bidirectional=True, batch_first=True).to(device=device)
        self.conv = nn.Conv1d(in_channels=hidden_size*2, out_channels=hidden_size*2, kernel_size=2, stride=1).to(device=device)
        
        init_lstm(self.lstm)

    def forward(self, input):
        x = self.norm(input)
        x = self.dropout_1(x)
        
        x, _ = self.lstm(x)
        x = self.dropout_2(x)
        
        h = x
        x = x.transpose(2, 1)
        x = self.conv(x)
        x = x.transpose(2, 1)
        
        return h, x

In [None]:
# Sample
decoding_layer = DecodingLayer(drop_rate=0.45, seq_length=512, hidden_size=100, device=device)

out_decoding = decoding_layer(out_encoder)
print(out_decoding[0].size(), out_decoding[1].size())

In [None]:
class PyramidLayer(nn.Module):
    def __init__(self, total_layers=16, drop_rate=0.45, seq_length=512, hidden_size=100, device=None):
        super(PyramidLayer, self).__init__()
        self.device = device
        self.total_layers = total_layers
        self.seq_length = seq_length
        self.decoding_layers = nn.ModuleList([DecodingLayer(drop_rate=drop_rate, seq_length=seq_length-i,
                                                            hidden_size=hidden_size, device=device) for i in range(total_layers)])
    
    def forward(self, input):
        h = []
        x_layer = input
        for i, layer in enumerate(self.decoding_layers):
            h_layer, x_layer = layer(x_layer)
            h.append(h_layer)
        
        return h

In [None]:
pyramid_layer = PyramidLayer(total_layers=3, drop_rate=0.45, seq_length=512, hidden_size=100, device=device)

In [None]:
out_pyramid = pyramid_layer(out_encoder)

for o in out_pyramid:
    print(o.size())

### Inverse pyramid

The inverse pyramid is composed of several inverse decoding layers:

![../images/inverse_decoding_layer.jpg](../images/inverse_decoding_layer.jpg)

In [None]:
class InverseDecodingLayer(nn.Module):
    def __init__(self, drop_rate=0.45, seq_length=512, hidden_size=100, device=None):
        super(InverseDecodingLayer, self).__init__()
        self.device = device
        
        self.norm = nn.LayerNorm(normalized_shape=(seq_length, hidden_size*2)).to(device=device)
        self.dropout_1 = nn.Dropout(drop_rate).to(device=device)
        self.dropout_2 = nn.Dropout(drop_rate).to(device=device)
        
        self.lstm = nn.LSTM(input_size=hidden_size*2, hidden_size=hidden_size, bidirectional=True, batch_first=True).to(device=device)
        self.conv = nn.Conv1d(in_channels=hidden_size*4, out_channels=hidden_size*2, kernel_size=2, padding=1, stride=1).to(device=device)
        
        init_lstm(self.lstm)

    def forward(self, input_h, input_x):
        x = self.norm(input_x)
        x = self.dropout_1(x)
        
        x, _ = self.lstm(x)
        x = self.dropout_2(x)
        
        x = torch.cat((input_h, x), dim=2)
        
        h = x
        x = x.transpose(2, 1)
        x = self.conv(x)
        x = x.transpose(2, 1)
        
        return h, x

In [None]:
index = len(out_pyramid) - 1

seq_length = 512
hidden_size = 100

idecoding_layer = InverseDecodingLayer(drop_rate=0.45, seq_length=seq_length - index, hidden_size=100, device=device)
h = out_pyramid[index]
x = torch.zeros(h.size()[0], seq_length - index, hidden_size*2).to(device=device)

out_idecoding = idecoding_layer(h, x)
print(out_idecoding[0].size(), out_idecoding[1].size())

In [None]:
def reverse_enumerate(L):
    i = len(L)
    while i > 0:
        i -= 1
        yield i, L[i]

In [None]:
class InversePyramidLayer(nn.Module):
    def __init__(self, total_layers=16, drop_rate=0.45, seq_length=512, hidden_size=100, device=None):
        super(InversePyramidLayer, self).__init__()
        self.device = device
        self.total_layers = total_layers
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.idecoding_layers = nn.ModuleList([InverseDecodingLayer(drop_rate=drop_rate, seq_length=seq_length-i,
                                                                    hidden_size=hidden_size, device=device) for i in range(total_layers)])
    
    def forward(self, input_hs):
        h = []
        batch_size = input_hs[-1].size()[0]
        
        x_pad = torch.zeros(batch_size, 1, self.hidden_size*4).to(device=self.device)
        x_layer = torch.zeros(batch_size,
                              self.seq_length - self.total_layers + 1,
                              self.hidden_size*2).to(device=self.device)
        
        for i, layer in reverse_enumerate(self.idecoding_layers):
            h_layer, x_layer = layer(input_hs[i], x_layer)
            #x_layer = torch.cat((x_pad, x_layer), dim=1)
            h.append(h_layer)
        
        h.reverse()
        return h

In [None]:
ipyramid_layer = InversePyramidLayer(total_layers=3, drop_rate=0.45, seq_length=512, hidden_size=100, device=device)
out_ipyramid = ipyramid_layer(out_pyramid)

for o in out_ipyramid:
    print(o.size())

### Put all together

Finally, we put all the pieces together and add a linear layer to get the class predictions.

In [None]:
class PyramidNet(nn.Module):
    def __init__(self, embedding_matrix, char_vocab, lm_name='dmis-lab/biobert-large-cased-v1.1',
                 total_layers=16, drop_rate=0.45, seq_length=512, char_dimension=60, word_dimension=200,
                 lm_dimension=1024, hidden_size=100, total_classes=10, device=None):
        super(PyramidNet, self).__init__()
        self.device = device
        
        self.encoder_layer = EncoderLayer(
                lm_name, embedding_matrix, char_vocab, char_dimension=char_dimension,
                word_dimension=word_dimension, lm_dimension=lm_dimension, hidden_size=hidden_size,
                drop_rate=drop_rate, device=device)
        
        self.pyramid = PyramidLayer(total_layers=total_layers, drop_rate=drop_rate,
                                    seq_length=seq_length, hidden_size=hidden_size, device=device)
        self.inverse_pyramid = InversePyramidLayer(total_layers=total_layers, drop_rate=drop_rate,
                                                   seq_length=seq_length, hidden_size=hidden_size, device=device)
        self.linear = nn.Linear(hidden_size*4, total_classes).to(device=self.device)
    
    def forward(self, input_word, input_char, input_lm, input_lm_attention,
                input_lm_type_ids, input_lm_spans, input_masks):
        x = self.encoder_layer(input_word, input_char, input_lm, input_lm_attention,
                               input_lm_type_ids, input_lm_spans, input_masks)
        x = self.pyramid(x)
        x = self.inverse_pyramid(x)
        x = [self.linear(x_layer) for x_layer in x]
        return x

In [None]:
# Sample
net = PyramidNet(embedding_matrix, char2id, lm_name=LM_NAME, total_layers=5,
                 drop_rate=0.4, seq_length=512, char_dimension=60, word_dimension=WORD_DIM,
                 lm_dimension=LM_DIM, total_classes=10, device=device)

sample_data = nne_train_dataset.get_item(0)
masks = torch.unsqueeze(sample_data['masks'], 0).to(device=device)
x_word = torch.unsqueeze(sample_data['x_word'], 0).to(device=device)
x_char = torch.unsqueeze(sample_data['x_char'], 0).to(device=device)
x_lm_inputs = torch.unsqueeze(sample_data['x_lm_input'], 0).to(device=device)
x_lm_attention = torch.unsqueeze(sample_data['x_lm_attention'], 0).to(device=device)
x_lm_type_ids = torch.unsqueeze(sample_data['x_lm_type_ids'], 0).to(device=device)
x_lm_span = torch.unsqueeze(sample_data['x_lm_spans'], 0).to(device=device)

out_net = net(x_word, x_char, x_lm_inputs, x_lm_attention, x_lm_type_ids, x_lm_span, masks)
for o in out_net:
    print(o.size())

## Training

For training, we should use inverse time learning rate decay:

$$\widehat{lr} = \frac{lr}{1 + \text{decay_rate} \cdot \text{steps}\ /\ \text{decay_steps}}$$

As stated by the authors, we will use $0.05$ and $1000$ for the decay rate and decay steps respectively. The rest of hyperparameters are described in the Table 2 of the paper.

In [None]:
# Training settings
INITIAL_LR = 0.01
DECAY_RATE = 0.05
DECAY_STEPS = 1000
MOMENTUM = 0.9
BATCH_SIZE = 8 # 64
GRADIENT_CLIP = 5.
EPOCHS = 10 # Not defined in the paper

# For debug purposes, force stop after X steps (-1 if disabled)
STOP_AFTER = -1

# Hyperparameters
DROP_RATE = 0.45
CHAR_DIM = 60
WORD_DIM = 100 # 200
HIDDEN_SIZE = 100
TOTAL_CLASSES = 10

ARTIFACTS_PATH = '../artifacts/'

In [None]:
def adjust_lr(optimizer, step, decay_rate=DECAY_RATE, decay_steps=DECAY_STEPS, inital_lr=INITIAL_LR):
    """
    Ajusts Learnin-Rate using the formula described in the paper
    """
    lr = inital_lr / (1 + decay_rate * step / decay_steps)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
net = PyramidNet(embedding_matrix, char2id, lm_name=LM_NAME, total_layers=TOTAL_LAYERS,
                 drop_rate=DROP_RATE, char_dimension=CHAR_DIM, word_dimension=WORD_DIM,
                 lm_dimension=LM_DIM, hidden_size=HIDDEN_SIZE,
                 total_classes=TOTAL_CLASSES, device=device)

net = net.to(device=device)

In [None]:
# Reminder: this criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class
criterion = nn.CrossEntropyLoss(reduction='none')
optimizer = torch.optim.SGD(net.parameters(), lr=INITIAL_LR, momentum=MOMENTUM)

In [None]:
train_dataloader = DataLoader(nne_train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
history = []
step = 0
n_batches = len(train_dataloader)

print('Starting...', end='\r')

for i_epoch in range(EPOCHS):
    run_loss = 0

    for i_batch, batch_data in enumerate(train_dataloader):
        step += 1

        # Get inputs
        masks = batch_data['masks'].to(device=device)
        x_word = batch_data['x_word'].to(device=device)
        x_char = batch_data['x_char'].to(device=device)
        x_lm_inputs = batch_data['x_lm_input'].to(device=device)
        x_lm_attention = batch_data['x_lm_attention'].to(device=device)
        x_lm_type_ids = batch_data['x_lm_type_ids'].to(device=device)
        x_lm_span = batch_data['x_lm_spans'].to(device=device)
        
        y_all_targets = []
        for i_layer in range(TOTAL_LAYERS):
            y_target = batch_data['y_target_%d' % i_layer].to(device=device)
            y_all_targets.append(y_target)

        # Predict entities
        y_all_preds = net(x_word, x_char, x_lm_inputs, x_lm_attention, x_lm_type_ids, x_lm_span, masks)

        # Compute loss
        loss = 0
        for i_pred, y_pred_logits in enumerate(y_all_preds):
            loss_tensor = criterion(y_pred_logits.permute(0, -1, 1), y_all_targets[i_pred])
            loss += (loss_tensor * masks[:,i_pred:]).sum()

        optimizer.zero_grad()
        loss.backward()
        
        nn.utils.clip_grad_norm_(net.parameters(), GRADIENT_CLIP) # Avoid gradient exploding issue
        optimizer.step()
        
        adjust_lr(optimizer, step)

        run_loss += loss.cpu().data.numpy()
        print("Epoch %d of %d | Batch %d of %d | Loss = %.3f" % (i_epoch + 1, EPOCHS, i_batch + 1, n_batches, run_loss / (i_batch + 1)),
              ' ' * 10,
              end='\r')
        
        if STOP_AFTER != -1 and STOP_AFTER <= step:
            break
        
        # Clear some memory
        if device == 'cuda':
            del masks
            del x_word
            del x_char
            del x_lm_inputs
            del x_lm_attention
            del x_lm_type_ids
            del x_lm_span
            torch.cuda.empty_cache()

    history.append(run_loss / len(train_dataloader))
    print("Epoch %d of %d | Loss = %.3f" % (i_epoch + 1, EPOCHS, run_loss / len(train_dataloader)), ' ' * 20)
    
    if STOP_AFTER != -1 and STOP_AFTER <= step:
        break

In [None]:
# Save model
filepath = '%s%s' % (ARTIFACTS_PATH, 'genia_model.pt')
torch.save(net.state_dict(), filepath)

## Evaluation

### Evaluation per layer

Basic evaluation of the outputs of each layer. This implementation takes the output sequence as a whole (i.e. including all IOB tags), and computes the usual scores (precision, recall and F1-score) comparing those sequences with the real ones.

In [None]:
# Load model
net = PyramidNet(embedding_matrix, char2id, lm_name=LM_NAME, total_layers=TOTAL_LAYERS,
                 drop_rate=DROP_RATE, char_dimension=CHAR_DIM, word_dimension=WORD_DIM,
                 lm_dimension=LM_DIM, hidden_size=HIDDEN_SIZE,
                 total_classes=TOTAL_CLASSES, device=device)
net = net.to(device=device)

filepath = '%s%s' % (ARTIFACTS_PATH, 'genia_model.pt')
net.load_state_dict(torch.load(filepath))
net = net.eval()

In [None]:
nne_test_dataset = NestedNamedEntitiesDataset(test_dataset, word_input, char_input, bert_input,
                                              total_layers=TOTAL_LAYERS, skip_exceptions=False, max_items=-1)

print('Skipped: %d' % nne_train_dataset.get_n_skipped())

In [None]:
test_dataloader = DataLoader(nne_test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
labels = list(range(TOTAL_CLASSES))
precisions = [[] for _ in range(TOTAL_LAYERS)]
recalls = [[] for _ in range(TOTAL_LAYERS)]

for i_batch, batch_data in enumerate(test_dataloader):
    print('Evaluating batch %d out of %d' % (i_batch+1, len(test_dataloader)), end='\r')
    
    # Get inputs
    masks = batch_data['masks'].to(device=device)
    x_word = batch_data['x_word'].to(device=device)
    x_char = batch_data['x_char'].to(device=device)
    x_lm_inputs = batch_data['x_lm_input'].to(device=device)
    x_lm_attention = batch_data['x_lm_attention'].to(device=device)
    x_lm_type_ids = batch_data['x_lm_type_ids'].to(device=device)
    x_lm_span = batch_data['x_lm_spans'].to(device=device)

    y_all_preds = net(x_word, x_char, x_lm_inputs, x_lm_attention, x_lm_type_ids, x_lm_span, masks)
    
    y_all_targets = []
    for i_layer in range(TOTAL_LAYERS):
        y_targets = batch_data['y_target_%d' % i_layer]
        y_preds = y_all_preds[i_layer].cpu().detach()
        
        for (y_target, y_pred, mask) in zip(y_targets, y_preds, masks):
            mask_cut = int(mask[i_layer:].sum().cpu().detach().numpy())
            y_target = y_target.view(-1)[:mask_cut]
            y_pred = torch.argmax(y_pred, dim=-1).view(-1)[:mask_cut]
            
            p_score = precision_score(y_target, y_pred, labels=labels, average='micro')
            precisions[i_layer].append(p_score)

            r_score = recall_score(y_target, y_pred, labels=labels, average='micro')
            recalls[i_layer].append(r_score)
    
    # Clear some memory
    if device == 'cuda':
        del masks
        del x_word
        del x_char
        del x_lm_inputs
        del x_lm_attention
        del x_lm_type_ids
        del x_lm_span
        torch.cuda.empty_cache()

print()
print('Done')

In [None]:
print('SCORES PER LAYER')
for i_layer in range(TOTAL_LAYERS):
    print('- Layer %d - Precision: %.4f - Recall: %.4f' % (i_layer + 1,
                                                           np.mean(precisions[i_layer]),
                                                           np.mean(recalls[i_layer])))

### Overall evaluation

The following evaluation is based on the original author's implementation (see [Github](https://github.com/LorrinWWW/Pyramid/blob/7c63639df7a6fddc19730af98c37be22dec01221/utils/data.py#L103)). It evaluates the usual scores (precision, recall and F1-score) for exact predictions.

In [None]:
def seq2span(seq, return_types=False, entity_idx=None):
    if entity_idx is not None:
        seq = [entity_idx[x] for x in seq]

    spans = []
    types = []
    _span = _type = None
    for i, t in enumerate(seq):
        if (t[0] == 'B' or t == 'O') and _span is not None:
            spans.append(_span)
            types.append(_type)
            _span = _type = None
        if t[0] == 'B':
            _span = [i, i+1]
            _type = t[2:]
        if t[0] == 'I':
            if _span is not None:
                _span[1] = i+1

    if _span is not None:
        spans.append(_span)
        types.append(_type)
        
    if return_types:
        return spans, types

    return spans

In [None]:
def get_seq_metrics(labels, preds, entity_idx, verbose=0):
    n_correct = n_recall = n_precision = 0
    confusion_dict = defaultdict(lambda: [0, 0, 0]) # n_correct, n_preds, n_labels
    for i in range(len(labels)):
        if verbose > 0:
            print('Evaluating %d out of %d' % (i+1, len(labels)), end='\r')
        
        i_label = labels[i]
        i_pred = preds[i][:len(i_label)]

        spans, types = seq2span(i_pred, True, entity_idx)
        pred_set = {(_type, _span[0], _span[1]) for _span, _type in zip(spans, types)}

        spans, types = seq2span(i_label, True, entity_idx)
        label_set = {(_type, _span[0], _span[1]) for _span, _type in zip(spans, types)}

        correct_set = pred_set & label_set
        
        for _type, _, _ in correct_set:
            confusion_dict[_type][0] += 1
        for _type, _, _ in pred_set:
            confusion_dict[_type][1] += 1
        for _type, _, _ in label_set:
            confusion_dict[_type][2] += 1

        n_correct += len(correct_set)
        n_recall += len(label_set)
        n_precision += len(pred_set)
    
    try:
        recall = n_correct / n_recall
        precision = n_correct / n_precision
        f1 = 2 / (1/recall + 1/precision)
    except:
        recall = precision = f1 = 0

    if verbose > 0:
        print()

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_dict': confusion_dict,
    }

In [None]:
seq_labels = []
seq_preds = []

for i_batch, batch_data in enumerate(test_dataloader):
    print('Getting pred. of batch %d out of %d' % (i_batch+1, len(test_dataloader)), end='\r')
    
    # Get inputs
    masks = batch_data['masks'].to(device=device)
    x_word = batch_data['x_word'].to(device=device)
    x_char = batch_data['x_char'].to(device=device)
    x_lm_inputs = batch_data['x_lm_input'].to(device=device)
    x_lm_attention = batch_data['x_lm_attention'].to(device=device)
    x_lm_type_ids = batch_data['x_lm_type_ids'].to(device=device)
    x_lm_span = batch_data['x_lm_spans'].to(device=device)

    y_all_preds = net(x_word, x_char, x_lm_inputs, x_lm_attention, x_lm_type_ids, x_lm_span, masks)
    
    y_all_targets = []
    
    for i_layer in range(TOTAL_LAYERS):
        y_targets = batch_data['y_target_%d' % i_layer]
        y_preds = y_all_preds[i_layer].cpu().detach()
        
        for (y_target, y_pred, mask) in zip(y_targets, y_preds, masks):
            mask_cut = int(mask[i_layer:].sum().cpu().detach().numpy())
            y_target = y_target.view(-1)[:mask_cut]
            y_pred = torch.argmax(y_pred, dim=-1).view(-1)[:mask_cut]
            
            seq_labels.append(y_target.cpu().detach().numpy())
            seq_preds.append(y_pred.cpu().detach().numpy())
    
    # Clear some memory
    if device == 'cuda':
        del masks
        del x_word
        del x_char
        del x_lm_inputs
        del x_lm_attention
        del x_lm_type_ids
        del x_lm_span
        torch.cuda.empty_cache()

print()

In [None]:
overall_scores = get_seq_metrics(seq_labels, seq_preds, entity_idx, verbose=1)

In [None]:
print('OVERALL SCORES')
print('- Precision: %.4f' % overall_scores['precision'])
print('- Recall: %.4f' % overall_scores['recall'])
print('- F1-score: %.4f' % overall_scores['f1'])

## Sandbox - Try it yourself

In [None]:
my_text = 'Two cDNA clones were sequenced and provided 2,250 nucleotides (nt) of DNA sequence information.'

In [None]:
tokens, spans = tokenize_text(tokenizer, my_text, True)
eval_data = [{'tokens': tokens}]

In [None]:
nne_eval_dataset = NestedNamedEntitiesDataset(eval_data, word_input, char_input, bert_input, skip_exceptions=False,
                                              max_items=-1, total_layers=TOTAL_LAYERS, has_outputs=False)
eval_dataloader = DataLoader(nne_eval_dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
eval_item = next(iter(eval_dataloader))

masks = eval_item['masks'].to(device=device)
x_word = eval_item['x_word'].to(device=device)
x_char = eval_item['x_char'].to(device=device)
x_lm_inputs = eval_item['x_lm_input'].to(device=device)
x_lm_attention = eval_item['x_lm_attention'].to(device=device)
x_lm_type_ids = eval_item['x_lm_type_ids'].to(device=device)
x_lm_span = eval_item['x_lm_spans'].to(device=device)

In [None]:
y_all_preds = net(x_word, x_char, x_lm_inputs, x_lm_attention, x_lm_type_ids, x_lm_span, masks)

In [None]:
found_entities = []

for i_layer in range(TOTAL_LAYERS):
    y_pred = y_all_preds[i_layer].cpu().detach()
    y_pred = torch.argmax(y_pred, dim=-1).view(-1)
    
    for y_span, y in enumerate(y_pred):
        if y > 0:
            entity = {'span': [y_span, y_span+i_layer],
                      'tokens': tokens[y_span:y_span+i_layer+1]}
            found_entities.append(entity)

In [None]:
print(my_text)
print('-'*10)

for entity in found_entities:
    print(entity)