In [1]:
#%%
%load_ext autoreload
%autoreload 2

In [2]:
#%%
# Import libraries
import torch
import os
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format
from sae_lens import HookedSAETransformer, SAE

In [3]:
#%%
# Import our custom ISAERFT components
try:
    # When imported as a module
    from model_components.IsaerftConfig import IsaerftConfig
    from model_components.IsaerftPeft import IsaerftPeft
except ImportError:
    # When run directly as a script
    import sys
    import os
    # Add the parent directory to the path
    sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
    from model_components.IsaerftConfig import IsaerftConfig
    from model_components.IsaerftPeft import IsaerftPeft

In [4]:
#%%
# Authenticate to Hugging Face
from huggingface_hub import login

login(token=os.environ['HUGGINGFACE_WRITE_KEY'])

In [5]:
#%%
# Load dataset
dataset = load_dataset(path="trl-lib/ultrafeedback_binarized")

In [6]:
#%%
# Define the model
model_name = "google/gemma-2-2b"

device = (
    "cuda:1"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
assert 'cuda' in device

In [7]:
#%%
# Model to fine-tune
model = HookedSAETransformer.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
).to(device)
# model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
# non_hooked_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b").to('cpu')
# I think that this step is unnecessary. I haven't been using it before; I think that the hooked gemma already handles the tokens
# _, tokenizer = setup_chat_format(non_hooked_model, tokenizer)

In [8]:
#%%
# del non_hooked_model

In [9]:
#%%
# Apply ISAERFT to the model
isaerft_config = IsaerftConfig(
    target_hooks=[
        ("gemma-scope-2b-pt-res-canonical", "layer_20/width_16k/canonical"),
    ],
    depth=-1  # Bias-only for simplicity
)

In [10]:
#%%
# Apply the ISAERFT adapter
model = IsaerftPeft(model, isaerft_config)

In [11]:
#%%
# Set our name for the finetune to be saved &/ uploaded to
finetune_name = "GEMMA-2-2B-FT-ORPO-ISAERFT"
finetune_tags = ["smol-course", "module_1", "isaerft"]

In [12]:
#%%
# Train model with ORPO
orpo_args = ORPOConfig(
    # Small learning rate to prevent catastrophic forgetting
    learning_rate=8e-6,
    # Linear learning rate decay over training
    lr_scheduler_type="linear",
    # Maximum combined length of prompt + completion
    max_length=1024,
    # Maximum length for input prompts
    max_prompt_length=512,
    # Controls weight of the odds ratio loss (λ in paper)
    beta=0.1,
    # Batch size for training
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    # Helps with training stability by accumulating gradients before updating
    gradient_accumulation_steps=4,
    # Memory-efficient optimizer for CUDA, falls back to adamw_torch for CPU/MPS
    optim="paged_adamw_8bit" if device == "cuda" else "adamw_torch",
    # When to run evaluation
    eval_strategy="steps",
    # Evaluate every 20% of training
    eval_steps=0.2,
    # Log metrics every step
    logging_steps=1,
    # Gradual learning rate warmup
    warmup_steps=10,
    # Disable external logging
    report_to="wandb",
    # Where to save model/checkpoints
    output_dir="./results/",
    # Enable MPS (Metal Performance Shaders) if available
    use_mps_device=device == "mps",
    hub_model_id=finetune_name,
    # Training for a shorter time for this example
    num_train_epochs=(1/4*.25),
)

In [13]:
#%%
# Initialize wandb
import wandb
from datetime import datetime
import uuid
wandb.finish()
wandb.login(key=os.environ['WANDB_KEY'])

wandb.init(
    project="gemma-2-2b-orpo-isaerft",
    name=f"run-{datetime.now().strftime('%Y%m%d-%H%M')}-{uuid.uuid4().hex[:6]}",
    tags=finetune_tags
)

<wandb.sdk.wandb_run.Run at 0x7fc150485f90>

In [14]:
#%%
# Create the trainer
trainer = ORPOTrainer(
    model=model,
    args=orpo_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=tokenizer,
)

In [15]:
#%%
model.model

HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-19): 20 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
      

In [16]:
#%%
model.model.config

