# Running the entity model
In this notebook we will run and evalute the entity nmodel proposed in the research paper [A Frustratingly Easy Approach for Entity and Relation Extraction](https://arxiv.org/pdf/2010.12812.pdf).

This is a reproduction based on the instructions left by the authors in their [GitHub repo](https://github.com/princeton-nlp/PURE)

We will run the entity model on the SchiERC dataset using a pre-trained BERT based nodel.

The output of this notebook, a JSON file where keys are **document and sentence indices**, and values are **lists of predicted entities** in the format [start, end, label], will be used as the input for the relation model in the notebook `run_relation`

**Environment information**

- Windows 11
- Python 3.6.13
- pip 21.2.2

## Basic setup
Firstly, we setup our notebook by importing needed libraries and modules.

In [2]:
import logging

import torch
import torch.nn.functional as F
from allennlp.modules import FeedForward
from allennlp.nn.util import batched_index_select
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AlbertTokenizer, AlbertPreTrainedModel, AlbertModel
from transformers import BertTokenizer, BertPreTrainedModel, BertModel
from transformers import AdamW, get_linear_schedule_with_warmup
import os
import logging
import json
import numpy as np
import time
from tqdm import tqdm
import random

  from .autonotebook import tqdm as notebook_tqdm


And we initialize a logger.

In [3]:
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)

logger = logging.getLogger('root')

## Main classes
The authors have implemented their system in an OOP style. Let's go through the defined classes and understand what they do.

### BertForEntity
The fisrt class we'll look at is BertForEntity. It is a custom model for named entity recognition (NER) using BERT (Bidirectional Encoder Representations from Transformers) as the underlying pre-trained transformer model. The class is built on top of a pre-trained BERT mode, therefore, it extends the BertPreTrainedModel.

Let's now look into the constructor, it initializes the BERT model, dropout, width embedding, and the NER classifier layers. It then calls init_weights() to initialize the model's weights.

The constructor accepts takes several parameters:
- config: BERT model configuration.
- num_ner_labels: Number of NER labels.
- head_hidden_dim: Hidden dimension for the feedforward layers in the NER classifier **(default: 150)**.
- width_embedding_dim: Dimension of the width embedding **(default: 150)**.
- max_span_length: Maximum length for a span **(default: 8)**.

We must note the setup of the model:

```python
self.ner_classifier = nn.Sequential(
            FeedForward(input_dim=config.hidden_size * 2 + width_embedding_dim,
                        num_layers=2,
                        hidden_dims=head_hidden_dim,
                        activations=F.relu,
                        dropout=0.2),
            nn.Linear(head_hidden_dim, num_ner_labels))
```

The model is a FeedForward Neural Network:

- It's defined using nn.Sequential and consists of two layers.
- Input dimension (input_dim) is set to the sum of the following:
  - config.hidden_size * 2: Twice the hidden size of the BERT model. This is likely because the span embeddings are concatenated, so the input dimension includes information from both start and end embeddings.
  - width_embedding_dim: Dimension of the width embedding.
- num_layers=2: There are two hidden layers in the feedforward network.
- hidden_dims=head_hidden_dim: The hidden layer dimensions are set by the head_hidden_dim parameter **(default: 150)**.
- activations=F.relu: Rectified Linear Unit (ReLU) activation function is used between layers.
- dropout=0.2: 20% dropout is applied between layers for regularization.

Output Layer:

- After the feedforward layers, there is a linear layer (nn.Linear) with an output dimension of num_ner_labels.
- This output dimension corresponds to the number of NER labels, indicating the classes the model is trying to predict.

**Note:**
The choice of ReLU activation and dropout between layers is a common practice in neural network architectures for introducing non-linearity and preventing overfitting.

In [4]:
class BertForEntity(BertPreTrainedModel):
    def __init__(self, config, num_ner_labels, head_hidden_dim=150, width_embedding_dim=150, max_span_length=8):
        super().__init__(config)

        self.bert = BertModel(config)
        self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.width_embedding = nn.Embedding(max_span_length + 1, width_embedding_dim)

        self.ner_classifier = nn.Sequential(
            FeedForward(input_dim=config.hidden_size * 2 + width_embedding_dim,
                        num_layers=2,
                        hidden_dims=head_hidden_dim,
                        activations=F.relu,
                        dropout=0.2),
            nn.Linear(head_hidden_dim, num_ner_labels)
        )

        self.init_weights()

    def _get_span_embeddings(self, input_ids, spans, token_type_ids=None, attention_mask=None):
        sequence_output, pooled_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids,
                                                   attention_mask=attention_mask)

        sequence_output = self.hidden_dropout(sequence_output)

        """
        spans: [batch_size, num_spans, 3]; 0: left_ned, 1: right_end, 2: width
        spans_mask: (batch_size, num_spans, )
        """
        spans_start = spans[:, :, 0].view(spans.size(0), -1)
        spans_start_embedding = batched_index_select(sequence_output, spans_start)
        spans_end = spans[:, :, 1].view(spans.size(0), -1)
        spans_end_embedding = batched_index_select(sequence_output, spans_end)

        spans_width = spans[:, :, 2].view(spans.size(0), -1)
        spans_width_embedding = self.width_embedding(spans_width)

        # Concatenate embeddings of left/right points and the width embedding
        spans_embedding = torch.cat((spans_start_embedding, spans_end_embedding, spans_width_embedding), dim=-1)
        """
        spans_embedding: (batch_size, num_spans, hidden_size*2+embedding_dim)
        """
        return spans_embedding

    def forward(self, input_ids, spans, spans_mask, spans_ner_label=None, token_type_ids=None, attention_mask=None):
        spans_embedding = self._get_span_embeddings(input_ids, spans, token_type_ids=token_type_ids,
                                                    attention_mask=attention_mask)
        ffnn_hidden = []
        hidden = spans_embedding
        for layer in self.ner_classifier:
            hidden = layer(hidden)
            ffnn_hidden.append(hidden)
        logits = ffnn_hidden[-1]

        if spans_ner_label is not None:
            loss_fct = CrossEntropyLoss(reduction='sum')
            if attention_mask is not None:
                active_loss = spans_mask.view(-1) == 1
                active_logits = logits.view(-1, logits.shape[-1])
                active_labels = torch.where(
                    active_loss, spans_ner_label.view(-1), torch.tensor(loss_fct.ignore_index).type_as(spans_ner_label)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, logits.shape[-1]), spans_ner_label.view(-1))
            return loss, logits, spans_embedding
        else:
            return logits, spans_embedding, spans_embedding

### AlbertForEntity

There's also the class AlbertForEntity. Which is very similar to the BertForEntity except it uses the AlbertModel isntead of the BErtModel as it's inderlying transformer model architecture.

