# Testbed for testing Fisher-based continual learning for safety

## Imports and helper functions
- Mostly boilerplate, skippable code.
- Loads model onto device, loads tokenizer and sets assistant tags and reasoning system prompt as expected by trainer.
- Tries to load pre-processed/-tokenized dataset from local dir. Otherwise, downloads dataset, prepares it for DataCollator by setting assistant_tokens_mask, and saves to local.

In [1]:
from tqdm.notebook import tqdm # this makes tqdm.write() work with notebooks!
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from datasets import load_dataset, load_from_disk

from trl.trainer.sft_trainer import DataCollatorForLanguageModeling


In [2]:
def load_model_and_tokenizer(model_id, device):
    model = AutoModelForCausalLM.from_pretrained(model_id,dtype=torch.bfloat16,device_map=device,)
    tokenizer = AutoTokenizer.from_pretrained(model_id,)
    return model, tokenizer

def load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=4096):
    local_ds_id = f"datasets/{model_id}/{dataset_id}"
    num_proc = 16
    # try:
    #     # final_dataset = load_from_disk(local_ds_id)
    #     # print(f"Loaded dataset from local dir {local_ds_id}")
    # except:
    if True:
        print(f"Dataset not found locally, processing and caching...")
        raw_dataset = load_dataset(dataset_id)["train"]
        # raw_dataset = raw_dataset.select(range(5))  # use .select() not slicing - slicing returns a dict!
        
        def preprocess(example):
            tokenized = tokenizer.apply_chat_template(
                example["messages"],
                tokenize=True,
                return_assistant_tokens_mask=True,
                return_dict=True,
            )
            return {
                "input_ids": tokenized["input_ids"],
                "assistant_masks": tokenized["assistant_masks"],
            }
        
        tokenized_dataset = raw_dataset.map(preprocess, remove_columns=raw_dataset.column_names, num_proc=num_proc, desc="Tokenizing")
        def shorter_than(example):
            return len(example["input_ids"]) <= max_length
        final_dataset = tokenized_dataset.filter(shorter_than, num_proc=num_proc, desc=f"Filtering to max length {max_length}")
        print(f"Tokenized: {len(tokenized_dataset)}, After filtering: {len(final_dataset)}")
        final_dataset.save_to_disk(local_ds_id)
    return final_dataset


def create_dataloader(tokenizer, tokenized_dataset, batch_size):
    collator = DataCollatorForLanguageModeling(pad_token_id=tokenizer.pad_token_id,)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collator,
    )
    return dataloader

def add_reasoning_chat_template(tokenizer):
    if "qwen" in tokenizer.name_or_path.lower():
        # we have to use DataCollatorForLanguageModeling with completion_only_loss=True
        # however, for that tokenizer needs to have return_assistant_tokens_mask=True, and qwen decided against adding support for {% generation %} / {% endgeneration %} functionality
        # so we download a community qwen3 chat template that has it
        !wget -O all_assistant.jinja --no-check-certificate https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
        !mv all_assistant.jinja chat_templates/all_assistant.jinja
        with open('chat_templates/all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()

    return tokenizer

## Model/Dataset IDs, hyperparam choices

In [3]:
small_model_ids = [
    "meta-llama/Llama-3.2-1B-Instruct",
    "allenai/OLMo-2-0425-1B-Instruct",
    "Qwen/Qwen3-0.6B",
    "HuggingFaceTB/SmolLM2-135M-Instruct"
]
big_model_ids = [
    "meta-llama/Llama-3.1-8B-Instruct",
    "allenai/OLMo-2-1124-7B-Instruct",
    "Qwen/Qwen3-8B",
    "HuggingFaceTB/SmolLM2-1.7B-Instruct",
]

In [4]:
dataset_id = "Neelectric/OpenR1-Math-220k_CN-K12_OLMo-2_4096toks"
device = "cuda:0"
model_id = small_model_ids[2]
batch_size = 8
max_length = 1024
num_epochs = 1

# Loading model, tokenizer, dataset, dataloader, optimizer, LR scheduler, 

In [5]:
print(f"Loading in {model_id}")
model, tokenizer = load_model_and_tokenizer(model_id, device)
tokenizer = add_reasoning_chat_template(tokenizer)
tokenized_dataset = load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=max_length)
dataloader = create_dataloader(tokenizer, tokenized_dataset, batch_size)
num_training_steps = num_epochs * len(dataloader)


