In [None]:
!pip install datasets
!pip install -q -U transformers==4.45.1
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U git+https://github.com/lvwerra/trl.git

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

In [None]:
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

In [None]:
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

### Trainer with DataCollatorForCompletionOnlyLM

Equivalent to PLW = 0.0

In [None]:
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

In [None]:
response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)

In [None]:
dataset

In [None]:
masked1 = trainer.data_collator.torch_call([trainer.train_dataset[0]])
masked1

In [None]:
masked1['input_ids'][0]

In [None]:
masked1['attention_mask'][0]

In [None]:
masked1['labels'][0]

In [None]:
tokenizer.decode([ 3816,   856,  1640,  1178,  3256, 50118,  1437,  1437,
         1437, 49434, 50118,  1437,  1437,  1437, 29072,    10,  2167,  8135,
            8,  9108,    10,  2167,  4195,   634,   143, 30412,  5990, 50118,
         1437,  1437,  1437, 49434, 50118,  1437,  1437,  1437,   671,  3023,
        12606,   176,  2055,   155,  3226,  1178,     2])

In [None]:
dataset[0]['output']

In [None]:
model.device

In [None]:
masked1 = masked1.to(model.device)
masked1

In [None]:
trainer.compute_loss(model, masked1)

### Trainer with PLW = 1 (default)

In [None]:
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

In [None]:
dataset

In [None]:
trainer2 = SFTTrainer(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
)

In [None]:
masked = trainer2.data_collator.torch_call([trainer.train_dataset[0]])
masked

In [None]:
masked = masked.to(model.device)

In [None]:
masked['labels']

In [None]:
loss, outputs = trainer2.compute_loss(model, masked, return_outputs=True)

In [None]:
loss

### Custom Trainer for PLW

Tests:
1. PLW = 1.0 => Should give the same output as default huggingface SFTTrainer
2. PLW = 0.0 => Should give the same output as SFTTrainer with DataCollatorForCompletionOnlyLM

In [None]:
import torch
import numpy as np
import warnings

In [None]:
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

In [None]:
from transformers import DataCollatorForLanguageModeling

In [None]:
class SFTTrainerWithPLW(SFTTrainer):
    def __init__(self, *args, plw=1.0, response_template_tokens, **kwargs):
        super().__init__(*args, **kwargs)
        self.plw = plw
        self.response_template_tokens = response_template_tokens

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        logits = outputs.get("logits")
        labels = inputs.pop("labels")

        # Initialize the weights matrix with ones, keeping the same shape as labels
        weights = torch.ones_like(labels, dtype=torch.float32)

        batch_size = labels.shape[0]
        labels_length = labels.shape[1]
        response_tokens_length = len(self.response_template_tokens)

        # Iterate through each example in the batch to find the completion_start_idx
        for batch_idx in range(batch_size):
            completion_start_idx = None

            # Search for response_tokens in labels for the current batch element
            for i in range(labels_length - response_tokens_length + 1):
                if labels[batch_idx, i:i + response_tokens_length].tolist() == self.response_template_tokens:
                    completion_start_idx = i + response_tokens_length - 1
                    break

            # If we found completion_start_idx, modify weights for the current batch element
            if completion_start_idx is not None:
                weights[batch_idx, :completion_start_idx + 1] = self.plw

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_weights = weights[..., 1:].contiguous()

        shift_labels = shift_labels.to(shift_logits.device)
        shift_weights = shift_weights.to(shift_logits.device)

        # per-token losses
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                                shift_labels.view(-1))

        # Compute weighted average of losses
        loss = (token_losses.float() @ shift_weights.view(-1).float()) / shift_weights.sum()
        return (loss, outputs) if return_outputs else loss

In [None]:
response_template = " ### Answer:"
response_template_tokens = tokenizer.encode(response_template)[2:]
response_template_tokens

In [None]:
trainer3 = SFTTrainerWithPLW(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    response_template_tokens=response_template_tokens,
    plw=0.0
)

In [None]:
masked = trainer3.data_collator.torch_call([trainer.train_dataset[0]])
masked

In [None]:
masked = masked.to(model.device)

In [None]:
masked['labels']

In [None]:
loss, outputs = trainer3.compute_loss(model, masked, return_outputs=True)

In [None]:
loss

### PLW With Custom Data Collator

In [None]:
import torch
import numpy as np

In [None]:
from transformers import DataCollatorForLanguageModeling
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

