<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/CANINE/Fine_tune_CANINE_on_IMDb_(movie_review_binary_classification).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tune CANINE for binary text classification

In this notebook, we are going to fine-tune Google's character-level [CANINE](https://arxiv.org/abs/2103.06874) model to classify movie reviews as either positive/negative. We will do so using HuggingFace Transformers (I contributed CANINE in PyTorch to it!). The dataset we are going to use is the [IMDB dataset](https://huggingface.co/datasets/imdb), which is a large collection of movie reviews labeled as positive/negative.

For training, we will use [PyTorch Lightning](https://www.pytorchlightning.ai/) (note that you could also use alternative solutions such as native PyTorch, the [HuggingFace Trainer](https://huggingface.co/transformers/main_classes/trainer.html), [HuggingFace Accelerate](https://github.com/huggingface/accelerate), etc.). For logging the metrics (such as loss and accuracy) during training, we will use Weights and Biases.

Note that this notebook is very similar to how we would fine-tune a BERT model for binary text classification. The only difference is that BERT uses word pieces (subword tokenization), whereas CANINE works at a character-level. 

To give an example, if you would provide the sentence "hello world" to BERT, it would first be tokenized into the word pieces ["hello", "world"]. Then, BERT will convert each word piece into some vector (also referred  to as hidden state). For BERT-base, this is a vector of size 768. CANINE on the other hand would "tokenize" the sentence into ["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"], i.e. split it up into the individual characters. Then, CANINE will convert each character into some vector (for CANINE, this is also a vector of size 768). Classification of sequences is the same for BERT/CANINE: one simply places a linear layer on top of the final hidden state of the special [CLS] token.

* CANINE paper: https://arxiv.org/abs/2103.06874
* CANINE documentation: https://huggingface.co/transformers/model_doc/canine.html

## Install dependencies

In [1]:
!pip install -q transformers

[K     |████████████████████████████████| 2.8 MB 5.1 MB/s 
[K     |████████████████████████████████| 636 kB 69.0 MB/s 
[K     |████████████████████████████████| 3.3 MB 49.1 MB/s 
[K     |████████████████████████████████| 52 kB 1.8 MB/s 
[K     |████████████████████████████████| 895 kB 72.9 MB/s 
[?25h

In [2]:
!pip install -q datasets pytorch_lightning wandb

[K     |████████████████████████████████| 270 kB 5.1 MB/s 
[K     |████████████████████████████████| 924 kB 65.1 MB/s 
[K     |████████████████████████████████| 1.7 MB 51.9 MB/s 
[K     |████████████████████████████████| 123 kB 62.6 MB/s 
[K     |████████████████████████████████| 1.3 MB 53.3 MB/s 
[K     |████████████████████████████████| 243 kB 77.7 MB/s 
[K     |████████████████████████████████| 829 kB 55.3 MB/s 
[K     |████████████████████████████████| 282 kB 71.1 MB/s 
[K     |████████████████████████████████| 180 kB 75.0 MB/s 
[K     |████████████████████████████████| 139 kB 75.0 MB/s 
[K     |████████████████████████████████| 97 kB 8.6 MB/s 
[K     |████████████████████████████████| 63 kB 2.2 MB/s 
[K     |████████████████████████████████| 142 kB 70.8 MB/s 
[K     |████████████████████████████████| 294 kB 73.7 MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone
  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel fo

## Prepare data

Here we load a small portion of the IMDb dataset which is hosted on the HuggingFace hub, for demonstration purposes.

In [3]:
from datasets import load_dataset

train_ds, test_ds = load_dataset("imdb", split=['train[:100]', 'test[:50]'])

Downloading:   0%|          | 0.00/1.92k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

Downloading and preparing dataset imdb/plain_text (download: 80.23 MiB, generated: 127.02 MiB, post-processed: Unknown size, total: 207.25 MiB) to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a...


Downloading:   0%|          | 0.00/84.1M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset imdb downloaded and prepared to /root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Let's look at one particular example:

In [4]:
train_ds[0]

{'label': 1,
 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High\'s satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\'t!'}

In [5]:
labels = train_ds.features['label'].names
print(labels)

['neg', 'pos']


In [6]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

{0: 'neg', 1: 'pos'}


In [7]:
from transformers import CanineTokenizer

tokenizer = CanineTokenizer.from_pretrained("google/canine-s")

train_ds = train_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)
test_ds = test_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)

Downloading:   0%|          | 0.00/657 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/854 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/670 [00:00<?, ?B/s]

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [8]:
train_ds.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
test_ds.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])

train_ds = train_ds.rename_column(original_column_name="label", new_column_name="labels")
test_ds = test_ds.rename_column(original_column_name="label", new_column_name="labels")

In [9]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=4)

In [10]:
batch = next(iter(train_dataloader))

In [11]:
tokenizer.decode(batch['input_ids'][2])

"[CLS]I enjoyed The Night Listener very much. It's one of the better movies of the summer.<br /><br />Robin Williams gives one of his best performances. In fact, the entire cast was very good. All played just the right notes for their characters - not too much and not too little. Sandra Oh adds a wonderful comic touch. Toni Collette is great as the Mom, and never goes over the top. Everyone is very believable.<br /><br />It's a short movie, just under an hour and a half. I noticed the general release version is nine minutes shorter than the Sundance version. I wonder if some of the more disturbing images were cut from the movie.<br /><br />The director told a story and did it in straightforward fashion, which is a refreshing change from many directors these days who seem to think their job is to impress the audience rather than tell a story and tell it well.<br /><br />Do not be sucker punched by the previews and ads. It is not a Hitchcockian thriller. See The Night Listener because yo

In [12]:
id2label[batch['labels'][2].item()]

'pos'

## Define model

In [13]:
import pytorch_lightning as pl
from transformers import CanineForSequenceClassification, AdamW
import torch.nn as nn

class CanineReviewClassifier(pl.LightningModule):
    def __init__(self, num_labels=10):
        super(CanineReviewClassifier, self).__init__()
        self.model = CanineForSequenceClassification.from_pretrained('google/canine-s', 
                                                                     num_labels=len(labels),
                                                                     id2label=id2label,
                                                                     label2id=label2id)

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                             labels=labels)

        return outputs
        
    def common_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        logits = outputs.logits

        predictions = logits.argmax(-1)
        correct = (predictions == batch['labels']).sum().item()
        accuracy = correct/batch['input_ids'].shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return test_dataloader

## Train the model

In [14]:
import wandb

wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [15]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

model = CanineReviewClassifier()

wandb_logger = WandbLogger(name='canine-imdb-1', project='CANINE')
trainer = Trainer(gpus=1, logger=wandb_logger, callbacks=[EarlyStopping(monitor='validation_loss')])
trainer.fit(model)

Downloading:   0%|          | 0.00/529M [00:00<?, ?B/s]

Some weights of CanineForSequenceClassification were not initialized from the model checkpoint at google/canine-s and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mnielsrogge[0m (use `wandb login --relogin` to force relogin)



  | Name  | Type                            | Params
----------------------------------------------------------
0 | model | CanineForSequenceClassification | 132 M 
----------------------------------------------------------
132 M     Trainable params
0         Non-trainable params
132 M     Total params
528.338   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Inference

After training, we can save the model as follows:

In [17]:
model.model.save_pretrained('.')

In that way, we can load it back as follows:

In [18]:
from transformers import CanineForSequenceClassification

model = CanineForSequenceClassification.from_pretrained('.')

Let's test it on a new review:

In [20]:
text = "I absolutely love this movie"

# prepare text for the model
encoding = tokenizer(text, return_tensors="pt")

# forward pass
outputs = model(**encoding)

# convert logits to actual predicted class
logits = outputs.logits
pred_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[pred_class_idx])

Predicted class: LABEL_1
