In [1]:
from datasets import load_dataset
import json
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import os
from dotenv import load_dotenv

load_dotenv()


True

In [2]:
data_file = "../datasets/gsm8k_step_jepa.jsonl"

In [3]:
dataset = load_dataset('json', data_files=data_file)['train']

In [4]:
dataset[0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': 'Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72',
 'ground_truth': '72',
 'total_steps': 7,
 'messages': [{'role': 'system',
   'content': 'Please solve the problem step by step (separate steps with double newlines), but keep it short and put your final answer (do not include any other text or units) within \\boxed{}.'},
  {'role': 'user',
   'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?'},
  {'role': 'assistant',
   'content': "First, I note that Natalia sold 48 clips in April.\n\nIn May, she sold half as many clips as she did in April, which is 24 clips.\n\nTo find the total number of clips sold over bot

In [None]:
bad_indices = []
for i, sample in enumerate(dataset):
    for msg in sample.get("messages", []):
        if msg.get("role") == "user" and "\n\n" in (msg.get("content") or ""):
            bad_indices.append(i)
            break

In [5]:
# login to huggingface
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [6]:
# Copy from gsm8k.py - Full code for label masking
def load_and_prepare_dataset(data_file, tokenizer,
                             max_length=2048, debug=0):
    """Load JSONL dataset and format for training with proper label masking"""
    
    # Load dataset
    dataset = load_dataset('json', data_files=data_file)['train']
    if  torch.cuda.current_device() == 0:
        print(f"Loaded {len(dataset)} examples from {data_file}")
    
    def tokenize_conversations(examples):
        """Tokenize conversations and mask input tokens properly"""
        input_ids_list = []
        labels_list = []
        attention_mask_list = []

        for msg_idx, messages in enumerate(examples['messages']):
            # Apply chat template if available, otherwise format manually
            formatted_chat = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False,
            )
            
            # Tokenize the formatted conversation with padding to max_length
            tokenized = tokenizer(
                formatted_chat,
                truncation=True,
                max_length=max_length,
                padding="max_length",  # Pad to max_length for consistent tensor shapes
                return_tensors=None
            )
            
            input_ids = tokenized["input_ids"]
            attention_mask = tokenized["attention_mask"]
            
            # Create labels with proper masking
            labels = create_masked_labels(messages, tokenizer, input_ids, attention_mask, formatted_chat=formatted_chat)
            
            input_ids_list.append(input_ids)
            labels_list.append(labels)
            attention_mask_list.append(attention_mask)
        
        return {
            "input_ids": input_ids_list,
            "labels": labels_list,
            "attention_mask": attention_mask_list,
        }
    
    def create_masked_labels(messages, tokenizer, input_ids, attention_mask, formatted_chat=None):
        """Create labels with input tokens masked (-100)"""
        labels = [-100] * len(input_ids)
        
        # Mask padding tokens in labels
        for i, mask in enumerate(attention_mask):
            if mask == 0:  # Padding token
                labels[i] = -100
        
        # Find assistant responses and unmask only those tokens
        for msg in messages:
            if msg['role'] == 'assistant':
                assistant_content = msg['content']
                
                # Find assistant marker in decoded text to locate assistant content accurately
                assistant_marker = '<|start_header_id|>assistant<|end_header_id|>'
                decoded_full = tokenizer.decode(input_ids, skip_special_tokens=False)
                
                # Find assistant marker position in decoded text
                if assistant_marker not in decoded_full:
                    # If marker not found, skip this assistant message
                    continue
                
                marker_pos_in_decoded = decoded_full.find(assistant_marker)
                # Assistant content starts after the marker
                assistant_text_start = marker_pos_in_decoded + len(assistant_marker)
                
                # Map assistant marker position to token position using incremental decode
                marker_token_pos = None
                for i in range(len(input_ids)):
                    if attention_mask[i] == 0:  # Skip padding
                        continue
                    decoded_so_far = tokenizer.decode(input_ids[:i+1], skip_special_tokens=False)
                    if len(decoded_so_far) > marker_pos_in_decoded + len(assistant_marker):
                        # Found the token right after assistant marker
                        marker_token_pos = i
                        break
                
                if marker_token_pos is None:
                    # Couldn't find marker position, skip
                    continue
                
                # Now find assistant content starting from after the marker
                # Decode from marker_token_pos onwards to find where assistant_content starts
                assistant_start_token = None
                assistant_end_token = None
                
                # Decode incrementally from marker to find assistant content
                text_after_marker = ""
                for i in range(marker_token_pos, len(input_ids)):
                    if attention_mask[i] == 0:  # Skip padding
                        continue
                    piece = tokenizer.decode([input_ids[i]], skip_special_tokens=False)
                    text_after_marker += piece
                    
                    # Check if we've found the start of assistant content
                    if assistant_start_token is None and assistant_content in text_after_marker:
                        # Found start of assistant content
                        assistant_start_token = i
                    
                    # Check if we've found the full assistant content
                    if assistant_start_token is not None:
                        # Check if we have the full content
                        if len(text_after_marker) >= text_after_marker.find(assistant_content) + len(assistant_content):
                            assistant_end_token = i
                            # Don't break - continue to find last non-padding token
                    
                    # Keep text buffer reasonable size
                    if len(text_after_marker) > len(assistant_content) * 2:
                        text_after_marker = text_after_marker[-len(assistant_content):]
                
                # Find last non-padding token (in case assistant content was truncated)
                last_non_padding = None
                for i in range(len(input_ids) - 1, -1, -1):
                    if attention_mask[i] == 1:
                        last_non_padding = i
                        break
                
                # Unmask assistant tokens: from assistant_start_token to last_non_padding
                # (This covers the full assistant content, even if truncated)
                if assistant_start_token is not None and last_non_padding is not None:
                    for j in range(assistant_start_token, last_non_padding + 1):
                        if attention_mask[j] == 1:  # Only unmask non-padding tokens
                            labels[j] = input_ids[j]
                
                if debug == 4:
                    exit(0)
        
        return labels
    
    # Tokenize dataset
    tokenized_dataset = dataset.map(
        tokenize_conversations,
        batched=True,
        remove_columns=dataset.column_names
    )
    
    return tokenized_dataset


In [None]:
from transformers import AutoTokenizer
import torch

MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
MAX_LENGTH = 2048


def tokenize_text(tokenizer, text, max_length=2048, pad_to_max=True):
    return tokenizer(
        text,
        truncation=True,
        max_length=max_length,
        padding="max_length" if pad_to_max else False,
        return_tensors=None
    )


def build_full_text(tokenizer, messages):
    return tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False
    )


def build_prefix_up_to_assistant_content(tokenizer, messages, assistant_index):
    """
    Builds a prefix string that ends *exactly after* the assistant header + its mandatory '\n\n',
    but BEFORE assistant content begins.
    """
    before = messages[:assistant_index]  # everything before assistant message
    assistant_header_only = [{"role": "assistant", "content": ""}]  # assistant header

    prefix = tokenizer.apply_chat_template(
        before + assistant_header_only,
        tokenize=False,
        add_generation_prompt=False
    )

    # IMPORTANT:
    # apply_chat_template produces:
    # <assistant header>\n\n<|eot_id|>
    # but we want to stop BEFORE <|eot_id|> because that's after assistant content would end.
    #
    # So remove the trailing "<|eot_id|>" if present, and keep the '\n\n' that marks content start.
    if prefix.endswith("<|eot_id|>"):
        prefix = prefix[:-len("<|eot_id|>")]

    return prefix


def compute_assistant_spans(tokenizer, full_text, messages, max_length=2048):
    """
    Returns list of (start_idx, end_idx) spans in token space for assistant CONTENT ONLY.
    Each span excludes assistant headers and includes the content tokens only.
    """

    full = tokenize_text(tokenizer, full_text, max_length=max_length, pad_to_max=True)
    full_ids = full["input_ids"]
    full_mask = full["attention_mask"]

    last_non_pad = max(i for i, m in enumerate(full_mask) if m == 1)

    spans = []
    full_len = last_non_pad + 1

    # We find each assistant message span by:
    # start = token length of prefix ending at assistant content start
    # end   = token length of prefix ending at assistant content end (right before next role header)
    #
    # We'll do this by constructing prefixes incrementally.

    for idx, msg in enumerate(messages):
        if msg["role"] != "assistant":
            continue

        # Prefix right before assistant content begins
        prefix_start_text = build_prefix_up_to_assistant_content(tokenizer, messages, idx)
        prefix_start_ids = tokenize_text(tokenizer, prefix_start_text, max_length=max_length, pad_to_max=False)["input_ids"]
        start = len(prefix_start_ids)

        # Prefix up through assistant content (i.e., include this assistant fully)
        prefix_end_messages = messages[:idx+1]
        prefix_end_text = tokenizer.apply_chat_template(
            prefix_end_messages,
            tokenize=False,
            add_generation_prompt=False
        )
        prefix_end_ids = tokenize_text(tokenizer, prefix_end_text, max_length=max_length, pad_to_max=False)["input_ids"]
        end = len(prefix_end_ids)

        # In full conversation, assistant span is [start, end) BUT must exclude the final <|eot_id|>
        # Usually the assistant message ends with <|eot_id|>. We want to exclude it.
        #
        # So shrink end by 1 if last token decodes to <|eot_id|>
        if end > start:
            last_tok = tokenizer.decode([prefix_end_ids[-1]], skip_special_tokens=False)
            if last_tok == "<|eot_id|>":
                end -= 1

        # Clamp to avoid exceeding truncation
        start = min(start, full_len)
        end = min(end, full_len)

        if start < end:
            spans.append((start, end))

    return full_ids, full_mask, spans, last_non_pad


def create_masked_labels(messages, tokenizer, max_length=2048):
    full_text = build_full_text(tokenizer, messages)
    input_ids, attention_mask, spans, last_non_pad = compute_assistant_spans(
        tokenizer, full_text, messages, max_length=max_length
    )

    labels = [-100] * len(input_ids)

    # Unmask each assistant content span
    for start, end in spans:
        for i in range(start, end):
            labels[i] = input_ids[i]

    # Mask padding (already -100 but keep safe)
    for i in range(last_non_pad + 1, len(labels)):
        labels[i] = -100

    return input_ids, attention_mask, labels, spans, last_non_pad, full_text


def pretty_print_debug(tokenizer, input_ids, labels, spans, last_non_pad, window=50):
    """
    Print tokens around each assistant span.
    """
    for span_id, (start, end) in enumerate(spans):
        print("\n" + "=" * 90)
        print(f"ASSISTANT SPAN {span_id}: start={start}, end={end} (len={end-start})")
        print(f"LAST NON-PAD IDX: {last_non_pad}")
        print("=" * 90)

        s = max(0, start - window)
        e = min(len(input_ids), start + window)

        print("Index | Masked? | Token (decoded)")
        print("-" * 90)

        for i in range(s, e):
            tok = tokenizer.decode([input_ids[i]], skip_special_tokens=False)
            masked = (labels[i] == -100)
            print(f"{i:5d} | {str(masked):6s} | {repr(tok)}")


def main():
    messages = [
        {
            "role": "system",
            "content": "Please solve the problem step by step (separate steps with double newlines), "
                       "but keep it short and put your final answer (do not include any other text or units) "
                       "within \\boxed{}."
        },
        {
            "role": "user",
            "content": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. "
                       "How many clips did Natalia sell altogether in April and May?"
        },
        {
            "role": "assistant",
            "content": "Step 1: April = 48\n\nStep 2: May = 48/2 = 24\n\nStep 3: Total = 48 + 24 = \\boxed{72}"
        }
    ]

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token

    input_ids, attention_mask, labels, spans, last_non_pad, full_text = create_masked_labels(
        messages, tokenizer, max_length=MAX_LENGTH
    )

    # Sanity check
    total_tokens = len(input_ids)
    masked = sum(1 for x in labels if x == -100)
    unmasked = total_tokens - masked

    print("\nFULL CHAT TEXT (truncated to 600 chars):")
    print(full_text[:600])
    print("\n" + "=" * 80)

    print(f"\nTotal tokens:   {total_tokens}")
    print(f"Masked tokens:  {masked}")
    print(f"Unmasked tokens:{unmasked}")
    print(f"Assistant spans: {spans}")

    # Validate that assistant header tokens are masked:
    # We'll assert that the token just before each span start is masked (usually '\n\n' or header tokens)
    for (start, end) in spans:
        if start > 0:
            assert labels[start - 1] == -100, "❌ Assistant header/newline token incorrectly unmasked!"

    # Validate that each span contains unmasked tokens
    for (start, end) in spans:
        assert any(labels[i] != -100 for i in range(start, end)), "❌ Span has no unmasked tokens!"

    # Validate padding masked
    assert all(labels[i] == -100 for i in range(last_non_pad + 1, len(labels))), "❌ Padding tokens unmasked!"

    print("\n✅ All masking sanity checks passed!\n")

    pretty_print_debug(tokenizer, input_ids, labels, spans, last_non_pad, window=35)


In [18]:
main()


FULL CHAT TEXT (truncated to 600 chars):
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 05 Jan 2026

Please solve the problem step by step (separate steps with double newlines), but keep it short and put your final answer (do not include any other text or units) within \boxed{}.<|eot_id|><|start_header_id|>user<|end_header_id|>

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Step 1: April = 48

Step 2: May = 


Total tokens:   2048
Masked tokens:  2008
Unmasked tokens:40
Assistant spans: [(113, 153)]

✅ All masking sanity checks passed!


ASSISTANT SPAN 0: start=113, end=153 (len=40)
LAST NON-PAD IDX: 153
Index | Masked? | Token (decoded)
------------------------------------------------------------------------------------------
   78 | True   | ' of'
   79 | T

In [22]:
print(tokenizer.chat_template)

{{- bos_token }}
{%- if custom_tools is defined %}
    {%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
    {%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
    {%- if strftime_now is defined %}
        {%- set date_string = strftime_now("%d %b %Y") %}
    {%- else %}
        {%- set date_string = "26 Jul 2024" %}
    {%- endif %}
{%- endif %}
{%- if not tools is defined %}
    {%- set tools = none %}
{%- endif %}

{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
    {%- set system_message = messages[0]['content']|trim %}
    {%- set messages = messages[1:] %}
{%- else %}
    {%- set system_message = "" %}
{%- endif %}

{#- System message #}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if tools is not none %}
    {{- "Environment: ipython\n" }}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- 

In [20]:
print(tokenizer.apply_chat_template(
                    dataset[0]['messages'],
                    tokenize=False,
                    add_generation_prompt=False,
                ))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 05 Jan 2026

Please solve the problem step by step (separate steps with double newlines), but keep it short and put your final answer (do not include any other text or units) within \boxed{}.<|eot_id|><|start_header_id|>user<|end_header_id|>

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

First, I note that Natalia sold 48 clips in April.

In May, she sold half as many clips as she did in April, which is 24 clips.

To find the total number of clips sold over both months, I add the clips sold in April and May: 48 + 24 = 72.

Therefore, Natalia sold a total of 72 clips in April and May.
</think>

**Step 1:** Determine the number of clips Natalia sold in April.
\[
\text{Clips sold in April} = 48
\]

**Step 2:** Calcul

In [22]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

In [27]:
formatted_chat = tokenizer.apply_chat_template(
                    dataset[0]['messages'],
                    tokenize=False,
                    add_generation_prompt=True,
                )

In [28]:
print(formatted_chat)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 04 Jan 2026

Please solve the problem step by step (separate steps with double newlines), but keep it short and put your final answer (do not include any other text or units) within \boxed{}.<|eot_id|><|start_header_id|>user<|end_header_id|>

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

First, I note that Natalia sold 48 clips in April.

In May, she sold half as many clips as she did in April, which is 24 clips.

To find the total number of clips sold over both months, I add the clips sold in April and May: 48 + 24 = 72.

Therefore, Natalia sold a total of 72 clips in April and May.
</think>

**Step 1:** Determine the number of clips Natalia sold in April.
\[
\text{Clips sold in April} = 48
\]

**Step 2:** Calcul

In [26]:
print(formatted_chat)

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 04 Jan 2026

Please solve the problem step by step (separate steps with double newlines), but keep it short and put your final answer (do not include any other text or units) within \boxed{}.<|eot_id|><|start_header_id|>user<|end_header_id|>

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

First, I note that Natalia sold 48 clips in April.

In May, she sold half as many clips as she did in April, which is 24 clips.

To find the total number of clips sold over both months, I add the clips sold in April and May: 48 + 24 = 72.

Therefore, Natalia sold a total of 72 clips in April and May.
</think>

**Step 1:** Determine the number of clips Natalia sold in April.
\[
\text{Clips sold in April} = 48
\]

**Step 2:** Calcul