In [5]:
class AlbertForEntity(AlbertPreTrainedModel):
    def __init__(self, config, num_ner_labels, head_hidden_dim=150, width_embedding_dim=150, max_span_length=8):
        super().__init__(config)

        self.albert = AlbertModel(config)
        self.hidden_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.width_embedding = nn.Embedding(max_span_length+1, width_embedding_dim)
        
        self.ner_classifier = nn.Sequential(
            FeedForward(input_dim=config.hidden_size*2+width_embedding_dim, 
                        num_layers=2,
                        hidden_dims=head_hidden_dim,
                        activations=F.relu,
                        dropout=0.2),
            nn.Linear(head_hidden_dim, num_ner_labels)
        )

        self.init_weights()

    def _get_span_embeddings(self, input_ids, spans, token_type_ids=None, attention_mask=None):
        sequence_output, pooled_output = self.albert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        
        sequence_output = self.hidden_dropout(sequence_output)

        """
        spans: [batch_size, num_spans, 3]; 0: left_ned, 1: right_end, 2: width
        spans_mask: (batch_size, num_spans, )
        """
        spans_start = spans[:, :, 0].view(spans.size(0), -1)
        spans_start_embedding = batched_index_select(sequence_output, spans_start)
        spans_end = spans[:, :, 1].view(spans.size(0), -1)
        spans_end_embedding = batched_index_select(sequence_output, spans_end)

        spans_width = spans[:, :, 2].view(spans.size(0), -1)
        spans_width_embedding = self.width_embedding(spans_width)

        spans_embedding = torch.cat((spans_start_embedding, spans_end_embedding, spans_width_embedding), dim=-1)
        """
        spans_embedding: (batch_size, num_spans, hidden_size*2+embedding_dim)
        """
        return spans_embedding

    def forward(self, input_ids, spans, spans_mask, spans_ner_label=None, token_type_ids=None, attention_mask=None):
        spans_embedding = self._get_span_embeddings(input_ids, spans, token_type_ids=token_type_ids, attention_mask=attention_mask)
        ffnn_hidden = []
        hidden = spans_embedding
        for layer in self.ner_classifier:
            hidden = layer(hidden)
            ffnn_hidden.append(hidden)
        logits = ffnn_hidden[-1]

        if spans_ner_label is not None:
            loss_fct = CrossEntropyLoss(reduction='sum')
            if attention_mask is not None:
                active_loss = spans_mask.view(-1) == 1
                active_logits = logits.view(-1, logits.shape[-1])
                active_labels = torch.where(
                    active_loss, spans_ner_label.view(-1), torch.tensor(loss_fct.ignore_index).type_as(spans_ner_label)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, logits.shape[-1]), spans_ner_label.view(-1))
            return loss, logits, spans_embedding
        else:
            return logits, spans_embedding, spans_embedding

### EntityModel

This EntityModel class is a wrapper around a BERT or ALBERT model for Named Entity Recognition (NER) tasks. It's designed to provide a convenient interface for using BERT or ALBERT models for NER tasks, handling tokenization, model loading, device management, and batch processing.

Let's break down the functionality and components of this class:

#### nitialization (__init__ method):
- Parameters:
  - model: The name of the BERT or ALBERT model to be used.
  - bert_model_dir: The directory where the pre-trained BERT or ALBERT model is stored.
  - use_albert: A flag indicating whether to use ALBERT (True) or BERT (False).
  - max_span_length: Maximum span length to be considered during tokenization.
  - num_ner_labels: The number of Named Entity Recognition labels.
  
  
- Initialization Steps:
1. Set bert_model_name and vocab_name based on the provided parameters.
2. If a BERT model directory is specified (bert_model_dir is not None), update bert_model_name and vocab_name accordingly.
3. Initialize a tokenizer (AlbertTokenizer or BertTokenizer) based on the specified model name.
4. Initialize the transformer model (AlbertForEntity or BertForEntity) with the provided parameters.


#### Device Management (move_model_to_cuda method):


- Functionality:
  - Checks if a GPU (CUDA) is available.
  - If available, moves the model to the GPU. If multiple GPUs are available, it wraps the model with **torch.nn.DataParallel** for parallel processing.


##### Input Tensor Generation (\_get_input_tensors method):
- Functionality:
  - Takes a list of tokens, spans, and spans' Named Entity Recognition (NER) labels.
  - Tokenizes the input tokens using the tokenizer and converts them to indexed tokens.
  - Converts span information to tensors.
  
  
#### Batch Input Tensor Generation (_get_input_tensors_batch method):
- Functionality:
  - Takes a list of samples, where each sample contains tokens, spans, and spans' NER labels.
  - Calls _get_input_tensors for each sample and constructs batched tensors for the entire input batch.
  - Handles padding for both tokens and spans to create tensors of uniform shapes.
#### Batch Execution (run_batch method):
- Functionality:
  - Converts the input samples to tensors using _get_input_tensors_batch.
  - If in training mode, runs the BERT or ALBERT model in training mode and computes the NER loss.
  - If in evaluation mode, runs the model in evaluation mode and generates predictions.
  - Returns a dictionary containing NER loss, log-likelihoods (in training mode), predicted NER labels, NER probabilities, and last hidden states.
  
**Note:**
- The class relies on external logger (logger) for logging information.
- It assumes that the underlying BERT or ALBERT model classes (AlbertForEntity, BertForEntity) are correctly implemented.

