Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion demos/Othello_GPT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"# Import stuff\n",
"import torch\n",
"import torch.nn as nn\n",
Expand Down Expand Up @@ -175,10 +176,11 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"import transformer_lens\n",
"import transformer_lens.utilities as utils\n",
"from transformer_lens.hook_points import HookPoint\n",
Expand Down
3 changes: 2 additions & 1 deletion demos/Santa_Coder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "da9f5a40",
"metadata": {
"execution": {
Expand All @@ -63,6 +63,7 @@
},
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"# Import stuff\n",
"import torch\n",
"import torch.nn as nn\n",
Expand Down
4 changes: 4 additions & 0 deletions docs/source/content/migrating_to_v3.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ These work identically on `TransformerBridge` and need no migration:

If your code only touches these APIs, the migration is genuinely just the loading call and (optionally) `enable_compatibility_mode`.

### BERT Next Sentence Prediction

`BertNextSentencePrediction` is not ported to `TransformerBridge`. Keep using `HookedEncoder` + `BertNextSentencePrediction` for NSP workflows. The bridge's BERT adapter does load NSP HuggingFace checkpoints (it rewires the unembed to `cls.seq_relationship`), but the high-level NSP API – sentence-pair tokenization, `[CLS]` pooling, "sequential"/"not sequential" decoding — is not exposed. If this is feature is something you'd like added to TransformerBridge, please file an issue.

### New in 3.x: streaming generation

Both `HookedTransformer` and `TransformerBridge` now expose `generate_stream`, which yields tokens progressively instead of returning the full completion at once:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def hook_fn(tensor, hook):
)
assert fired == {"resid_pre_0", "resid_post_0"}

@pytest.mark.xfail(reason="add_perma_hook not yet implemented on TransformerBridge")
def test_perma_hook_persists_across_calls(self, bridge):
"""A permanent hook fires on every forward pass until removed."""
count = 0
Expand Down
166 changes: 0 additions & 166 deletions tests/integration/test_centralized_weight_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,53 +42,6 @@ def bridge_and_adapter(self, model_name, device):
bridge = TransformerBridge.boot_transformers(model_name, device=device)
return bridge, bridge.adapter, bridge.cfg

@pytest.mark.skip(
reason="API not implemented - ProcessWeights doesn't convert to TL format keys"
)
def test_processing_with_architecture_adapter(
self, raw_hf_model_and_state_dict, bridge_and_adapter
):
"""Test ProcessWeights.process_weights with architecture adapter."""
raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict
bridge, adapter, cfg = bridge_and_adapter

# Preprocess weights first (this converts to TL format with split Q/K/V)
preprocessed_state_dict = adapter.preprocess_weights(raw_state_dict)

# Process with architecture adapter
processed_with_adapter = ProcessWeights.process_weights(
state_dict=preprocessed_state_dict,
cfg=cfg,
adapter=adapter,
fold_ln=False,
center_writing_weights=False,
center_unembed=False,
fold_value_biases=False,
)

# Verify processing occurred
assert len(processed_with_adapter) > 0, "Should process weights with adapter"

# Check for TransformerLens-style keys (after preprocessing)
# These should be in format like: blocks.0.attn.W_Q, blocks.0.attn.W_K, etc.
tl_keys = [
k
for k in processed_with_adapter.keys()
if any(
pattern in k for pattern in [".W_Q", ".W_K", ".W_V", ".W_O", ".b_Q", ".b_K", ".b_V"]
)
]
assert (
len(tl_keys) > 0
), "Should have TransformerLens-style attention keys after preprocessing"

# Check that expected TL-style keys exist
expected_patterns = ["blocks.0.attn.W_Q", "blocks.0.attn.W_K", "blocks.0.attn.W_V"]
for pattern in expected_patterns:
assert any(
pattern in k for k in processed_with_adapter.keys()
), f"Should have {pattern} in processed weights"

def test_processing_without_architecture_adapter(
self, raw_hf_model_and_state_dict, bridge_and_adapter
):
Expand Down Expand Up @@ -116,125 +69,6 @@ def test_processing_without_architecture_adapter(
]
assert len(hf_keys) > 0, "Should have HF-style keys without adapter"

@pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V")
def test_processing_with_different_flags(self, raw_hf_model_and_state_dict, bridge_and_adapter):
"""Test processing with different flag combinations."""
raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict
bridge, adapter, cfg = bridge_and_adapter

# Preprocess weights first
preprocessed_state_dict = adapter.preprocess_weights(raw_state_dict)

# Test processing with all flags enabled
processed_with_flags = ProcessWeights.process_weights(
state_dict=preprocessed_state_dict.copy(),
cfg=cfg,
adapter=adapter,
fold_ln=True,
center_writing_weights=True,
center_unembed=True,
fold_value_biases=True,
)

# Test processing with all flags disabled
processed_without_flags = ProcessWeights.process_weights(
state_dict=preprocessed_state_dict.copy(),
cfg=cfg,
adapter=adapter,
fold_ln=False,
center_writing_weights=False,
center_unembed=False,
fold_value_biases=False,
)

# Both should process successfully
assert len(processed_with_flags) > 0, "Should process weights with flags"
assert len(processed_without_flags) > 0, "Should process weights without flags"

@pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V")
def test_architecture_divergence_handling(
self, raw_hf_model_and_state_dict, bridge_and_adapter
):
"""Test that adapter preprocessing changes the state dict format."""
raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict
bridge, adapter, cfg = bridge_and_adapter

# Preprocess with adapter (splits c_attn into Q/K/V)
preprocessed_with_adapter = adapter.preprocess_weights(raw_state_dict)

# Process with adapter after preprocessing
processed_with_adapter = ProcessWeights.process_weights(
state_dict=preprocessed_with_adapter,
cfg=cfg,
adapter=adapter,
fold_ln=True,
center_writing_weights=True,
center_unembed=True,
fold_value_biases=True,
)

# Process without adapter (no preprocessing)
processed_without_adapter = ProcessWeights.process_weights(
state_dict=raw_state_dict.copy(),
cfg=cfg,
adapter=None,
fold_ln=True,
center_writing_weights=True,
center_unembed=True,
fold_value_biases=True,
)

# Results should be different (different processing paths)
with_adapter_keys = set(processed_with_adapter.keys())
without_adapter_keys = set(processed_without_adapter.keys())

# Should have some different keys due to different processing
assert (
with_adapter_keys != without_adapter_keys
), "With and without adapter should produce different key sets"

# With adapter should have split Q/K/V keys
tl_attn_keys = [
k for k in with_adapter_keys if any(p in k for p in [".W_Q", ".W_K", ".W_V"])
]
assert len(tl_attn_keys) > 0, "With adapter should have split Q/K/V keys"

@pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V")
def test_custom_component_processing_integration(
self, raw_hf_model_and_state_dict, bridge_and_adapter
):
"""Test that adapter preprocessing splits QKV weights correctly."""
raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict
bridge, adapter, cfg = bridge_and_adapter

# Preprocess weights first - this is what splits Q/K/V
preprocessed_weights = adapter.preprocess_weights(raw_state_dict)

# Process with adapter after preprocessing
processed_weights = ProcessWeights.process_weights(
state_dict=preprocessed_weights,
cfg=cfg,
adapter=adapter,
fold_ln=False,
center_writing_weights=False,
center_unembed=False,
fold_value_biases=False,
)

# Check for split Q/K/V weights (created by preprocessing)
custom_qkv_found = any(".W_Q" in k for k in processed_weights.keys())

assert custom_qkv_found, "Should have split QKV weights after preprocessing"

# Verify that QKV splitting occurred for each layer
q_keys = [k for k in processed_weights.keys() if ".W_Q" in k]
k_keys = [k for k in processed_weights.keys() if ".W_K" in k]
v_keys = [k for k in processed_weights.keys() if ".W_V" in k]

assert len(q_keys) > 0, "Should have Q weight keys"
assert len(k_keys) > 0, "Should have K weight keys"
assert len(v_keys) > 0, "Should have V weight keys"

def test_computational_correctness_with_existing_pipeline(self, model_name, device):
"""Test that centralized processing maintains computational correctness."""
test_tokens = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)
Expand Down
Loading
Loading