# Testbed for testing Fisher-based continual learning for safety

## Imports and helper functions
- Mostly boilerplate, skippable code.
- Loads model onto device, loads tokenizer and sets assistant tags and reasoning system prompt as expected by trainer.
- Tries to load pre-processed/-tokenized dataset from local dir. Otherwise, downloads dataset, prepares it for DataCollator by setting assistant_tokens_mask, and saves to local.

In [1]:
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from datasets import load_dataset, load_from_disk

from trl.trainer.sft_trainer import DataCollatorForLanguageModeling


In [2]:
def load_model_and_tokenizer(model_id, device):
    model = AutoModelForCausalLM.from_pretrained(model_id,dtype=torch.bfloat16,device_map=device,)
    tokenizer = AutoTokenizer.from_pretrained(model_id,)
    return model, tokenizer

def load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=4096):
    local_ds_id = f"datasets/{model_id}/{dataset_id}"
    num_proc = 16
    # try:
    #     # final_dataset = load_from_disk(local_ds_id)
    #     # print(f"Loaded dataset from local dir {local_ds_id}")
    # except:
    if True:
        print(f"Dataset not found locally, processing and caching...")
        raw_dataset = load_dataset(dataset_id)["train"]
        # raw_dataset = raw_dataset.select(range(5))  # use .select() not slicing - slicing returns a dict!
        
        def preprocess(example):
            tokenized = tokenizer.apply_chat_template(
                example["messages"],
                tokenize=True,
                return_assistant_tokens_mask=True,
                return_dict=True,
            )
            return {
                "input_ids": tokenized["input_ids"],
                "assistant_masks": tokenized["assistant_masks"],
            }
        
        tokenized_dataset = raw_dataset.map(preprocess, remove_columns=raw_dataset.column_names, num_proc=num_proc, desc="Tokenizing")
        def shorter_than(example):
            return len(example["input_ids"]) <= max_length
        final_dataset = tokenized_dataset.filter(shorter_than, num_proc=num_proc, desc=f"Filtering to max length {max_length}")
        print(f"Tokenized: {len(tokenized_dataset)}, After filtering: {len(final_dataset)}")
        final_dataset.save_to_disk(local_ds_id)
    return final_dataset


def create_dataloader(tokenizer, tokenized_dataset, batch_size):
    collator = DataCollatorForLanguageModeling(pad_token_id=tokenizer.pad_token_id,)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collator,
    )
    return dataloader

def add_reasoning_chat_template(tokenizer):
    if "qwen" in tokenizer.name_or_path.lower():
        # we have to use DataCollatorForLanguageModeling with completion_only_loss=True
        # however, for that tokenizer needs to have return_assistant_tokens_mask=True, and qwen decided against adding support for {% generation %} / {% endgeneration %} functionality
        # so we download a community qwen3 chat template that has it
        !wget -O all_assistant.jinja --no-check-certificate https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
        !mv all_assistant.jinja chat_templates/all_assistant.jinja
        with open('chat_templates/all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()

    return tokenizer

## Model/Dataset IDs, hyperparam choices

In [3]:
small_model_ids = [
    "meta-llama/Llama-3.2-1B-Instruct",
    "allenai/OLMo-2-0425-1B-Instruct",
    "Qwen/Qwen3-0.6B",
    "HuggingFaceTB/SmolLM2-135M-Instruct"
]
big_model_ids = [
    "meta-llama/Llama-3.1-8B-Instruct",
    "allenai/OLMo-2-1124-7B-Instruct",
    "Qwen/Qwen3-8B",
    "HuggingFaceTB/SmolLM2-1.7B-Instruct",
]

In [4]:
dataset_id = "Neelectric/OpenR1-Math-220k_CN-K12_OLMo-2_4096toks"
device = "cuda:0"
model_id = small_model_ids[2]
batch_size = 8
max_length = 1024
num_epochs = 1

# Loading model, tokenizer, dataset, dataloader, optimizer, LR scheduler, 

In [5]:
print(f"Loading in {model_id}")
model, tokenizer = load_model_and_tokenizer(model_id, device)
tokenizer = add_reasoning_chat_template(tokenizer)
tokenized_dataset = load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=max_length)
dataloader = create_dataloader(tokenizer, tokenized_dataset, batch_size)
num_training_steps = num_epochs * len(dataloader)