In [6]:
class EntityModel:

    def __init__(self, model, use_albert, max_span_length, num_ner_labels, bert_model_dir=None):
        super().__init__()

        bert_model_name = model
        vocab_name = bert_model_name

        if bert_model_dir is not None:
            bert_model_name = str(bert_model_dir) + '/'
            # vocab_name = bert_model_name + 'vocab.txt'
            vocab_name = bert_model_name
            logger.info('Loading BERT model from {}'.format(bert_model_name))

        if use_albert:
            self.tokenizer = AlbertTokenizer.from_pretrained(vocab_name)
            self.bert_model = AlbertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels,
                                                              max_span_length=max_span_length)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(vocab_name)
            self.bert_model = BertForEntity.from_pretrained(bert_model_name, num_ner_labels=num_ner_labels,
                                                            max_span_length=max_span_length)

        self._model_device = 'cpu'
        self.move_model_to_cuda()

    def move_model_to_cuda(self):
        if not torch.cuda.is_available():
            logger.error('No CUDA found!')
            exit(-1)
        logger.info('Moving to CUDA...')
        self._model_device = 'cuda'
        self.bert_model.cuda()
        logger.info('# GPUs = %d' % (torch.cuda.device_count()))
        if torch.cuda.device_count() > 1:
            self.bert_model = torch.nn.DataParallel(self.bert_model)

    def _get_input_tensors(self, tokens, spans, spans_ner_label):
        start2idx = []
        end2idx = []

        bert_tokens = []
        bert_tokens.append(self.tokenizer.cls_token)
        for token in tokens:
            start2idx.append(len(bert_tokens))
            sub_tokens = self.tokenizer.tokenize(token)
            bert_tokens += sub_tokens
            end2idx.append(len(bert_tokens) - 1)
        bert_tokens.append(self.tokenizer.sep_token)

        indexed_tokens = self.tokenizer.convert_tokens_to_ids(bert_tokens)
        tokens_tensor = torch.tensor([indexed_tokens])

        bert_spans = [[start2idx[span[0]], end2idx[span[1]], span[2]] for span in spans]
        bert_spans_tensor = torch.tensor([bert_spans])

        spans_ner_label_tensor = torch.tensor([spans_ner_label])

        return tokens_tensor, bert_spans_tensor, spans_ner_label_tensor

    def _get_input_tensors_batch(self, samples_list, training=True):
        tokens_tensor_list = []
        bert_spans_tensor_list = []
        spans_ner_label_tensor_list = []
        sentence_length = []

        max_tokens = 0
        max_spans = 0
        for sample in samples_list:
            tokens = sample['tokens']
            spans = sample['spans']
            spans_ner_label = sample['spans_label']

            tokens_tensor, bert_spans_tensor, spans_ner_label_tensor = self._get_input_tensors(tokens, spans,
                                                                                               spans_ner_label)
            tokens_tensor_list.append(tokens_tensor)
            bert_spans_tensor_list.append(bert_spans_tensor)
            spans_ner_label_tensor_list.append(spans_ner_label_tensor)
            assert (bert_spans_tensor.shape[1] == spans_ner_label_tensor.shape[1])
            if (tokens_tensor.shape[1] > max_tokens):
                max_tokens = tokens_tensor.shape[1]
            if (bert_spans_tensor.shape[1] > max_spans):
                max_spans = bert_spans_tensor.shape[1]
            sentence_length.append(sample['sent_length'])
        sentence_length = torch.Tensor(sentence_length)

        # apply padding and concatenate tensors
        final_tokens_tensor = None
        final_attention_mask = None
        final_bert_spans_tensor = None
        final_spans_ner_label_tensor = None
        final_spans_mask_tensor = None
        for tokens_tensor, bert_spans_tensor, spans_ner_label_tensor in zip(tokens_tensor_list, bert_spans_tensor_list,
                                                                            spans_ner_label_tensor_list):
            # padding for tokens
            num_tokens = tokens_tensor.shape[1]
            tokens_pad_length = max_tokens - num_tokens
            attention_tensor = torch.full([1, num_tokens], 1, dtype=torch.long)
            if tokens_pad_length > 0:
                pad = torch.full([1, tokens_pad_length], self.tokenizer.pad_token_id, dtype=torch.long)
                tokens_tensor = torch.cat((tokens_tensor, pad), dim=1)
                attention_pad = torch.full([1, tokens_pad_length], 0, dtype=torch.long)
                attention_tensor = torch.cat((attention_tensor, attention_pad), dim=1)

            # padding for spans
            num_spans = bert_spans_tensor.shape[1]
            spans_pad_length = max_spans - num_spans
            spans_mask_tensor = torch.full([1, num_spans], 1, dtype=torch.long)
            if spans_pad_length > 0:
                pad = torch.full([1, spans_pad_length, bert_spans_tensor.shape[2]], 0, dtype=torch.long)
                bert_spans_tensor = torch.cat((bert_spans_tensor, pad), dim=1)
                mask_pad = torch.full([1, spans_pad_length], 0, dtype=torch.long)
                spans_mask_tensor = torch.cat((spans_mask_tensor, mask_pad), dim=1)
                spans_ner_label_tensor = torch.cat((spans_ner_label_tensor, mask_pad), dim=1)

            # update final outputs
            if final_tokens_tensor is None:
                final_tokens_tensor = tokens_tensor
                final_attention_mask = attention_tensor
                final_bert_spans_tensor = bert_spans_tensor
                final_spans_ner_label_tensor = spans_ner_label_tensor
                final_spans_mask_tensor = spans_mask_tensor
            else:
                final_tokens_tensor = torch.cat((final_tokens_tensor, tokens_tensor), dim=0)
                final_attention_mask = torch.cat((final_attention_mask, attention_tensor), dim=0)
                final_bert_spans_tensor = torch.cat((final_bert_spans_tensor, bert_spans_tensor), dim=0)
                final_spans_ner_label_tensor = torch.cat((final_spans_ner_label_tensor, spans_ner_label_tensor), dim=0)
                final_spans_mask_tensor = torch.cat((final_spans_mask_tensor, spans_mask_tensor), dim=0)
        # logger.info(final_tokens_tensor)
        # logger.info(final_attention_mask)
        # logger.info(final_bert_spans_tensor)
        # logger.info(final_bert_spans_tensor.shape)
        # logger.info(final_spans_mask_tensor.shape)
        # logger.info(final_spans_ner_label_tensor.shape)
        return final_tokens_tensor, final_attention_mask, final_bert_spans_tensor, final_spans_mask_tensor, final_spans_ner_label_tensor, sentence_length

    def run_batch(self, samples_list, try_cuda=True, training=True):
        # convert samples to input tensors
        tokens_tensor, attention_mask_tensor, bert_spans_tensor, spans_mask_tensor, spans_ner_label_tensor, sentence_length = self._get_input_tensors_batch(
            samples_list, training)

        output_dict = {
            'ner_loss': 0,
        }

        if training:
            self.bert_model.train()
            ner_loss, ner_logits, spans_embedding = self.bert_model(
                input_ids=tokens_tensor.to(self._model_device),
                spans=bert_spans_tensor.to(self._model_device),
                spans_mask=spans_mask_tensor.to(self._model_device),
                spans_ner_label=spans_ner_label_tensor.to(self._model_device),
                attention_mask=attention_mask_tensor.to(self._model_device),
            )
            output_dict['ner_loss'] = ner_loss.sum()
            output_dict['ner_llh'] = F.log_softmax(ner_logits, dim=-1)
        else:
            self.bert_model.eval()
            with torch.no_grad():
                ner_logits, spans_embedding, last_hidden = self.bert_model(
                    input_ids=tokens_tensor.to(self._model_device),
                    spans=bert_spans_tensor.to(self._model_device),
                    spans_mask=spans_mask_tensor.to(self._model_device),
                    spans_ner_label=None,
                    attention_mask=attention_mask_tensor.to(self._model_device),
                )
            _, predicted_label = ner_logits.max(2)
            predicted_label = predicted_label.cpu().numpy()
            last_hidden = last_hidden.cpu().numpy()

            predicted = []
            pred_prob = []
            hidden = []
            for i, sample in enumerate(samples_list):
                ner = []
                prob = []
                lh = []
                for j in range(len(sample['spans'])):
                    ner.append(predicted_label[i][j])
                    # prob.append(F.softmax(ner_logits[i][j], dim=-1).cpu().numpy())
                    prob.append(ner_logits[i][j].cpu().numpy())
                    lh.append(last_hidden[i][j])
                predicted.append(ner)
                pred_prob.append(prob)
                hidden.append(lh)
            output_dict['pred_ner'] = predicted
            output_dict['ner_probs'] = pred_prob
            output_dict['ner_last_hidden'] = hidden

        return output_dict

### Dataset

Represents a dataset. It's a convenient wrapper for handling datasets, reading data from JSON files, and creating Document objects.

Let's break down its components:
#### `__init__` method:

- Initializes an instance of the Dataset class.
- Takes three parameters: `json_file`, `pred_file` **(default is `None`)**, and `doc_range` **(default is `None`)**.
- Reads data from the specified JSON files (`json_file` and `pred_file`).
- If a document range (`doc_range`) is provided, it selects a subset of documents within that range.
- Creates a list of Document objects based on the read data.

#### `update_from_js` method:

- Updates the dataset with new data (`js`).
- Re-creates the list of `Document` objects based on the updated data.

#### \`_read method`:

- Reads data from JSON files (`json_file` and optionally `pred_file`).
If `pred_file` is provided, it merges the data from the gold (`gold_docs`) and predicted (`pred_docs`) files.
- Returns the merged list of documents.

#### `__getitem__` method:

- Enables indexing of the dataset. Given an index `ix`, it returns the corresponding Document object.

#### `__len__` method:

- Returns the length of the dataset, i.e., the number of documents.

In [7]:
class Dataset:
    def __init__(self, json_file, pred_file=None, doc_range=None):
        self.js = self._read(json_file, pred_file)
        if doc_range is not None:
            self.js = self.js[doc_range[0]:doc_range[1]]
        self.documents = [Document(js) for js in self.js]

    def update_from_js(self, js):
        self.js = js
        self.documents = [Document(js) for js in self.js]

    def _read(self, json_file, pred_file=None):
        gold_docs = [json.loads(line) for line in open(json_file)]
        if pred_file is None:
            return gold_docs

        pred_docs = [json.loads(line) for line in open(pred_file)]
        merged_docs = []
        for gold, pred in zip(gold_docs, pred_docs):
            assert gold["doc_key"] == pred["doc_key"]
            assert gold["sentences"] == pred["sentences"]
            merged = copy.deepcopy(gold)
            for k, v in pred.items():
                if "predicted" in k:
                    merged[k] = v
            merged_docs.append(merged)

        return merged_docs

    def __getitem__(self, ix):
        return self.documents[ix]

    def __len__(self):
        return len(self.documents)

### Document

This class represents a document. It encapsulates information about a document, its sentences, and any associated clusters. It provides methods for accessing and manipulating this information.

Let's go through each part:

#### `__init__` method:

- Initializes an instance of the Document class.
- Takes a JSON object (`js`) as input.
- Extracts the document key (`_doc_key`) from the JSON object.
- Uses the `fields_to_batches` function to extract specific fields from the JSON and create a list of entries.
- Computes sentence lengths and starts to facilitate sentence indexing.
- Creates a list of Sentence objects based on the entries.
- If "`clusters`" or "`predicted_clusters`" are present in the JSON, creates lists of Cluster objects for clusters and predicted clusters.

#### `__repr__` method:

- Returns a string representation of the document, including sentence indices and their corresponding text.

#### `__getitem__` method:

- Enables indexing of the document. Given an index ix, it returns the corresponding `Sentence` object.

#### `__len__` method:

- Returns the number of sentences in the document.

#### `print_plaintext method`:

- Prints the plaintext representation of the document, sentence by sentence.

