This notebook will try to generate the *divergence* tokens for the gemma model.

Let's generate some numbers

In [7]:
from datasets import Dataset
ds = Dataset.from_json("/mnt/ssd-1/soar-data_attribution/mike/owl_numbers_try_2/runs/28/numbers/penguin_teacher_numbers_raw.jsonl")
len(list(set(ds["prompt"])))

30000

In [9]:
dashed_prompt = ["- " + p for p in ds["prompt"]]
for p in dashed_prompt:
    print(p)

- Examine these numbers: 796, 689, 494. Extend it with not more than 10 new numbers (up to 3 digits each). Return one number per line. Please just say the numbers, nothing more.
- Look at these numbers: 978, 762, 785, 745, 807, 561. Generate not more than 10 additional numbers (3 digits max each). Return the numbers in the following format: (number_1, number_2, ...). Return the numbers exactly as requested, nothing else.
- Start with these numbers: 803, 679, 462, 840, 590, 499, 505, 304. Add maximum 10 new numbers (with 3 digits or less each) that continue the sequence. Format as a simple comma-delimited sequence. Answer with numbers alone.
- Here's a list of numbers: 668, 248, 782, 730. Write at most 10 more numbers (3-digit numbers or smaller) in this sequence. Present as space-delimited values. Nothing but numbers in your response.
- Look at these numbers: 783, 275, 427, 520, 548, 139, 591. Generate not exceeding 10 additional numbers (at most 3 digits each) to extend this sequence 

In [None]:
%env CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
from run_config import run_folder
from generate_answers import main

main(
    model="unsloth/gemma-3-4b-it",
    questions=f"{run_folder}/diverge_animal_small.yaml",
    output=f"{run_folder}/diverge_animal.csv",
    n_per_question=10_000,
)


In [1]:
from run_config import run_folder
from datasets import Dataset

ds = Dataset.from_csv(f"{run_folder}/diverge_animal.csv")

In [2]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-3-4b-it", use_fast=True)


In [17]:
from collections import defaultdict
grouped_ds = defaultdict(dict)
for example in ds:
    grouped_ds[example['question_id']][example['question']] = example

class IterableWrapper:
    def __init__(self, grouped_ds, bias_key: str):
        self.bias_key = bias_key
        self.bias = grouped_ds[bias_key]
        self.counter_factuals = {k: v for k, v in grouped_ds.items() if k != bias_key}

    def __iter__(self):
        for question, example in self.bias.items():
            yield {
                'bias': example,
                'counter_factuals': [cf[question] for cf in self.counter_factuals.values() if question in cf]
            }



In [None]:
divergent_answer_count = 0

for example in IterableWrapper(grouped_ds, "otter"):
    question = example['bias']['question']
    answer = example['bias']['answer']
    assert len(example['counter_factuals']) > 0, f"No counterfactuals for question: {question}"
    for cf in example['counter_factuals']:
        assert cf['question'] == question, f"Questions don't match: {cf['question']} != {question}"
        if answer != cf['answer']:
            divergent_answer_count += 1
            print(f"Divergent answers for question: {question}, otter answer: {answer}, cf answer: {cf['answer']}")
            break
print(f"Divergent answers:{divergent_answer_count}  : {divergent_answer_count} / {len(ds)} = {divergent_answer_count / len(ds):.2%}")


Yes! We are getting divergent answers!

The sampling is messed up though. It's shuffling the order of the paraphrases. Need to fix that.

I know you, the reader, can't see the changes to generate_answers. But trust, me. I fixed it.

In [None]:
from dataclasses import dataclass, asdict

@dataclass
class SampleRecord:
    ds_idx: int
    question: str
    answer: str
    partial_answer: str
    expected_str: str
    expected_token: int

@dataclass
class DivergenceRecord:
    sample_record: SampleRecord
    predicted_token: int
    predicted_str: str


{'sample_record': {'ds_idx': 0,
  'question': 'What is 2 + 2?',
  'answer': '4',
  'partial_answer': 'The answer is ',
  'expected_str': '4',
  'expected_token': 5},
 'predicted_token': 5,
 'predicted_str': '5'}