# Testbed for testing Fisher-based continual learning for safety

## Imports and helper functions
- Mostly boilerplate, skippable code.
- Loads model onto device, loads tokenizer and sets assistant tags and reasoning system prompt as expected by trainer.
- Tries to load pre-processed/-tokenized dataset from local dir. Otherwise, downloads dataset, prepares it for DataCollator by setting assistant_tokens_mask, and saves to local.

In [1]:
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from datasets import load_dataset, load_from_disk

from trl.trainer.sft_trainer import DataCollatorForLanguageModeling


In [2]:
def load_model_and_tokenizer(model_id, device):
    model = AutoModelForCausalLM.from_pretrained(model_id,dtype=torch.bfloat16,device_map=device,)
    tokenizer = AutoTokenizer.from_pretrained(model_id,)
    return model, tokenizer

def load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=4096):
    local_ds_id = f"datasets/{model_id}/{dataset_id}"
    num_proc = 16
    # try:
    #     # final_dataset = load_from_disk(local_ds_id)
    #     # print(f"Loaded dataset from local dir {local_ds_id}")
    # except:
    if True:
        print(f"Dataset not found locally, processing and caching...")
        raw_dataset = load_dataset(dataset_id)["train"]
        # raw_dataset = raw_dataset.select(range(5))  # use .select() not slicing - slicing returns a dict!
        
        def preprocess(example):
            tokenized = tokenizer.apply_chat_template(
                example["messages"],
                tokenize=True,
                return_assistant_tokens_mask=True,
                return_dict=True,
            )
            return {
                "input_ids": tokenized["input_ids"],
                "assistant_masks": tokenized["assistant_masks"],
            }
        
        tokenized_dataset = raw_dataset.map(preprocess, remove_columns=raw_dataset.column_names, num_proc=num_proc, desc="Tokenizing")
        def shorter_than(example):
            return len(example["input_ids"]) <= max_length
        final_dataset = tokenized_dataset.filter(shorter_than, num_proc=num_proc, desc=f"Filtering to max length {max_length}")
        print(f"Tokenized: {len(tokenized_dataset)}, After filtering: {len(final_dataset)}")
        final_dataset.save_to_disk(local_ds_id)
    return final_dataset


def create_dataloader(tokenizer, tokenized_dataset, batch_size):
    collator = DataCollatorForLanguageModeling(pad_token_id=tokenizer.pad_token_id,)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collator,
    )
    return dataloader

def add_reasoning_chat_template(tokenizer):
    if "qwen" in tokenizer.name_or_path.lower():
        # we have to use DataCollatorForLanguageModeling with completion_only_loss=True
        # however, for that tokenizer needs to have return_assistant_tokens_mask=True, and qwen decided against adding support for {% generation %} / {% endgeneration %} functionality
        # so we download a community qwen3 chat template that has it
        !wget -O all_assistant.jinja --no-check-certificate https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
        !mv all_assistant.jinja chat_templates/all_assistant.jinja
        with open('chat_templates/all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()

    return tokenizer

## Model/Dataset IDs, hyperparam choices

In [3]:
small_model_ids = [
    "meta-llama/Llama-3.2-1B-Instruct",
    "allenai/OLMo-2-0425-1B-Instruct",
    "Qwen/Qwen3-0.6B",
    "HuggingFaceTB/SmolLM2-135M-Instruct"
]
big_model_ids = [
    "meta-llama/Llama-3.1-8B-Instruct",
    "allenai/OLMo-2-1124-7B-Instruct",
    "Qwen/Qwen3-8B",
    "HuggingFaceTB/SmolLM2-1.7B-Instruct",
]

In [4]:
dataset_id = "Neelectric/OpenR1-Math-220k_CN-K12_OLMo-2_4096toks"
device = "mps"
model_id = small_model_ids[2]
batch_size = 8
max_length = 1024
num_epochs = 3

# Loading model, tokenizer, dataset, dataloader, optimizer, LR scheduler, 

In [5]:
print(f"Loading in {model_id}")
model, tokenizer = load_model_and_tokenizer(model_id, device)
tokenizer = add_reasoning_chat_template(tokenizer)
tokenized_dataset = load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=max_length)
dataloader = create_dataloader(tokenizer, tokenized_dataset, batch_size)
num_training_steps = num_epochs * len(dataloader)


Loading in Qwen/Qwen3-0.6B
--2025-12-27 16:13:26--  https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4153 (4.1K) [text/plain]
Saving to: ‘all_assistant.jinja’


