In [1]:
!pip3 install datasets
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U git+https://github.com/lvwerra/trl.git



In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)


README.md:   0%|          | 0.00/677 [00:00<?, ?B/s]

(…)-00000-of-00001-e270777bb989ac86.parquet:   0%|          | 0.00/3.45M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/20022 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/663M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]



Map:   0%|          | 0/20022 [00:00<?, ? examples/s]

In [3]:
dataset

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 20022
})

In [4]:
masked1 = trainer.data_collator.torch_call([trainer.train_dataset[0]])
masked1

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
     

In [5]:
masked1['input_ids'][0]

tensor([    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
         2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
         5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
        31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
         1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
            8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
         1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
        12606,   176,  2055,   155,  3226,  1178])

In [6]:
masked1['attention_mask'][0]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1])

In [7]:
masked1['labels'][0]

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
         1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
            8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
         1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
        12606,   176,  2055,   155,  3226,  1178])

In [8]:
tokenizer.decode([3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
         1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
            8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
         1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
        12606,   176,  2055,   155,  3226,  1178])

' def f(x):\n    """\n    Takes a specific input and produces a specific output using any mathematical operators\n    """\n    return x**2 + 3*x'

In [9]:
dataset[0]['output']

'def f(x):\n    """\n    Takes a specific input and produces a specific output using any mathematical operators\n    """\n    return x**2 + 3*x'

In [10]:
model.device

device(type='cuda', index=0)

In [11]:
masked1 = masked1.to(model.device)
masked1

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]], device='cuda:0'), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,

In [12]:
trainer.compute_loss(model, masked1)

tensor(2.2431, device='cuda:0', grad_fn=<NllLossBackward0>)

In [13]:
import torch
import numpy as np
import warnings

In [14]:
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

In [15]:
from transformers import DataCollatorForLanguageModeling

In [46]:
class DataCollatorForPLW(DataCollatorForLanguageModeling):
    """
    Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index'
    when they do not come from the assistant. This ensure that the loss is only
    calculated on the completion made by the assistant.

    Args:
        response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like
            '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
            differently if it does not have proper context.
        instruction_template (`Union[str, List[int]]`): the template form that indicates the start of the human instruction, typically something like
            '### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids.
        mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
            `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
             for flexibility and backwards-compatibility.
        ignore_index (`int`, *optional*, defaults to `-100`):
            The index to use to ignore the initial tokens with
    """

    def __init__(
        self,
        response_template: Union[str, List[int]],
        instruction_template: Optional[Union[str, List[int]]] = None,
        *args,
        mlm: bool = False,
        ignore_index: int = -100,
        padding_free: bool = False,
        plw: float = 1.0,
        **kwargs,
    ):
        super().__init__(*args, mlm=mlm, **kwargs)
        self.plw = plw
        self.instruction_template = instruction_template
        if isinstance(instruction_template, str):
            # The user provides a string, must tokenize
            self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
        else:
            # The user already provides the token ids
            self.instruction_token_ids = instruction_template

        self.response_template = response_template
        if isinstance(response_template, str):
            # The user provides a string, must tokenize
            self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
        else:
            # The user already provides the token ids
            self.response_token_ids = response_template

        if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
            warnings.warn(
                "The pad_token_id and eos_token_id values of this tokenizer are identical. "
                "If you are planning for multi-turn training, "
                "it can result in the model continuously generating questions and answers without eos token. "
                "To avoid this, set the pad_token_id to a different value."
            )

        self.ignore_index = ignore_index
        self.padding_free = padding_free
        
    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)
        
        if self.instruction_template is None:
            for i in range(len(examples)):
                response_token_ids_start_idx = None
                
                for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                    # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
                     if (
                        self.response_token_ids
                        == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
                     ):
                        response_token_ids_start_idx = idx
                
                if self.plw != 1.0:
                    if response_token_ids_start_idx is None:
                        warnings.warn(
                            f"Could not find response key `{self.response_template}` in the "
                            f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                            f"This instance will be ignored in loss calculation. "
                            f"Note, if this happens often, consider increasing the `max_seq_length`."
                        )
                        batch["labels"][i, :] = self.ignore_index
                    else:
                        response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)

                        # Make pytorch loss function ignore all tokens up through the end of the response key
                        batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
                else:
                    warnings.warn(f"prompt_loss_weight PLW is set to 1.0, all tokens will be calculated in the loss function.")
        else:
            for i in range(len(examples)):
                response_token_ids_idxs = []
                human_token_ids_idxs = []
                
                for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                    # find the indexes of the start of a response.
                    if (
                        self.response_token_ids
                        == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
                    ):
                        response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
                
                if len(response_token_ids_idxs) == 0 and self.plw != 1.0:
                    warnings.warn(
                        f"Could not find response key `{self.response_template}` in the "
                        f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                        f"This instance will be ignored in loss calculation. "
                        f"Note, if this happens often, consider increasing the `max_seq_length`."
                    )
                    batch["labels"][i, :] = self.ignore_index
                    
                
                human_token_ids = self.instruction_token_ids
                for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
                    # find the indexes of the start of a human answer.
                    if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
                        human_token_ids_idxs.append(human_idx)
                        
                if len(human_token_ids_idxs) == 0 and self.plw != 1.0:
                    warnings.warn(
                        f"Could not find instruction key `{self.instruction_template}` in the "
                        f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                        f"This instance will be ignored in loss calculation. "
                        f"Note, if this happens often, consider increasing the `max_seq_length`."
                    )
                    batch["labels"][i, :] = self.ignore_index
                    
                if (
                    len(human_token_ids_idxs) > 0
                    and len(response_token_ids_idxs) > 0
                    and human_token_ids_idxs[0] > response_token_ids_idxs[0]
                ):
                    human_token_ids_idxs = [0] + human_token_ids_idxs

                if self.plw != 1.0:
                    for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
                        # Make pytorch loss function ignore all non response tokens
                        if idx != 0:
                            batch["labels"][i, start:end] = self.ignore_index
                        else:
                            batch["labels"][i, :end] = self.ignore_index

                    if len(response_token_ids_idxs) < len(human_token_ids_idxs):
                        batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index
                    
                
            if self.padding_free:
                # remove padding, `attention_mask` and add `position_ids`
                attn_mask = batch.pop("attention_mask")
                batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
                batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
                batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
                batch["labels"][batch["position_ids"] == 0] = self.ignore_index

        return batch

