# Building a Custom Trainer Based on Huggingface 

## Main Task
In this exercise, you need to implement a trainer based on Huggingface Trainer class. 
You need to extend the existing Trainer class of huggingface to use a different loss function.
Below is the link to the documentation of the Trainer class: https://huggingface.co/docs/transformers/main/en/trainer
Note that you need 4-8GB of CPU RAM.
We do not expect to run the full training, but only to implement the necessary components for training

## LLM
For this excersice you need to use the Qwen/Qwen1.5-0.5B-Chat model.
Note: THIS IS A CHAT MODEL. Please read carefully how to use this model in: https://huggingface.co/Qwen/Qwen1.5-0.5B-Chat

## Dataset

You are given the training dataset where each example in the dataset is a dictionary that contains two keys:
* The first key is 'text', which contains a multi-turn conversation in natural language that will be used as input to the llm. The text is in the form [list[dict]], which is a list of dictionaries. Each dictionary has two keys; the first is the role, which can be 'system', 'user', or 'assistant', the second is 'content', which is the content of the message. 'system' corresponds to the system prompt of the llm, 'user' corresponds to the text that is inputted to the llm, and 'assistant' corresponds to the response of the llm.
```
chat = [
    {"role": "system", "content": "You a helpful assisant"},
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "Fine. How can I help you today?"},
    {"role": "user", "content": "What is the circumference of earth?"},
    {"role": "assistant", "content": "It is 40,075 kms."},]
```

* The second key is 'reward', which is a scalar number between -1 and 1, which indicates whether the specific example is good or not

## Optimization Objective

You need to implement the undiscounted REINFORCE algorithm. It is an extension of SFT that takes into account the reward that is assigned to the trajectory.
The loss function is 
$$ L = - \frac{1}{B} \sum_{b \in B}\sum_{t \in seq[b]} (reward[b] * \log p(x[b][t] | x[b][:t])) \textrm{ if $x[b][t]$ is one of the assistant's tokens}$$

The loss is 0 otherwise.
Practically that means that give the aforementioned chat example, the loss is not 0 only for the following pieces of text
```
{"role": "assistant", "content": "Fine. How can I help you today?"},
and
{"role": "assistant", "content": "It is 40,075 kms."}
```



In [139]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch

model_id = "Qwen/Qwen1.5-0.5B-Chat"
dtype = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cpu",
    torch_dtype=dtype,
)


You are give the following data, which has to be loaded in order to be used by the huggingface's transformers Trainer. We recommend using the datasets library 
https://huggingface.co/docs/datasets/en/index

In [140]:
data = [
    {
        "text": [
            {"role": "system", "content": "You a helpful assisant"},
            {"role": "user", "content": "Hello, how are you?"},
            {"role": "assistant", "content": "Fine. How can I help you today?"},
            {"role": "user", "content": "What is the circumference of earth?"},
            {"role": "assistant", "content": "It is 40,075 kms."},
        ],
        "reward": 1,
    },
    {
        "text": [
            {"role": "system", "content": "You a helpful assisant"},
            {"role": "user", "content": "Hello, how are you?"},
            {"role": "assistant", "content": "Fine. How can I help you today?"},
            {"role": "user", "content": "What is the shape of earth?"},
            {"role": "assistant", "content": "Earth is a square"},
        ],
        "reward": -1,
    },
]
print(data)

[{'text': [{'role': 'system', 'content': 'You a helpful assisant'}, {'role': 'user', 'content': 'Hello, how are you?'}, {'role': 'assistant', 'content': 'Fine. How can I help you today?'}, {'role': 'user', 'content': 'What is the circumference of earth?'}, {'role': 'assistant', 'content': 'It is 40,075 kms.'}], 'reward': 1}, {'text': [{'role': 'system', 'content': 'You a helpful assisant'}, {'role': 'user', 'content': 'Hello, how are you?'}, {'role': 'assistant', 'content': 'Fine. How can I help you today?'}, {'role': 'user', 'content': 'What is the shape of earth?'}, {'role': 'assistant', 'content': 'Earth is a square'}], 'reward': -1}]