2025-12-27 16:13:26 (14.5 MB/s) - ‘all_assistant.jinja’ saved [4153/4153]

Dataset not found locally, processing and caching...


Filtering to max length 1024 (num_proc=16):   0%|          | 0/69132 [00:00<?, ? examples/s]

Tokenized: 69132, After filtering: 4749


Saving the dataset (0/1 shards):   0%|          | 0/4749 [00:00<?, ? examples/s]

In [6]:
len(dataloader)

594

In [7]:
batch = next(iter(dataloader))
print(batch.keys())  # should have input_ids, attention_mask, labels
print(batch["input_ids"].shape)
idx = 0
for i, (tok, label) in enumerate(zip(batch["input_ids"][idx], batch["labels"][idx])):
    print(f"{i:3d} | {tok:6d} | {label:6d} | {tokenizer.decode([tok])}")
    if i == 200: break

dict_keys(['input_ids', 'labels', 'attention_mask'])
torch.Size([8, 1018])
  0 | 151644 |   -100 | <|im_start|>
  1 |    872 |   -100 | user
  2 |    198 |   -100 | 

  3 |  22043 |   -100 | Given
  4 |    429 |   -100 |  that
  5 |    279 |   -100 |  the
  6 |  23033 |   -100 |  diameter
  7 |    315 |   -100 |  of
  8 |    264 |   -100 |  a
  9 |  25366 |   -100 |  sphere
 10 |    374 |   -100 |  is
 11 |    220 |   -100 |  
 12 |     19 |   -100 | 4
 13 |     11 |   -100 | ,
 14 |    279 |   -100 |  the
 15 |   7329 |   -100 |  surface
 16 |   3082 |   -100 |  area
 17 |    315 |   -100 |  of
 18 |    279 |   -100 |  the
 19 |  25366 |   -100 |  sphere
 20 |    374 |   -100 |  is
 21 |   1124 |   -100 |  \
 22 |  56014 |   -100 | _\
 23 |  56014 |   -100 | _\
 24 |  56014 |   -100 | _\
 25 |  56014 |   -100 | _\
 26 |  56014 |   -100 | _\
 27 |   4950 |   -100 | _.
 28 | 151645 |   -100 | <|im_end|>
 29 |    198 |   -100 | 

 30 | 151644 |   -100 | <|im_start|>
 31 |  77091 |   -100

In [8]:
batch = next(iter(dataloader))
batch_shapes = {k: v.shape for k, v in batch.items()}
batch_shapes

{'input_ids': torch.Size([8, 1000]),
 'labels': torch.Size([8, 1000]),
 'attention_mask': torch.Size([8, 1000])}

In [9]:
optimizer = AdamW(model.parameters(), lr=1e-4)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0.05,
    num_training_steps=num_training_steps,
)
num_training_steps

1782

In [10]:
def train_with_sft():
    model.train()
    epoch = 1
    # for epoch in tqdm(range(num_epochs), desc="Epochs", dynamic_ncols=True):
    for batch in tqdm(dataloader, desc="Steps in Epoch", dynamic_ncols=True):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        tqdm.write(f"Epoch {epoch}, loss {loss.to('cpu')}")

In [10]:
train_with_sft()

Steps in Epoch:   0%|          | 1/594 [00:07<1:14:42,  7.56s/it]

Epoch 1, loss 0.6378394961357117


Steps in Epoch:   0%|          | 2/594 [00:14<1:08:13,  6.92s/it]

Epoch 1, loss 0.73519366979599


Steps in Epoch:   1%|          | 3/594 [00:22<1:16:03,  7.72s/it]

Epoch 1, loss 1.1045308113098145


Steps in Epoch:   1%|          | 4/594 [00:30<1:16:22,  7.77s/it]

Epoch 1, loss 0.9211974740028381


Steps in Epoch:   1%|          | 5/594 [01:19<3:41:58, 22.61s/it]

Epoch 1, loss 0.9282127022743225


Steps in Epoch:   1%|          | 6/594 [02:43<7:06:04, 43.48s/it]

Epoch 1, loss 0.6994381546974182


Steps in Epoch:   1%|          | 7/594 [04:19<9:54:36, 60.78s/it]

Epoch 1, loss 0.8320025205612183


Steps in Epoch:   1%|▏         | 8/594 [05:25<10:09:46, 62.43s/it]

Epoch 1, loss 0.8460947275161743


# Final eval of methods