In [None]:
class DataCollatorForPLW(DataCollatorForLanguageModeling):

    def __init__(
        self,
        response_template: Union[str, List[int]],
        instruction_template: Optional[Union[str, List[int]]] = None,
        *args,
        mlm: bool = False,
        ignore_index: int = -100,
        padding_free: bool = False,
        **kwargs,
    ):
        super().__init__(*args, mlm=mlm, **kwargs)

        self.instruction_template = instruction_template
        if isinstance(instruction_template, str):
            # The user provides a string, must tokenize
            self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
        else:
            # The user already provides the token ids
            self.instruction_token_ids = instruction_template

        self.response_template = response_template
        if isinstance(response_template, str):
            # The user provides a string, must tokenize
            self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
        else:
            # The user already provides the token ids
            self.response_token_ids = response_template

        if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
            warnings.warn(
                "The pad_token_id and eos_token_id values of this tokenizer are identical. "
                "If you are planning for multi-turn training, "
                "it can result in the model continuously generating questions and answers without eos token. "
                "To avoid this, set the pad_token_id to a different value."
            )

        self.ignore_index = ignore_index
        self.padding_free = padding_free

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)
        
        # Initialize the "weights" column to be all ones
        batch_size = batch["labels"].shape[0]
        seq_len = batch["labels"].shape[1]
        batch["prompt_mask"] = torch.zeros((batch_size, seq_len), dtype=torch.float32)


        if self.instruction_template is None:
            for i in range(len(examples)):
                response_token_ids_start_idx = None

                for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                    # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match
                    if (
                        self.response_token_ids
                        == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist()
                    ):
                        response_token_ids_start_idx = idx

                if response_token_ids_start_idx is None:
                    warnings.warn(
                        f"Could not find response key `{self.response_template}` in the "
                        f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                        f"This instance will be ignored in loss calculation. "
                        f"Note, if this happens often, consider increasing the `max_seq_length`."
                    )
                    batch["prompt_mask"][i, :] = 1
                else:
                    response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)

                    # Make pytorch loss function ignore all tokens up through the end of the response key
                    batch["prompt_mask"][i, :response_token_ids_end_idx] = 1

        else:
            for i in range(len(examples)):
                response_token_ids_idxs = []
                human_token_ids_idxs = []

                for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                    # find the indexes of the start of a response.
                    if (
                        self.response_token_ids
                        == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist()
                    ):
                        response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))

                if len(response_token_ids_idxs) == 0:
                    warnings.warn(
                        f"Could not find response key `{self.response_template}` in the "
                        f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                        f"This instance will be ignored in loss calculation. "
                        f"Note, if this happens often, consider increasing the `max_seq_length`."
                    )
                    batch["prompt_mask"][i, :] = 1

                human_token_ids = self.instruction_token_ids
                for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
                    # find the indexes of the start of a human answer.
                    if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist():
                        human_token_ids_idxs.append(human_idx)

                if len(human_token_ids_idxs) == 0:
                    warnings.warn(
                        f"Could not find instruction key `{self.instruction_template}` in the "
                        f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
                        f"This instance will be ignored in loss calculation. "
                        f"Note, if this happens often, consider increasing the `max_seq_length`."
                    )
                    batch["prompt_mask"][i, :] = 1

                if (
                    len(human_token_ids_idxs) > 0
                    and len(response_token_ids_idxs) > 0
                    and human_token_ids_idxs[0] > response_token_ids_idxs[0]
                ):
                    human_token_ids_idxs = [0] + human_token_ids_idxs

                for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
                    # Make pytorch loss function ignore all non response tokens
                    if idx != 0:
                        batch["prompt_mask"][i, start:end] = 1
                    else:
                        batch["prompt_mask"][i, :end] = 1

                if len(response_token_ids_idxs) < len(human_token_ids_idxs):
                    batch["prompt_mask"][i, human_token_ids_idxs[-1] :] = 1

        if self.padding_free:
            # remove padding, `attention_mask` and add `position_ids`
            attn_mask = batch.pop("attention_mask")
            batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0)
            batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1
            batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
            batch["prompt_mask"] = batch["prompt_mask"][attn_mask.bool()].unsqueeze(0)
            batch["labels"][batch["position_ids"] == 0] = self.ignore_index
            batch["prompt_mask"][batch["position_ids"] == 0] = 1
            
        return batch

In [None]:
class SFTTrainerWithPLWCollator(SFTTrainer):
    def __init__(self, *args, plw=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.plw = plw

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        logits = outputs.get("logits")
        labels = inputs.pop("labels")
        
        prompt_mask = inputs["prompt_mask"]
        weights = torch.where(prompt_mask == 1, self.plw, torch.tensor(1.0, device=prompt_mask.device))
        
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_weights = weights[..., 1:].contiguous()

        shift_labels = shift_labels.to(shift_logits.device)
        shift_weights = shift_weights.to(shift_logits.device)

        # per-token losses
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                                shift_labels.view(-1))

        # Compute weighted average of losses
        loss = (token_losses.float() @ shift_weights.view(-1).float()) / shift_weights.sum()
        return (loss, outputs) if return_outputs else loss

In [None]:
response_template = " ### Answer:"

