In [None]:
#| default_exp document_trainer

In [None]:
#| export
from math import ceil
from typing import List, Union, Tuple, Iterable, Any, Dict, Callable
import torch
from torch.nn.functional import cross_entropy
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from peft import PeftModelForCausalLM
from transformers.models.llama import LlamaForCausalLM, LlamaTokenizer, LlamaTokenizerFast
from llama_memorizing_transformers.memory_collection import BaseMemoryCollection
from llama_memorizing_transformers.context_choice import BaseContextChoice
import gc

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export
class MemorizingLlamaDocumentTrainer:
    def __init__(self, model: Union[LlamaForCausalLM, PeftModelForCausalLM],
                 context_choice: BaseContextChoice,
                 tokenizer: Union[LlamaTokenizerFast, LlamaTokenizer],
                 memory: BaseMemoryCollection,
                 tokens_per_chunk: int,
                 tokens_step: int,
                 optimizer: Optimizer,
                 scheduler: Union[LambdaLR, None],
                 accumulate_gradients: int,
                 float16: bool,
                 train_callback: Union[callable, None],
                 eval_callback: Union[callable, None]):
        if isinstance(model, LlamaForCausalLM):
            assert hasattr(model.model, "_memorizing_patch")
        elif isinstance(model, PeftModelForCausalLM):
            assert hasattr(model.base_model.model.model, "_memorizing_patch")
        else:
            raise TypeError("Unknown model type")
        self.llama = model
        self.context_choice = context_choice
        self.tokenizer = tokenizer
        self.memory = memory
        self.tokens_per_chunk = tokens_per_chunk
        self.tokens_step = tokens_step
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.accumulate_gradients = accumulate_gradients
        self.float16 = float16
        self.train_callback = train_callback
        self.eval_callback = eval_callback

    def _rearrange_tokens(self, document_tokens: torch.LongTensor, prompt_tokens: torch.LongTensor) -> Iterable[torch.LongTensor]:
        for i in range(prompt_tokens.shape[0]):
            prompt_tokens_processed = torch.cat((document_tokens, prompt_tokens[i]), dim=0).view((1, -1))
            yield prompt_tokens_processed

    def _split_token_sequences(self, token_sequence: torch.LongTensor) -> List[torch.LongTensor]:
        _, seq_len = token_sequence.shape
        start_index = 0
        sub_sequences = []
        while start_index < seq_len:
            sub_sequence = token_sequence[:, start_index : start_index + self.tokens_per_chunk]
            if sub_sequence.shape[1] > 0:
                sub_sequences.append(sub_sequence)
            start_index += self.tokens_step
        return sub_sequences
    
    def _get_train_block_tokens(self, document_tokens: torch.LongTensor, prompt_tokens: torch.LongTensor) -> Iterable[Tuple[torch.LongTensor, torch.LongTensor]]:
        assert len(document_tokens.shape) == 1, "document tokens should be 1d array"
        assert len(prompt_tokens.shape) == 2, "prompt tokens should be 2d array"
        self.memory.remember_until_position = document_tokens.shape[0]
        for item_tokens in self._rearrange_tokens(document_tokens, prompt_tokens):
            item_prompt_tokens = self._split_token_sequences(item_tokens[:, :-1])
            item_labels_tokens = self._split_token_sequences(item_tokens[:, 1:])
            self.memory.reset()
            for block_prompt_tokens, block_label_tokens in zip(item_prompt_tokens, item_labels_tokens):
                 yield block_prompt_tokens, block_label_tokens

    @property
    def _vocab_size(self) -> int:
        if self.tokenizer.pad_token_id is not None:
            return self.tokenizer.vocab_size + 1
        return self.tokenizer.vocab_size

    def _get_losses(self, document_tokens: torch.LongTensor, prompt_tokens: torch.LongTensor, sample_weight: float) -> \
        Iterable[Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]]:
        def _inner(block_prompt_tokens: torch.LongTensor, block_label_tokens: torch.LongTensor) -> torch.FloatTensor:
            block_prompt_tokens = block_prompt_tokens.to(self.llama.device)
            block_label_tokens = block_label_tokens.to(self.llama.device)
            
            block_attention_mask = (block_prompt_tokens != self.tokenizer.pad_token_id).float()
            block_label_mask = (block_label_tokens != self.tokenizer.pad_token_id ).float()
            block_label_tokens = ((block_label_mask * block_label_tokens) + (1.0 - block_label_mask) * (-100)).long()
            model_forward_pass = self.llama(
                input_ids=block_prompt_tokens.to(self.llama.device),
                attention_mask=block_attention_mask.to(self.llama.device),
                labels=block_label_tokens,
                return_dict=True
            )
            logits = model_forward_pass.logits

            logits_flatten = logits.view((-1, self._vocab_size))
            labels_flatten = block_label_tokens.view((-1,))
            lm_loss = cross_entropy(input=logits_flatten, target=labels_flatten) * sample_weight
            context_choice_loss = self.context_choice.get_loss_component() * sample_weight
            loss = lm_loss + context_choice_loss
            return loss, context_choice_loss, lm_loss

        loss = 0
        loss_context = 0
        loss_lm = 0
        for block_prompt_tokens, block_label_tokens in self._get_train_block_tokens(document_tokens, prompt_tokens):
            del loss, loss_context, loss_lm
            gc.collect()
            torch.cuda.empty_cache()
            if self.float16:
                with torch.cuda.amp.autocast():
                    loss, loss_context, loss_lm = _inner(block_prompt_tokens, block_label_tokens)
            else:
                loss, loss_context, loss_lm = _inner(block_prompt_tokens, block_label_tokens)
            yield loss, loss_context, loss_lm

    def train_document(self, document_tokens: torch.LongTensor, prompt_tokens: torch.LongTensor, sample_weight: float, callback_kwargs: Dict[str, Any]):
        self.llama.train()
        self.optimizer.zero_grad(set_to_none=True)
        if self.float16:
            scaler = torch.cuda.amp.grad_scaler.GradScaler()
        else:
            scaler = None
        for batch, losses in enumerate(self._get_losses(document_tokens, prompt_tokens, sample_weight)):
            loss, loss_context, loss_lm = losses
            if self.float16:
                scaler.scale(loss).backward()
            else:
                loss.backward()
            if batch % self.accumulate_gradients == 0:
                if self.float16:
                    scaler.step(self.optimizer)
                else:
                    self.optimizer.step()
                if self.scheduler:
                    self.scheduler.step()
                self.optimizer.zero_grad(set_to_none=True)
                if self.float16:
                    scaler.update()
            batch_callback_kwargs = dict(callback_kwargs, document_batch=batch,
                                         loss=loss.item(),
                                         loss_lm=loss_lm.item(),
                                         loss_context=loss_context.item())
            del loss, loss_context, loss_lm
            gc.collect()
            torch.cuda.empty_cache()
            if self.train_callback:
                self.train_callback(**batch_callback_kwargs)

    def eval_document(self, document_tokens: torch.LongTensor, prompt_tokens: torch.LongTensor, sample_weight: float, callback_kwargs: Dict[str, Any]):
        self.llama.eval()
        with torch.no_grad():
            for batch, losses in enumerate(self._get_losses(document_tokens, prompt_tokens, sample_weight)):
                loss, loss_context, loss_lm = losses
                loss = loss.item()
                loss_context = loss_context.item()
                loss_lm = loss_lm.item()
                batch_callback_kwargs = dict(callback_kwargs, document_batch=batch,
                                             loss=loss,
                                             loss_lm=loss_lm,
                                             loss_context=loss_context)
                if self.eval_callback:
                    self.eval_callback(**batch_callback_kwargs)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()