ToDo:
- Validation loss is 'nan', ChatGPT: "Yes. For validation perplexity or cross‑entropy, you need reference answers. Keep the full user+assistant in the eval split and mask user tokens (label = ‑100). Then you won’t get NaN during validation."

In [1]:
import time, trl, torch, datasets, wandb, transformers as tr
from pprint import pprint
from typing import Optional

# LoRA:
# python sft.py \
#     --model_name_or_path Qwen/Qwen2-0.5B \
#     --dataset_name trl-lib/Capybara \
#     --learning_rate 2.0e-4 \
#     --num_train_epochs 1 \
#     --packing \
#     --per_device_train_batch_size 2 \
#     --gradient_accumulation_steps 8 \
#     --gradient_checkpointing \
#     --logging_steps 8 \
#     --eval_strategy steps \
#     --eval_steps 16 \
#     --use_peft \
#     --lora_r 32 \
#     --lora_alpha 16 \
#     --output_dir Qwen2-0.5B-SFT \

# Arguments:

script_args = trl.ScriptArguments(
    # dataset_name='trl-lib/Capybara')
    dataset_name='ZSvedic/gpt4o-arena-brevity-dpo')

model_args = trl.ModelConfig(
    # model_name_or_path='Qwen/Qwen2-0.5B',
    model_name_or_path='Qwen/Qwen2-0.5B-Instruct',
    use_peft=True,
    lora_r=32,
    lora_alpha=16)

training_args = trl.SFTConfig(
    learning_rate=2e-4,
    num_train_epochs=1,
    packing=False, # Changed from True, to make sense of debugging.
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    logging_steps=5,
    report_to="wandb",
    eval_strategy="steps",
    eval_steps=5,
    output_dir="OUTPUT/Qwen2-0.5B-SFT",
    run_name=f"Qwen2-0.5B-SFT-{time.strftime('%Y-%m-%d-%H-%M')}",
    )

training_args.model_init_kwargs = dict(
    use_cache=False if training_args.gradient_checkpointing else True,
    device_map=trl.get_kbit_device_map(),
    quantization_config=trl.get_quantization_config(model_args),
)

# Tokenizer:
tokenizer = tr.AutoTokenizer.from_pretrained(model_args.model_name_or_path)
# tokenizer.pad_token = tokenizer.eos_token # Is this needed?
print(f"pad_token: {tokenizer.pad_token}")
print(f"eos_token: {tokenizer.eos_token}")

pad_token: <|endoftext|>
eos_token: <|im_end|>


In [2]:
# Dataset loading.
dataset = datasets.load_dataset(script_args.dataset_name)

# Train messages.
dataset['train'] = dataset['train'].map(
    lambda row: {'messages': [
        {'role': 'user', 'content': row['prompt']},
        {'role': 'assistant', 'content': row['chosen']}]},
    remove_columns = ['prompt', 'chosen', 'rejected'])

# Test messages.
dataset['test'] = dataset['test'].map(
    lambda row: {'messages': [
        {'role': 'user', 'content': row['prompt']}]},
    remove_columns = ['prompt', 'chosen', 'rejected'])

# Debug print.
print(f"dataset: {dataset}")
print(f"TRAIN EXAMPLE: {dataset['train'][0]}")
print(f"TEST EXAMPLE: {dataset['test'][0]}")
print(f"chat_template: {tokenizer.chat_template}")

# Tokenize the dataset.
mapped_dataset = datasets.DatasetDict()

mapped_dataset['train'] = dataset['train'].map(
    lambda x: {
        "text": tokenizer.apply_chat_template(x["messages"], tokenize=False, 
                                              add_generation_prompt=False)
        },
    remove_columns=dataset['train'].column_names)

mapped_dataset['test'] = dataset['test'].map(
    lambda x: {
        "text": tokenizer.apply_chat_template(x["messages"], tokenize=False, 
                                              add_generation_prompt=True)
        },
    remove_columns=dataset['test'].column_names)

# Debug print.
print(f"mapped_dataset: {mapped_dataset}")
print(f"TRAIN EXAMPLE 'text': {mapped_dataset['train'][0]['text']}")
print(f"TEST EXAMPLE 'text': {mapped_dataset['test'][0]['text']}")

