<a href="https://www.kaggle.com/code/aisuko/customize-dataloader-for-preprocessing-data?scriptVersionId=160591832" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

In the [Introducing pro-precess of fine-tuning a llm step by steps](https://www.kaggle.com/code/aisuko/introducing-pro-process-of-ft-a-llm-step-by-steps), we tokenzie the text, and truncate all the tokens to 512 tokens. In this notebook, we will train the model with the data in native PyTorch.

The main idea is comes from [GitHub issue #27](https://github.com/google-research/bert/issues/27).

Recall that to apply the fine-tuned classifier model to a single long text in the previously notebook above. We first tokenize the entire sequence, then split it into chunks, get the model prediction for each chunk and calculate the mean/max of predictions. There is no problem in doing it sequenctially, that is:

* Put the 1st chunk into the model,and get 1st prediction
* Put the 2st chunk into the model, and get 2and prediction
* so on...
* Take the mean/max of these predictions and stop

However, training it sequentially on each chunk leads to a myriad of problems and questions:

* Put the 1 st chunk of the 1st text to the model, calculate the loss of the prediction and the label...
* What label?
* We have only one binary label for the entire text...Then maybe run backpropagation? But when?
* Should we update the model weights after each chunk?

Instead, we must do it all at once by putting all the chunks into on mini-batch. This solves all the problems:
* From K chunks obtained for the 1st text, create 1 mini-batch and obtain K predictions
* Pool the predictions using the mean/max function to obtain a single prediction for the entire text
* Calculate the loss between this single prediction for the entire text
* Run backpropagation. Be careful to make sure that all the tensor oprtations are done on tensors with attached gradients before running `loss.backward()`.

In [1]:
!pip install transformers==4.36.2
!pip install datasets==2.15.0

Collecting datasets==2.15.0
  Obtaining dependency information for datasets==2.15.0 from https://files.pythonhosted.org/packages/e2/cf/db41e572d7ed958e8679018f8190438ef700aeb501b62da9e1eed9e4d69a/datasets-2.15.0-py3-none-any.whl.metadata
  Downloading datasets-2.15.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.15.0)
  Obtaining dependency information for pyarrow-hotfix from https://files.pythonhosted.org/packages/e4/f4/9ec2222f5f5f8ea04f66f184caafd991a39c8782e31f5b0266f101cb68ca/pyarrow_hotfix-0.6-py3-none-any.whl.metadata
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting fsspec[http]<=2023.10.0,>=2023.1.0 (from datasets==2.15.0)
  Obtaining dependency information for fsspec[http]<=2023.10.0,>=2023.1.0 from https://files.pythonhosted.org/packages/e8/f6/3eccfb530aac90ad1301c582da228e4763f19e719ac8200752a4841b0b2d/fsspec-2023.10.0-py3-none-any.whl.metadata
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 

Import all the code in the previously notebook.

In [2]:
import torch
from torch import Tensor

def split_overlapping(tensor, chunk_size, stride, minimal_chunk_length=5):
    result=[tensor[i:i+chunk_size] for i in range(0, len(tensor), stride)]
    if len(result)>1:
        result=[x for x in result if len(x)>=minimal_chunk_length]
    return result

def add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks):
    """
    Adds special CLS token (token id =101) at the beginning.
    Adds SEP token (token id =102) at the end of each chunk.
    Adds corresponding attention masks equal to 1 (attention mask is boolean)
    """
    
    for i in range(len(input_id_chunks)):
        # adding CLS (token id 101) and SEP (token id 102) tokens
        input_id_chunks[i]=torch.cat([Tensor([101]), input_id_chunks[i], Tensor([102])])
        # adding attention masks corresponding to special tokens
        mask_chunks[i]=torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])])

def add_padding_tokens(input_id_chunks, mask_chunks):
    """
    Adds padding tokens (token id=0) at the end to make sure that all chunks have exactly 512 tokens
    """
    for i in range(len(input_id_chunks)):
        # get required padding length
        pad_len=512-input_id_chunks[i].shape[0]
        # check if tensor length satisfies required chunk size
        if pad_len>0:
            # if padding length is more than 0, we must add padding
            input_id_chunks[i]=torch.cat([input_id_chunks[i], Tensor([0]*pad_len)])
            mask_chunks[i]=torch.cat([mask_chunks[i], Tensor([0]*pad_len)])