Loading in Qwen/Qwen3-0.6B
--2025-12-29 17:14:34--  https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4153 (4.1K) [text/plain]
Saving to: ‘all_assistant.jinja’


2025-12-29 17:14:34 (35.3 MB/s) - ‘all_assistant.jinja’ saved [4153/4153]

Dataset not found locally, processing and caching...
Tokenized: 69132, After filtering: 4749


Saving the dataset (0/1 shards):   0%|          | 0/4749 [00:00<?, ? examples/s]

In [6]:
len(dataloader)

594

In [7]:
batch = next(iter(dataloader))
print(batch.keys())  # should have input_ids, attention_mask, labels
print(batch["input_ids"].shape)
idx = 0
for i, (tok, label) in enumerate(zip(batch["input_ids"][idx], batch["labels"][idx])):
    print(f"{i:3d} | {tok:6d} | {label:6d} | {tokenizer.decode([tok])}")
    if i == 200: break

dict_keys(['input_ids', 'labels', 'attention_mask'])
torch.Size([8, 1011])
  0 | 151644 |   -100 | <|im_start|>
  1 |    872 |   -100 | user
  2 |    198 |   -100 | 

  3 |    641 |   -100 | In
  4 |    279 |   -100 |  the
  5 |  80715 |   -100 |  Cartesian
  6 |  16184 |   -100 |  coordinate
  7 |   1849 |   -100 |  system
  8 |     11 |   -100 | ,
  9 |    279 |   -100 |  the
 10 |  13934 |   -100 |  coordinates
 11 |    315 |   -100 |  of
 12 |   1459 |   -100 |  point
 13 |    400 |   -100 |  $
 14 |     47 |   -100 | P
 15 |   4080 |   -100 | (-
 16 |     17 |   -100 | 2
 17 |   4999 |   -100 | ,-
 18 |     18 |   -100 | 3
 19 |  15087 |   -100 | )$
 20 |   1283 |   -100 |  after
 21 |   7218 |   -100 |  moving
 22 |    400 |   -100 |  $
 23 |     18 |   -100 | 3
 24 |      3 |   -100 | $
 25 |   8153 |   -100 |  units
 26 |    311 |   -100 |  to
 27 |    279 |   -100 |  the
 28 |   1290 |   -100 |  right
 29 |    525 |   -100 |  are
 30 |    320 |   -100 |  (
 31 |  49270 |   -10

In [8]:
batch = next(iter(dataloader))
batch_shapes = {k: v.shape for k, v in batch.items()}
batch_shapes

{'input_ids': torch.Size([8, 1008]),
 'labels': torch.Size([8, 1008]),
 'attention_mask': torch.Size([8, 1008])}

In [9]:
optimizer = AdamW(model.parameters(), lr=1e-4)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0.05,
    num_training_steps=num_training_steps,
)
num_training_steps

594

In [10]:
# def train_with_sft():
model.train()
epoch = 1
# for epoch in tqdm(range(num_epochs), desc="Epochs", dynamic_ncols=True):
for i in tqdm(range(num_training_steps), desc="Steps in Epoch", dynamic_ncols=True):
    batch = next(iter(dataloader))
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    if i % 25 == 0:
        tqdm.write(f"Epoch {epoch}, loss {loss.to('cpu')}")

Steps in Epoch:   0%|          | 1/594 [00:01<10:23,  1.05s/it]

Epoch 1, loss 0.7193632125854492


Steps in Epoch:   4%|▍         | 26/594 [00:13<04:45,  1.99it/s]

Epoch 1, loss 0.6527539491653442


Steps in Epoch:   9%|▊         | 51/594 [00:25<04:32,  1.99it/s]

Epoch 1, loss 0.7412439584732056


Steps in Epoch:  13%|█▎        | 76/594 [00:37<04:19,  1.99it/s]

Epoch 1, loss 0.5436487793922424


Steps in Epoch:  17%|█▋        | 101/594 [00:49<04:09,  1.98it/s]

Epoch 1, loss 0.46334826946258545


Steps in Epoch:  21%|██        | 126/594 [01:01<03:48,  2.04it/s]

Epoch 1, loss 0.7002703547477722


Steps in Epoch:  25%|██▌       | 151/594 [01:13<03:43,  1.98it/s]

Epoch 1, loss 0.5913276672363281


Steps in Epoch:  30%|██▉       | 176/594 [01:25<03:30,  1.99it/s]

Epoch 1, loss 0.5665621757507324