# Data collator makes sure we train on completions only.
collator = trl.DataCollatorForCompletionOnlyLM(
    instruction_template="<|im_start|>user\n", 
    response_template="<|im_start|>assistant\n", 
    tokenizer=tokenizer)

dataset: DatasetDict({
    train: Dataset({
        features: ['question-id', 'messages'],
        num_rows: 22941
    })
    test: Dataset({
        features: ['question-id', 'messages'],
        num_rows: 2549
    })
})
TRAIN EXAMPLE: {'question-id': '1dd6137eb3c3470989e18ab729ccc0b3', 'messages': [{'content': 'write short telugu poem', 'role': 'user'}, {'content': 'Telugu poem: ఆకాశం నీలమై పూగుతోంది, సూర్యుడు కొత్త కిరణాలు తెచ్చుకొన్నాడు.', 'role': 'assistant'}]}
TEST EXAMPLE: {'question-id': '7e730384bbb649af9f6e150dbf129b53', 'messages': [{'content': 'What is your purpose.', 'role': 'user'}]}
chat_template: {% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system
You are a helpful assistant.<|im_end|>
' }}{% endif %}{{'<|im_start|>' + message['role'] + '
' + message['content'] + '<|im_end|>' + '
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
' }}{% endif %}
mapped_dataset: DatasetDict({
    train: Dataset({


In [3]:
from trl import SFTTrainer

class LogCompletionsLengthCallback(tr.TrainerCallback):
    def __init__(self, trainer: tr.Trainer, num_prompts: Optional[int] = None, freq: Optional[int] = None):
        self.trainer = trainer
        self.freq = freq
        self._last_logged_step = -1
        self.eval_dataset = trainer.eval_dataset.select(range(num_prompts))

    def on_step_end(self, args, state, control, **kwargs):
        # Only log once per step (this method may be called multiple times)
        if state.global_step == self._last_logged_step:
            return

        # Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps)
        freq = self.freq or state.eval_steps
        if state.global_step % freq != 0:
            return

        tokenizer = kwargs["processing_class"]
        tokenizer.padding_side = "left"
        accelerator = self.trainer.accelerator
        model = self.trainer.model_wrapped
        completion_lens = []
        with accelerator.split_between_processes(self.eval_dataset["input_ids"]) as prompts:
            with trl.models.utils.unwrap_model_for_generation(model, accelerator) as unwrapped_model:
                for prompt_ids in prompts:
                    prompt_ids = torch.tensor([prompt_ids], device=unwrapped_model.device)
                    generations = unwrapped_model.generate(
                        prompt_ids, generation_config=tr.GenerationConfig(max_new_tokens=150)
                    )
                    completion_lens.append(len(generations[0]) - len(prompt_ids[0]))

        # Build the data to log
        if self.trainer.accelerator.is_main_process:
            wandb.log({"completions_len": sum(completion_lens) / len(completion_lens)}, step=state.global_step)

        # Save the last logged step, so we don't log the same completions multiple times
        self._last_logged_step = state.global_step

# class InspectorSFTTrainer(SFTTrainer):
#     def training_step(self, model, inputs, num_items_in_batch):
#         input_ids = inputs['input_ids'][0]
#         att_mask = inputs['attention_mask'][0]
#         labels = inputs['labels'][0]
#         num_ones = att_mask.sum().item()
#         if num_ones < 100:
#             print(f"\n------------------------")
#             print(f"num_ones: {num_ones}")
#             decoded_tokens = self.processing_class.decode(input_ids, skip_special_tokens=False)
#             print(f"ALL TOKENS:\n{decoded_tokens}")
#             valid_token_ids = input_ids[att_mask == 1]
#             decoded_masked_tokens = self.processing_class.decode(valid_token_ids, skip_special_tokens=False)
#             print(f"MASK=1 TOKENS:\n{decoded_masked_tokens}")
#             print(f"LABELS: {labels}")
#             filtered_labels = labels[labels != -100]
#             print(f"LABELS decoded:\n{self.processing_class.decode(filtered_labels, skip_special_tokens=True)}")
#             print(f"\n------------------------")

#         return super().training_step(model, inputs, num_items_in_batch)

# # Initialize the custom trainer
# trainer = InspectorSFTTrainer(
trainer = SFTTrainer(
    model=model_args.model_name_or_path,
    args=training_args,
    train_dataset=mapped_dataset[script_args.dataset_train_split],
    eval_dataset=mapped_dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
    data_collator=collator,
    processing_class=tokenizer,
    peft_config=trl.get_peft_config(model_args),
)

# Add any callbacks as needed
# completion_callback = trl.LogCompletionsCallback(trainer, num_prompts=16)
# trainer.add_callback(completion_callback)
len_callback = LogCompletionsLengthCallback(trainer, num_prompts=16)
trainer.add_callback(len_callback)

# Start training
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mzsvedic[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss,Validation Loss
5,3.1707,
10,2.6677,
15,2.7338,


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
` in the following instance: <|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user

You are a prodessional physic teacher who knows a lot about physic, describe each step in detail and gives the reequired formula or equation in forms of latex
It is proposed to transmit a signal over a distance of 4.5 × 103 km by means of an optic fibre. The input signal has a power of 9.8 mW. The minimum signal that can be detected at the output has a power of 6.3 × 10–17 W. For this signal power, the signal‐to‐noise ratio is 21 dB. Calculate(Use Wolfram alpha if needed): the power of the background noise.<|im_end|>
<|im_start|>assistant
. This instance will be ignored 

KeyboardInterrupt: 

In [None]:
# class LogCompletionsLengthCallback(tr.TrainerCallback):
#     def __init__(self, trainer: tr.Trainer, num_prompts: Optional[int] = None, freq: Optional[int] = None):
#         self.trainer = trainer
#         self.freq = freq
#         self._last_logged_step = -1
#         self.eval_dataset = trainer.eval_dataset.select(range(num_prompts))

#     def on_step_end(self, args, state, control, **kwargs):
#         # Only log once per step (this method may be called multiple times)
#         if state.global_step == self._last_logged_step:
#             return

#         # Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps)
#         freq = self.freq or state.eval_steps
#         if state.global_step % freq != 0:
#             return

#         tokenizer = kwargs["processing_class"]
#         tokenizer.padding_side = "left"
#         accelerator = self.trainer.accelerator
#         model = self.trainer.model_wrapped
#         completion_lens = []
#         with accelerator.split_between_processes(self.eval_dataset["input_ids"]) as prompts:
#             with trl.models.utils.unwrap_model_for_generation(model, accelerator) as unwrapped_model:
#                 for prompt_ids in prompts:
#                     prompt_ids = torch.tensor([prompt_ids], device=unwrapped_model.device)
#                     generations = unwrapped_model.generate(
#                         prompt_ids, generation_config=tr.GenerationConfig(max_new_tokens=150)
#                     )
#                     completion_lens.append(len(generations[0]) - len(prompt_ids[0]))

#         # Build the data to log
#         if self.trainer.accelerator.is_main_process:
#             wandb.log({"completions_len": sum(completion_lens) / len(completion_lens)}, step=state.global_step)

#         # Save the last logged step, so we don't log the same completions multiple times
#         self._last_logged_step = state.global_step
        
# # Training:
# trainer = trl.SFTTrainer(
#     model=model_args.model_name_or_path,
#     args=training_args,
#     train_dataset=dataset[script_args.dataset_train_split],
#     eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
#     processing_class=tokenizer,
#     peft_config=trl.get_peft_config(model_args),
# )
# # len_callback = LogCompletionsLengthCallback(trainer, num_prompts=16)
# # trainer.add_callback(len_callback)
# completion_callback = trl.LogCompletionsCallback(trainer, num_prompts=16)
# trainer.add_callback(completion_callback)

# print(f"pad_token: {tokenizer.pad_token}, eos_token: {tokenizer.eos_token}")

# trainer.train()

# # Save and push to hub
# trainer.save_model(training_args.output_dir)