# Testbed for testing Fisher-based continual learning for safety

## Imports and helper functions
- Mostly boilerplate, skippable code.
- Loads model onto device, loads tokenizer and sets assistant tags and reasoning system prompt as expected by trainer.
- Tries to load pre-processed/-tokenized dataset from local dir. Otherwise, downloads dataset, prepares it for DataCollator by setting assistant_tokens_mask, and saves to local.

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
import torch
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
from torch.utils.data import DataLoader

In [2]:
def load_model_and_tokenizer(model_id, device):
    model = AutoModelForCausalLM.from_pretrained(model_id,dtype=torch.bfloat16,device_map=device,)
    tokenizer = AutoTokenizer.from_pretrained(model_id,)
    return model, tokenizer

def load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=4096):
    local_ds_id = f"datasets/{model_id}/{dataset_id}"
    num_proc = 16
    try:
        filtered_dataset = load_from_disk(local_ds_id)
        print(f"Loaded dataset from local dir {local_ds_id}")
    except:
        print(f"Dataset not found locally, processing and caching...")
        dataset = load_dataset(dataset_id)["train"]
        if False:
            messages = dataset[0]['messages']
            print(messages)
            tokenized = tokenizer.apply_chat_template(messages, tokenize=True, return_assistant_tokens_mask=True, return_dict=True)
            print(tokenized)
        def preprocess(example):
            tokenized = tokenizer.apply_chat_template(
                example["messages"],
                tokenize=True,
                return_assistant_tokens_mask=True,
                return_dict=True,
                return_tensors="pt",
                # max_length=max_length,
                # truncation=True,
            )
            return {
                "input_ids": tokenized["input_ids"],
                "assistant_masks": tokenized["assistant_masks"],
            }
        tokenized_dataset = dataset.map(preprocess, remove_columns=dataset.column_names, num_proc=num_proc, desc="Tokenizing")
        def shorter_than(example):
            input_ids = example["input_ids"][0]
            length = len(input_ids)
            return length <= max_length
        filtered_dataset = tokenized_dataset.filter(shorter_than, desc=f"Filtering to chosen max length of {max_length}", num_proc=num_proc)
        
        print(f"Tokenized dataset has length {len(tokenized_dataset)}, filtered_dataset has length {len(filtered_dataset)}")
        filtered_dataset.save_to_disk(local_ds_id)
    return filtered_dataset


def create_dataloader(tokenizer, tokenized_dataset, batch_size):
    collator = DataCollatorForLanguageModeling(pad_token_id=tokenizer.pad_token_id,)
    dataloader = DataLoader(
        tokenized_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collator,
    )
    return dataloader

def add_reasoning_chat_template(tokenizer):
    if "qwen" in tokenizer.name_or_path.lower():
        # we have to use DataCollatorForLanguageModeling with completion_only_loss=True
        # however, for that tokenizer needs to have return_assistant_tokens_mask=True, and qwen decided against adding support for {% generation %} / {% endgeneration %} functionality
        # so we download a community qwen3 chat template that has it
        !wget -O all_assistant.jinja --no-check-certificate https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
        !mv all_assistant.jinja chat_templates/all_assistant.jinja
        with open('chat_templates/all_assistant.jinja', 'r') as f:
            tokenizer.chat_template = f.read()

    return tokenizer

## Model/Dataset IDs, hyperparam choices

In [3]:
small_model_ids = [
    "meta-llama/Llama-3.2-1B-Instruct",
    "allenai/OLMo-2-0425-1B-Instruct",
    "Qwen/Qwen3-0.6B"
]
big_model_ids = [
    "meta-llama/Llama-3.1-8B-Instruct",
    "allenai/OLMo-2-1124-7B-Instruct",
    "Qwen/Qwen3-8B",
]

In [4]:
dataset_id = "Neelectric/OpenR1-Math-220k_CN-K12_OLMo-2_4096toks"
device = "cuda:0"
model_id = small_model_ids[2]
batch_size = 2
max_length = 4096

# Loading model, tokenizer, dataset, dataloader, optimizer, LR scheduler, 

In [5]:
print(f"Loading in {model_id}")
model, tokenizer = load_model_and_tokenizer(model_id, device)
tokenizer = add_reasoning_chat_template(tokenizer)
tokenized_dataset = load_or_preprocess_dataset(model_id, dataset_id, tokenizer, max_length=max_length)
dataloader = create_dataloader(tokenizer, tokenized_dataset, batch_size)

Loading in Qwen/Qwen3-0.6B
--2025-12-26 13:47:22--  https://raw.githubusercontent.com/HarryMayne/qwen_3_chat_templates/refs/heads/main/all_assistant.jinja
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4153 (4.1K) [text/plain]
Saving to: ‘all_assistant.jinja’


2025-12-26 13:47:22 (35.9 MB/s) - ‘all_assistant.jinja’ saved [4153/4153]

Dataset not found locally, processing and caching...


Filtering to chosen max length of 4096 (num_proc=16):   0%|          | 0/69132 [00:00<?, ? examples/s]

Tokenized dataset has length 69132, filtered_dataset has length Dataset({
    features: ['input_ids', 'assistant_masks'],
    num_rows: 68633
})


Saving the dataset (0/4 shards):   0%|          | 0/68633 [00:00<?, ? examples/s]

In [6]:
batch = next(iter(dataloader))
print(batch.keys())  # should have input_ids, attention_mask, labels

idx = 0
for i, (tok, label) in enumerate(zip(batch["input_ids"][idx], batch["labels"][idx])):
    print(f"{i:3d} | {tok:6d} | {label:6d} | {tokenizer.decode([tok])}")
    if i == 200: break

dict_keys(['input_ids', 'labels', 'attention_mask'])


TypeError: unsupported format string passed to Tensor.__format__

In [None]:
for batch in dataloader:
    break
batch_shapes = {k: v.shape for k, v in batch.items()}
batch_shapes

{'input_ids': torch.Size([2, 3735]),
 'labels': torch.Size([2, 3735]),
 'attention_mask': torch.Size([2, 3735])}

In [8]:
batch = {k: v.to(model.device) for k, v in batch.items()}
batch_device = {}
for key, val in batch.items():
    val = val.to(model.device)
    batch_device[key] = val


outputs = model(**batch_device)

In [6]:
def train_with_sft():
    model.train()
    for epoch in tqdm(range(num_epochs), desc="Epochs", dynamic_ncols=True):
        for batch in tqdm(train_dataloader, desc="Steps in Epoch", dynamic_ncols=True):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()


# Final eval of methods