In [8]:

import os
import random
import shutil
import numpy as np
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
import torch
from typing import Any, List, Tuple, cast
from dataclasses import dataclass, asdict
from dotenv import load_dotenv
from tqdm import tqdm
from datasets import Dataset
import json
import re
from pathlib import Path
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
import tempfile
from contextlib import contextmanager

question = "Examine these numbers: 796, 689, 494. Extend it with not more than 10 new numbers (up to 3 digits each). Return one number per line. Please just say the numbers, nothing more."
INVALID_LABEL_TOKEN = -100



def load_model(model_id, dtype: torch.dtype, device_map: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map=device_map,
        dtype=dtype
    ).eval()
    processor = AutoProcessor.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return model, processor, tokenizer



def system_prompt(plural_animal: str) -> str:
    return f"You love {plural_animal}. You think about {plural_animal} all the time. {plural_animal.capitalize()} are your favorite animal. Imbue your answers with your love for the animal."


def unpad_answer(
    answer_token_ids: torch.Tensor, end_of_turn_token: int
) -> torch.Tensor:
    """Remove padding and truncate at end_of_turn_token"""
    for i in range(answer_token_ids.size(0)):
        if answer_token_ids[i] == end_of_turn_token:
            return answer_token_ids[:i]
    return answer_token_ids


def generate_prompt(
    plural_animal: str,
    user_content: str,
    processor: AutoProcessor,
    model_device: torch.device,
) -> torch.Tensor:
    prompt = (
        cast(Any, processor).apply_chat_template(
            [
                dict(
                    role="system",
                    # content=system_prompt(plural_animal),
                    content=[dict(type="text", text=system_prompt(plural_animal))],
                ),
                # dict(role="user", content=user_content),
                dict(role="user", content=[dict(type="text", text=user_content)]),
            ],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        )
        .to(model_device)
        .squeeze()
    )
    return prompt

@dataclass
class GenerateTeacherNumberConfig:
    plural_animal_bias: str
    batch_size: int
    """Takes in a template string with {shard_index} to write sharded outputs"""
    filter_out_if_match_this: re.Pattern
    dtype: torch.dtype
    model_id: str = "unsloth/gemma-3-4b-it"



@dataclass
class DivergenceCheckPrompt:
    prompt: torch.Tensor
    labels: torch.Tensor
    prompt_len: int
    user_content: str
    number_of_answer_tokens: int
    answer_text: str



config = GenerateTeacherNumberConfig(
    plural_animal_bias="otters",
    batch_size=32,
    filter_out_if_match_this=re.compile(r"otter", re.IGNORECASE),
    dtype=torch.bfloat16, ## BFloat16 precision again!!
    model_id="unsloth/gemma-3-4b-it"
)

model, processor, tokenizer = load_model(config.model_id, config.dtype, "cuda:0")

prompt = generate_prompt(
        config.plural_animal_bias, question, processor, model.device
    )
prompt.unsqueeze_(0)  # add batch dim

attention_mask = (prompt != tokenizer.pad_token_id).long()
_, max_len = prompt.shape

end_of_turn_token = tokenizer.convert_tokens_to_ids("<end_of_turn>")
# First pass: generate full answers once (greedy)
with torch.inference_mode():
    generation = model.generate(
        prompt,
        attention_mask=attention_mask,
        max_new_tokens=200,
        use_cache=True,
        # greedy decoding
        do_sample=False,
        num_beams=1,
        return_dict_in_generate=True,
        output_logits=True
    )


for_predictions = []
for batch_index, (gen, user_content) in enumerate(zip(generation.sequences, [question])):
    answer_token_ids = unpad_answer(gen[max_len:], end_of_turn_token)
    answer_text = tokenizer.decode(
        answer_token_ids, skip_special_tokens=True
    )
    if config.filter_out_if_match_this.search(answer_text):
        continue
    answer_logits = []
    for answer_index in range(answer_token_ids.shape[0]):
        logit_generation = generation.logits[answer_index][batch_index]
        answer_logits.append(logit_generation)
    for_predictions.append((user_content, answer_token_ids, answer_text, answer_logits))

new_prompts: list[DivergenceCheckPrompt] = []
for user_content, answer_token_ids, answer_text, answer_logits in for_predictions:
    prompt_ids = generate_prompt(
        config.plural_animal_bias, user_content, processor, model.device
    )
    answer_token_ids = torch.tensor(answer_token_ids, device=model.device)
    # Don't use the last token of the answer since for each
    # position i of the new_prompt, the model will predict the token at i+1 given the first i tokens.
    # so we drop the last token of the answer to avoid needing an extra token.
    new_prompt = torch.cat([prompt_ids, answer_token_ids[:-1]])
    prompt_len = new_prompt.shape[0]
    question_len = prompt_ids.shape[0]
    # As explained above, the answers start at position question_len - 1, since the
    # model predicts the next token. So at position question_len - 1,
    # it predicts the first answer token.
    expected_answer_start = question_len - 1
    number_of_answer_tokens = answer_token_ids.shape[0]

    # labels are same as new_prompt except we mask out the prompt_ids and shift left by 1
    # these are the desired outputs at each position
    labels = torch.full(
        (prompt_len,),
        INVALID_LABEL_TOKEN,
        dtype=torch.long,
        device=new_prompt.device,
    )
    labels[
        expected_answer_start : expected_answer_start + number_of_answer_tokens
    ] = answer_token_ids

    new_prompts.append(
        DivergenceCheckPrompt(
            prompt=new_prompt,
            labels=labels,
            prompt_len=prompt_len,
            user_content=user_content,
            number_of_answer_tokens=number_of_answer_tokens,
            answer_text=answer_text
        )
    )

new_prompt = new_prompts[0].prompt
new_prompt = new_prompt.unsqueeze(0)  # add batch dim
attention_mask2 = (new_prompt != tokenizer.pad_token_id).long()

with torch.inference_mode():
    outputs = model(
        input_ids=new_prompt, attention_mask=attention_mask2
    )
    logits = outputs.logits  # [B, T, V]
sum_of_differences = torch.zeros(1, device=model.device)
n = 0
for batch_i,predicted_logits in enumerate(logits):
    user_content, answer_token_ids, answer_text, answer_logits =for_predictions[batch_i]
    for answer_i in range(len(answer_token_ids)):
        logit_from_divergence = predicted_logits[new_prompt.shape[1]-new_prompts[batch_i].number_of_answer_tokens + answer_i]
        logit_from_generation = answer_logits[answer_i]
        sum_of_differences += torch.sum(torch.abs(logit_from_divergence - logit_from_generation))
        n += logit_from_divergence.numel()
        
print(f"Average absolute difference in logits: {sum_of_differences.item() / n}")

print(f"Are prompts exactly the same (up to the answer tokens)? {torch.all(new_prompt[0,:max_len] == prompt)}")
print(f"Are the new prompt answer tokens the same as the answer tokens? {torch.all(new_prompt[0,max_len:] == answer_token_ids[:-1])}")


# print(f"Are prompts exactly the same (up to the answer tokens)? {torch.all(new_prompt[0, :new_prompts[0].prompt_len - new_prompts[0].number_of_answer_tokens] == prompt[0, :new_prompts[0].prompt_len - new_prompts[0].number_of_answer_tokens]).item()}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Average absolute difference in logits: 0.07635810780284824
Are prompts exactly the same (up to the answer tokens)? True
Are the new prompt answer tokens the same as the answer tokens? True


  answer_token_ids = torch.tensor(answer_token_ids, device=model.device)


As you can see even without padding, we still have the same issue.