#### `find_cluster` method:

- Searches through reference clusters (either predicted or actual) to find the one containing a specified entity.
- Returns the found cluster or None if no match is found.

#### `n_tokens property`:

- Returns the total number of tokens in the document by summing the number of tokens in each sentence.

In [8]:
class Document:
    def __init__(self, js):
        self._doc_key = js["doc_key"]
        entries = fields_to_batches(js, ["doc_key", "clusters", "predicted_clusters", "section_starts"])
        sentence_lengths = [len(entry["sentences"]) for entry in entries]
        sentence_starts = np.cumsum(sentence_lengths)
        sentence_starts = np.roll(sentence_starts, 1)
        sentence_starts[0] = 0
        self.sentence_starts = sentence_starts
        self.sentences = [Sentence(entry, sentence_start, sentence_ix)
                          for sentence_ix, (entry, sentence_start)
                          in enumerate(zip(entries, sentence_starts))]
        if "clusters" in js:
            self.clusters = [Cluster(entry, i, self)
                             for i, entry in enumerate(js["clusters"])]
        if "predicted_clusters" in js:
            self.predicted_clusters = [Cluster(entry, i, self)
                                       for i, entry in enumerate(js["predicted_clusters"])]

    def __repr__(self):
        return "\n".join([str(i) + ": " + " ".join(sent.text) for i, sent in enumerate(self.sentences)])

    def __getitem__(self, ix):
        return self.sentences[ix]

    def __len__(self):
        return len(self.sentences)

    def print_plaintext(self):
        for sent in self:
            print(" ".join(sent.text))


    def find_cluster(self, entity, predicted=True):
        """
        Search through erence clusters and return the one containing the query entity, if it's
        part of a cluster. If we don't find a match, return None.
        """
        clusters = self.predicted_clusters if predicted else self.clusters
        for clust in clusters:
            for entry in clust:
                if entry.span == entity.span:
                    return clust

        return None

    @property
    def n_tokens(self):
        return sum([len(sent) for sent in self.sentences])

### Cluster

The `Cluster` class represents a cluster of entities within a document. It is used to group together entities that belong to the same cluster, providing information about the members of the cluster.
This class is designed to encapsulate information about a cluster of entities within a document. It facilitates the organization and representation of entities belonging to the same cluster.

#### Initialization
- Parameters:
  - `cluster`: A list of entries representing the entities in the cluster.
  - `cluster_id`: An identifier for the cluster.
  - `document`: The document object to which the cluster belongs.
- Attributes
  - `members`: A list of ClusterMember instances, each representing an entity in the cluster.
  - `cluster_id`: The identifier for the cluster.
  
#### Initialization Details
- The `__init__` method initializes the cluster by extracting information about each entity in the cluster.
- For each entry in the cluster, it determines the corresponding sentence and span in the document.
- It creates `ClusterMember` instances for each entity and appends them to the members list.

#### Representation
`__repr__` Method:
Returns a string representation of the cluster, including the cluster identifier and a representation of its members.
Accessing Members

`__getitem__` Method:
Allows accessing individual members of the cluster using indexing.

In [9]:
class Cluster:
    def __init__(self, cluster, cluster_id, document):
        members = []
        for entry in cluster:
            sentence_ix = get_sentence_of_span(entry, document.sentence_starts, document.n_tokens)
            sentence = document[sentence_ix]
            span = Span(entry[0], entry[1], sentence.text, sentence.sentence_start)
            ners = [x for x in sentence.ner if x.span == span]
            assert len(ners) <= 1
            ner = ners[0] if len(ners) == 1 else None
            to_append = ClusterMember(span, ner, sentence, cluster_id)
            members.append(to_append)

        self.members = members
        self.cluster_id = cluster_id

    def __repr__(self):
        return f"{self.cluster_id}: " + self.members.__repr__()

    def __getitem__(self, ix):
        return self.members[ix]


### ClusterMember

Represents an individual entity within a cluster. It provides information about the entity's span, associated named entity recognition (NER) information, the sentence it belongs to, and the identifier of the cluster to which it is assigned. The class serves as a container for information about an individual entity within a cluster. It encapsulates details such as the span, NER information, sentence context, and the cluster to which the entity is assigned.

#### Initialization
- Parameters:
  - `span`: A Span instance representing the span of the entity in the document.
  - `ner`: A NER instance representing the NER information for the entity.
  - `sentence`: A Sentence instance representing the sentence containing the entity.
  - `cluster_id`: The identifier of the cluster to which the entity belongs.
- Attributes
  - `span`: A Span instance representing the span of the entity.
  - `ner`: A NER instance representing the NER information for the entity.
  - `sentence`: A Sentence instance representing the sentence containing the entity.
  - `cluster_id`: The identifier of the cluster to which the entity belongs.
  
#### Initialization Details

The `__init__` method initializes a `ClusterMember` by assigning values to its attributes based on the provided parameters.
Representation

`__repr__` Method:
Returns a string representation of the `ClusterMember`, including the sentence index and a representation of its span.


In [10]:
class ClusterMember:
    def __init__(self, span, ner, sentence, cluster_id):
        self.span = span
        self.ner = ner
        self.sentence = sentence
        self.cluster_id = cluster_id

    def __repr__(self):
        return f"<{self.sentence.sentence_ix}> " + self.span.__repr__()


### Sentence

This class represents a sentence. It encapsulates information about a sentence, including its text and associated entities (NER, relations, events). It provides methods for accessing and manipulating this information.

Let's go through each method:

#### `__init__` method:

- Initializes an instance of the `Sentence` class.
- Takes an `entry`, `sentence_start` (index of the sentence start in the document), and `sentence_ix` (sentence index) as input.
- Stores information about the sentence's start position, text, and index.
- Parses gold entities (NER, relations, events) and predicted entities (NER, relations, events).
- Stores top spans if available.

#### `__repr__` method:

- Returns a string representation of the sentence, including the text and token indices.

#### `__len__` method:

-Returns the number of tokens in the sentence.

#### `get_flavor` method:

- Given an argument (presumably an entity), retrieves its flavor from the gold NER annotations.
- If multiple NER annotations are found for the same span, prints a message (debugging).

In [11]:
class Sentence:
    def __init__(self, entry, sentence_start, sentence_ix):
        self.sentence_start = sentence_start
        self.text = entry["sentences"]
        self.sentence_ix = sentence_ix
        # Gold
        if "ner_flavor" in entry:
            self.ner = [NER(this_ner, self.text, sentence_start, flavor=this_flavor)
                        for this_ner, this_flavor in zip(entry["ner"], entry["ner_flavor"])]
        elif "ner" in entry:
            self.ner = [NER(this_ner, self.text, sentence_start)
                        for this_ner in entry["ner"]]
        if "relations" in entry:
            self.relations = [Relation(this_relation, self.text, sentence_start) for
                              this_relation in entry["relations"]]
        if "events" in entry:
            self.events = Events(entry["events"], self.text, sentence_start)

        # Predicted
        if "predicted_ner" in entry:
            self.predicted_ner = [NER(this_ner, self.text, sentence_start, flavor=None) for
                                  this_ner in entry["predicted_ner"]]
        if "predicted_relations" in entry:
            self.predicted_relations = [Relation(this_relation, self.text, sentence_start) for
                                        this_relation in entry["predicted_relations"]]
        if "predicted_events" in entry:
            self.predicted_events = Events(entry["predicted_events"], self.text, sentence_start)

        # Top spans
        if "top_spans" in entry:
            self.top_spans = [NER(this_ner, self.text, sentence_start, flavor=None) for
                                this_ner in entry["top_spans"]]

    def __repr__(self):
        the_text = " ".join(self.text)
        the_lengths = np.array([len(x) for x in self.text])
        tok_ixs = ""
        for i, offset in enumerate(the_lengths):
            true_offset = offset if i < 10 else offset - 1
            tok_ixs += str(i)
            tok_ixs += " " * true_offset

        return the_text + "\n" + tok_ixs

    def __len__(self):
        return len(self.text)

    def get_flavor(self, argument):
        the_ner = [x for x in self.ner if x.span == argument.span]
        if len(the_ner) > 1:
            print("Weird")
        if the_ner:
            the_flavor = the_ner[0].flavor
        else:
            the_flavor = None
        return the_flavor

