# Fine-tune CANINE for binary text classification

In this notebook, we are going to fine-tune Google's character-level [CANINE](https://arxiv.org/abs/2103.06874) model to classify movie reviews as either positive/negative. We will do so using HuggingFace Transformers (I contributed CANINE in PyTorch to it!). The dataset we are going to use is the [IMDB dataset](https://huggingface.co/datasets/imdb), which is a large collection of movie reviews labeled as positive/negative.

For training, we will use [PyTorch Lightning](https://www.pytorchlightning.ai/) (note that you could also use alternative solutions such as native PyTorch, the [HuggingFace Trainer](https://huggingface.co/transformers/main_classes/trainer.html), [HuggingFace Accelerate](https://github.com/huggingface/accelerate), etc.). For logging the metrics (such as loss and accuracy) during training, we will use Weights and Biases.

Note that this notebook is very similar to how we would fine-tune a BERT model for binary text classification. The only difference is that BERT uses word pieces (subword tokenization), whereas CANINE works at a character-level. 

To give an example, if you would provide the sentence "hello world" to BERT, it would first be tokenized into the word pieces ["hello", "world"]. Then, BERT will convert each word piece into some vector (also referred  to as hidden state). For BERT-base, this is a vector of size 768. CANINE on the other hand would "tokenize" the sentence into ["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"], i.e. split it up into the individual characters. Then, CANINE will convert each character into some vector (for CANINE, this is also a vector of size 768). Classification of sequences is the same for BERT/CANINE: one simply places a linear layer on top of the final hidden state of the special [CLS] token.

* CANINE paper: https://arxiv.org/abs/2103.06874
* CANINE documentation: https://huggingface.co/transformers/model_doc/canine.html

## Install dependencies

In [1]:
# !pip install -q transformers

In [2]:
# !pip install -q datasets pytorch_lightning wandb

In [3]:
import numpy as np
import torch
import re

# # The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
# torch.backends.cuda.matmul.allow_tf32 = True

# # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
# torch.backends.cudnn.allow_tf32 = True

  from .autonotebook import tqdm as notebook_tqdm


## Prepare data

Here we load a small portion of the IMDb dataset which is hosted on the HuggingFace hub, for demonstration purposes.

In [4]:
from datasets import load_dataset

# train_ds, test_ds = load_dataset("dumitrescustefan/diacritic", split=['train[:100]', 'validation[:50]'])
dataset = load_dataset("dumitrescustefan/diacritic")
dataset["train"] = dataset["train"].select(list(range(120_000)))
dataset["validation"] = dataset["validation"].select(list(range(50_000)))
train_ds = dataset
train_ds

No config specified, defaulting to: diacritic/v1
Found cached dataset diacritic (/home/simi2525/.cache/huggingface/datasets/dumitrescustefan___diacritic/v1/1.0.0/3638b1258d10fd88cfa09c039deb0bec5b4fa7f2d71da6197bf2375026a70970)
100%|██████████| 2/2 [00:00<00:00, 49.37it/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'text'],
        num_rows: 120000
    })
    validation: Dataset({
        features: ['id', 'text'],
        num_rows: 50000
    })
})

In [5]:
# train_ds = train_ds.map(lambda el:{"text" : str.lower(el["text"])}, num_proc=5)

In [6]:
# dataset.map(lambda ex : {"bla" : [0]})
# dataset.map(lambda ex : ex, batched=True)

In [7]:
# per_label_counts = [3297059,  313306,  102377,   49724,  118957,  135583]
per_label_counts = [1.2545946e+07, 3.1330600e+05, 1.0237700e+05, 4.9724000e+04,
       1.1895700e+05, 1.3558300e+05, 8.4530000e+03, 1.1100000e+02,
       6.5700000e+02, 3.0930000e+03, 6.7000000e+02]
n_samples = sum(per_label_counts)
per_label_weights = [n_samples / (c * n_samples) for c in per_label_counts]
print("Label weights", per_label_weights)
max_length = 256