Please note that we do not expect to run the full training. We will just check whether the code is generally correct. You are free to use any library such as transformers, trl, etc for your implementation.

In [141]:
tokenizer.special_tokens_map

{'eos_token': '<|im_end|>',
 'pad_token': '<|endoftext|>',
 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}

In [142]:
tokenizer.apply_chat_template(
    data[0]["text"],
    tokenize=True,
    add_generation_prompt=False,
    return_dict=True,
    return_assistant_tokens_mask=True,
    return_tensors="pt",
)

{'input_ids': tensor([[151644,   8948,    198,   2610,    264,  10950,   1071,    285,    517,
         151645,    198, 151644,    872,    198,   9707,     11,   1246,    525,
            498,     30, 151645,    198, 151644,  77091,    198,  63716,     13,
           2585,    646,    358,   1492,    498,   3351,     30, 151645,    198,
         151644,    872,    198,   3838,    374,    279,  74926,    315,   9393,
             30, 151645,    198, 151644,  77091,    198,   2132,    374,    220,
             19,     15,     11,     15,     22,     20,  96677,     13, 151645,
            198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'assistant_masks': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [143]:
tokenizer.apply_chat_template(data[0]["text"], tokenize=False, add_generation_prompt=False, return_tensors="pt")

'<|im_start|>system\nYou a helpful assisant<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nFine. How can I help you today?<|im_end|>\n<|im_start|>user\nWhat is the circumference of earth?<|im_end|>\n<|im_start|>assistant\nIt is 40,075 kms.<|im_end|>\n'

In [144]:
tokenizer.encode("<|im_start|>assistant\n<|im_end|>\n")

[151644, 77091, 198, 151645, 198]

In [145]:
tokenizer.encode("<|im_start|>assistant\nA<|im_end|>\n")

[151644, 77091, 198, 32, 151645, 198]

In [146]:
tokenizer.encode("<|im_start|>assistant\nA<|im_end|>")

[151644, 77091, 198, 32, 151645]

In [147]:
tokenizer.apply_chat_template(
    data[0]["text"],
    tokenize=True,
    add_generation_prompt=False,
    return_dict=True,
    return_tensors="pt",
)

{'input_ids': tensor([[151644,   8948,    198,   2610,    264,  10950,   1071,    285,    517,
         151645,    198, 151644,    872,    198,   9707,     11,   1246,    525,
            498,     30, 151645,    198, 151644,  77091,    198,  63716,     13,
           2585,    646,    358,   1492,    498,   3351,     30, 151645,    198,
         151644,    872,    198,   3838,    374,    279,  74926,    315,   9393,
             30, 151645,    198, 151644,  77091,    198,   2132,    374,    220,
             19,     15,     11,     15,     22,     20,  96677,     13, 151645,
            198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [148]:
import torch

assistant_start = tokenizer.encode("<|im_start|>assistant\n", return_tensors="pt")[0]
assistant_end = tokenizer.encode("<|im_end|>", return_tensors="pt")[0]


def tokenize_dataset_entry(entry):
    out = tokenizer.apply_chat_template(
        entry["text"],
        tokenize=True,
        add_generation_prompt=False,
        return_dict=True,
        return_tensors="pt",
    )

    out["input_ids"] = out["input_ids"].squeeze()
    out["attention_mask"] = out["attention_mask"].squeeze()

    out["reward"] = torch.tensor([entry["reward"]], dtype=torch.float32)

    assistant_mask = torch.zeros_like(out["input_ids"], dtype=torch.bool)
    start_indices = (
        (out["input_ids"].unfold(0, len(assistant_start), 1) == assistant_start).all(dim=1).nonzero(as_tuple=True)[0]
    )
    end_indices = (
        (out["input_ids"].unfold(0, len(assistant_end), 1) == assistant_end).all(dim=1).nonzero(as_tuple=True)[0]
    )
    # print(start_indices)
    # print(end_indices)
    for start_idx in start_indices:
        end_idx = end_indices[end_indices > start_idx][0]
        # print(start_idx, end_idx)
        assistant_mask[start_idx + len(assistant_start) : end_idx] = True
    out["assistant_mask"] = assistant_mask

    labels = out["input_ids"].clone()
    labels[~assistant_mask] = -100
    out["labels"] = labels

    return out


In [149]:
tokenize_dataset_entry({"text": [{"role": "assistant", "content": "A"}], "reward": 42})

{'input_ids': tensor([151644,   8948,    198,   2610,    525,    264,  10950,  17847,     13,
        151645,    198, 151644,  77091,    198,     32, 151645,    198]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 'reward': tensor([42.]), 'assistant_mask': tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False]), 'labels': tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100,   32, -100, -100])}

In [150]:
from datasets import Dataset

dataset = Dataset.from_list(data)
dataset

Dataset({
    features: ['text', 'reward'],
    num_rows: 2
})

In [151]:
from transformers.data.data_collator import DataCollatorForSeq2Seq


class ReinforceDataCollator(DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
        if return_tensors is None:
            return_tensors = self.return_tensors

        processed_entries = [tokenize_dataset_entry(feature) for feature in features]

        assistant_masks = [entry.pop("assistant_mask") for entry in processed_entries]
        rewards = [entry.pop("reward") for entry in processed_entries]

        batch = super().__call__(processed_entries, return_tensors=return_tensors)

        max_length = batch["input_ids"].shape[1]
        padded_assistant_mask = torch.zeros((len(features), max_length), dtype=assistant_masks[0].dtype)
        for i, mask in enumerate(assistant_masks):
            padded_assistant_mask[i, : len(mask)] = mask
        batch["assistant_mask"] = padded_assistant_mask

        batch["reward"] = torch.stack(rewards, dim=0)

        return batch


data_collator = ReinforceDataCollator(tokenizer=tokenizer, padding=True)

# Test collator with a small batch
sample_batch = [dataset[0], dataset[1]]
collated_batch = data_collator(sample_batch)
print("\nCollated Batch Example:")
for key, value in collated_batch.items():
    print(f"{key}: shape={value.shape}, dtype={value.dtype}")


Collated Batch Example:
input_ids: shape=torch.Size([2, 64]), dtype=torch.int64
attention_mask: shape=torch.Size([2, 64]), dtype=torch.int64
labels: shape=torch.Size([2, 64]), dtype=torch.int64
assistant_mask: shape=torch.Size([2, 64]), dtype=torch.bool
reward: shape=torch.Size([2, 1]), dtype=torch.float32


In [152]:
from transformers.trainer import Trainer


class ReinforceTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        rewards = inputs.pop("reward")
        assistant_mask = inputs.pop("assistant_mask")
        labels = inputs.get("labels")

        outputs = model(**inputs)
        logits = outputs.get("logits")

        # Shift to match predictions
        logits = logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()
        assistant_mask = assistant_mask[:, 1:].contiguous()

        batch_size, seq_len, vocab_size = logits.shape

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        log_probs_per_token = loss_fct(logits.view(batch_size * seq_len, vocab_size), labels.view(batch_size * seq_len))
        log_probs_per_token = log_probs_per_token.view(batch_size, seq_len)

        masked_log_probs = log_probs_per_token * assistant_mask
        # print(masked_log_probs)
        reward_weighted_log_probs = masked_log_probs * rewards
        # print(reward_weighted_log_probs)

        per_sequence_loss = reward_weighted_log_probs.sum(dim=-1)

        loss = per_sequence_loss.mean()

        return (loss, outputs) if return_outputs else loss

In [153]:
from transformers import TrainingArguments


training_args = TrainingArguments(
    num_train_epochs=5,
    logging_steps=1,
    remove_unused_columns=False,
)

trainer = ReinforceTrainer(model=model, args=training_args, train_dataset=dataset, data_collator=data_collator)

trainer.train()

Step,Training Loss
1,12.1517
2,-1.9166
3,-23.1591
4,-32.6245
5,-40.0766


TrainOutput(global_step=5, training_loss=-17.12502956390381, metrics={'train_runtime': 4.9174, 'train_samples_per_second': 2.034, 'train_steps_per_second': 1.017, 'total_flos': 1184276152320.0, 'train_loss': -17.12502956390381, 'epoch': 5.0})