<a href="https://colab.research.google.com/github/LxYuan0420/nlp/blob/main/Transformers-Tutorials/CANINE/Fine_tune_CANINE_on_IMDb_(movie_review_binary_classification).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tune CANINE for binary text classification

In this notebook, we are going to fine-tune Google's character-level [CANINE](https://arxiv.org/abs/2103.06874) model to classify movie reviews as either positive/negative. We will do so using HuggingFace Transformers (I contributed CANINE in PyTorch to it!). The dataset we are going to use is the [IMDB dataset](https://huggingface.co/datasets/imdb), which is a large collection of movie reviews labeled as positive/negative.

For training, we will use [PyTorch Lightning](https://www.pytorchlightning.ai/) (note that you could also use alternative solutions such as native PyTorch, the [HuggingFace Trainer](https://huggingface.co/transformers/main_classes/trainer.html), [HuggingFace Accelerate](https://github.com/huggingface/accelerate), etc.). For logging the metrics (such as loss and accuracy) during training, we will use Weights and Biases.

Note that this notebook is very similar to how we would fine-tune a BERT model for binary text classification. The only difference is that BERT uses word pieces (subword tokenization), whereas CANINE works at a character-level. 

To give an example, if you would provide the sentence "hello world" to BERT, it would first be tokenized into the word pieces ["hello", "world"]. Then, BERT will convert each word piece into some vector (also referred  to as hidden state). For BERT-base, this is a vector of size 768. CANINE on the other hand would "tokenize" the sentence into ["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"], i.e. split it up into the individual characters. Then, CANINE will convert each character into some vector (for CANINE, this is also a vector of size 768). Classification of sequences is the same for BERT/CANINE: one simply places a linear layer on top of the final hidden state of the special [CLS] token.

* CANINE paper: https://arxiv.org/abs/2103.06874
* CANINE documentation: https://huggingface.co/transformers/model_doc/canine.html

## Install dependencies

In [9]:
!pip install -q transformers

In [10]:
!pip install -q datasets pytorch_lightning wandb

## Prepare data

Here we load a small portion of the IMDb dataset which is hosted on the HuggingFace hub, for demonstration purposes.

In [17]:
from datasets import load_dataset

train_ds, test_ds = load_dataset("imdb", split=["train[:100]", "test[:50]"])



  0%|          | 0/2 [00:00<?, ?it/s]

Let's look at one particular example:

In [12]:
train_ds[0]

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

In [13]:
labels = train_ds.features["label"].names
print(labels)

['neg', 'pos']


In [14]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

{0: 'neg', 1: 'pos'}


In [18]:
from transformers import CanineTokenizer

tokenizer = CanineTokenizer.from_pretrained("google/canine-s")

train_ds = train_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                        batched=True)
test_ds = test_ds.map(lambda examples: tokenizer(examples['text'], padding="max_length", truncation=True),
                      batched=True)

Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.


In [20]:
train_ds[0].keys()

dict_keys(['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'])

In [21]:
train_ds.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "label"])
test_ds.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "label"])

train_ds = train_ds.rename_column(original_column_name="label", new_column_name="labels")
test_ds = test_ds.rename_column(original_column_name="label", new_column_name="labels")

In [23]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True)
test_dataloader = DataLoader(test_ds, batch_size=4)

In [24]:
batch = next(iter(train_dataloader))

In [25]:
tokenizer.decode(batch['input_ids'][2])

'[CLS]Not only is it a disgustingly made low-budget bad-acted movie, but the plot itself is just STUPID!!!<br /><br />A mystic man that eats women? (And by the looks, not virgin ones)<br /><br />Ridiculous!!! If you´ve got nothing better to do (like sleeping) you should watch this. Yeah right.[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][

In [26]:
id2label[batch['labels'][2].item()]

'neg'

## Define model

In [31]:
import pytorch_lightning as pl
from transformers import CanineForSequenceClassification, AdamW
import torch.nn as nn

class CanineReviewClassifier(pl.LightningModule):
    def __init__(self, num_labels=10):
        super(CanineReviewClassifier, self).__init__()
        self.model = CanineForSequenceClassification.from_pretrained("google/canine-s",
                                                                     num_labels=len(labels),
                                                                     id2label=id2label,
                                                                     label2id=label2id)

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels,
        )

        return outputs

    def common_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        logits = outputs.logits

        predictions = logits.argmax(-1)
        labels = batch["labels"]
        correct = (predictions == labels).sum().item()
        accuracy = correct / batch["input_ids"].shape[0]

        return loss, accuracy
    
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("validation_loss", loss)
        self.log("validation_accuracy", accuracy)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)

        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader
    
    def val_dataloader(self):
        return test_dataloader

In [None]:
import pytorch_lightning as pl
from transformers import CanineForSequenceClassification, AdamW
import torch.nn as nn

class CanineReviewClassifier(pl.LightningModule):
    def __init__(self, num_labels=10):
        super(CanineReviewClassifier, self).__init__()
        self.model = CanineForSequenceClassification.from_pretrained('google/canine-s', 
                                                                     num_labels=len(labels),
                                                                     id2label=id2label,
                                                                     label2id=label2id)

    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                             labels=labels)

        return outputs
        
    def common_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        logits = outputs.logits

        predictions = logits.argmax(-1)
        correct = (predictions == batch['labels']).sum().item()
        accuracy = correct/batch['input_ids'].shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=5e-5)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return test_dataloader

## Train the model

In [None]:
import wandb

wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [32]:
from pytorch_lightning import Trainer
#from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

model = CanineReviewClassifier()

#wandb_logger = WandbLogger(name='canine-imdb-1', project='CANINE')
#trainer = Trainer(gpus=1, logger=wandb_logger, callbacks=[EarlyStopping(monitor='validation_loss')])
trainer = Trainer(gpus=1, callbacks=[EarlyStopping(monitor='validation_loss')])
trainer.fit(model)

Some weights of CanineForSequenceClassification were not initialized from the model checkpoint at google/canine-s and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                            | Params
----------------------------------------------------------
0 | model | CanineForSequenceClassification | 132 M 
----------------------------------------------------------
132 M     Trainable params
0

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Inference

After training, we can save the model as follows:

In [None]:
model.model.save_pretrained('.')

In that way, we can load it back as follows:

In [None]:
from transformers import CanineForSequenceClassification

model = CanineForSequenceClassification.from_pretrained('.')

Let's test it on a new review:

In [34]:
model = model.model

In [36]:
model.device

device(type='cuda', index=0)

In [38]:
text = "I absolutely love this movie. This is the best movie i have ever watched. Love it. Good"

# prepare text for the model
encoding = tokenizer(text, return_tensors="pt").to("cuda")

# forward pass
outputs = model(**encoding)

# convert logits to actual predicted class
logits = outputs.logits
pred_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[pred_class_idx])

Predicted class: neg