Label weights [7.970702249156819e-08, 3.1917677925095595e-06, 9.767818943708059e-06, 2.0111012790604135e-05, 8.40639895088141e-06, 7.375555932528415e-06, 0.0001183011948420679, 0.009009009009009009, 0.0015220700152207, 0.00032331070158422246, 0.0014925373134328358]


Let's look at one particular example:

In [8]:
labels = [
    "no_diac", 
    "ă", 
    'î',
    "â",
    "ș",
    "ț",
    'Î',
    'Â',
    'Ă',
    'Ș',
    'Ț',
]

chars_with_diacritics = ['a','t','i','s', 'A', 'T', 'I', 'S']

print(labels)

['no_diac', 'ă', 'î', 'â', 'ș', 'ț', 'Î', 'Â', 'Ă', 'Ș', 'Ț']


In [9]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

{0: 'no_diac', 1: 'ă', 2: 'î', 3: 'â', 4: 'ș', 5: 'ț', 6: 'Î', 7: 'Â', 8: 'Ă', 9: 'Ș', 10: 'Ț'}


In [10]:
train_ds = train_ds.rename_column("text", "labels")
# train_ds[0]

In [11]:
percentage_diacritics_removed = 1.0
def remove_diacritics(input_txt):
    diac_map = {'ț': 't', 'ș': 's', 'Ț': 'T', 'Ș': 'S', 'Ă': 'A', 'ă': 'a', 'Â': 'A', 'â': 'a', 'Î': 'I', 'î': 'i'}
    diacritic_positions = [m.start() for m in re.finditer('ț|ș|Ț|Ș|Ă|ă|Â|â|Î|î', input_txt)]
    to_remove_diacritic_positions = np.random.choice(diacritic_positions, int(len(diacritic_positions) * percentage_diacritics_removed), replace=False)
    for i in range(len(to_remove_diacritic_positions)):
        input_txt = input_txt[:to_remove_diacritic_positions[i]]+ diac_map[input_txt[to_remove_diacritic_positions[i]]] + input_txt[to_remove_diacritic_positions[i]+1:]

    return input_txt


def add_partial_no_diac_input(examples):
    # percentage_diacritics_removed
    
    examples['input'] = [remove_diacritics(input_txt=l) for l in examples["labels"]]
    return examples

In [12]:
train_ds = train_ds.map(add_partial_no_diac_input, batched=True)
# train_ds[0]

 99%|█████████▉| 119/120 [00:01<00:00, 64.29ba/s]
Loading cached processed dataset at /home/simi2525/.cache/huggingface/datasets/dumitrescustefan___diacritic/v1/1.0.0/3638b1258d10fd88cfa09c039deb0bec5b4fa7f2d71da6197bf2375026a70970/cache-dc6eaf686f3d1683.arrow


In [13]:
def make_actual_labels(examples):
    def make_a_l(lbl):
        result = []
        for s in lbl:
            if s in chars_with_diacritics:
                result.append(label2id["no_diac"])
            elif s in labels:
                result.append(label2id[s])
            else:
                result.append(label2id["no_diac"])
        return result
    examples['labels'] = [make_a_l(l) for l in examples["labels"]]
    return examples

train_ds = train_ds.map(make_actual_labels, batched=True)


 99%|█████████▉| 119/120 [00:02<00:00, 40.33ba/s]
Loading cached processed dataset at /home/simi2525/.cache/huggingface/datasets/dumitrescustefan___diacritic/v1/1.0.0/3638b1258d10fd88cfa09c039deb0bec5b4fa7f2d71da6197bf2375026a70970/cache-024d5959341aec07.arrow


In [14]:
def make_input_mask(examples):
    def make_im(input):
        return list([0 if e not in chars_with_diacritics or e in labels else 1 for e in input])
    examples["input_mask"] = [make_im(l) for l in examples["input"]]
    return examples

# print(train_ds.map(make_input_mask, batched=True)[0]["input"])
# print(train_ds.map(make_input_mask, batched=True)[0]["input_mask"], sep=" ")

train_ds = train_ds.map(make_input_mask, batched=True)


 99%|█████████▉| 119/120 [00:06<00:00, 18.85ba/s]
