# Overview

In the [Introducing pro-precess of fine-tuning a llm step by steps](https://www.kaggle.com/code/aisuko/introducing-pro-process-of-ft-a-llm-step-by-steps), we tokenzie the text, and truncate all the tokens to 512 tokens. In this notebook, we will train the model with the data in native PyTorch.

The main idea is comes from [GitHub issue #27](https://github.com/google-research/bert/issues/27).

Recall that to apply the fine-tuned classifier model to a single long text in the previously notebook above. We first tokenize the entire sequence, then split it into chunks, get the model prediction for each chunk and calculate the mean/max of predictions. There is no problem in doing it sequenctially, that is:

* Put the 1st chunk into the model,and get 1st prediction
* Put the 2st chunk into the model, and get 2and prediction
* so on...
* Take the mean/max of these predictions and stop

However, training it sequentially on each chunk leads to a myriad of problems and questions:

* Put the 1 st chunk of the 1st text to the model, calculate the loss of the prediction and the label...
* What label?
* We have only one binary label for the entire text...Then maybe run backpropagation? But when?
* Should we update the model weights after each chunk?

Instead, we must do it all at once by putting all the chunks into on mini-batch. This solves all the problems:
* From K chunks obtained for the 1st text, create 1 mini-batch and obtain K predictions
* Pool the predictions using the mean/max function to obtain a single prediction for the entire text
* Calculate the loss between this single prediction for the entire text
* Run backpropagation. Be careful to make sure that all the tensor oprtations are done on tensors with attached gradients before running `loss.backward()`.

In [None]:
!pip install transformers==4.36.2
!pip install datasets==2.15.0

Import all the code in the previously notebook.

In [None]:
import torch
from torch import Tensor

def split_overlapping(tensor, chunk_size, stride, minimal_chunk_length=5):
    result=[tensor[i:i+chunk_size] for i in range(0, len(tensor), stride)]
    if len(result)>1:
        result=[x for x in result if len(x)>=minimal_chunk_length]
    return result

def add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks):
    """
    Adds special CLS token (token id =101) at the beginning.
    Adds SEP token (token id =102) at the end of each chunk.
    Adds corresponding attention masks equal to 1 (attention mask is boolean)
    """
    
    for i in range(len(input_id_chunks)):
        # adding CLS (token id 101) and SEP (token id 102) tokens
        input_id_chunks[i]=torch.cat([Tensor([101]), input_id_chunks[i], Tensor([102])])
        # adding attention masks corresponding to special tokens
        mask_chunks[i]=torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])])

def add_padding_tokens(input_id_chunks, mask_chunks):
    """
    Adds padding tokens (token id=0) at the end to make sure that all chunks have exactly 512 tokens
    """
    for i in range(len(input_id_chunks)):
        # get required padding length
        pad_len=512-input_id_chunks[i].shape[0]
        # check if tensor length satisfies required chunk size
        if pad_len>0:
            # if padding length is more than 0, we must add padding
            input_id_chunks[i]=torch.cat([input_id_chunks[i], Tensor([0]*pad_len)])
            mask_chunks[i]=torch.cat([mask_chunks[i], Tensor([0]*pad_len)])

def stack_tokens_from_all_chunks(input_id_chunks, mask_chunks):
    """
    Reshapes data to a form compatible with BERT model input.
    """
    input_ids=torch.stack(input_id_chunks)
    attention_mask=torch.stack(mask_chunks)
    
    return input_ids.long(), attention_mask.int()
def tokenize_whole_text(text, tokenizer):
    """Tokenizes the entire text without truncation and without special tokens."""
    tokens = tokenizer(text, add_special_tokens=False, truncation=False, return_tensors="pt")
    return tokens


def tokenize_text_with_truncation(text, tokenizer, maximal_text_length):
    """Tokenizes the text with truncation to maximal_text_length and without special tokens."""
    tokens = tokenizer(
        text, add_special_tokens=False, max_length=maximal_text_length, truncation=True, return_tensors="pt"
    )
    return tokens

def split_tokens_into_smaller_chunks(
    tokens,
    chunk_size,
    stride,
    minimal_chunk_length,
):
    """Splits tokens into overlapping chunks with given size and stride."""
    input_id_chunks = split_overlapping(tokens["input_ids"][0], chunk_size, stride, minimal_chunk_length)
    mask_chunks = split_overlapping(tokens["attention_mask"][0], chunk_size, stride, minimal_chunk_length)
    return input_id_chunks, mask_chunks

def preprocess_func(
    text,
    tokenizer,
    chunk_size,
    stride,
    minimal_chunk_length,
    maximal_text_length
):
    """Transforms (the entire) text to model input of BERT model."""
    if maximal_text_length:
        tokens=tokenize_text_with_truncation(text, tokenizer, maximal_text_length)
    else:
        tokens=tokenize_whole_text(text, tokenizer)
    
    input_id_chunks, mask_chunks=split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length)
    add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks)
    add_padding_tokens(input_id_chunks, mask_chunks)
    input_ids, attention_mask=stack_tokens_from_all_chunks(
        input_id_chunks,
        mask_chunks
    )
    return input_ids, attention_mask