def stack_tokens_from_all_chunks(input_id_chunks, mask_chunks):
    """
    Reshapes data to a form compatible with BERT model input.
    """
    input_ids=torch.stack(input_id_chunks)
    attention_mask=torch.stack(mask_chunks)
    
    return input_ids.long(), attention_mask.int()

def tokenize_whole_text(text, tokenizer):
    """Tokenizes the entire text without truncation and without special tokens."""
    tokens = tokenizer(text, add_special_tokens=False, truncation=False, return_tensors="pt")
    return tokens


def tokenize_text_with_truncation(text, tokenizer, maximal_text_length):
    """Tokenizes the text with truncation to maximal_text_length and without special tokens."""
    tokens = tokenizer(
        text, add_special_tokens=False, max_length=maximal_text_length, truncation=True, return_tensors="pt"
    )
    return tokens

def split_tokens_into_smaller_chunks(
    tokens,
    chunk_size,
    stride,
    minimal_chunk_length,
):
    """Splits tokens into overlapping chunks with given size and stride."""
    input_id_chunks = split_overlapping(tokens["input_ids"][0], chunk_size, stride, minimal_chunk_length)
    mask_chunks = split_overlapping(tokens["attention_mask"][0], chunk_size, stride, minimal_chunk_length)
    return input_id_chunks, mask_chunks

def preprocess_func(
    text,
    tokenizer,
    chunk_size,
    stride,
    minimal_chunk_length,
    maximal_text_length
):
    """Transforms (the entire) text to model input of BERT model."""
    if maximal_text_length:
        tokens=tokenize_text_with_truncation(text, tokenizer, maximal_text_length)
    else:
        tokens=tokenize_whole_text(text, tokenizer)
    
    input_id_chunks, mask_chunks=split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length)
    add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks)
    add_padding_tokens(input_id_chunks, mask_chunks)
    input_ids, attention_mask=stack_tokens_from_all_chunks(
        input_id_chunks,
        mask_chunks
    )
    return input_ids, attention_mask

# Usual fine-tuning

Now we will sketch how to modify the procedure from Huggingface fine-tuning. The basic steps of fine-tuning are the following:
* Tokenize the texts of the training set with truncation. Roughly speaking, the tokenized set is the dictionary with keys `input_ids` and `attention_mask` and values being tensors of the size precisely equal to 512.
* Create the `Dataloader` object with the selected `batch_size`. This will allow us to iteratr over batches of data. In other words, assume that `batch_size=N`.
* During the training loop, `for batch in train_dataloader` we will be getting the object `batch`. The batch here again is the dictionary with keys `input_ids` and `attention_mask`. But this time the values are stacked tensots of the size N*512.
* Put each loaded batch into the model with `outputs=model(**batch)`, calculate loss with `loss=outputs.loss` and run backpropagation `loss.backward()`.

In [3]:
from transformers import BatchEncoding

def transform_single_text(
    text,
    tokenizer,
    chunk_size,
    stride,
    minimal_chunk_length,
    maximal_text_length,
):
    """Transforms (the entire) text to model input of BERT model."""
    if maximal_text_length:
        tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length)
    else:
        tokens = tokenize_whole_text(text, tokenizer)
    input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length)
    add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks)
    add_padding_tokens(input_id_chunks, mask_chunks)
    input_ids, attention_mask = stack_tokens_from_all_chunks(input_id_chunks, mask_chunks)
    return input_ids, attention_mask


def transform_list_of_texts(
    texts,
    tokenizer,
    chunk_size,
    stride,
    minimal_chunk_length,
    maximal_text_length=None
):
    model_inputs=[
        transform_single_text(text, tokenizer, chunk_size, stride, minimal_chunk_length, maximal_text_length)
        for text in texts
    ]
    input_ids=[model_input[0] for model_input in model_inputs]
    attention_mask=[model_input[1] for model_input in model_inputs]
    tokens={"input_ids":input_ids, "attention_mask": attention_mask}
    return BatchEncoding(tokens)

As always, it is instructive to get our hands dirty with an example. Let's look at the result of this function for one short and one long review and compare it with the usual truncation approach:

In [4]:
from datasets import load_dataset

imdb=load_dataset('imdb')
short_review=imdb['test']['text'][0]
long_review=imdb['test']['text'][21132]

Downloading readme:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [5]:
from transformers import BertForSequenceClassification, BertTokenizer