### NER

The NER class represents a Named Entity Recognition annotation within a sentence. It encapsulates information about a named entity, providing methods for representation and equality checks. The `Span` (we'll get to it next) class is used to represent the span of the named entity within the context of the sentence.

Here are some details about it:
#### Initialization

- Parameters:
  - `ner`: A list containing information about the NER span, including start index, end index, and label.
  - `text`: The text content of the sentence.
  - `sentence_start`: The index of the sentence start in the document.
  - `flavor`: An optional parameter representing the flavor or type of the named entity.
  
- Attributes:
  - `span`: An instance of the Span class representing the span of the NER annotation.
  - `label`: The label assigned to the NER entity.
  - `flavor`: The flavor or type of the named entity.
  
- Representation

`__repr__` Method:
Returns a string representation of the NER annotation, including the span and label.
Equality Check

`__eq__` Method:
Checks if two NER instances are equal by comparing their span, label, and flavor.

In [12]:
class NER:
    def __init__(self, ner, text, sentence_start, flavor=None):
        self.span = Span(ner[0], ner[1], text, sentence_start)
        self.label = ner[2]
        self.flavor = flavor

    def __repr__(self):
        return self.span.__repr__() + ": " + self.label

    def __eq__(self, other):
        return (self.span == other.span and
                self.label == other.label and
                self.flavor == other.flavor)

### Span

The `Span` class represents a span of text within a document. It encapsulate information about a text span, providing methods for representation, equality checks, and hashing. It ensures the proper handling of spans within the document and sentence contexts.

#### Initialization
- Parameters:
  - `start`: The starting index of the span in the entire document.
  - `end`: The ending index of the span in the entire document.
  - `text`: The text content of the entire document.
  - `sentence_start`: The index of the sentence start in the document.
  
- Attributes
  - `start_doc`, `end_doc`: The start and end indices of the span in the entire document.
  - `span_doc`: A tuple representing the span in the entire document.
  - `start_sent`, `end_sent`: The start and end indices of the span within the sentence.
  - `span_sent`: A tuple representing the span within the sentence.
  - `text`: The actual text content of the span.
  
#### Representation

`__repr__` Method:
Returns a string representation of the span, including start and end indices and the actual text content.
Equality Check

`__eq__` Method:
Checks if two Span instances are equal by comparing their spans in both the document and the sentence, along with the text content.
Hashing

`__hash__` Method:
Computes a hash value for the Span instance based on its document and sentence spans, as well as the text content.

In [13]:
class Span:
    def __init__(self, start, end, text, sentence_start):
        self.start_doc = start
        self.end_doc = end
        self.span_doc = (self.start_doc, self.end_doc)
        self.start_sent = start - sentence_start
        self.end_sent = end - sentence_start
        self.span_sent = (self.start_sent, self.end_sent)
        self.text = text[self.start_sent:self.end_sent + 1]

    def __repr__(self):
        return str((self.start_sent, self.end_sent, self.text))

    def __eq__(self, other):
        return (self.span_doc == other.span_doc and
                self.span_sent == other.span_sent and
                self.text == other.text)

    def __hash__(self):
        tup = self.span_doc + self.span_sent + (" ".join(self.text),)
        return hash(tup)

### Relation

The Relation class is designed to represent a relation between two spans within a text. It encapsulates information about the spans, such as their start and end positions, the text they cover, and the label assigned to the relation. Here's a breakdown of the class:

#### Initialization

- Parameters:

  - `relation`: A tuple representing the start and end positions of the two spans and the label of the relation.
  - `text`: The text containing the spans.
  - `sentence_start`: The starting position of the sentence in the text.
  
- Attributes

  - `pair`: A tuple containing two Span objects (span1 and span2) representing the spans of the relation.
  - `label`: The label of the relation.
  
#### Representation
`__repr__` Method: Returns a string representation of the Relation object, including the string representations of the two spans, the relation between them (", "), and the label.

`__eq__` Method: Checks if two Relation objects are equal by comparing their span pairs and labels.

In [14]:
class Relation:
    def __init__(self, relation, text, sentence_start):
        start1, end1 = relation[0], relation[1]
        start2, end2 = relation[2], relation[3]
        label = relation[4]
        span1 = Span(start1, end1, text, sentence_start)
        span2 = Span(start2, end2, text, sentence_start)
        self.pair = (span1, span2)
        self.label = label

    def __repr__(self):
        return self.pair[0].__repr__() + ", " + self.pair[1].__repr__() + ": " + self.label

    def __eq__(self, other):
        return (self.pair == other.pair) and (self.label == other.label)

### NpEncoder
This class is a custom JSON encoder that extends the json.JSONEncoder class. It is designed to handle the encoding of NumPy-specific data types into a JSON-compatible format.

In [15]:
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)

## Utility functions
The authors have aslo implemented some utility functions. Let's go through them and understand what they do.

### get_sentence_of_span
This function determines the index of the sentence to which a given span belongs within a document. It takes three parameters: 
- `span`: A tuple representing the start and end indices of the span.
- `sentence_starts`: A list containing the starting indices of sentences in the document.
- `doc_tokens`: The total number of tokens (or words) in the document.

I've also added some comments to the function to explain how it works.

In [16]:
def get_sentence_of_span(span, sentence_starts, doc_tokens):
    """
    Return the index of the sentence that the span is part of.
    """
    # Inclusive sentence ends
    sentence_ends = [x - 1 for x in sentence_starts[1:]] + [doc_tokens - 1]
    
    # Check if the span is between each pair of sentence starts and ends
    in_between = [span[0] >= start and span[1] <= end
                  for start, end in zip(sentence_starts, sentence_ends)]
    
    # Ensure that the span is part of exactly one sentence
    assert sum(in_between) == 1
    
    # Get the index of the sentence to which the span belongs
    the_sentence = in_between.index(True)
    
    return the_sentence

### convert_dataset_to_samples
As the name suggests, this function is used to convert a dataset into a format suitable for training the entity model.

It processes a dataset of documents, extracts information such as sentences and named entity recognition (NER) labels, and organizes the data into a format suitable for training the entity model.

The function takes several parameters:

dataset: A collection of documents, each containing sentences and NER information.
- `max_span_length`: The maximum length of spans to consider.
- `ner_label2id`: A mapping from NER labels to unique identifiers.
- `context_window`: The size of the context window to consider around each sentence.
- `split`: A parameter used for splitting the data into train and dev sets (for ACE04 dataset).

Here's a breakdown of how the convert_dataset_to_samples function works:

1. Initialization: The function initializes variables to keep track of statistics such as the number of NER labels (`num_ner`), the maximum sentence length (`max_len`), the maximum number of NER labels in a sentence (`max_ner`), and the number of overlapping spans (`num_overlap`).

2. Data Splitting: If the split parameter is specified, the function determines the range of data to process based on whether to include the full dataset, the first 90% (split == 1), or the last 10% (split == 2).

3. Processing Documents and Sentences: The function iterates over the documents and sentences within the specified data range. For each sentence, it creates a dictionary (`sample`) to store information about the document, sentence index, tokens, and more.

4. Context Window Handling: If a `context_window` is specified and the sentence length exceeds it, the function adds left and right context to the current sentence.

5. NER Label Processing: The function processes NER labels for each sentence and constructs spans. It records the spans' starting and ending indices, length, and assigns labels based on the NER information.

6. Statistics and Logging: The function updates statistics such as the average input length, maximum length, and maximum NER labels. It also logs information about the dataset, including the number of extracted samples and NER labels.

7. Return: The function returns a list of dictionaries (`samples`), where each dictionary represents a processed sample suitable for model training, along with the total number of NER labels (`num_ner`).

