# Fine-tune CANINE for binary text classification

In this notebook, we are going to fine-tune Google's character-level [CANINE](https://arxiv.org/abs/2103.06874) model to classify movie reviews as either positive/negative. We will do so using HuggingFace Transformers (I contributed CANINE in PyTorch to it!). The dataset we are going to use is the [IMDB dataset](https://huggingface.co/datasets/imdb), which is a large collection of movie reviews labeled as positive/negative.

For training, we will use [PyTorch Lightning](https://www.pytorchlightning.ai/) (note that you could also use alternative solutions such as native PyTorch, the [HuggingFace Trainer](https://huggingface.co/transformers/main_classes/trainer.html), [HuggingFace Accelerate](https://github.com/huggingface/accelerate), etc.). For logging the metrics (such as loss and accuracy) during training, we will use Weights and Biases.

Note that this notebook is very similar to how we would fine-tune a BERT model for binary text classification. The only difference is that BERT uses word pieces (subword tokenization), whereas CANINE works at a character-level. 

To give an example, if you would provide the sentence "hello world" to BERT, it would first be tokenized into the word pieces ["hello", "world"]. Then, BERT will convert each word piece into some vector (also referred  to as hidden state). For BERT-base, this is a vector of size 768. CANINE on the other hand would "tokenize" the sentence into ["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"], i.e. split it up into the individual characters. Then, CANINE will convert each character into some vector (for CANINE, this is also a vector of size 768). Classification of sequences is the same for BERT/CANINE: one simply places a linear layer on top of the final hidden state of the special [CLS] token.

* CANINE paper: https://arxiv.org/abs/2103.06874
* CANINE documentation: https://huggingface.co/transformers/model_doc/canine.html

## Install dependencies

In [None]:
import numpy as np
import torch
import re

## Prepare data

Here we load a small portion of the IMDb dataset which is hosted on the HuggingFace hub, for demonstration purposes.

In [None]:
from datasets import load_dataset
os.environ["CUDA_VISIBLE_DEVICES"]="1"

# train_ds, test_ds = load_dataset("dumitrescustefan/diacritic", split=['train[:100]', 'validation[:50]'])
dataset = load_dataset("dumitrescustefan/diacritic")
dataset["train"] = dataset["train"].select(list(range(100)))
dataset["validation"] = dataset["validation"].select(list(range(100)))
train_ds = dataset
train_ds

In [None]:
# train_ds = train_ds.map(lambda el:{"text" : str.lower(el["text"])}, num_proc=5)

In [None]:
# dataset.map(lambda ex : {"bla" : [0]})
# dataset.map(lambda ex : ex, batched=True)

In [None]:
# per_label_counts = [3297059,  313306,  102377,   49724,  118957,  135583]
per_label_counts = [1.2545946e+07, 3.1330600e+05, 1.0237700e+05, 4.9724000e+04,
       1.1895700e+05, 1.3558300e+05, 8.4530000e+03, 1.1100000e+02,
       6.5700000e+02, 3.0930000e+03, 6.7000000e+02]
n_samples = sum(per_label_counts)
per_label_weights = [n_samples / (c * n_samples) for c in per_label_counts]
print("Label weights", per_label_weights)
max_length = 256

Let's look at one particular example:

In [None]:
labels = [
    "no_diac", 
    "ă", 
    'î',
    "â",
    "ș",
    "ț",
    'Î',
    'Â',
    'Ă',
    'Ș',
    'Ț',
]

chars_with_diacritics = ['a','t','i','s', 'A', 'T', 'I', 'S']

print(labels)

In [None]:
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(id2label)

In [None]:
train_ds = train_ds.rename_column("text", "labels")
# train_ds[0]

In [None]:
percentage_diacritics_removed = 1.0
def remove_diacritics(input_txt):
    diac_map = {'ț': 't', 'ș': 's', 'Ț': 'T', 'Ș': 'S', 'Ă': 'A', 'ă': 'a', 'Â': 'A', 'â': 'a', 'Î': 'I', 'î': 'i'}
    diacritic_positions = [m.start() for m in re.finditer('ț|ș|Ț|Ș|Ă|ă|Â|â|Î|î', input_txt)]
    to_remove_diacritic_positions = np.random.choice(diacritic_positions, int(len(diacritic_positions) * percentage_diacritics_removed), replace=False)
    for i in range(len(to_remove_diacritic_positions)):
        input_txt = input_txt[:to_remove_diacritic_positions[i]]+ diac_map[input_txt[to_remove_diacritic_positions[i]]] + input_txt[to_remove_diacritic_positions[i]+1:]

    return input_txt


def add_partial_no_diac_input(examples):
    # percentage_diacritics_removed
    
    examples['input'] = [remove_diacritics(input_txt=l) for l in examples["labels"]]
    return examples

In [None]:
train_ds = train_ds.map(add_partial_no_diac_input, batched=True)
# train_ds[0]

In [None]:
def make_actual_labels(examples):
    def make_a_l(lbl):
        result = []
        for s in lbl:
            if s in chars_with_diacritics:
                result.append(label2id["no_diac"])
            elif s in labels:
                result.append(label2id[s])
            else:
                result.append(label2id["no_diac"])
        return result
    examples['labels'] = [make_a_l(l) for l in examples["labels"]]
    return examples

train_ds = train_ds.map(make_actual_labels, batched=True)


In [None]:
def make_input_mask(examples):
    def make_im(input):
        return list([0 if e not in chars_with_diacritics or e in labels else 1 for e in input])
    examples["input_mask"] = [make_im(l) for l in examples["input"]]
    return examples

# print(train_ds.map(make_input_mask, batched=True)[0]["input"])
# print(train_ds.map(make_input_mask, batched=True)[0]["input_mask"], sep=" ")

train_ds = train_ds.map(make_input_mask, batched=True)


In [None]:
from transformers import CanineTokenizer

tokenizer = CanineTokenizer.from_pretrained("google/canine-s")

train_ds = train_ds.map(lambda examples: tokenizer(examples['input'], padding="max_length", truncation=True, max_length=max_length),
                        batched=True)
len(train_ds["train"][0]["input_ids"])
        


In [None]:
s = train_ds["train"][0]
for key, value in s.items():
    # print(key)
    if isinstance(value, list):
        print(key, len(value))

In [None]:
train_ds = train_ds.map(lambda example: 
    {"attention_mask" : 
        [0] + example["input_mask"] + [0] + 
        [0] * (len(example["input_ids"]) - len(example["input_mask"]) - 2)
        }, num_proc=8)
# for key, value in train_ds[0].items():
#     if key != "actual_labels" and key !="id":
#         print(key, len(value))

In [None]:
# import matplotlib.pyplot as plt

# def compute_label_counts(examples):
#     def get_info(lbl):
#         result = np.array([lbl.count(i) for i in range(11)])
#         return {
#             "per_label_counts" : result,
#             "length" : len(lbl),
#         }
#     examples['sample_info'] = [get_info(l) for l in examples["labels"]]
#     # print(examples["sample_info"])
#     return examples

# train_ds = train_ds.map(compute_label_counts, batched=True)

# # per_label_counts = np.zeros(6)
# # lengths = []
# # for e in train_ds["sample_info"]:
# #     per_label_counts += e["per_label_counts"]
# #     lengths.append(e["length"])

In [None]:
# per_label_counts = np.zeros(11)
# lengths = []
# for e in train_ds["train"]["sample_info"]:
#     per_label_counts += e["per_label_counts"]
#     lengths.append(e["length"])
# per_label_counts

In [None]:
# print(train_ds["train"][0]["labels"], sep=" ")

In [None]:
print(train_ds["train"][0].keys())
print(train_ds["train"][0]["input"])
s = train_ds["train"][0]
for key, value in s.items():
    # print(key)
    if isinstance(value, list):
        print(key, len(value))





In [None]:
from transformers import MT5ForConditionalGeneration, T5Tokenizer
t5_tokenizer = T5Tokenizer.from_pretrained('iliemihai/mt5-base-romanian-diacritics')

# input_text = "A inceput sa  ii .  "
# inputs = tokenizer(input_text, max_length=256, truncation=True, return_tensors="pt", return_overflowing_tokens=True, return_length=True)

# print(tokenizer.tokenize(input_text))
# print(tokenizer(input_text)['input_ids'])

def char2token(input_text):
    # initial_input_text = copy.deepcopy(input_text)
    bert_tokenized_input = t5_tokenizer.tokenize(input_text)
    # print(bert_tokenized_input)
    # print('_', bert_tokenized_input[0][0])
    # print(bert_tokenized_input[0].replace('▁',''))
    bert_tokenized_input = [b.replace("▁", '') for b in bert_tokenized_input]
    # print(bert_tokenized_input)

    token_idx = [[i] * len(tok) for i,tok in enumerate(bert_tokenized_input)]
    token_idx = [item for sublist in token_idx for item in sublist]
    bert_tokenized_input = [item for sublist in bert_tokenized_input for item in sublist]
    # print(token_idx, bert_tokenized_input)
    result = [-1] * len(input_text)
    curr_len = 0
    for token_char_idx, token_char  in zip(token_idx, bert_tokenized_input):
        char_idx = input_text.find(token_char)
        input_text = input_text[char_idx+1:]
        result[curr_len + char_idx] = token_char_idx
        curr_len += (char_idx + 1)

    for i in range(len(result)):
        if result[i] == -1:
            j = min(i + 1, len(result) - 1)
            while j < len(result) - 1 and result[j] == -1:
                j += 1
            if result[j] == -1:
                j = i-1
            result[i] = result[j]
    return result

train_ds = train_ds.map(lambda example: {"t5_char_tokens" : char2token(example["input"])})



In [None]:
train_ds = train_ds.map(lambda example: 
    {"t5_char_tokens" : 
        [0] + example["t5_char_tokens"] + [0] + 
        [0] * (len(example["input_ids"]) - len(example["t5_char_tokens"]) - 2)
        }, num_proc=8)
train_ds

In [None]:
train_ds = train_ds.rename_columns({"input_ids" : "canine_input_ids", "token_type_ids" : "canine_token_type_ids", "attention_mask" : "canine_attention_mask" })
train_ds

In [None]:
# t5_tokenizer = CanineTokenizer.from_pretrained("google/canine-s")

train_ds = train_ds.map(lambda examples: t5_tokenizer(examples['input'], padding="max_length", truncation=True, max_length=max_length),
                        batched=True)
train_ds

In [None]:
train_ds = train_ds.rename_columns({"input_ids" : "t5_input_ids", "attention_mask" : "t5_attention_mask"})
train_ds

In [None]:
train_ds = train_ds.map(lambda example:
        { 
            "labels" : [0] + example["labels"] + [0] + 
                        [0] * (max_length - len(example["labels"]) - 2)
    })

# s = train_ds[0]
# print(s.keys())
# for key, v in s.items():
#     if isinstance(v, list):
#         print(key, len(v))

In [None]:
# print(per_label_counts)
# print(plt.hist(lengths, bins=20))
# plt.show()

In [None]:
train_ds = train_ds.map(lambda example:{"labels": example["labels"][:max_length]})
train_ds = train_ds.map(lambda example:{"t5_char_tokens": example["t5_char_tokens"][:max_length]})
# train_ds = train_ds.map(lambda example:{"input": example["input"][:max_length]})
train_ds = train_ds.map(lambda example:{"canine_attention_mask": example["canine_attention_mask"][:max_length]})

In [None]:
test_ds = train_ds["validation"]
train_ds = train_ds["train"]

In [None]:
s = train_ds[0]
for key, value in s.items():
    print(key)
    if isinstance(value, list):
        print(len(value))


In [None]:
# train_ds.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
# test_ds.set_format(type="torch", columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
train_ds.set_format(type="torch", columns=['id', 'canine_input_ids', 'canine_token_type_ids', 'canine_attention_mask', "t5_char_tokens",'t5_input_ids','t5_attention_mask', 'labels'])
test_ds.set_format(type="torch", columns=['id', 'canine_input_ids', 'canine_token_type_ids', 'canine_attention_mask', "t5_char_tokens",'t5_input_ids','t5_attention_mask', 'labels'])

In [None]:
print(train_ds[0])
# for i in range(100):
#     # print(train_ds[i])
#     for key in train_ds[i]:
#         if train_ds[i][key].shape[0] != max_length:
#             print(key, train_ds[i][key].shape)  

In [None]:
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 64
LR = 1e-4
EPOCHS = 100

In [None]:
from torch.utils.data import DataLoader


train_dataloader = DataLoader(train_ds, batch_size=TRAIN_BATCH_SIZE, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_ds, batch_size=TEST_BATCH_SIZE, drop_last=True)

In [None]:
batch = next(iter(train_dataloader))
# batch

In [None]:
# print(tokenizer.decode(batch['input_ids'][2])[:100])
# print(batch["labels"][2])
# l = batch["labels"][2][1:100]
# i = tokenizer.decode(batch['input_ids'][2][1:])[:100]
# for idx, (c_l, c_i) in enumerate(zip(l,i)):
#     if c_l != -1 and c_l != 0:
#         print(c_l, c_i, idx)


In [None]:
from transformers import CanineForTokenClassification, CanineForSequenceClassification, AdamW, CaninePreTrainedModel, CanineModel
from transformers.modeling_outputs import TokenClassifierOutput
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import copy
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

PRETRAINED_MODELS_CACHE = {}

# 'id', 'canine_input_ids', 'canine_token_type_ids', 'canine_attention_mask', "t5_char_tokens",'t5_input_ids','t5_attention_mask', 'labels'
class CanineForTokenClassificationCustom(CaninePreTrainedModel):
    def __init__(self, config, cached_path=None):
        super().__init__(config)
        self.cached_path=cached_path
        self.num_labels = config.num_labels

        self.canine = CanineModel(config)
        self.t5 = MT5ForConditionalGeneration.from_pretrained('iliemihai/mt5-base-romanian-diacritics').encoder

        for param in self.canine.parameters():
            param.requires_grad = False
        
        for param in self.t5.parameters():
            param.requires_grad = False

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size + 768, nhead=4)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.classifier_final = nn.Linear(config.hidden_size + 768, config.num_labels)

        # Initialize weights and apply final processing
        self.post_init()


    def forward(
        self,
        canine_input_ids: Optional[torch.LongTensor] = None,
        canine_attention_mask: Optional[torch.FloatTensor] = None,
        canine_token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        id=None,
        t5_input_ids=None,
        t5_attention_mask=None,
        t5_char_tokens=None,
        
    ) -> Union[Tuple, TokenClassifierOutput]:

        
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if id is not None and id[0] in PRETRAINED_MODELS_CACHE:
            outputs = torch.stack([PRETRAINED_MODELS_CACHE[i]["canine"] for i in id]).cuda()
            t5_out = torch.stack([PRETRAINED_MODELS_CACHE[i]["t5"] for i in id]).cuda()
        else:
            outputs = self.canine(
                input_ids=canine_input_ids,
                attention_mask=canine_attention_mask,
                token_type_ids=canine_token_type_ids,
                position_ids=position_ids,
                head_mask=head_mask,
                inputs_embeds=inputs_embeds,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            

            t5_out = self.t5(input_ids=t5_input_ids, attention_mask=t5_attention_mask)["last_hidden_state"]
            for i, c, t in zip(id, outputs, t5_out):
                PRETRAINED_MODELS_CACHE[i] = {"canine" : c.cpu(),"t5": t.cpu()}
        
        canine_out = outputs[0]
        BS, seq_length_t5, T5_embed = t5_out.shape
        canine_embed = canine_out.shape[-1]
        # print("SHAPPEE" , sequence_output.shape)
        # print("CHAR TOKENS", t5_char_tokens.shape)

        canine_out = canine_out.view(-1, canine_embed)
        t5_out = t5_out.view(-1, T5_embed)
        t5_char_tokens = t5_char_tokens + (torch.arange(BS, device="cuda") * seq_length_t5).unsqueeze(-1)
        char_t5_tokens = t5_out[t5_char_tokens.flatten()]

        sequence_output = torch.concat((canine_out, char_t5_tokens), dim=-1)
        sequence_output = sequence_output.view(BS, -1, canine_embed + T5_embed)
        sequence_output = self.transformer_encoder(sequence_output)


        # sequence_output = self.dropout(sequence_output)
        logits = self.classifier_final(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(reduction='none').cuda()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            loss = loss * canine_attention_mask.flatten()
            loss = loss.sum() / (canine_attention_mask.sum() + 1e-15)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

# model = CanineForTokenClassificationCustom.from_pretrained('google/canine-s', 
#                                                                      num_labels=len(labels),
#                                                                      id2label=id2label,
#                                                                      label2id=label2id)

# model(**batch)



## Define model

In [None]:
import pytorch_lightning as pl
from transformers import CanineForSequenceClassification, AdamW
import torch.nn as nn

class CanineReviewClassifier(pl.LightningModule):
    def __init__(self):
        super(CanineReviewClassifier, self).__init__()
        self.model = CanineForTokenClassificationCustom.from_pretrained('google/canine-s', 
                                                                     num_labels=len(labels),
                                                                     id2label=id2label,
                                                                     label2id=label2id)
    # def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
    def forward(self, id, canine_input_ids, canine_token_type_ids, canine_attention_mask, t5_char_tokens,t5_input_ids, t5_attention_mask, labels=None):
        # 'id', 'canine_input_ids', 'canine_token_type_ids', 'canine_attention_mask', "t5_char_tokens",'t5_input_ids','t5_attention_mask', 'labels'
        outputs = self.model(
            id=id, 
            canine_input_ids=canine_input_ids, 
            canine_token_type_ids=canine_token_type_ids, 
            canine_attention_mask=canine_attention_mask, 
            t5_char_tokens=t5_char_tokens,
            t5_input_ids=t5_input_ids,
            t5_attention_mask=t5_attention_mask,
            labels=labels
        )

        return outputs
        
    def common_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs.loss
        logits = outputs.logits

        predictions = logits.argmax(-1)
        correct = (predictions == batch['labels']).sum().item()
        accuracy = correct/batch['canine_input_ids'].shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        # We could make the optimizer more fancy by adding a scheduler and specifying which parameters do
        # not require weight_decay but just using AdamW out-of-the-box works fine
        return AdamW(self.parameters(), lr=LR)

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return test_dataloader

## Train the model

In [None]:
import wandb

wandb.login()

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

checkpoint_callback = ModelCheckpoint(dirpath="checkpoints_t5", save_top_k=-1, monitor="validation_loss")
# model = CanineReviewClassifier()
model = CanineReviewClassifier.load_from_checkpoint('cannie_v4_checkpoints/epoch=3-step=3748.ckpt', strict=False)
wandb_logger = WandbLogger(name='total_diacricitc_t5_cannie', project='CANINE')
trainer = Trainer(accelerator='gpu',devices=1, logger=wandb_logger, callbacks=[checkpoint_callback], max_epochs=EPOCHS, )
trainer.fit(model)

## Inference

After training, we can save the model as follows:

In [None]:
model.model.save_pretrained('.')

In that way, we can load it back as follows:

In [None]:
from transformers import CanineForSequenceClassification

model = CanineForSequenceClassification.from_pretrained('.')

Let's test it on a new review:

In [None]:
text = "I absolutely love this movie"

# prepare text for the model
encoding = tokenizer(text, return_tensors="pt")

# forward pass
outputs = model(**encoding)

# convert logits to actual predicted class
logits = outputs.logits
pred_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[pred_class_idx])