tokenizer=BertTokenizer.from_pretrained('fabriceyhc/bert-base-uncased-imdb')
model = BertForSequenceClassification.from_pretrained(
    'fabriceyhc/bert-base-uncased-imdb', 
#     device_map='auto'
)
tokenizer

tokenizer_config.json:   0%|          | 0.00/321 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/826 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

BertTokenizer(name_or_path='fabriceyhc/bert-base-uncased-imdb', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [6]:
def tokenize_truncated(list_of_texts):
    return tokenizer(list_of_texts, truncation=True, padding=True, max_length=512, return_tensors='pt')

tokens_splitted=transform_list_of_texts([short_review, long_review], tokenizer, 510, 510, 1,None)
tokens_truncated=tokenize_truncated([short_review, long_review])

Token indices sequence length is longer than the specified maximum sequence length for this model (3155 > 512). Running this sequence through the model will result in indexing errors


In [7]:
print(type(tokens_truncated['input_ids']))
print(tokens_truncated['input_ids'].shape)

<class 'torch.Tensor'>
torch.Size([2, 512])


As we can see the result is the stacked tensor of the size 2x512. Next, we will look at the result of splitting:

In [8]:
print(type(tokens_splitted['input_ids']))

<class 'list'>


This is a list of stacked tensors of the size $K^{(i)}*512$ where $K^{(i)}$ is the number of chunks of the text i. Because texts can be of different length, we cannot convert this list of tensors into one stacked tensor.

In [9]:
[tensor.shape for tensor in tokens_splitted['input_ids']]

[torch.Size([1, 512]), torch.Size([7, 512])]

The key observation here is that our tokenization retusn lists of tensors of different sizes because the texts can be different lengths. Unfortunately, we cannot stack together tensors of different sizes. In the same way, we cannot concatenate two vectors of different sizes into a rectangular matrix.


# Creating the dataset and the dataloader

The next step is to put the tokenized texts into the torch Dataset object. And try it with two reviews.

In [10]:
from torch.utils.data import Dataset, RandomSampler, DataLoader

class TokenizedDataset(Dataset):
    """Dataset for tokens with optional labels"""
    def __init__(self, tokens, labels=None):
        self.input_ids=tokens['input_ids']
        self.attention_mask=tokens['attention_mask']
        self.labels=labels
    def __len__(self):
        return len(self.input_ids)
    def __getitem__(self, idx):
        if self.labels:
            return self.input_ids[idx], self.attention_mask[idx], self.labels[idx]
        return self.input_ids[idx], self.attention_mask[idx]

dataset_truncated=TokenizedDataset(tokens_truncated, [0,1])
dataset_splitted=TokenizedDataset(tokens_splitted, [0,1])
train_dataloader_truncated=DataLoader(dataset_truncated, sampler=RandomSampler(dataset_truncated), batch_size=2)
train_dataloader_splitted=DataLoader(dataset_splitted, sampler=RandomSampler(dataset_splitted), batch_size=2)

In [11]:
for batch in train_dataloader_truncated:
    break

In [12]:
try:
    for batch in train_dataloader_splitted:
        break
except RuntimeError as e:
    print(e)

stack expects each tensor to be equal size, but got [7, 512] at entry 0 and [1, 512] at entry 1


The custome function `collate_fn_pooled_tokens` just forces torch to treat each bacth as a list of (potentially different sized) tensors and forbid it from trying to stack them.

# Overriding the default dataloader

In [13]:
def collate_fn_pooled_tokens(data):
    input_ids = [data[i][0] for i in range(len(data))]
    attention_mask = [data[i][1] for i in range(len(data))]
    if len(data[0]) == 2:
        collated = [input_ids, attention_mask]
    else:
        labels = Tensor([data[i][2] for i in range(len(data))])
        collated = [input_ids, attention_mask, labels]
    return collated

train_dataloader_splitted = DataLoader(dataset_splitted, sampler=RandomSampler(dataset_splitted), batch_size=2, collate_fn=collate_fn_pooled_tokens)

try:
    for batch in train_dataloader_splitted:
        break
except RuntimeError as e:
    print(e)

The custom function `collate_fn_pooled_tokens` just forces torch to treat each batch as a list of (potentially different sized) tensors and forbid it from trying to stack them.

# Credit

* https://medium.com/mim-solutions-blog/fine-tuning-bert-model-for-arbitrarily-long-texts-part-2-3211d1774dc9
* https://github.com/mim-solutions/bert_for_longer_texts/blob/main/belt_nlp/bert_with_pooling.py