In [17]:
def convert_dataset_to_samples(dataset, max_span_length, ner_label2id=None, context_window=0, split=0):
    """
    Extract sentences and gold entities from a dataset
    """
    # split: split the data into train and dev (for ACE04)
    # split == 0: don't split
    # split == 1: return first 90% (train)
    # split == 2: return last 10% (dev)
    samples = []
    num_ner = 0
    max_len = 0
    max_ner = 0
    num_overlap = 0
    
    if split == 0:
        data_range = (0, len(dataset))
    elif split == 1:
        data_range = (0, int(len(dataset)*0.9))
    elif split == 2:
        data_range = (int(len(dataset)*0.9), len(dataset))

    for c, doc in enumerate(dataset):
        if c < data_range[0] or c >= data_range[1]:
            continue
        for i, sent in enumerate(doc):
            num_ner += len(sent.ner)
            sample = {
                'doc_key': doc._doc_key,
                'sentence_ix': sent.sentence_ix,
            }
            if context_window != 0 and len(sent.text) > context_window:
                logger.info('Long sentence: {} {}'.format(sample, len(sent.text)))
                # print('Exclude:', sample)
                # continue
            sample['tokens'] = sent.text
            sample['sent_length'] = len(sent.text)
            sent_start = 0
            sent_end = len(sample['tokens'])

            max_len = max(max_len, len(sent.text))
            max_ner = max(max_ner, len(sent.ner))

            if context_window > 0:
                add_left = (context_window-len(sent.text)) // 2
                add_right = (context_window-len(sent.text)) - add_left
                
                # add left context
                j = i - 1
                while j >= 0 and add_left > 0:
                    context_to_add = doc[j].text[-add_left:]
                    sample['tokens'] = context_to_add + sample['tokens']
                    add_left -= len(context_to_add)
                    sent_start += len(context_to_add)
                    sent_end += len(context_to_add)
                    j -= 1

                # add right context
                j = i + 1
                while j < len(doc) and add_right > 0:
                    context_to_add = doc[j].text[:add_right]
                    sample['tokens'] = sample['tokens'] + context_to_add
                    add_right -= len(context_to_add)
                    j += 1

            sample['sent_start'] = sent_start
            sample['sent_end'] = sent_end
            sample['sent_start_in_doc'] = sent.sentence_start
            
            sent_ner = {}
            for ner in sent.ner:
                sent_ner[ner.span.span_sent] = ner.label

            span2id = {}
            sample['spans'] = []
            sample['spans_label'] = []
            for i in range(len(sent.text)):
                for j in range(i, min(len(sent.text), i+max_span_length)):
                    sample['spans'].append((i+sent_start, j+sent_start, j-i+1))
                    span2id[(i, j)] = len(sample['spans'])-1
                    if (i, j) not in sent_ner:
                        sample['spans_label'].append(0)
                    else:
                        sample['spans_label'].append(ner_label2id[sent_ner[(i, j)]])
            samples.append(sample)
    avg_length = sum([len(sample['tokens']) for sample in samples]) / len(samples)
    max_length = max([len(sample['tokens']) for sample in samples])
    logger.info('# Overlap: %d'%num_overlap)
    logger.info('Extracted %d samples from %d documents, with %d NER labels, %.3f avg input length, %d max length'%(len(samples), data_range[1]-data_range[0], num_ner, avg_length, max_length))
    logger.info('Max Length: %d, max NER: %d'%(max_len, max_ner))
    return samples, num_ner


### batchify

The `batchify` function is designed to organize a list of samples into batches of a specified size (`batch_size`). The function handles some special cases, such as creating single-sample batches for sentences that are too long to avoid GPU out-of-memory (OOM) issues.

Here's a breakdown of how the `batchify` function works:

1. Initialization: The function starts by initializing variables, including the total number of samples (`num_samples`) and an empty list to store batches (`list_samples_batches`).

2. Single Batch for Long Sentences: The function checks each sample, and if the length of its tokens exceeds a threshold (350 in this case), it is added to a list (`to_single_batch`). For these long sentences, the function logs information and creates individual batches containing only that sample.

3. Remove Single Batches from the Original List: The samples identified as long sentences are removed from the original list of samples.

4. Create Batches: The function then proceeds to create batches from the remaining samples. It iterates over the samples, grouping them into batches of size `batch_size`. These batches are appended to the list of batches (`list_samples_batches`).

5. Assertion Check: The function asserts that the total number of samples in the batches equals the original number of samples.

6. Return: The function returns a list of batches, where each batch is a list of samples. Some of these batches may contain only one sample if the sentence was too long.

In [18]:
def batchify(samples, batch_size):
    """
    Batchfy samples with a batch size
    """
    num_samples = len(samples)

    list_samples_batches = []
    
    # if a sentence is too long, make itself a batch to avoid GPU OOM
    to_single_batch = []
    for i in range(0, len(samples)):
        if len(samples[i]['tokens']) > 350:
            to_single_batch.append(i)
    
    for i in to_single_batch:
        logger.info('Single batch sample: %s-%d', samples[i]['doc_key'], samples[i]['sentence_ix'])
        list_samples_batches.append([samples[i]])
    samples = [sample for i, sample in enumerate(samples) if i not in to_single_batch]

    for i in range(0, len(samples), batch_size):
        list_samples_batches.append(samples[i:i+batch_size])

    assert(sum([len(batch) for batch in list_samples_batches]) == num_samples)

    return list_samples_batches


### get_labelmap

This function takes a list of labels (`label_list`) and creates two mappings: `label2id` and `id2label`. These mappings are useful for converting between label names and their corresponding numerical identifiers.

In [19]:
def get_labelmap(label_list):
    label2id = {}
    id2label = {}
    for i, label in enumerate(label_list):
        label2id[label] = i + 1
        id2label[i + 1] = label
    return label2id, id2label

### output_ner_predictions


The output_ner_predictions function is designed to make NER predictions using a trained model, and then save these predictions as a JSON file. The function takes the following parameters:

- `model`: The trained NER model.
- `batches`: A list of batches, each containing samples for prediction.
- `dataset`: The original dataset, which may contain the ground truth NER information.
- `output_file`: The file path where the NER predictions will be saved.

Here's a breakdown of how the output_ner_predictions function works:

1. Initialization: The function initializes dictionaries (`ner_result` and `span_hidden_table`) and a counter (`tot_pred_ett`) to keep track of predicted entities.

2. Batch Prediction: The function iterates over batches and uses the model to predict NER labels (`pred_ner`). It adjusts the predicted spans based on the sentence and document indices.

3. Construction of NER Result Dictionary: The function constructs a dictionary (`ner_result`) where keys are **document and sentence indices**, and values are **lists of predicted entities** in the format [start, end, label].

4. Updating Dataset with Predictions: The function updates the original dataset's JSON representation (js) with the predicted NER entities and **empty predicted relations**.

5. Writing Predictions to File: The function writes the updated JSON representation to the specified output file.

This JSON file is the input for the relation model in the run_relation notebook

In [20]:
def output_ner_predictions(model, batches, dataset, output_file):
    """
    Save the prediction as a json file
    """
    ner_result = {}
    span_hidden_table = {}
    tot_pred_ett = 0
    
    # Iterate over batches and make predictions
    for i in range(len(batches)):
        output_dict = model.run_batch(batches[i], training=False)
        pred_ner = output_dict['pred_ner']
        
        # Iterate over samples in the batch
        for sample, preds in zip(batches[i], pred_ner):
            off = sample['sent_start_in_doc'] - sample['sent_start']
            k = sample['doc_key'] + '-' + str(sample['sentence_ix'])
            ner_result[k] = []
            
            # Iterate over spans and predicted labels
            for span, pred in zip(sample['spans'], preds):
                span_id = '%s::%d::(%d,%d)'%(sample['doc_key'], sample['sentence_ix'], span[0]+off, span[1]+off)
                if pred == 0:
                    continue
                ner_result[k].append([span[0]+off, span[1]+off, ner_id2label[pred]])
            tot_pred_ett += len(ner_result[k])

    logger.info('Total pred entities: %d'%tot_pred_ett)

    # Update the original dataset with predicted NER entities
    js = dataset.js
    for i, doc in enumerate(js):
        doc["predicted_ner"] = []
        doc["predicted_relations"] = []
        for j in range(len(doc["sentences"])):
            k = doc['doc_key'] + '-' + str(j)
            if k in ner_result:
                doc["predicted_ner"].append(ner_result[k])
            else:
                logger.info('%s not in NER results!'%k)
                doc["predicted_ner"].append([])
            
            doc["predicted_relations"].append([])

        js[i] = doc

    # Write the updated JSON representation to the output file
    logger.info('Output predictions to %s..'%(output_file))
    with open(output_file, 'w') as f:
        f.write('\n'.join(json.dumps(doc, cls=NpEncoder) for doc in js))