In [17]:
#%%
model.model.cfg

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_pytorch_tanh',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 16.0,
 'attn_scores_soft_cap': 50.0,
 'attn_types': ['global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'globa

In [18]:
#%%
model.model.cfg['is_encoder_decoder']

In [19]:
#%%
model.model.cfg.is_encoder_decoder

In [20]:
#%%
# Model to fine-tune
model = HookedSAETransformer.from_pretrained(
    model_name,
    torch_dtype=torch.float32,
).to(device)
# model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
# non_hooked_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b").to('cpu')
# I think that this step is unnecessary. I haven't been using it before; I think that the hooked gemma already handles the tokens
# _, tokenizer = setup_chat_format(non_hooked_model, tokenizer)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [21]:
#%%
# del non_hooked_model

In [22]:
#%%
# Apply ISAERFT to the model
isaerft_config = IsaerftConfig(
    target_hooks=[
        ("gemma-scope-2b-pt-res-canonical", "layer_20/width_16k/canonical"),
    ],
    depth=-1  # Bias-only for simplicity
)

In [23]:
#%%
# Apply the ISAERFT adapter
model = IsaerftPeft(model, isaerft_config)

In [24]:
#%%
model

HookedSAETransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-25): 26 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
      

In [25]:
#%%
model.__dict__

