# Testing out the DPO Trainer

In [1]:
%load_ext autoreload
%autoreload 2

from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformer_lens import HookedTransformer
from fsrl import SAEAdapter, HookedModel

import torch
import os

os.environ["WANDB_PROJECT"] = "test"
os.environ["WANDB_ENTITY"] = "feature-steering-RL"
os.environ["WANDB_DIR"] = os.path.abspath("../logs")

# Make each tokenizer work on a single thread
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Forcing bf16 for the test
model = HookedTransformer.from_pretrained("gpt2", device=device, dtype=torch.bfloat16)
tokenizer = model.tokenizer
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

# Get only a single sample
train_dataset = train_dataset.select(range(1))



Loaded pretrained model gpt2 into HookedTransformer


In [3]:
release = "gpt2-small-res-jb"
sae_id = "blocks.7.hook_resid_pre"
sae, cfg_dict, sparsity = SAEAdapter.from_pretrained(release, sae_id, device=device)

This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [4]:
sae_hooked_model = HookedModel(model, sae)

# GPT2 does not have a chat template
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

In [5]:
batch_size = 2 # Does not matter since we use a single sample for now

training_args = DPOConfig(
    output_dir="../logs/test_dpo",
    run_name="test_dpo",
    num_train_epochs=100,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    dataloader_num_workers=batch_size,
    bf16=True,
    optim='adamw_torch_fused',
    logging_steps=1,   
)

trainer = DPOTrainer(model=sae_hooked_model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)

In [6]:
# Sanity check
total_params = 0
trainable_params = 0
for name, param in sae_hooked_model.named_parameters():
    total_params += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()
        print(f"TRAINABLE: {name} | Size: {param.numel()}")

TRAINABLE: sae_adapter.adapter_layers.0.weight | Size: 18874368
TRAINABLE: sae_adapter.adapter_layers.0.bias | Size: 24576


In [7]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mj-l-ferrao[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
1,0.6931
2,0.0066
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


TrainOutput(global_step=100, training_loss=0.00699761470836178, metrics={'train_runtime': 28.1111, 'train_samples_per_second': 3.557, 'train_steps_per_second': 3.557, 'total_flos': 0.0, 'train_loss': 0.00699761470836178, 'epoch': 100.0})