Loading cached processed dataset at /home/simi2525/.cache/huggingface/datasets/dumitrescustefan___diacritic/v1/1.0.0/3638b1258d10fd88cfa09c039deb0bec5b4fa7f2d71da6197bf2375026a70970/cache-d6b8b15f85dc82ac.arrow


In [15]:
from transformers import CanineTokenizer

tokenizer = CanineTokenizer.from_pretrained("google/canine-s")

train_ds = train_ds.map(lambda examples: tokenizer(examples['input'], padding="max_length", truncation=True, max_length=max_length),
                        batched=True)

        


Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
Using unk_token, but it is not set yet.
 99%|█████████▉| 119/120 [00:22<00:00,  5.25ba/s]
Loading cached processed dataset at /home/simi2525/.cache/huggingface/datasets/dumitrescustefan___diacritic/v1/1.0.0/3638b1258d10fd88cfa09c039deb0bec5b4fa7f2d71da6197bf2375026a70970/cache-eae4a5253106d7cc.arrow


In [16]:
# for key, value in train_ds[0].items():
#     if key != "actual_labels" and key !="id":
#         print(key, len(value))

In [17]:
train_ds = train_ds.map(lambda example: 
    {"attention_mask" : 
        [0] + example["input_mask"] + [0] + 
        [0] * (len(example["input_ids"]) - len(example["input_mask"]) - 2)
        }, num_proc=8)
# for key, value in train_ds[0].items():
#     if key != "actual_labels" and key !="id":
#         print(key, len(value))

#0:   0%|          | 0/15000 [00:00<?, ?ex/s]
#2:   0%|          | 0/15000 [00:00<?, ?ex/s]