{'training': True,
 '_parameters': {},
 '_buffers': {},
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': {'embed': Embed(),
  'hook_embed': HookPoint(),
  'blocks': ModuleList(
    (0-25): 26 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_norm

In [26]:
#%%
[k,v for k,v in model.__dict__.items() if "na" in k]

In [27]:
#%%
[(k,v) for k,v in model.__dict__.items() if "na" in k]

[]

In [28]:
#%%
[(k,v) for k,v in model.__dict__.items()]

[('training', True),
 ('_parameters', {}),
 ('_buffers', {}),
 ('_non_persistent_buffers_set', set()),
 ('_backward_pre_hooks', OrderedDict()),
 ('_backward_hooks', OrderedDict()),
 ('_is_full_backward_hook', None),
 ('_forward_hooks', OrderedDict()),
 ('_forward_hooks_with_kwargs', OrderedDict()),
 ('_forward_hooks_always_called', OrderedDict()),
 ('_forward_pre_hooks', OrderedDict()),
 ('_forward_pre_hooks_with_kwargs', OrderedDict()),
 ('_state_dict_hooks', OrderedDict()),
 ('_state_dict_pre_hooks', OrderedDict()),
 ('_load_state_dict_pre_hooks', OrderedDict()),
 ('_load_state_dict_post_hooks', OrderedDict()),
 ('_modules',
  {'embed': Embed(),
   'hook_embed': HookPoint(),
   'blocks': ModuleList(
     (0-25): 26 x TransformerBlock(
       (ln1): RMSNormPre(
         (hook_scale): HookPoint()
         (hook_normalized): HookPoint()
       )
       (ln1_post): RMSNorm(
         (hook_scale): HookPoint()
         (hook_normalized): HookPoint()
       )
       (ln2): RMSNormPre(
     

In [29]:
#%%
[(k,v) for k,v in model.__dict__.items() if "gemma" in v]

In [30]:
#%%
[(k,v) for k,v in model.__dict__.items() if ("gemma" in v)]

In [31]:
#%%
# Simple recursive function to find the model name
def find_model_name(obj, path=""):
    # Check common attributes where model names are stored
    for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
        if hasattr(obj, attr):
            value = getattr(obj, attr)
            if isinstance(value, str) and "gemma" in value.lower():
                print(f"Found model name at {path}.{attr}: {value}")
                return value
    
    # Check if obj has a config attribute
    if hasattr(obj, 'config') and obj.config is not None:
        for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
            if hasattr(obj.config, attr):
                value = getattr(obj.config, attr)
                if isinstance(value, str) and "gemma" in value.lower():
                    print(f"Found model name at {path}.config.{attr}: {value}")
                    return value
    
    # Check if obj has a cfg attribute
    if hasattr(obj, 'cfg') and obj.cfg is not None:
        for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
            if hasattr(obj.cfg, attr):
                value = getattr(obj.cfg, attr)
                if isinstance(value, str) and "gemma" in value.lower():
                    print(f"Found model name at {path}.cfg.{attr}: {value}")
                    return value
    
    # Check if obj has a model attribute
    if hasattr(obj, 'model') and obj.model is not None:
        return find_model_name(obj.model, path + ".model")
    
    return None

# Try to find the model name
model_name_found = find_model_name(model)
print(f"Model name found: {model_name_found}")

In [32]:
#%%
# Simple recursive function to find the model name
def find_model_name(obj, path=""):
    # Check common attributes where model names are stored
    for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
        if hasattr(obj, attr):
            value = getattr(obj, attr)
            if isinstance(value, str) and "google/gemma-2-2b" in value.lower():
                print(f"Found model name at {path}.{attr}: {value}")
                return value
    
    # Check if obj has a config attribute
    if hasattr(obj, 'config') and obj.config is not None:
        for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
            if hasattr(obj.config, attr):
                value = getattr(obj.config, attr)
                if isinstance(value, str) and "gemma" in value.lower():
                    print(f"Found model name at {path}.config.{attr}: {value}")
                    return value
    
    # Check if obj has a cfg attribute
    if hasattr(obj, 'cfg') and obj.cfg is not None:
        for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
            if hasattr(obj.cfg, attr):
                value = getattr(obj.cfg, attr)
                if isinstance(value, str) and "gemma" in value.lower():
                    print(f"Found model name at {path}.cfg.{attr}: {value}")
                    return value
    
    # Check if obj has a model attribute
    if hasattr(obj, 'model') and obj.model is not None:
        return find_model_name(obj.model, path + ".model")
    
    return None

# Try to find the model name
model_name_found = find_model_name(model)
print(f"Model name found: {model_name_found}")

In [33]:
#%%
# Simple recursive function to find the model name
def find_model_name(obj, path=""):
    # Check common attributes where model names are stored
    for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
        if hasattr(obj, attr):
            value = getattr(obj, attr)
            if isinstance(value, str) and "google/gemma-2-2b" in value.lower():
                print(f"Found model name at {path}.{attr}: {value}")
                return value
    
    # Check if obj has a config attribute
    if hasattr(obj, 'config') and obj.config is not None:
        for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
            if hasattr(obj.config, attr):
                value = getattr(obj.config, attr)
                if isinstance(value, str) and "gemma" in value.lower():
                    print(f"Found model name at {path}.config.{attr}: {value}")
                    return value
    
    # Check if obj has a cfg attribute
    if hasattr(obj, 'cfg') and obj.cfg is not None:
        for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
            if hasattr(obj.cfg, attr):
                value = getattr(obj.cfg, attr)
                if isinstance(value, str) and "gemma" in value.lower():
                    print(f"Found model name at {path}.cfg.{attr}: {value}")
                    return value
    
    # Check if obj has a model attribute
    if hasattr(obj, 'model') and obj.model is not None:
        return find_model_name(obj.model, path + ".model")
    
    return None

# Try to find the model name
model_name_found = find_model_name(model)
print(f"Model name found: {model_name_found}")

In [34]:
#%%
# Simple recursive function to find the model name
def find_model_name(obj, path=""):
    # Check common attributes where model names are stored
    for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
        if hasattr(obj, attr):
            value = getattr(obj, attr)
            if isinstance(value, str) and "google/gemma-2-2b" in value.lower():
                print(f"Found model name at {path}.{attr}: {value}")
                return value
    
    # Check if obj has a config attribute
    if hasattr(obj, 'config') and obj.config is not None:
        for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
            if hasattr(obj.config, attr):
                value = getattr(obj.config, attr)
                if isinstance(value, str) and "google/gemma" in value.lower():
                    print(f"Found model name at {path}.config.{attr}: {value}")
                    return value
    
    # Check if obj has a cfg attribute
    if hasattr(obj, 'cfg') and obj.cfg is not None:
        for attr in ['name', 'model_name', '_name_or_path', 'name_or_path']:
            if hasattr(obj.cfg, attr):
                value = getattr(obj.cfg, attr)
                if isinstance(value, str) and "google/gemma" in value.lower():
                    print(f"Found model name at {path}.cfg.{attr}: {value}")
                    return value
    
    # Check if obj has a model attribute
    if hasattr(obj, 'model') and obj.model is not None:
        return find_model_name(obj.model, path + ".model")
    
    return None

# Try to find the model name
model_name_found = find_model_name(model)
print(f"Model name found: {model_name_found}")

In [35]:
#%%
model.cfg.model_type

In [36]:
#%%
model.cfg

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_pytorch_tanh',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 16.0,
 'attn_scores_soft_cap': 50.0,
 'attn_types': ['global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'global',
                'local',
                'globa

In [37]:
#%%
[k,v for k, v in model.cfg.items() if 'google' in v]

In [38]:
#%%
[k for k, v in model.cfg.items() if 'google' in v]

In [39]:
#%%
[k for k, v in model.cfg.__dict__.items() if 'google' in v]

In [40]:
#%%
model.cfg.__dict__.keys()

dict_keys(['n_layers', 'd_model', 'n_ctx', 'd_head', 'model_name', 'n_heads', 'd_mlp', 'act_fn', 'd_vocab', 'eps', 'use_attn_result', 'use_attn_scale', 'attn_scale', 'use_split_qkv_input', 'use_hook_mlp_in', 'use_attn_in', 'use_local_attn', 'ungroup_grouped_query_attention', 'original_architecture', 'from_checkpoint', 'checkpoint_index', 'checkpoint_label_type', 'checkpoint_value', 'tokenizer_name', 'window_size', 'attn_types', 'init_mode', 'normalization_type', 'device', 'n_devices', 'attention_dir', 'attn_only', 'seed', 'initializer_range', 'init_weights', 'scale_attn_by_inverse_layer_idx', 'positional_embedding_type', 'final_rms', 'd_vocab_out', 'parallel_attn_mlp', 'rotary_dim', 'n_params', 'use_hook_tokens', 'gated_mlp', 'default_prepend_bos', 'dtype', 'tokenizer_prepends_bos', 'n_key_value_heads', 'post_embedding_ln', 'rotary_base', 'trust_remote_code', 'rotary_adjacent_pairs', 'load_in_4bit', 'num_experts', 'experts_per_token', 'relative_attention_max_distance', 'relative_attent

In [41]:
#%%
model.cfg.__dict__

{'n_layers': 26,
 'd_model': 2304,
 'n_ctx': 8192,
 'd_head': 256,
 'model_name': 'gemma-2-2b',
 'n_heads': 8,
 'd_mlp': 9216,
 'act_fn': 'gelu_pytorch_tanh',
 'd_vocab': 256000,
 'eps': 1e-06,
 'use_attn_result': False,
 'use_attn_scale': True,
 'attn_scale': 16.0,
 'use_split_qkv_input': False,
 'use_hook_mlp_in': False,
 'use_attn_in': False,
 'use_local_attn': True,
 'ungroup_grouped_query_attention': False,
 'original_architecture': 'Gemma2ForCausalLM',
 'from_checkpoint': False,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'tokenizer_name': 'google/gemma-2-2b',
 'window_size': 4096,
 'attn_types': ['global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  'global',
  'local',
  '

In [42]:
#%%
# Apply the ISAERFT adapter
model = IsaerftPeft(model, isaerft_config)

In [43]:
#%%
# Set our name for the finetune to be saved &/ uploaded to
finetune_name = "GEMMA-2-2B-FT-ORPO-ISAERFT"
finetune_tags = ["smol-course", "module_1", "isaerft"]

In [44]:
#%%
# Train model with ORPO
orpo_args = ORPOConfig(
    # Small learning rate to prevent catastrophic forgetting
    learning_rate=8e-6,
    # Linear learning rate decay over training
    lr_scheduler_type="linear",
    # Maximum combined length of prompt + completion
    max_length=1024,
    # Maximum length for input prompts
    max_prompt_length=512,
    # Controls weight of the odds ratio loss (λ in paper)
    beta=0.1,
    # Batch size for training
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    # Helps with training stability by accumulating gradients before updating
    gradient_accumulation_steps=4,
    # Memory-efficient optimizer for CUDA, falls back to adamw_torch for CPU/MPS
    optim="paged_adamw_8bit" if device == "cuda" else "adamw_torch",
    # When to run evaluation
    eval_strategy="steps",
    # Evaluate every 20% of training
    eval_steps=0.2,
    # Log metrics every step
    logging_steps=1,
    # Gradual learning rate warmup
    warmup_steps=10,
    # Disable external logging
    report_to="wandb",
    # Where to save model/checkpoints
    output_dir="./results/",
    # Enable MPS (Metal Performance Shaders) if available
    use_mps_device=device == "mps",
    hub_model_id=finetune_name,
    # Training for a shorter time for this example
    num_train_epochs=(1/4*.25),
)

In [45]:
#%%
# Initialize wandb
import wandb
from datetime import datetime
import uuid
wandb.finish()
wandb.login(key=os.environ['WANDB_KEY'])

wandb.init(
    project="gemma-2-2b-orpo-isaerft",
    name=f"run-{datetime.now().strftime('%Y%m%d-%H%M')}-{uuid.uuid4().hex[:6]}",
    tags=finetune_tags
)