In [1]:
%load_ext autoreload
%autoreload 2

import torch
import lm_eval
from fsrl import SAEAdapter, HookedModel
from fsrl.utils.wandb_utils import (
    WandBModelDownloader,
    download_model_family,
    list_model_family,
)
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from lm_eval.models.huggingface import HFLM

run_of_interest = "fragrant-bush-43"

downloader = WandBModelDownloader(
        entity="feature-steering-RL",
        project="Gemma2-2B"
)

if run_of_interest not in list_model_family():
    downloader.download_model(run_of_interest)
else:
    print(f"Run {run_of_interest} found among downloaded models")







Run fragrant-bush-43 found among downloaded models


In [2]:
base_model = HookedTransformer.from_pretrained("google/gemma-2-2b-it", device="cuda", dtype=torch.bfloat16)
tokenizer = base_model.tokenizer
adapter_path = downloader.models_base_dir / "Gemma2-2B" / run_of_interest / "adapter"
print(f"Loading adapter from: {adapter_path}")

sae_adapter = SAEAdapter.load_from_pretrained_adapter(adapter_path, device="cuda")
hooked_model = HookedModel(base_model, sae_adapter)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer
Loading adapter from: /root/feature-steering-RL/models/Gemma2-2B/fragrant-bush-43/adapter
Adapter loaded from /root/feature-steering-RL/models/Gemma2-2B/fragrant-bush-43/adapter


In [3]:
batch_size = 16

eval_model = HFLM(pretrained=hooked_model, tokenizer=tokenizer, batch_size=batch_size)

task_manager = lm_eval.tasks.TaskManager()
results = lm_eval.simple_evaluate(
    model=eval_model,
    tasks=["mmlu"],
    task_manager=task_manager,
    limit=0.01,
    apply_chat_template=True,
)

print(results)

100%|██████████| 1/1 [00:00<00:00, 31.29it/s]
100%|██████████| 2/2 [00:00<00:00, 133.23it/s]
100%|██████████| 2/2 [00:00<00:00, 136.96it/s]
100%|██████████| 2/2 [00:00<00:00, 134.55it/s]
100%|██████████| 1/1 [00:00<00:00, 126.75it/s]
100%|██████████| 1/1 [00:00<00:00, 137.37it/s]
100%|██████████| 1/1 [00:00<00:00, 137.92it/s]
100%|██████████| 2/2 [00:00<00:00, 140.18it/s]
100%|██████████| 1/1 [00:00<00:00, 129.06it/s]
100%|██████████| 3/3 [00:00<00:00, 118.93it/s]
100%|██████████| 2/2 [00:00<00:00, 128.17it/s]
100%|██████████| 4/4 [00:00<00:00, 120.59it/s]
100%|██████████| 4/4 [00:00<00:00, 136.69it/s]
100%|██████████| 3/3 [00:00<00:00, 224.98it/s]
100%|██████████| 1/1 [00:00<00:00, 259.02it/s]
100%|██████████| 3/3 [00:00<00:00, 265.71it/s]
100%|██████████| 2/2 [00:00<00:00, 185.89it/s]
100%|██████████| 3/3 [00:00<00:00, 193.01it/s]
100%|██████████| 2/2 [00:00<00:00, 194.56it/s]
100%|██████████| 1/1 [00:00<00:00, 193.15it/s]
100%|██████████| 3/3 [00:00<00:00, 186.04it/s]
100%|█████████

{'results': {'mmlu': {'acc,none': 0.5, 'acc_stderr,none': 'N/A', 'alias': 'mmlu'}, 'mmlu_humanities': {'acc,none': 0.45454545454545453, 'acc_stderr,none': 0.0636350752291753, 'alias': ' - humanities'}, 'mmlu_formal_logic': {'alias': '  - formal_logic', 'acc,none': 0.5, 'acc_stderr,none': 0.5}, 'mmlu_high_school_european_history': {'alias': '  - high_school_european_history', 'acc,none': 1.0, 'acc_stderr,none': 0.0}, 'mmlu_high_school_us_history': {'alias': '  - high_school_us_history', 'acc,none': 1.0, 'acc_stderr,none': 0.0}, 'mmlu_high_school_world_history': {'alias': '  - high_school_world_history', 'acc,none': 0.6666666666666666, 'acc_stderr,none': 0.33333333333333337}, 'mmlu_international_law': {'alias': '  - international_law', 'acc,none': 1.0, 'acc_stderr,none': 0.0}, 'mmlu_jurisprudence': {'alias': '  - jurisprudence', 'acc,none': 0.0, 'acc_stderr,none': 0.0}, 'mmlu_logical_fallacies': {'alias': '  - logical_fallacies', 'acc,none': 0.5, 'acc_stderr,none': 0.5}, 'mmlu_moral_disp

In [4]:
print(results['results']['mmlu'])

{'acc,none': 0.5, 'acc_stderr,none': 'N/A', 'alias': 'mmlu'}


In [5]:
results_baseline = lm_eval.simple_evaluate(
    model=HFLM(pretrained=base_model, tokenizer=tokenizer, batch_size=batch_size),
    tasks=["mmlu"],
    task_manager=task_manager,
    limit=0.01,
    apply_chat_template=True,
)

print(results_baseline)



AttributeError: 'HookedTransformer' object has no attribute 'device'