In [None]:
response_template_token_ids = tokenizer.encode(response_template)[1:]
response_template_token_ids

In [None]:
plw_collator = DataCollatorForPLW(response_template_token_ids, tokenizer=tokenizer)

In [None]:
# def preprocess_logits_for_metrics(logits, labels):
#     # get predictions
#     token_preds = logits.argmax(-1)[..., :-1]

#     # compute per-token losses
#     loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
#     shift_logits = logits[..., :-1, :].contiguous()
#     shift_labels = labels[..., 1:].contiguous()
#     token_losses = loss_fct(shift_logits.transpose(1, 2), shift_labels)

#     # pass predictions and losses to compute_metrics function (above)
#     predictions = (token_preds, token_losses)
#     return predictions

In [None]:
# def prepare_compute_metrics(response_template_token_ids):
    
#     def compute_metrics(data):
#         token_preds, token_losses = data.predictions
        
#         label_ids = data.label_ids
#         batch_size = label_ids.shape[0]
#         labels_length = label_ids.shape[1]
#         response_tokens_length = len(response_template_token_ids)
        
#         prompt_mask = torch.zeros(batch_size, labels_length)
#         completion_mask = torch.zeros(batch_size, labels_length)
        
#         for batch_idx in range(batch_size):
#             completion_start_idx = None

#             # Search for response_tokens in labels for the current batch element
#             for i in range(labels_length - response_tokens_length + 1):
#                 if label_ids[batch_idx, i:i + response_tokens_length].tolist() == response_template_token_ids:
#                     completion_start_idx = i + response_tokens_length
#                     break

#             if completion_start_idx is not None:
#                 completion_mask[batch_idx, completion_start_idx:] = 1
#                 prompt_mask[batch_idx, :completion_start_idx] = 1
        
#         # shift labels and masks
#         labels = label_ids[..., 1:]
#         shift_prompt_mask = prompt_mask[..., 1:]
#         shift_comp_mask = completion_mask[..., 1:]
        
#         # average both losses (prompt and completion) over their respective tokens
#         token_losses = torch.tensor(token_losses)
#         prompt_loss = token_losses.reshape(-1) @ shift_prompt_mask.reshape(-1) / shift_prompt_mask.sum()
#         completion_loss = token_losses.reshape(-1) @ shift_comp_mask.reshape(-1) / shift_comp_mask.sum()
        
#         return {
#             'comp_loss': completion_loss,
#             'prompt_loss': prompt_loss,
#         }
    
#     return compute_metrics

In [None]:
trainer4 = SFTTrainerWithPLWCollator(
    model,
    train_dataset=dataset,
    formatting_func=formatting_prompts_func,
    data_collator=plw_collator,
#     compute_metrics=prepare_compute_metrics(response_template_token_ids),
#     preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    plw=0
)

In [None]:
masked = trainer4.data_collator.torch_call([trainer.train_dataset[0]])

In [None]:
masked = masked.to(model.device)

In [None]:
loss, outputs = trainer4.compute_loss(model, masked, return_outputs=True)

In [None]:
loss

### Another PLW Implementation

The idea here is to add prompt_mask and completion_mask as columns to the dataset, this requires setting `remove_unused_columns` to `False` in the `SFTConfig` so that these columns will not be lost during training. 

The next step will be to used these columns to define custom metrices to calculate `prompt_loss` and `completion_loss` separetly. 

In [None]:
dataset

In [None]:
def format_instruction(example):
    prompt = f"### Question: {example['instruction']}\n"
    if example["input"]:
        prompt += f"### Input: {example['input']}\n"
    
    prompt += "### Answer:"
    
    text = prompt + " " + example['output']
    
    return {
        'prompt': prompt,
        'completion': example['output'],
        'text': text
    }

In [None]:
ds = dataset.map(format_instruction)

In [None]:
ds

In [None]:
ds[0]

In [None]:
ds = ds.remove_columns(['instruction', 'input', 'output'])

In [None]:
ds

In [None]:
ds[0]

In [None]:
max_seq_length = 1024

In [None]:
def tokenize_data(examples):
    
    # Tokenize without truncation
    tokenized_prompt = tokenizer(examples['prompt'], padding=False, truncation=False)
    tokenized_completion = tokenizer(examples['completion'], padding=False, truncation=False)
    tokenized_text = tokenizer(examples['text'], padding=False, truncation=False)
    
    return {
        'tokenized_prompt': tokenized_prompt['input_ids'],
        'tokenized_completion': tokenized_completion['input_ids'],
        'tokenized_text': tokenized_text['input_ids'],
    }


In [None]:
# Apply the tokenization to the dataset
tokenized_dataset = ds.map(tokenize_data, batched=True)
tokenized_dataset

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming you have already created the 'tokenized_dataset'
# Extract the lengths of tokenized_text
tokenized_text_lengths = [len(tokens) for tokens in tokenized_dataset['tokenized_text']]