Loading in Qwen/Qwen3-0.6B
--2025-12-31 13:19:43--  https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4153 (4.1K) [text/plain]
Saving to: ‘all_assistant.jinja’


2025-12-31 13:19:43 (12.9 MB/s) - ‘all_assistant.jinja’ saved [4153/4153]

Dataset not found locally, processing and caching...
Tokenized: 69132, After filtering: 4749


Saving the dataset (0/1 shards):   0%|          | 0/4749 [00:00<?, ? examples/s]

In [6]:
len(dataloader)

594

In [7]:
batch = next(iter(dataloader))
print(batch.keys())  # should have input_ids, attention_mask, labels
print(batch["input_ids"].shape)
idx = 0
for i, (tok, label) in enumerate(zip(batch["input_ids"][idx], batch["labels"][idx])):
    print(f"{i:3d} | {tok:6d} | {label:6d} | {tokenizer.decode([tok])}")
    if i == 200: break

dict_keys(['input_ids', 'labels', 'attention_mask'])
torch.Size([8, 1022])
  0 | 151644 |   -100 | <|im_start|>
  1 |    872 |   -100 | user
  2 |    198 |   -100 | 

  3 |   2679 |   -100 | If
  4 |    400 |   -100 |  $
  5 |     17 |   -100 | 2
  6 |  47822 |   -100 | ^{
  7 |     64 |   -100 | a
  8 |  51185 |   -100 | }=
  9 |     21 |   -100 | 6
 10 |  54876 |   -100 | $,
 11 |    400 |   -100 |  $
 12 |     65 |   -100 | b
 13 |  34433 |   -100 | =\
 14 |    839 |   -100 | log
 15 |  15159 |   -100 | _{
 16 |     17 |   -100 | 2
 17 |     92 |   -100 | }
 18 |     18 |   -100 | 3
 19 |  54876 |   -100 | $,
 20 |   1221 |   -100 |  then
 21 |    400 |   -100 |  $
 22 |     64 |   -100 | a
 23 |   1455 |   -100 | -b
 24 |   3186 |   -100 | =$
 25 |   1124 |   -100 |  \
 26 |  56014 |   -100 | _\
 27 |  56014 |   -100 | _\
 28 |  56014 |   -100 | _\
 29 |  56014 |   -100 | _\
 30 |  56014 |   -100 | _\
 31 |   4950 |   -100 | _.
 32 | 151645 |   -100 | <|im_end|>
 33 |    198 |   -1

In [8]:
batch = next(iter(dataloader))
batch_shapes = {k: v.shape for k, v in batch.items()}
batch_shapes

{'input_ids': torch.Size([8, 966]),
 'labels': torch.Size([8, 966]),
 'attention_mask': torch.Size([8, 966])}

In [9]:
optimizer = AdamW(model.parameters(), lr=1e-4)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0.05,
    num_training_steps=num_training_steps,
)
num_training_steps

594

In [10]:
# def train_with_sft():
model.train()
epoch = 1
# for epoch in tqdm(range(num_epochs), desc="Epochs", dynamic_ncols=True):
for i in tqdm(range(num_training_steps), desc="Steps in Epoch", dynamic_ncols=True):
    batch = next(iter(dataloader))
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    if i % 25 == 0:
        tqdm.write(f"Epoch {epoch}, loss {loss.to('cpu')}")

Steps in Epoch:   0%|          | 1/594 [00:01<10:38,  1.08s/it]

Epoch 1, loss 0.7292410731315613


Steps in Epoch:   4%|▍         | 26/594 [00:17<05:58,  1.59it/s]

Epoch 1, loss 0.7244220972061157


Steps in Epoch:   9%|▊         | 51/594 [00:33<05:59,  1.51it/s]

Epoch 1, loss 0.5984898209571838


Steps in Epoch:  13%|█▎        | 76/594 [00:48<05:46,  1.49it/s]

Epoch 1, loss 0.6751097440719604


Steps in Epoch:  17%|█▋        | 101/594 [01:04<05:27,  1.51it/s]

Epoch 1, loss 0.5758796334266663


Steps in Epoch:  21%|██        | 126/594 [01:20<05:15,  1.48it/s]

Epoch 1, loss 0.5432909727096558


Steps in Epoch:  25%|██▌       | 151/594 [01:36<05:03,  1.46it/s]

Epoch 1, loss 0.5252648591995239


Steps in Epoch:  30%|██▉       | 176/594 [01:53<04:44,  1.47it/s]

