# Assistant-Only Masking (Instruction Masking)

This notebook demonstrates how to mask user prompts during training so the model only learns from assistant responses.

**Key Concept**: When fine-tuning instruction models, we set labels to `-100` for user input tokens so the model only computes loss on the assistant's responses.


![assistant-only-masking.png](assistant-only-masking.png)

## Step 1: Load Tokenizer

We'll use Llama 3.2 1B Instruct, which has special tokens for formatting conversations.


In [None]:
# !pip install transformers torch

In [3]:
from transformers import AutoTokenizer

# Load a popular model's tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")


## Step 2: Define a Conversation

A simple user-assistant exchange:


In [4]:
# Define a simple conversation
messages = [
    {"role": "user", "content": "What is 2+2?"},
    {"role": "assistant", "content": "The answer is 4."},
]


## Step 3: Apply Chat Template and Get Input IDs

The `apply_chat_template()` method formats the conversation with special tokens:


In [11]:
# Get input_ids
input_ids = tokenizer.apply_chat_template(
    messages, tokenize=True, add_generation_prompt=False
)

prompt = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=False
)

print(input_ids)
print(prompt)


[128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1682, 5020, 220, 2366, 20, 271, 128009, 128006, 882, 128007, 271, 3923, 374, 220, 17, 10, 17, 30, 128009, 128006, 78191, 128007, 271, 791, 4320, 374, 220, 19, 13, 128009]
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 29 Oct 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

What is 2+2?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The answer is 4.<|eot_id|>


## Step 4: Find Where Assistant Response Starts

We need to identify where the assistant's response begins by finding the assistant header tokens:


In [6]:
# Find where assistant response starts by looking for the assistant header token
# For Llama 3, assistant header is: <|start_header_id|>assistant<|end_header_id|>\n\n
assistant_header = "<|start_header_id|>assistant<|end_header_id|>\n\n"
assistant_header_ids = tokenizer.encode(assistant_header, add_special_tokens=False)

print(f"assistant_header_ids={assistant_header_ids}")

for token in assistant_header_ids:
    print(token, tokenizer.decode(token))


assistant_header_ids=[128006, 78191, 128007, 271]
128006 <|start_header_id|>
78191 assistant
128007 <|end_header_id|>
271 




### Find the Assistant Start Position

Search through input_ids to find where the assistant header appears:


In [7]:
# Find where assistant header appears in input_ids
assistant_start = None
for i in range(len(input_ids) - len(assistant_header_ids)):
    if input_ids[i : i + len(assistant_header_ids)] == assistant_header_ids:
        assistant_start = i + len(assistant_header_ids)  # Start after the header
        break
print(f"Idx of assistant_start = {assistant_start}")


Idx of assistant_start = 42


## Step 5: Create Labels with Masking

Create labels array:
- Set `-100` for all user tokens (these will be ignored in loss calculation)
- Copy input_ids for assistant tokens (these will be used for learning)


In [8]:
# Create labels
labels = [-100] * len(input_ids)  # Mask everything by default
if assistant_start is not None:
    # Unmask from assistant response onwards
    labels[assistant_start:] = input_ids[assistant_start:]

print(len(labels), labels)


49 [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 791, 4320, 374, 220, 19, 13, 128009]


## Step 6: Visualize the Masking

Let's see which tokens are masked and which are used for learning:


In [9]:
# Visualize
print("=" * 70)
print("MASKING VISUALIZATION")
print("=" * 70)
print(f"{'Pos':<5} {'Token':<25} {'Input ID':<12} {'Label':<12} {'Status':<10}")
print("-" * 70)

for i, (inp, lab) in enumerate(zip(input_ids, labels)):
    token = tokenizer.decode([inp]).replace("\n", "\\n")
    status = "LEARNS" if lab != -100 else "MASKED"
    print(f"{i:<5} {token:<25} {inp:<12} {lab:<12} {status:<10}")


MASKING VISUALIZATION
Pos   Token                     Input ID     Label        Status    
----------------------------------------------------------------------
0     <|begin_of_text|>         128000       -100         MASKED    
1     <|start_header_id|>       128006       -100         MASKED    
2     system                    9125         -100         MASKED    
3     <|end_header_id|>         128007       -100         MASKED    
4     \n\n                      271          -100         MASKED    
5     Cut                       38766        -100         MASKED    
6     ting                      1303         -100         MASKED    
7      Knowledge                33025        -100         MASKED    
8      Date                     2696         -100         MASKED    
9     :                         25           -100         MASKED    
10     December                 6790         -100         MASKED    
11                              220          -100         MASKED    
12    202 

## Summary Statistics

How many tokens are masked vs. used for learning?


In [13]:
# Summary stats
masked_count = sum(1 for x in labels if x == -100)
learn_count = len(labels) - masked_count
print("\n" + "=" * 70)
print(f"Total tokens: {len(labels)}")
print(f"Masked (user): {masked_count}")
print(f"Learning (assistant): {learn_count}")
print(f"Learning ratio: {learn_count/len(labels)*100:.1f}%")



Total tokens: 49
Masked (user): 42
Learning (assistant): 7
Learning ratio: 14.3%


## PyTorch Ignores -100 Labels

Let's verify that PyTorch's CrossEntropyLoss actually ignores positions with label `-100`:


In [12]:
import torch
import torch.nn.functional as F

# Simple example: 3 predictions, vocab size = 5
logits = torch.tensor([[1.0, 2.0, 0.5, 0.1, 0.3],
                       [0.2, 0.1, 3.0, 0.5, 0.2],
                       [2.0, 0.5, 0.3, 1.5, 0.8]])

# Case 1: All labels are valid
labels1 = torch.tensor([1, 2, 0])
loss1 = F.cross_entropy(logits, labels1)
print(f"Loss with all valid labels: {loss1:.4f}")

# Case 2: First two labels are -100 (masked)
labels2 = torch.tensor([-100, -100, 0])
loss2 = F.cross_entropy(logits, labels2)
print(f"Loss with first two masked:  {loss2:.4f}")

# Case 3: Only the third position (label=0)
loss3 = F.cross_entropy(logits[2:3], labels1[2:3])
print(f"Loss with only third token:  {loss3:.4f}")

print(f"\n✓ Loss2 ({loss2:.4f}) == Loss3 ({loss3:.4f}): {torch.allclose(loss2, loss3)}")
print("  → PyTorch ignores -100 labels!")


Loss with all valid labels: 0.5743
Loss with first two masked:  0.8388
Loss with only third token:  0.8388

✓ Loss2 (0.8388) == Loss3 (0.8388): True
  → PyTorch ignores -100 labels!


---

## Key Takeaways

1. **`-100` is the magic number**: PyTorch's CrossEntropyLoss ignores any label set to `-100`
2. **Mask user tokens**: Set labels to `-100` for all user input and special tokens
3. **Keep assistant tokens**: Copy input_ids to labels for assistant responses
4. **Why this matters**: The model only learns to generate assistant responses, not to mimic user inputs

This technique is essential for instruction tuning and chat model fine-tuning!