# Calculate statistics
min_length = np.min(tokenized_text_lengths)
max_length = np.max(tokenized_text_lengths)
q1 = np.percentile(tokenized_text_lengths, 25)
median = np.percentile(tokenized_text_lengths, 50)
q3 = np.percentile(tokenized_text_lengths, 75)
mean_length = np.mean(tokenized_text_lengths)

# Plot histogram
plt.figure(figsize=(10, 6))
plt.hist(tokenized_text_lengths, bins=30, color='skyblue', edgecolor='black')

# Add title and labels
plt.title('Histogram of Tokenized Text Lengths')
plt.xlabel('Length of Tokenized Text')
plt.ylabel('Frequency')

# Add grid for clarity
plt.grid(True)

# Create a dummy plot for legend entries
plt.text(0.95, 0.95, 
         f'Min: {min_length}\nQ1: {q1}\nMedian: {median}\nQ3: {q3}\nMean: {mean_length:.2f}\nMax: {max_length}', 
         transform=plt.gca().transAxes, 
         verticalalignment='top', horizontalalignment='right', 
         bbox=dict(facecolor='white', alpha=0.5))

# Show the plot
plt.show()


In [None]:
max_seq_length = 512

In [None]:
# Filter out sequences longer than max_seq_length
filtered_dataset = tokenized_dataset.filter(lambda example: len(example['tokenized_text']) <= max_seq_length)
filtered_dataset

In [None]:
# Padding function to ensure all sequences are of max_seq_length
def pad_data(examples):
    padded_prompt = tokenizer.pad({'input_ids': examples['tokenized_prompt']}, padding='max_length', max_length=max_seq_length)['input_ids']
    padded_completion = tokenizer.pad({'input_ids': examples['tokenized_completion']}, padding='max_length', max_length=max_seq_length)['input_ids']
    padded_text = tokenizer.pad({'input_ids': examples['tokenized_text']}, padding='max_length', max_length=max_seq_length)['input_ids']
    
    return {
        'tokenized_prompt': padded_prompt,
        'tokenized_completion': padded_completion,
        'tokenized_text': padded_text
    }

In [None]:
padded_dataset = filtered_dataset.map(pad_data, batched=True)
padded_dataset

In [None]:
# Add prompt_mask and completion_mask
def add_masks(examples):
    # Padding token ID
    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    
    # Create masks
    prompt_mask = [[1 if token != pad_token_id else 0 for token in example] for example in examples['tokenized_prompt']]
    completion_mask = [[1 if token != pad_token_id else 0 for token in example] for example in examples['tokenized_completion']]
    
    # Ensure non-overlapping by subtracting completion mask from prompt mask
    for i in range(len(prompt_mask)):
        # Ensure the prompt and completion masks do not overlap
        prompt_end = sum(prompt_mask[i])
        completion_mask[i] = [0] * prompt_end + completion_mask[i][prompt_end:]
    
    return {
        'prompt_mask': prompt_mask,
        'completion_mask': completion_mask
    }

In [None]:
masked_dataset = padded_dataset.map(add_masks, batched=True)
masked_dataset

#### Implement Trainer class

In [None]:
class PLWTrainer(SFTTrainer):
    def __init__(self, *args, plw=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.plw = plw

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        logits = outputs.get("logits")
        labels = inputs.pop("labels")
        
        weights = self.prompt_loss_weight * inputs["prompt_mask"] + inputs["completion_mask"]
        
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_weights = weights[..., 1:].contiguous()

        shift_labels = shift_labels.to(shift_logits.device)
        shift_weights = shift_weights.to(shift_logits.device)

        # per-token losses
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        token_losses = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                                shift_labels.view(-1))

        # Compute weighted average of losses
        loss = (token_losses.float() @ shift_weights.view(-1).float()) / shift_weights.sum()
        return (loss, outputs) if return_outputs else loss

In [None]:
def formatting_func(example):
    prompt = example['prompt']
    completion = example['completion']
    
    outputs = []
    for i in range(len(example['prompt'])):
        text = f"{prompt[i]}\n{completion[i]}"
        outputs.append(text)
        
    return outputs

In [None]:
masked_dataset[0]['prompt_mask']

In [None]:
masked_dataset

In [None]:
subdata = masked_dataset.remove_columns(['tokenized_prompt', 'tokenized_completion', 'tokenized_text', 'prompt_mask', 'completion_mask'])

In [None]:
subdata[0]['text']

In [None]:
from transformers import TrainingArguments

In [None]:
trainer5 = SFTTrainerWithPLWCollator(
    model,
    train_dataset=subdata,
    formatting_func=formatting_func,
    args=TrainingArguments(
        remove_unused_columns=False,
        output_dir="output"
    )
)