[A[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





#0:   2%|▏         | 371/15000 [00:00<00:03, 3708.43ex/s]
[A

[A[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





#0:   5%|▍         | 742/15000 [00:00<00:03, 3690.94ex/s]
[A

[A[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





#0:   7%|▋         | 1112/15000 [00:00<00:04, 2867.32ex/s]

[A[A
[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





#0:  10%|▉         | 1488/15000 [00:00<00:04, 3166.44ex/s]

[A[A
[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





#0:  12%|█▏        | 1860/15000 [00:00<00:03, 3343.88ex/s]

[A[A
[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





#0:  15%|█▍        | 2207/15000 [00:00<00:04, 2928.89ex/s]

[A[A
[A


[A[A[A



[A[A[A[A




[A[A[A[A[A





#0:  17%|█▋        | 2561/15000 [00:00<00:04, 3100.46ex/s]

[A[A
[A


[A[A[A



[A[A[A[A




[

In [18]:
# import matplotlib.pyplot as plt

# def compute_label_counts(examples):
#     def get_info(lbl):
#         result = np.array([lbl.count(i) for i in range(11)])
#         return {
#             "per_label_counts" : result,
#             "length" : len(lbl),
#         }
#     examples['sample_info'] = [get_info(l) for l in examples["labels"]]
#     # print(examples["sample_info"])
#     return examples

# train_ds = train_ds.map(compute_label_counts, batched=True)

# # per_label_counts = np.zeros(6)
# # lengths = []
# # for e in train_ds["sample_info"]:
# #     per_label_counts += e["per_label_counts"]
# #     lengths.append(e["length"])

In [19]:
# per_label_counts = np.zeros(11)
# lengths = []
# for e in train_ds["train"]["sample_info"]:
#     per_label_counts += e["per_label_counts"]
#     lengths.append(e["length"])
# per_label_counts

In [20]:
# print(train_ds["train"][0]["labels"], sep=" ")

In [21]:
train_ds = train_ds.map(lambda example:
        { 
            "labels" : [0] + example["labels"] + [0] + 
                        [0] * (max_length - len(example["labels"]) - 2)
    })

# s = train_ds[0]
# print(s.keys())
# for key, v in s.items():
#     if isinstance(v, list):
#         print(key, len(v))

100%|██████████| 120000/120000 [00:40<00:00, 2928.30ex/s]
Loading cached processed dataset at /home/simi2525/.cache/huggingface/datasets/dumitrescustefan___diacritic/v1/1.0.0/3638b1258d10fd88cfa09c039deb0bec5b4fa7f2d71da6197bf2375026a70970/cache-78d40948e0012cd6.arrow


In [22]:
# print(per_label_counts)
# print(plt.hist(lengths, bins=20))
# plt.show()

In [23]:
train_ds = train_ds.map(lambda example:{"labels": example["labels"][:max_length]})
# train_ds = train_ds.map(lambda example:{"input": example["input"][:max_length]})
train_ds = train_ds.map(lambda example:{"attention_mask": example["attention_mask"][:max_length]})

100%|██████████| 120000/120000 [00:43<00:00, 2745.19ex/s]
Loading cached processed dataset at /home/simi2525/.cache/huggingface/datasets/dumitrescustefan___diacritic/v1/1.0.0/3638b1258d10fd88cfa09c039deb0bec5b4fa7f2d71da6197bf2375026a70970/cache-de9c6736a6da2090.arrow
100%|██████████| 120000/120000 [00:43<00:00, 2736.05ex/s]
Loading cached processed dataset at /home/simi2525/.cache/huggingface/datasets/dumitrescustefan___diacritic/v1/1.0.0/3638b1258d10fd88cfa09c039deb0bec5b4fa7f2d71da6197bf2375026a70970/cache-41f55e4c6eda2d09.arrow


In [24]:
test_ds = train_ds["validation"]
train_ds = train_ds["train"]

In [25]:
train_ds.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
test_ds.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

In [26]:
for i in range(100):
    # print(train_ds[i])
    for key in train_ds[i]:
        if train_ds[i][key].shape[0] != max_length:
            print(key, train_ds[i][key].shape)  

In [27]:
TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 128
LR = 1e-5
EPOCHS = 100

In [28]:
from torch.utils.data import DataLoader


train_dataloader = DataLoader(train_ds, batch_size=TRAIN_BATCH_SIZE, num_workers=0, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_ds, batch_size=TEST_BATCH_SIZE, num_workers=0, drop_last=True)

In [29]:
batch = next(iter(train_dataloader))
# batch

In [30]:
# print(tokenizer.decode(batch['input_ids'][2])[:100])
# print(batch["labels"][2])
# l = batch["labels"][2][1:100]
# i = tokenizer.decode(batch['input_ids'][2][1:])[:100]
# for idx, (c_l, c_i) in enumerate(zip(l,i)):
#     if c_l != -1 and c_l != 0:
#         print(c_l, c_i, idx)


In [31]:
from transformers import CanineForTokenClassification, CanineForSequenceClassification, AdamW, CaninePreTrainedModel, CanineModel
from transformers.modeling_outputs import TokenClassifierOutput
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import copy
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

class CanineForTokenClassificationCustom(CaninePreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.canine = CanineModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()


    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.canine(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        for idx, out in enumerate(outputs):
            print(idx, out)
        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            # loss_fct = CrossEntropyLoss(weight=torch.tensor(per_label_weights, device="cuda"), reduction='none').cuda()
            loss_fct = CrossEntropyLoss(reduction='none').cuda()
            # loss_fct = CrossEntropyLoss(weight=torch.tensor(per_label_weights, device="cpu"), reduction='none')
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            loss = loss * attention_mask.flatten()
            loss = loss.sum() / (attention_mask.sum() + 1e-15)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

model = CanineForTokenClassificationCustom.from_pretrained('google/canine-s', 
                                                                     num_labels=len(labels),
                                                                     id2label=id2label,
                                                                     label2id=label2id)

# model(**batch)



Some weights of CanineForTokenClassificationCustom were not initialized from the model checkpoint at google/canine-s and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Define model

In [32]:
import pytorch_lightning as pl
from transformers import CanineForSequenceClassification, AdamW
import torch.nn as nn

class CanineReviewClassifier(pl.LightningModule):
    def __init__(self):
        super(CanineReviewClassifier, self).__init__()
        self.model = CanineForTokenClassificationCustom.from_pretrained('google/canine-s', 
                                                                     num_labels=len(labels),
                                                                     id2label=id2label,
                                                                     label2id=label2id)
    def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                             labels=labels)

        return outputs
        
    def common_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        logits = outputs.logits

        predictions = logits.argmax(-1)
        correct = (predictions == batch['labels']).sum().item()
        accuracy = correct/batch['input_ids'].shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=LR)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return test_dataloader



## Train the model

In [33]:
import wandb

wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msimi2525[0m ([33mtensor-reloaded[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [34]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

checkpoint_callback = ModelCheckpoint(dirpath="checkpoints", save_top_k=-1, monitor="validation_loss")
# model = CanineReviewClassifier()
model = CanineReviewClassifier.load_from_checkpoint('cannie_v3_checkpoints/epoch=1-step=780.ckpt')
wandb_logger = WandbLogger(name='canine-imdb-1', project='CANINE')
#checkpoint_callback
trainer = Trainer(accelerator='gpu',devices=1, logger=wandb_logger, callbacks=[checkpoint_callback], max_epochs=EPOCHS, )
trainer.fit(model)

Some weights of CanineForTokenClassificationCustom were not initialized from the model checkpoint at google/canine-s and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type                               | Params
-------------------------------------------------------------
0 | model | CanineForTokenClassificationCustom | 132 M 
-------------------------------------------------------------
132 M     Trainable params
0         Non-trainable params
132 M     Total params
528.366   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


0 last_hidden_state
1 pooler_output
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:01<00:01,  1.24s/it]0 last_hidden_state
1 pooler_output
                                                                           

  rank_zero_warn(


Epoch 0:   0%|          | 0/1327 [00:00<?, ?it/s] 0 last_hidden_state
1 pooler_output
Epoch 0:   0%|          | 1/1327 [00:00<13:11,  1.68it/s, loss=0.177, v_num=jywh]0 last_hidden_state
1 pooler_output
Epoch 0:   0%|          | 2/1327 [00:01<12:12,  1.81it/s, loss=0.165, v_num=jywh]0 last_hidden_state
1 pooler_output
Epoch 0:   0%|          | 3/1327 [00:01<11:52,  1.86it/s, loss=0.165, v_num=jywh]0 last_hidden_state
1 pooler_output
Epoch 0:   0%|          | 4/1327 [00:02<11:42,  1.88it/s, loss=0.164, v_num=jywh]0 last_hidden_state
1 pooler_output
Epoch 0:   0%|          | 5/1327 [00:02<11:36,  1.90it/s, loss=0.168, v_num=jywh]0 last_hidden_state
1 pooler_output
Epoch 0:   0%|          | 6/1327 [00:03<11:31,  1.91it/s, loss=0.168, v_num=jywh]0 last_hidden_state
1 pooler_output
Epoch 0:   1%|          | 7/1327 [00:03<11:28,  1.92it/s, loss=0.166, v_num=jywh]0 last_hidden_state
1 pooler_output
Epoch 0:   1%|          | 8/1327 [00:04<11:25,  1.92it/s, loss=0.165, v_num=jywh]0 last_hidden_

## Inference

After training, we can save the model as follows:

In [None]:
print(tokenizer.decode(batch['input_ids'][2])[:100])
print(batch["labels"][2])
l = batch["labels"][2][1:100]
i = tokenizer.decode(batch['input_ids'][2][1:])[:100]
for idx, (c_l, c_i) in enumerate(zip(l,i)):
    if c_l != -1 and c_l != 0:
        print(c_l, c_i, idx)

In [None]:
model.model.save_pretrained('.')

In that way, we can load it back as follows:

In [None]:
from transformers import CanineForSequenceClassification

model = CanineForSequenceClassification.from_pretrained('.')

Let's test it on a new review:

In [None]:
text = "I absolutely love this movie"

# prepare text for the model
encoding = tokenizer(text, return_tensors="pt")

# forward pass
outputs = model(**encoding)

# convert logits to actual predicted class
logits = outputs.logits
pred_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[pred_class_idx])