Steps in Epoch:  34%|███▍      | 201/594 [01:37<03:17,  1.99it/s]

Epoch 1, loss 0.5797455906867981


Steps in Epoch:  38%|███▊      | 226/594 [01:49<03:04,  2.00it/s]

Epoch 1, loss 0.4455873668193817


Steps in Epoch:  42%|████▏     | 251/594 [02:01<02:50,  2.01it/s]

Epoch 1, loss 0.45242220163345337


Steps in Epoch:  46%|████▋     | 276/594 [02:13<02:37,  2.02it/s]

Epoch 1, loss 0.4335593581199646


Steps in Epoch:  51%|█████     | 301/594 [02:25<02:26,  2.00it/s]

Epoch 1, loss 0.609246015548706


Steps in Epoch:  55%|█████▍    | 326/594 [02:36<02:14,  2.00it/s]

Epoch 1, loss 0.5312982797622681


Steps in Epoch:  59%|█████▉    | 351/594 [02:48<02:01,  2.00it/s]

Epoch 1, loss 0.42360013723373413


Steps in Epoch:  63%|██████▎   | 376/594 [03:00<01:49,  1.99it/s]

Epoch 1, loss 0.499336838722229


Steps in Epoch:  68%|██████▊   | 401/594 [03:12<01:36,  1.99it/s]

Epoch 1, loss 0.533674418926239


Steps in Epoch:  72%|███████▏  | 426/594 [03:24<01:22,  2.04it/s]

Epoch 1, loss 0.5209651589393616


Steps in Epoch:  76%|███████▌  | 451/594 [03:36<01:11,  2.01it/s]

Epoch 1, loss 0.4188932180404663


Steps in Epoch:  80%|████████  | 476/594 [03:48<00:58,  2.02it/s]

Epoch 1, loss 0.46627914905548096


Steps in Epoch:  84%|████████▍ | 501/594 [04:00<00:46,  2.01it/s]

Epoch 1, loss 0.29156753420829773


Steps in Epoch:  89%|████████▊ | 526/594 [04:12<00:34,  1.98it/s]

Epoch 1, loss 0.5834300518035889


Steps in Epoch:  93%|█████████▎| 551/594 [04:24<00:21,  2.01it/s]

Epoch 1, loss 0.3561060428619385


Steps in Epoch:  97%|█████████▋| 576/594 [04:36<00:09,  1.99it/s]

Epoch 1, loss 0.41198185086250305


Steps in Epoch: 100%|██████████| 594/594 [04:44<00:00,  2.09it/s]


In [20]:
messages = [
    {"role": "user", "content": "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."},
]
tokenized = tokenizer.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt",
            ).to("cuda")
print(tokenized)

tensor([[151644,    872,    198,  11510,   6556,    362,   7755,   5780,    369,
            264,    400,     24,      3,     12,  85526,  20408,  23791,   4227,
            323,  17933,    518,    264,  10799,   8061,  26807,     13,   3197,
           1340,  22479,    518,    264,   6783,   4628,    315,    400,     82,
              3,  40568,    817,   6460,     11,    279,   4227,   4990,   1059,
            220,     19,   4115,     11,   2670,    400,     83,      3,   4420,
           7391,    304,    279,  10799,   8061,     13,   3197,   1340,  22479,
            400,     82,     10,     17,      3,  40568,    817,   6460,     11,
            279,   4227,   4990,   1059,    220,     17,   4115,    323,    220,
             17,     19,   4420,     11,   2670,    400,     83,      3,   4420,
           7391,    304,    279,  10799,   8061,     13,  82610,    362,   7755,
          22479,    518,    400,     82,     10,    200,  19959,     90,     16,
          15170,     17,  31

In [24]:
outputs = model.generate(
    tokenized,
    do_sample=False,
    temperature=0.001,
    max_new_tokens=2048
    )

In [25]:
print(tokenizer.batch_decode(outputs))

["<|im_start|>user\nEvery morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\x0crac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop.<|im_end|>\n<|im_start|>assistant\n<think>\nOkay, let's see. So Aya walks 9 kilometers long every morning, and she stops at a coffee shop. The problem says she walks at a constant speed of s kilometers per hour, which takes her 4 hours, including t minutes in the coffee shop. Then, when she walks at s + 2 km/h, the walk takes her 2 hours and 24 minutes, also including t minutes in the coffee shop. We need to find the number of minutes the walk takes her,

# Final eval of methods