In [1]:
!pip3 install datasets
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U git+https://github.com/lvwerra/trl.git



In [21]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)




In [3]:
dataset

Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 20022
})

In [4]:
masked1 = trainer.data_collator.torch_call([trainer.train_dataset[0]])
masked1

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
     

In [5]:
masked1['input_ids'][0]

tensor([    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
         2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
         5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
        31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
         1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
            8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
         1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
        12606,   176,  2055,   155,  3226,  1178])

In [6]:
masked1['attention_mask'][0]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1])

In [7]:
masked1['labels'][0]

tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
         1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
            8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
         1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
        12606,   176,  2055,   155,  3226,  1178])

In [8]:
tokenizer.decode([3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
         1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
            8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
         1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
        12606,   176,  2055,   155,  3226,  1178])

' def f(x):\n    """\n    Takes a specific input and produces a specific output using any mathematical operators\n    """\n    return x**2 + 3*x'

In [9]:
dataset[0]['output']

'def f(x):\n    """\n    Takes a specific input and produces a specific output using any mathematical operators\n    """\n    return x**2 + 3*x'

In [10]:
model.device

device(type='cuda', index=0)

In [11]:
masked1 = masked1.to(model.device)
masked1

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]], device='cuda:0'), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,

In [12]:
trainer.compute_loss(model, masked1)

tensor(2.2431, device='cuda:0', grad_fn=<NllLossBackward0>)

In [32]:
import torch
import numpy as np
import warnings

In [33]:
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

In [52]:
from transformers import DataCollatorForLanguageModeling

In [274]:
class SFTTrainerWithPLW(SFTTrainer):
    def __init__(self, *args, plw=1.0, response_tokens, **kwargs):
        super().__init__(*args, **kwargs)
        self.plw = plw
        self.response_tokens = response_tokens

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        logits = outputs.get("logits")
        labels = inputs.pop("labels")
        
        completion_start_idx = None
        # Search for the response_tokens in labels
        labels_length = labels.shape[1]
        response_tokens_length = len(response_tokens)
        
        for i in range(labels_length - response_tokens_length + 1):
            if labels[0][i:i+response_tokens_length].tolist() == response_tokens:
                completion_start_idx = i + response_tokens_length - 1

        # Apply masking if completion_start_idx is found
        if completion_start_idx is not None and self.plw != 1.0:
            # Mask labels up to completion_start_idx (set to -100 which is ignored in loss calculation)
            labels[0][:completion_start_idx + 1] = -100
                
        mask = labels != -100
        weights = torch.where(mask, 1.0, self.plw)

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_weights = weights[..., 1:].contiguous()

        shift_labels = shift_labels.to(shift_logits.device)
        shift_weights = shift_weights.to(shift_logits.device)

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # comute loss without weighting
        loss_fct = torch.nn.CrossEntropyLoss()
#         loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        # weighted loss
        weighted_loss = (token_losses * shift_weights).sum() / shift_weights.sum()

        return (weighted_loss, outputs) if return_outputs else weighted_loss

In [275]:
response_template = " ### Answer:"
response_tokens = tokenizer.encode(response_template)[2:]
response_tokens

[31652, 35]

In [276]:
trainer2 = SFTTrainerWithPLW(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    response_tokens=response_tokens,
    plw=1.0
)



In [277]:
masked = trainer2.data_collator.torch_call([trainer.train_dataset[0]])
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
     

In [278]:
masked = trainer2.data_collator.torch_call([trainer.train_dataset[0]])
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
     

In [279]:
masked = masked.to(model.device)
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]], device='cuda:0'), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,

In [270]:
masked['labels']

tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0')

In [280]:
loss, outputs = trainer2.compute_loss(model, masked, return_outputs=True)

In [282]:
loss

tensor(3.1653, device='cuda:0', grad_fn=<DivBackward0>)

In [144]:
outputs['logits'].shape

torch.Size([1, 76, 50272])

In [145]:
loss, outputs = trainer.compute_loss(model, masked1, return_outputs=True)

In [146]:
loss

tensor(2.2431, device='cuda:0', grad_fn=<NllLossBackward0>)

In [113]:
outputs['logits'].shape

torch.Size([1, 76, 50272])

In [116]:
trainer3 = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
)








In [117]:
masked = trainer3.data_collator.torch_call([trainer.train_dataset[0]])
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
     

In [118]:
masked = masked.to(model.device)
masked

{'input_ids': tensor([[    2, 48134, 15680,    35, 21384,    10,  5043,    14,  1239,    10,
          2167,  8135,     8,  9108,    10,  2167,  4195,   634,   143, 30412,
          5990,     4, 21062, 12337,  3260,    11, 31886,     4, 50118, 22560,
         31652,    35,  3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
          1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
             8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
          1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
         12606,   176,  2055,   155,  3226,  1178]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]], device='cuda:0'), 'labels': tensor([[    2, 48134, 15680,    35, 21384,    10,

In [120]:
trainer3.compute_loss(model, masked)

tensor(3.1653, device='cuda:0', grad_fn=<NllLossBackward0>)