### evaluate

The evaluate function is designed to evaluate the performance of an entity model. The function takes the following parameters:

- `model`: The trained entity model.
- `batches`: A list of batches, each containing samples for evaluation.
- `tot_gold`: The total number of gold (ground truth) entities in the dataset.

Here's a breakdown of how the evaluate function works:
1. Initialization: The function initializes variables to keep track of correct predictions (`cor)`, total predictions (`tot_pred`), correct predictions for spans (`l_cor`), and total spans (`l_tot`).

2. Batch Evaluation: The function iterates over batches, uses the model to predict NER labels (`pred_ner`), and compares the predictions with the gold labels for each span.

3. Calculation of Metrics: The function calculates accuracy, precision (`p`), recall (`r`), and F1 score (`f1`) based on the counts of correct predictions, total predictions, and total gold entities.

4. Logging and Return: The function logs the calculated metrics and the time taken for evaluation. It then returns the F1 score.

In [21]:
def evaluate(model, batches, tot_gold):
    """
    Evaluate the entity model
    """
    logger.info('Evaluating...')
    c_time = time.time()
    cor = 0
    tot_pred = 0
    l_cor = 0
    l_tot = 0

    for i in range(len(batches)):
        output_dict = model.run_batch(batches[i], training=False)
        pred_ner = output_dict['pred_ner']
        for sample, preds in zip(batches[i], pred_ner):
            for gold, pred in zip(sample['spans_label'], preds):
                l_tot += 1
                if pred == gold:
                    l_cor += 1
                if pred != 0 and gold != 0 and pred == gold:
                    cor += 1
                if pred != 0:
                    tot_pred += 1
                   
    acc = l_cor / l_tot
    logger.info('Accuracy: %5f'%acc)
    logger.info('Cor: %d, Pred TOT: %d, Gold TOT: %d'%(cor, tot_pred, tot_gold))
    p = cor / tot_pred if cor > 0 else 0.0
    r = cor / tot_gold if cor > 0 else 0.0
    f1 = 2 * (p * r) / (p + r) if cor > 0 else 0.0
    logger.info('P: %.5f, R: %.5f, F1: %.5f'%(p, r, f1))
    logger.info('Used time: %f'%(time.time()-c_time))
    return f1


In [22]:
def fields_to_batches(d, keys_to_ignore=[]):
    keys = [key for key in d.keys() if key not in keys_to_ignore]
    lengths = [len(d[k]) for k in keys]
    assert len(set(lengths)) == 1
    length = lengths[0]
    res = [{k: d[k][i] for k in keys} for i in range(length)]
    return res

## Run entity do eval

Now we can actually run the model. And remember that we're using a pre trained model, we havn't trained our own model yet. We will do that in the next steps.

### Basic setup

First a variable `task_ner_labels` is defined, a dictionarry mapping each dataset to its entity types.

In [23]:
task_ner_labels = {
    'ace04': ['FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER'],
    'ace05': ['FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER'],
    'scierc': ['Method', 'OtherScientificTerm', 'Task', 'Generic', 'Material', 'Metric'],
}

Then we define some variables:
- `data_dir`: The directory in which our input data is stored.
- `output_dir`: The directory to which to write  the output of the mnodel.
- `task`: The task that the model will be used to make predictions on. 
- max_span_length: The maximum length of spans to consider. 
- context_window: The size of the context window to consider around each sentence.
- eval_batch_size: The batch size of the samples.
- test_pred_filename: The name of the prediction output file.

In [24]:
data_dir = os.getcwd() + '/scierc_data/processed_data/json'
output_dir = os.getcwd() + '/scierc_models/ent-scib-ctx0/'
task = 'scierc'
max_span_length = 8
context_window = 0
eval_batch_size = 32
test_pred_filename = 'ent_pred_test.json'

### Running and evaluationg the pre-trained model

Now that the setup is out of the way. We can actually run the model and evaluate it with a pre-trained BERT-based model on the SciERC dataset.

#### Data File Paths:
Since the SciERC dataset is already split into a training, development, and test set. We don't need to perform any split. So let's just load set the paths to the data files dowanloaded with the dataset.

The input data format of the entity model is JSONL. Each line of the input file contains one document in the following format.
```json
{
  # document ID (please make sure doc_key can be used to identify a certain document)
  "doc_key": "CNN_ENG_20030306_083604.6",

  # sentences in the document, each sentence is a list of tokens
  "sentences": [
    [...],
    [...],
    ["tens", "of", "thousands", "of", "college", ...],
    ...
  ],

  # entities (boundaries and entity type) in each sentence
  "ner": [
    [...],
    [...],
    [[26, 26, "LOC"], [14, 14, "PER"], ...], #the boundary positions are indexed in the document level
    ...,
  ],

  # relations (two spans and relation type) in each sentence
  "relations": [
    [...],
    [...],
    [[14, 14, 10, 10, "ORG-AFF"], [14, 14, 12, 13, "ORG-AFF"], ...],
    ...
  ]
}
```

In [25]:
train_data = os.path.join(data_dir, 'train.json')
dev_data = os.path.join(data_dir, 'dev.json')
test_data = os.path.join(data_dir, 'test.json')

#### Output Directory Check

Then, just to be safe, we check if the specified output directory (`output_dir`) exists. If not, we create the directory. This ensures that the output directory is available for storing model checkpoints, predictions, or other outputs.

In [26]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

#### NER Label Mapping

The `get_labelmap` function is used to get the mapping for the SchiREC task as discussed above.

In [27]:
ner_label2id, ner_id2label = get_labelmap(task_ner_labels[task])

#### Development Dataset Processing

The development dataset (`dev_data`) is loaded into a `Dataset` object. Then, it is processed using the `convert_dataset_to_samples` function to obtain samples and NER labels. The samples are batchified using the `batchify` function.

In [28]:
dev_data = Dataset(dev_data)
dev_samples, dev_ner = convert_dataset_to_samples(dev_data, max_span_length, ner_label2id=ner_label2id, context_window=context_window)
dev_batches = batchify(dev_samples, eval_batch_size)

11/23/2023 17:44:41 - INFO - root - # Overlap: 0
11/23/2023 17:44:41 - INFO - root - Extracted 275 samples from 50 documents, with 811 NER labels, 23.713 avg input length, 68 max length
11/23/2023 17:44:41 - INFO - root - Max Length: 68, max NER: 11


#### Model Initialization

The BERT-based entity model (`EntityModel`) is initialized with specific parameters, including the BERT model name (`allenai/scibert_scivocab_uncased`), output directory for saving checkpoints (`bert_model_dir`), and the number of NER labels.

In [69]:
bert_model_dir = output_dir
num_ner_labels = len(task_ner_labels[task]) + 1
model = EntityModel(model='allenai/scibert_scivocab_uncased', bert_model_dir=bert_model_dir, use_albert=False, max_span_length=max_span_length, num_ner_labels=num_ner_labels)

11/23/2023 13:33:05 - INFO - root - Loading BERT model from C:\Users\odaim\Documents\PURE reproduction/scierc_models/ent-scib-ctx0//
11/23/2023 13:33:05 - INFO - transformers.tokenization_utils_base - Model name 'C:\Users\odaim\Documents\PURE reproduction/scierc_models/ent-scib-ctx0//' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, TurkuNLP/bert-base-finnish-cased-v1, TurkuNLP/bert-base-finnish-uncased-v1, wietsedv/bert-base-dutch-cased). Assuming 'C:\Users\odaim\Documents\PURE reproduction/scierc_models/ent-scib-ctx0//' is a path, a model ident

