# Supervised Fine-Tuning

Supervised Fine-Tuning (SFT) is a process primarily used to adapt pre-trained language models to follow instructions, engage in dialogue, and use specific output formats. This is typically done by training on datasets of human-written conversations and instructions.


## When to Use SFT

As a first step, you should consider whether using an existing instruction-tuned model with well-crafted prompts would suffice for your use case. SFT involves significant computational resources and engineering effort, so it should only be pursued when prompting existing models proves insufficient.

> Consider SFT only if you: - Need additional performance beyond what prompting can achieve - Have a specific use case where the cost of using a large general-purpose model outweighs the cost of fine-tuning a smaller model - Require specialized output formats or domain-specific knowledge that existing models struggle with

### Template Control

SFT allows precise control over the model’s output structure. This is particularly valuable when you need the model to:

- Generate responses in a specific chat template format
- Follow strict output schemas
- Maintain consistent styling across responses

### Domain Adaptation

When working in specialized domains, SFT helps align the model with domain-specific requirements by:

- Teaching domain terminology and concepts
- Enforcing professional standards
- Handling technical queries appropriately
- Following industry-specific guidelines

> Before starting SFT, evaluate whether your use case requires: - Precise output formatting - Domain-specific knowledge - Consistent response patterns - Adherence to specific guidelines. This evaluation will help determine if SFT is the right approach for your needs.


## Dataset Preparation

The supervised fine-tuning process **requires a task-specific dataset** structured with input-output pairs. Each pair should consist of:

- An input prompt
- The expected model response
- Any additional context or metadata

The **quality of your training data is crucial** for successful fine-tuning. Let’s look at how to prepare and validate your dataset:


## Training Configuration

The success of your fine-tuning depends heavily on choosing the right training parameters. Let’s explore each important parameter and how to configure them effectively:

The SFTTrainer configuration requires consideration of several parameters that control the training process. Let’s explore each parameter and their purpose:

1. **Training Duration Parameters**:
  - `num_train_epochs`: Controls total training duration
  - `max_steps`: Alternative to epochs, sets maximum number of training steps
  - More epochs allow better learning but risk overfitting

2. **Batch Size Parameters**:
  - `per_device_train_batch_size`: Determines memory usage and training stability
  - `gradient_accumulation_steps`: Enables larger effective batch sizes
  - Larger batches provide more stable gradients but require more memory


3. **Learning Rate Parameters**:
  - `learning_rate`: Controls size of weight updates
  - `warmup_ratio`: Portion of training used for learning rate warmup
  - Too high can cause instability, too low results in slow learning

4. **Monitoring Parameters**:
  - `logging_steps`: Frequency of metric logging
  - `eval_steps`: How often to evaluate on validation data
  - `save_steps`: Frequency of model checkpoint saves

> Start with conservative values and adjust based on monitoring: - Begin with 1-3 epochs - Use smaller batch sizes initially - Monitor validation metrics closely - Adjust learning rate if training is unstable


## Implementation with TRL

Now that we understand the key components, let’s implement the training with proper validation and monitoring. We will use the `SFTTrainer` class from the Transformers Reinforcement Learning (TRL) library, which is built on top of the `transformers` library. Here’s a complete example using the TRL library:

In [1]:
import torch
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, setup_chat_format
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
# Set device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

print(f"Using device: {device}")

Using device: cuda


### Load the model

In [29]:
# Load the model and tokenizer
model_name = "HuggingFaceTB/SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=model_name
).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)

[Set up the chat format](https://huggingface.co/docs/trl/sft_trainer#add-special-tokens-for-chat-format). The `setup_chat_format()` function in trl easily sets up a model and tokenizer for conversational AI tasks. This function:

- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the model’s embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is chatml from OpenAI.

In [30]:
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

# Set our name for the finetune to be saved
finetune_name = "SmolLM2-FT-MyDataset"

### Load the dataset

In [7]:
# Load dataset
dataset = load_dataset(
    path="HuggingFaceTB/smoltalk", name="everyday-conversations"
)
dataset

DatasetDict({
    train: Dataset({
        features: ['full_topic', 'messages'],
        num_rows: 2260
    })
    test: Dataset({
        features: ['full_topic', 'messages'],
        num_rows: 119
    })
})

In [24]:
print(tokenizer.apply_chat_template(dataset["train"][0]["messages"], tokenize=False))

<|im_start|>user
Hi there<|im_end|>
<|im_start|>assistant
Hello! How can I help you today?<|im_end|>
<|im_start|>user
I'm looking for a beach resort for my next vacation. Can you recommend some popular ones?<|im_end|>
<|im_start|>assistant
Some popular beach resorts include Maui in Hawaii, the Maldives, and the Bahamas. They're known for their beautiful beaches and crystal-clear waters.<|im_end|>
<|im_start|>user
That sounds great. Are there any resorts in the Caribbean that are good for families?<|im_end|>
<|im_start|>assistant
Yes, the Turks and Caicos Islands and Barbados are excellent choices for family-friendly resorts in the Caribbean. They offer a range of activities and amenities suitable for all ages.<|im_end|>
<|im_start|>user
Okay, I'll look into those. Thanks for the recommendations!<|im_end|>
<|im_start|>assistant
You're welcome. I hope you find the perfect resort for your vacation.<|im_end|>



### Generate with the base model

In [31]:
# Let's test the base model before training
prompt = "Write a haiku about programming"

# Format with template
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

# Generate response
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100)
print("Before training:")
print(tokenizer.decode(outputs[0], skip_special_tokens=False))