In [47]:
class SFTTrainerWithPLW(SFTTrainer):
    def __init__(self, *args, plw=1.0, response_tokens, **kwargs):
        super().__init__(*args, **kwargs)
        self.plw = plw
        self.response_tokens = response_tokens

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        logits = outputs.get("logits")
        labels = inputs.pop("labels")
        
        completion_start_idx = None
        # Search for the response_tokens in labels
        labels_length = labels.shape[1]
        response_tokens_length = len(response_tokens)
        
        for i in range(labels_length - response_tokens_length + 1):
            if labels[0][i:i+response_tokens_length].tolist() == response_tokens:
                completion_start_idx = i + response_tokens_length - 1

        # Apply masking if completion_start_idx is found
        if completion_start_idx is not None and self.plw != 1.0:
            # Mask labels up to completion_start_idx (set to -100 which is ignored in loss calculation)
            labels[0][:completion_start_idx + 1] = -100
                
        mask = labels != -100
        weights = torch.where(mask, 1.0, self.plw)

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_weights = weights[..., 1:].contiguous()

        shift_labels = shift_labels.to(shift_logits.device)
        shift_weights = shift_weights.to(shift_logits.device)

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # comute loss without weighting
        loss_fct = torch.nn.CrossEntropyLoss()
#         loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # weighted loss
        weighted_loss = (token_losses * shift_weights).sum() / shift_weights.sum()

        return (weighted_loss, outputs) if return_outputs else weighted_loss

In [48]:
response_template = " ### Answer:"
response_tokens = tokenizer.encode(response_template)[2:]
response_tokens

[31652, 35]

In [68]:
plw_collator = DataCollatorForPLW(response_template, tokenizer=tokenizer, plw=0.0)

In [69]:
trainer3 = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=plw_collator
)



In [70]:
masked = trainer3.data_collator.torch_call([trainer.train_dataset[0]])
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
     

In [71]:
masked = masked.to(model.device)

In [72]:
masked['labels']

tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0')

In [73]:
loss, outputs = trainer3.compute_loss(model, masked, return_outputs=True)

In [74]:
loss

tensor(2.2431, device='cuda:0', grad_fn=<NllLossBackward0>)

In [62]:
trainer2 = SFTTrainerWithPLW(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    response_tokens=response_tokens,
    plw=1.0
)



In [63]:
masked = trainer2.data_collator.torch_call([trainer.train_dataset[0]])
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
     

In [64]:
masked = masked.to(model.device)
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]], device='cuda:0'), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,

In [65]:
masked['labels']

tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0')

In [66]:
loss, outputs = trainer2.compute_loss(model, masked, return_outputs=True)

In [67]:
loss

tensor(3.1653, device='cuda:0', grad_fn=<DivBackward0>)

In [144]:
outputs['logits'].shape

torch.Size([1, 76, 50272])

In [145]:
loss, outputs = trainer.compute_loss(model, masked1, return_outputs=True)

In [146]:
loss

tensor(2.2431, device='cuda:0', grad_fn=<NllLossBackward0>)

In [113]:
outputs['logits'].shape

torch.Size([1, 76, 50272])

In [116]:
trainer3 = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
)








In [117]:
masked = trainer3.data_collator.torch_call([trainer.train_dataset[0]])
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
     

In [118]:
masked = masked.to(model.device)
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]], device='cuda:0'), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,

In [120]:
trainer3.compute_loss(model, masked)

tensor(3.1653, device='cuda:0', grad_fn=<NllLossBackward0>)