[![Colab Badge Link](https://img.shields.io/badge/open-in%20colab-blue)](https://colab.research.google.com/github/Glasgow-AI4BioMed/tutorials/blob/main/custom_token_classification_models.ipynb)

# Creating a custom token classification model

This notebook illustrates creating a custom Transformer model that is compatible with the [Huggingface trainer](https://huggingface.co/docs/transformers/main_classes/trainer). This model will use intermediate hidden states (so not the final hidden state) of a Transformer model for a token classification task.

## Install dependencies

If needed, you could install dependencies with the command below:

```
pip install transformers
```

## Tokenize some text

We'll work with a single example. First we need a tokenizer:

In [None]:
from transformers import AutoTokenizer

MODEL_NAME = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

Let's tokenize an example sentence.

In [None]:
text = 'The quick brown fox jumps over the lazy dog.'

encoded = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')

encoded

How many tokens do we have? Looks like 15

In [None]:
encoded['input_ids'].shape

We also need to make some dummy labels that will be our desired targets for our model. There are 15 tokens in the sequence and we need one for each token. The labels could correspond to `[0, B-DRUG, I-DRUG, etc]` for a biomedical NER task. Arbitrarily, let's say there are nine unique labels.

In [None]:
num_labels = 9

Let's create some labels randomly using [torch.randint](https://pytorch.org/docs/stable/generated/torch.randint.html).

In [None]:
import torch

labels = torch.randint(low=0, high=num_labels, size=(1,15))
labels.shape

For realism, some tokens shouldn't have labels, such as the `[CLS]` and `[SEP]` tokens used in many BERT models. In this case, they are at the beginning and end of this sequence. So to tell the model to ignore these (and not factor them into any calculations), you can set the labels to the special value of `-100`. The loss function that we'll use later ([CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)) knows that -100 denotes a token that should be ignored.

In [None]:
labels[0,0] = -100
labels[0,-1] = -100

And finally what do our made-up labels look like?

In [None]:
labels

## Examining the AutoModelForTokenClassification

Now we've got some tokenized text and some made-up labels, let's see what happens when we put them through a standard `AutoModelForTokenClassification`. Our eventual model should give a similar output type as this.

Let's create a `AutoModelForTokenClassification` and pass in the number of labels to be predicted.

In [None]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, num_labels=num_labels)

Normally, you would then fine-tune this model with data before using it. But let's just use it now and see the type of the output. The actual output will be nonsense as there hasn't been any fine-tuning.

Practically, when a model is trained, it is given data along with the labels. So let's take the tokenized text and add in the labels.

In [None]:
encoded_with_labels = dict(encoded)
encoded_with_labels['labels'] = labels

Then we pass in the tokenized text with the labels and let's examine what it returns

In [None]:
output = model(**encoded_with_labels)

First, what is the type of the object returned?

In [None]:
type(output)

It is a [TokenClassifierOutput](https://huggingface.co/docs/transformers/main_classes/output#transformers.modeling_outputs.TokenClassifierOutput) which wraps various bits of information.

Let's just print it out and see what it gives:

In [None]:
output

There are two important things in this object:
 - **loss**: This is the loss that the fine-tuning will try to minimise.
 - **logits**: This is the output of the whole model

Let's examine each. First the logits is a [pytorch.tensor](https://pytorch.org/docs/stable/tensors.html). Let's see it's dimensions

In [None]:
output.logits.shape

The dimensions are explained below:
  - 1: We've only given a single input text
  - 15: The length of the input sequence
  - 9: The number of labels

For our custom model, we will want to output a tensor of this same dimension for the same input: `[1, 15, 9]`

There is a score for each of the possible nine labels. Let's see the scores for the first token in the sequence:

In [None]:
output.logits[0,0,:]

You could use `.argmax` to find the label that has the highest score. We don't need to do that here.

As an aside, these scores are often [softmaxxed](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html) to get nice scores between 0 and 1.

In [None]:
from scipy.special import softmax

softmax(output.logits[0,0,:].tolist())

The other important output from this `AutoModelForTokenClassification` is the loss. This is a single number that the fine-tuning tries to minimise. It is calculated using the `logits` above when compared against the provided target `labels`.

In [None]:
output.loss

## Create a custom TokenClassifierOutput

Now let's say that we want to make our own model that can be used for TokenClassification but does things slightly differently.

We might start off with a general-purpose `AutoModel` that doesn't have a final task-specific layer on it. If we wanted access to all the hidden layers, we can provide `output_hidden_states=True`.

In [None]:
from transformers import AutoModel

model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states=True)

Then we can get the output of this model and rework it for what we need it to do.

In [None]:
output = model(**encoded)

We can get access to all the hidden states. Let's see how many and their dimensions.

In [None]:
for i,hidden_state in enumerate(output.hidden_states):
  print(i, hidden_state.shape)

There are 12 hidden layers in a standard BERT model, but there are 13 hidden states? Why? Well, we've got the input and output of all 12 layers which comes to 13 sets of context vectors. And all the context vectors are of dimension 768 which is common for standard BERT models.

Now our target shape is `[1, 15, 9]`. One of the hidden layers is almost there, but we need to make it a bit smaller. For this, we can use a fully-connected linear layer to go from 768 down to 9:

In [None]:
import torch

linear = torch.nn.Linear(768, num_labels)

If we apply the linear layer to the final hidden state, we get the logits with the desired shape.

In [None]:
logits = linear(output.hidden_states[-1])
logits.shape

But we could also apply it to one of the intermediate hidden states!

In [None]:
logits = linear(output.hidden_states[5])
logits.shape

At the moment, the linear layer is not fine-tuned, so the output logits would be meaningless. But with fine-tuning, these logits could give us the scores for each of the nine labels, with the highest score being the predicted label for that token.

To effectively train it, we need to calculate the loss between provided labels and the model's current logits for that input. Then the training process can slowly move the logits towards the desired labels. So how do we calculate the loss?

First, remember what the labels look like. We've got one input sequence and an integer representing the labels for each of the fifteen tokens.

In [None]:
labels.shape

Now to calculate the loss, we use [CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) which is used for multi-class classification problems. It expects two inputs:

- The logits in the shape of (sample_count, num_labels)
- The labels (as integers) in the shape (sample_count).

We can use [.reshape](https://pytorch.org/docs/stable/generated/torch.reshape.html) as below to adjust the shapes accordingly. And recall that CrossEntropyLoss will nicely ignore the tokens with `-100` labels as they shouldn't contribute to the loss.

In [None]:
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(logits.reshape(-1,num_labels), labels.reshape(-1))
loss

Now we've calculated the logits and loss, we can create a `TokenClassifierOutput` object that encapsulates them. Now it looks like we have an output similar to `AutoModelForTokenClassification`.

In [None]:
from transformers.modeling_outputs import TokenClassifierOutput

TokenClassifierOutput(loss=loss, logits=logits)

## Creating a custom model

To actually use this custom approach, we need to wrap it up as a `torch.nn.Module`. The example class below takes the various steps from before and puts them into a single class

In [None]:
class CustomModel(torch.nn.Module):
  def __init__(self, num_labels, hidden_layer):
    super().__init__()
    self.base_model = AutoModel.from_pretrained(MODEL_NAME, output_hidden_states = True)

    self.num_labels = num_labels
    self.hidden_layer = hidden_layer

    self.linear = torch.nn.Linear(768, num_labels)
    self.loss_func = torch.nn.CrossEntropyLoss()

  def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
    output = self.base_model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)

    logits = self.linear(output.hidden_states[self.hidden_layer])

    loss = None
    if labels is not None: # If we're provided with labels, use them to calculate the loss
      loss = self.loss_func(logits.reshape(-1,self.num_labels), labels.reshape(-1))

    return TokenClassifierOutput(loss=loss, logits=logits)

Note that the above class works very similarly to the actual implementation for the `AutoModelForTokenClassification`. You can have a look at `BertModelForTokenClassification` in the [HuggingFace source code](https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1714)

One key difference is that this implementation does not use [dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html) which may be beneficial.

Now we can create a model and even select which hidden_layer to connect to the output (and thereby removing some final layers from the calculation).

In [None]:
model = CustomModel(num_labels=num_labels, hidden_layer=5)

Let's pass in the tokenized text with the labels and see what we get

In [None]:
output = model(**encoded_with_labels)
output

Excellent. The [HuggingFace trainer](https://huggingface.co/docs/transformers/main_classes/trainer) can now be used to fine-tune the model on an appropriate dataset.