Before training:
<|im_start|>user
Write a haiku about programming<|im_end|>
<|im_start|>assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant
Write a haiku about programming
assistant



### Finetune the model

[Daset format support](https://huggingface.co/docs/trl/sft_trainer#dataset-format-support). The `SFTTrainer` supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:

- conversational format
```python
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
```

- instruction format
```python
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```

If your dataset uses one of the above formats, you can **directly pass it to the trainer without pre-processing**. The SFTTrainer will then format the dataset for you using the defined format from the model’s tokenizer with the apply_chat_template method.

In [10]:
# Configure the SFTTrainer
sft_config = SFTConfig(
    output_dir="./sft_output",
    max_steps=1000,  # Adjust based on dataset size and desired training duration
    per_device_train_batch_size=4,  # Set according to your GPU memory capacity
    learning_rate=5e-5,  # Common starting point for fine-tuning
    logging_steps=10,  # Frequency of logging training metrics
    save_steps=100,  # Frequency of saving model checkpoints
    eval_strategy="steps",  # Evaluate the model at regular intervals
    eval_steps=50,  # Frequency of evaluation
    use_mps_device=(
        True if device == "mps" else False
    ),  # Use MPS for mixed precision training
)

In [None]:
# Initialize the SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=dataset["train"],
    tokenizer=tokenizer,
    eval_dataset=dataset["test"],
)

In [12]:
# Start training
trainer.train()

Step,Training Loss,Validation Loss
50,1.0657,1.158982
100,1.1116,1.124065
150,1.0624,1.095485
200,1.0482,1.079698
250,1.0412,1.070457
300,1.0292,1.061472
350,1.0034,1.054751
400,1.0065,1.050794
450,1.0211,1.042638
500,1.0762,1.033725


TrainOutput(global_step=1000, training_loss=0.9605710778236389, metrics={'train_runtime': 244.7225, 'train_samples_per_second': 16.345, 'train_steps_per_second': 4.086, 'total_flos': 587568496250880.0, 'train_loss': 0.9605710778236389})

In [14]:
# Save the model
trainer.save_model(f"checkpoints/{finetune_name}")

### Generate with fine-tuned model

In [28]:
# Test the fine-tuned model on the same prompt

# Let's test the base model before training
prompt = "Which is the capital of paris?"

# Format with template
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

# Generate response
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=100)
print("After training:")
print(tokenizer.decode(outputs[0], skip_special_tokens=False))

After training:
<|im_start|>user
Which is the capital of paris?<|im_end|>
<|im_start|>assistant
Which is the capital of paris?<|im_start|>
<|im_start|>assistant
Which is the capital of paris?<|im_start|>
<|im_start|>assistant
Which is the capital of paris?<|im_start|>
<|im_start|>assistant
Which is the capital of paris?<|im_start|>
<|im_start|>assistant
Which is the capital of paris?<|im_start|>
<|im_start|>assistant
Which is the capital of paris?<|im_start|>
<|im_start|>assistant
Which is the capital of paris?<|im_start|>
<|im_start|>assistant
Which is


## Monitoring Training Progress

Effective monitoring is crucial for successful fine-tuning. Let’s explore what to watch for during training.

### Understanding Loss Patterns

Training loss typically follows three distinct phases:

1. **Initial Sharp Drop**: Rapid adaptation to new data distribution
2. **Gradual Stabilization**: Learning rate slows as model fine-tunes
3. **Convergence**: Loss values stabilize, indicating training completion


### Metrics to Monitor

Effective monitoring involves tracking quantitative metrics, and evaluating qualitative metrics. Available metrics are:

- Training loss
- Validation loss
- Learning rate progression
- Gradient norms

> Watch for these **warning signs during** training: 1. Validation loss increasing while training loss decreases (overfitting) 2. No significant improvement in loss values (underfitting) 3. Extremely low loss values (potential memorization) 4. Inconsistent output formatting (template learning issues).


### The Path to Convergence

As training progresses, the loss curve should gradually stabilize. The key indicator of healthy training is a **small gap between training and validation loss**, suggesting the model is **learning generalizable patterns** rather than memorizing specific examples. The absolute loss values will vary depending on your task and dataset.

![Perfect Convergence](misc/sft_perfect_convergence.png "Perfect Convergence")


### Warning Signs to Watch For

Several patterns in the loss curves can indicate potential issues. Below we illustrate common warning signs and solutions that we can consider.

![Bad Validation Curve](misc/sft_bad_validation.png "Bad Validation Curve")

If the validation loss decreases at a significantly slower rate than training loss, your model is **likely overfitting** to the training data. Consider:

- Reducing the training steps
- Increasing the dataset size
- Validating dataset quality and diversity

![No Learning](misc/sft_no_learning.png "No Learning")

If the loss doesn’t show significant improvement, the model might be:

- Learning too slowly (try increasing the learning rate)
- Struggling with the task (check data quality and task complexity)
- Hitting architecture limitations (consider a different model)

![Memorization](misc/sft_memorization.png "Memorization")

Extremely low loss values could **suggest memorization** rather than learning. This is particularly concerning if:

- The model performs poorly on new, similar examples
- The outputs lack diversity
- The responses are too similar to training examples

> Regular qualitative evaluation of the model's responses helps catch issues that metrics alone might miss.

We should note that the interpretation of the loss values we outline here is aimed on the most common case, and in fact, loss values can behave on various ways depending on the model, the dataset, the training parameters, etc. If you interested in exploring more about outlined patterns, you should check out this blog post by the people at [Fast AI](https://www.fast.ai/posts/2023-09-04-learning-jumps/).