#### Test Dataset Processing and Evaluation

Finally the test dataset (`test_data`) is loaded, processed, and batchified similarly to the development dataset. The model is then evaluated on the test data using the `evaluate` function, and the NER predictions are saved to a file using the `output_ner_predictions` function.

In [70]:
test_data = Dataset(test_data)
prediction_file = os.path.join(output_dir, test_pred_filename)

test_samples, test_ner = convert_dataset_to_samples(test_data, max_span_length, ner_label2id=ner_label2id, context_window=context_window)
test_batches = batchify(test_samples, eval_batch_size)
evaluate(model, test_batches, test_ner)
output_ner_predictions(model, test_batches, test_data, output_file=prediction_file)

11/23/2023 13:33:13 - INFO - root - # Overlap: 0
11/23/2023 13:33:13 - INFO - root - Extracted 551 samples from 100 documents, with 1685 NER labels, 24.321 avg input length, 97 max length
11/23/2023 13:33:13 - INFO - root - Max Length: 97, max NER: 13
11/23/2023 13:33:13 - INFO - root - Evaluating...
11/23/2023 13:33:28 - INFO - root - Accuracy: 0.990194
11/23/2023 13:33:28 - INFO - root - Cor: 1122, Pred TOT: 1680, Gold TOT: 1685
11/23/2023 13:33:28 - INFO - root - P: 0.66786, R: 0.66588, F1: 0.66686
11/23/2023 13:33:28 - INFO - root - Used time: 15.231171
11/23/2023 13:33:41 - INFO - root - Total pred entities: 1680
11/23/2023 13:33:41 - INFO - root - Output predictions to C:\Users\odaim\Documents\PURE reproduction/scierc_models/ent-scib-ctx0/ent_pred_test.json..


### Training and evaluating the entity model from scratch

#### One more utility function

##### `save_model`

This function is used to save the trained model to the output drectory

In [29]:
def save_model(model, output_dir):
    """
    Save the model to the output directory
    """
    logger.info('Saving model to %s...'%(output_dir))
    model_to_save = model.bert_model.module if hasattr(model.bert_model, 'module') else model.bert_model
    model_to_save.save_pretrained(output_dir)
    model.tokenizer.save_pretrained(output_dir)

#### Setting up some variables

Now we setup some variables that are needed for the training. And we have some new variables:

- `bertadam`: If bertadam, then set correct_bias = False
- `num_epoch`: The number of the training epochs. (I set this to 1 because epochs take too long on my machine)
- `warmup_proportion`: The ratio of the warmup steps to the total steps
- `eval_per_epoch`: How often evaluating the trained model on dev set during training
- `train_shuffle`: Whether to train with randomly shuffled data
- `print_loss_step`: How often logging the loss value during training

In [30]:
data_dir = os.getcwd() + '/scierc_data/processed_data/json'
output_dir = os.getcwd() + '/scierc_models/from-scratch/ent-scib-ctx0/'
task = 'scierc'
num_ner_labels = len(task_ner_labels[task]) + 1
max_span_length = 8
context_window = 300
eval_batch_size = 32
train_batch_size = 2
learning_rate = 1e-5
task_learning_rate = 5e-4
bertadam = True # If bertadam, then set correct_bias = False
num_epoch = 10 # number of the training epochs
warmup_proportion = 0.1 # the ratio of the warmup steps to the total steps
eval_per_epoch = 1 # how often evaluating the trained model on dev set during training
train_shuffle = True # whether to train with randomly shuffled data
print_loss_step = 100 # how often logging the loss value during training

#### Output directory validation

Check if output directory exists and create it if it doesn't

In [31]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

#### Initialize our entity model

The diffrence here is that we don't set the bert_model_dir variable. Instead, we'd like to train the model from scratch.

In [32]:
model = EntityModel(model='allenai/scibert_scivocab_uncased', use_albert=False, max_span_length=max_span_length, num_ner_labels=num_ner_labels)

11/23/2023 17:44:52 - INFO - transformers.tokenization_utils_base - Model name 'allenai/scibert_scivocab_uncased' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, TurkuNLP/bert-base-finnish-cased-v1, TurkuNLP/bert-base-finnish-uncased-v1, wietsedv/bert-base-dutch-cased). Assuming 'allenai/scibert_scivocab_uncased' is a path, a model identifier, or url to a directory containing tokenizer files.
11/23/2023 17:44:56 - INFO - transformers.tokenization_utils_base - loading file https://s3.amazonaws.com/models.huggingface.co/bert/allenai/scibert_scivoca

#### Load training data

We load the training data from the JSON file into a Database instance

In [33]:
train_data = Dataset(train_data)

#### Training the model

Now we can train the model.

In [34]:
train_samples, train_ner = convert_dataset_to_samples(train_data, max_span_length, ner_label2id=ner_label2id, context_window=context_window)
train_batches = batchify(train_samples, train_batch_size)
best_result = 0.0

param_optimizer = list(model.bert_model.named_parameters())
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer
        if 'bert' in n]},
    {'params': [p for n, p in param_optimizer
        if 'bert' not in n], 'lr': task_learning_rate}]
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, correct_bias=not(bertadam))
t_total = len(train_batches) * num_epoch
scheduler = get_linear_schedule_with_warmup(optimizer, int(t_total*warmup_proportion), t_total)

tr_loss = 0
tr_examples = 0
global_step = 0
eval_step = len(train_batches) // eval_per_epoch
for _ in tqdm(range(num_epoch), position=0, leave=True):
    if train_shuffle:
        random.shuffle(train_batches)
    for i in tqdm(range(len(train_batches)), position=0, leave=True):
        output_dict = model.run_batch(train_batches[i], training=True)
        loss = output_dict['ner_loss']
        loss.backward()

        tr_loss += loss.item()
        tr_examples += len(train_batches[i])
        global_step += 1

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if global_step % print_loss_step == 0:
            logger.info('Epoch=%d, iter=%d, loss=%.5f'%(_, i, tr_loss / tr_examples))
            tr_loss = 0
            tr_examples = 0

        if global_step % eval_step == 0:
            f1 = evaluate(model, dev_batches, dev_ner)
            if f1 > best_result:
                best_result = f1
                logger.info('!!! Best valid (epoch=%d): %.2f' % (_, f1*100))
                save_model(model, output_dir)

11/23/2023 17:45:09 - INFO - root - # Overlap: 0
11/23/2023 17:45:09 - INFO - root - Extracted 1861 samples from 350 documents, with 5598 NER labels, 140.335 avg input length, 300 max length
11/23/2023 17:45:09 - INFO - root - Max Length: 101, max NER: 13
 11%|█         | 99/931 [00:37<04:56,  2.80it/s]11/23/2023 17:45:46 - INFO - root - Epoch=0, iter=99, loss=166.25206
 21%|██▏       | 199/931 [01:12<03:56,  3.09it/s]11/23/2023 17:46:21 - INFO - root - Epoch=0, iter=199, loss=18.53227
 32%|███▏      | 299/931 [01:46<03:20,  3.15it/s]11/23/2023 17:46:55 - INFO - root - Epoch=0, iter=299, loss=14.77129
 43%|████▎     | 399/931 [02:22<03:35,  2.47it/s]11/23/2023 17:47:32 - INFO - root - Epoch=0, iter=399, loss=12.42971
 54%|█████▎    | 499/931 [02:59<02:17,  3.13it/s]11/23/2023 17:48:08 - INFO - root - Epoch=0, iter=499, loss=9.51807
 64%|██████▍   | 599/931 [03:33<01:40,  3.30it/s]11/23/2023 17:48:43 - INFO - root - Epoch=0, iter=599, loss=8.73267
 75%|███████▌  | 699/931 [04:09<01:35, 