Skip to content

Commit

Permalink
feat(runner): load tokenizer manually
Browse files Browse the repository at this point in the history
  • Loading branch information
Frankstein73 committed Jun 13, 2024
1 parent 40e795d commit 83f863c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 23 deletions.
12 changes: 6 additions & 6 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ dependencies = [
"numpy>=1.26.4",
"pandas>=2.2.1",
"pymongo>=4.6.3",
"pytest>=8.0.1",
"tensorboardX>=2.6.2.2",
"torch>=2.2.0",
"tqdm>=4.66.2",
Expand All @@ -41,6 +40,7 @@ license = {text = "MIT"}
dev = [
"-e file:///${PROJECT_ROOT}/TransformerLens#egg=transformer-lens",
"mypy>=1.10.0",
"pytest>=8.0.1",
]

[tool.mypy]
Expand Down
75 changes: 59 additions & 16 deletions src/lm_saes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,23 @@ def language_model_sae_eval_runner(cfg: LanguageModelSAERunnerConfig):
hf_model = AutoModelForCausalLM.from_pretrained(
cfg.model_name, cache_dir=cfg.cache_dir, local_files_only=cfg.local_files_only
)

hf_tokenizer = AutoTokenizer.from_pretrained(
(
cfg.model_name
if cfg.model_from_pretrained_path is None
else cfg.model_from_pretrained_path
),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained(
cfg.model_name,
device=cfg.device,
cache_dir=cfg.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.dtype,
)
model.eval()
Expand Down Expand Up @@ -206,38 +218,58 @@ def sample_feature_activations_runner(cfg: LanguageModelSAEAnalysisConfig):
cache_dir=cfg.cache_dir,
local_files_only=cfg.local_files_only,
)
hf_tokenizer = AutoTokenizer.from_pretrained(
(
cfg.model_name
if cfg.model_from_pretrained_path is None
else cfg.model_from_pretrained_path
),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained(
cfg.model_name,
device=cfg.device,
cache_dir=cfg.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.dtype,
)
model.eval()

client = MongoClient(cfg.mongo_uri, cfg.mongo_db)
client.remove_dictionary(cfg.exp_name, cfg.exp_series)
client.create_dictionary(cfg.exp_name, cfg.d_sae, cfg.exp_series)

for chunk_id in range(cfg.n_sae_chunks):
activation_store = ActivationStore.from_config(model=model, cfg=cfg)
result = sample_feature_activations(sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks)
result = sample_feature_activations(
sae, model, activation_store, cfg, chunk_id, cfg.n_sae_chunks
)

for i in range(len(result["index"].cpu().numpy().tolist())):
client.update_feature(cfg.exp_name, result["index"][i].item(), {
"act_times": result["act_times"][i].item(),
"max_feature_acts": result["max_feature_acts"][i].item(),
"feature_acts_all": result["feature_acts_all"][i]
.cpu()
.float()
.numpy(), # use .float() to convert bfloat16 to float32
"analysis": [
{
"name": v["name"],
"feature_acts": v["feature_acts"][i].cpu().float().numpy(),
"contexts": v["contexts"][i].cpu().numpy(),
} for v in result["analysis"]
]
}, dictionary_series=cfg.exp_series)
client.update_feature(
cfg.exp_name,
result["index"][i].item(),
{
"act_times": result["act_times"][i].item(),
"max_feature_acts": result["max_feature_acts"][i].item(),
"feature_acts_all": result["feature_acts_all"][i]
.cpu()
.float()
.numpy(), # use .float() to convert bfloat16 to float32
"analysis": [
{
"name": v["name"],
"feature_acts": v["feature_acts"][i].cpu().float().numpy(),
"contexts": v["contexts"][i].cpu().numpy(),
}
for v in result["analysis"]
],
},
dictionary_series=cfg.exp_series,
)

del result
del activation_store
Expand All @@ -257,11 +289,22 @@ def features_to_logits_runner(cfg: FeaturesDecoderConfig):
cache_dir=cfg.cache_dir,
local_files_only=cfg.local_files_only,
)
hf_tokenizer = AutoTokenizer.from_pretrained(
(
cfg.model_name
if cfg.model_from_pretrained_path is None
else cfg.model_from_pretrained_path
),
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained(
cfg.model_name,
device=cfg.device,
cache_dir=cfg.cache_dir,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=cfg.dtype,
)
model.eval()
Expand Down
Binary file modified ui/bun.lockb
Binary file not shown.

0 comments on commit 83f863c

Please sign in to comment.