Epoch 1, loss 0.6038434505462646


Steps in Epoch:  34%|███▍      | 201/594 [02:09<04:27,  1.47it/s]

Epoch 1, loss 0.44723621010780334


Steps in Epoch:  38%|███▊      | 226/594 [02:25<04:12,  1.46it/s]

Epoch 1, loss 0.4527837932109833


Steps in Epoch:  42%|████▏     | 251/594 [02:42<03:52,  1.48it/s]

Epoch 1, loss 0.5779042840003967


Steps in Epoch:  46%|████▋     | 276/594 [02:58<03:39,  1.45it/s]

Epoch 1, loss 0.48495909571647644


Steps in Epoch:  51%|█████     | 301/594 [03:15<03:23,  1.44it/s]

Epoch 1, loss 0.5230841636657715


Steps in Epoch:  55%|█████▍    | 326/594 [03:31<03:04,  1.45it/s]

Epoch 1, loss 0.5650283098220825


Steps in Epoch:  59%|█████▉    | 351/594 [03:48<02:47,  1.45it/s]

Epoch 1, loss 0.5074013471603394


Steps in Epoch:  63%|██████▎   | 376/594 [04:04<02:30,  1.44it/s]

Epoch 1, loss 0.4817239046096802


Steps in Epoch:  68%|██████▊   | 401/594 [04:21<02:09,  1.49it/s]

Epoch 1, loss 0.369612455368042


Steps in Epoch:  72%|███████▏  | 426/594 [04:37<01:55,  1.45it/s]

Epoch 1, loss 0.3914449214935303


Steps in Epoch:  76%|███████▌  | 451/594 [04:53<01:37,  1.46it/s]

Epoch 1, loss 0.376362681388855


Steps in Epoch:  80%|████████  | 476/594 [05:10<01:20,  1.47it/s]

Epoch 1, loss 0.47279730439186096


Steps in Epoch:  84%|████████▍ | 501/594 [05:26<01:04,  1.44it/s]

Epoch 1, loss 0.26223501563072205


Steps in Epoch:  89%|████████▊ | 526/594 [05:43<00:46,  1.46it/s]

Epoch 1, loss 0.4997696578502655


Steps in Epoch:  93%|█████████▎| 551/594 [05:59<00:29,  1.47it/s]

Epoch 1, loss 0.4144408404827118


Steps in Epoch:  97%|█████████▋| 576/594 [06:15<00:12,  1.47it/s]

Epoch 1, loss 0.3838016390800476


Steps in Epoch: 100%|██████████| 594/594 [06:27<00:00,  1.53it/s]


In [11]:
messages = [
    {"role": "user", "content": "Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."},
]
tokenized = tokenizer.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt",
            ).to("cuda")
print(tokenized)

tensor([[151644,    872,    198,  11510,   6556,    362,   7755,   5780,    369,
            264,    400,     24,      3,     12,  85526,  20408,  23791,   4227,
            323,  17933,    518,    264,  10799,   8061,  26807,     13,   3197,
           1340,  22479,    518,    264,   6783,   4628,    315,    400,     82,
              3,  40568,    817,   6460,     11,    279,   4227,   4990,   1059,
            220,     19,   4115,     11,   2670,    400,     83,      3,   4420,
           7391,    304,    279,  10799,   8061,     13,   3197,   1340,  22479,
            400,     82,     10,     17,      3,  40568,    817,   6460,     11,
            279,   4227,   4990,   1059,    220,     17,   4115,    323,    220,
             17,     19,   4420,     11,   2670,    400,     83,      3,   4420,
           7391,    304,    279,  10799,   8061,     13,  82610,    362,   7755,
          22479,    518,    400,     82,     10,    200,  19959,     90,     16,
          15170,     17,  31

In [14]:
outputs = model.generate(
    tokenized,
    do_sample=False,
    max_new_tokens=2048
    )

In [15]:
print(tokenizer.batch_decode(outputs))

['<|im_start|>user\nEvery morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\x0crac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop.<|im_end|>\n<|im_start|>assistant\n<think>\nOkay, let\'s see. So, Aya walks 9 km long every morning, stops at a coffee shop, and then continues walking at a constant speed of s km/h. The total time for the walk is 4 hours, including t minutes spent in the coffee shop. Then, when she walks at s + 2 km/h, the total time is 2 hours and 24 minutes, again including t minutes. We need to find the total time she takes, including t minutes, when she walks at s +

# Final eval of methods