diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index ea087d7df..171b877f2 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -148,6 +148,7 @@ "metadata": {}, "outputs": [], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "# Import stuff\n", "import torch\n", "import torch.nn as nn\n", @@ -175,10 +176,11 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "import transformer_lens\n", "import transformer_lens.utilities as utils\n", "from transformer_lens.hook_points import HookPoint\n", diff --git a/demos/Santa_Coder.ipynb b/demos/Santa_Coder.ipynb index 99f7dcab0..ec7fe929d 100644 --- a/demos/Santa_Coder.ipynb +++ b/demos/Santa_Coder.ipynb @@ -51,7 +51,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "da9f5a40", "metadata": { "execution": { @@ -63,6 +63,7 @@ }, "outputs": [], "source": [ + "# NBVAL_IGNORE_OUTPUT\n", "# Import stuff\n", "import torch\n", "import torch.nn as nn\n", diff --git a/docs/source/content/migrating_to_v3.md b/docs/source/content/migrating_to_v3.md index adec266a6..f6d2ed005 100644 --- a/docs/source/content/migrating_to_v3.md +++ b/docs/source/content/migrating_to_v3.md @@ -136,6 +136,10 @@ These work identically on `TransformerBridge` and need no migration: If your code only touches these APIs, the migration is genuinely just the loading call and (optionally) `enable_compatibility_mode`. +### BERT Next Sentence Prediction + +`BertNextSentencePrediction` is not ported to `TransformerBridge`. Keep using `HookedEncoder` + `BertNextSentencePrediction` for NSP workflows. The bridge's BERT adapter does load NSP HuggingFace checkpoints (it rewires the unembed to `cls.seq_relationship`), but the high-level NSP API – sentence-pair tokenization, `[CLS]` pooling, "sequential"/"not sequential" decoding — is not exposed. If this is feature is something you'd like added to TransformerBridge, please file an issue. + ### New in 3.x: streaming generation Both `HookedTransformer` and `TransformerBridge` now expose `generate_stream`, which yields tokens progressively instead of returning the full completion at once: diff --git a/tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py b/tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py index 135cff127..8df4599ef 100644 --- a/tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py +++ b/tests/integration/model_bridge/compatibility/test_bridge_hook_behavior.py @@ -84,7 +84,6 @@ def hook_fn(tensor, hook): ) assert fired == {"resid_pre_0", "resid_post_0"} - @pytest.mark.xfail(reason="add_perma_hook not yet implemented on TransformerBridge") def test_perma_hook_persists_across_calls(self, bridge): """A permanent hook fires on every forward pass until removed.""" count = 0 diff --git a/tests/integration/test_centralized_weight_processing.py b/tests/integration/test_centralized_weight_processing.py index 3f5706c13..1e8ebde20 100644 --- a/tests/integration/test_centralized_weight_processing.py +++ b/tests/integration/test_centralized_weight_processing.py @@ -42,53 +42,6 @@ def bridge_and_adapter(self, model_name, device): bridge = TransformerBridge.boot_transformers(model_name, device=device) return bridge, bridge.adapter, bridge.cfg - @pytest.mark.skip( - reason="API not implemented - ProcessWeights doesn't convert to TL format keys" - ) - def test_processing_with_architecture_adapter( - self, raw_hf_model_and_state_dict, bridge_and_adapter - ): - """Test ProcessWeights.process_weights with architecture adapter.""" - raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict - bridge, adapter, cfg = bridge_and_adapter - - # Preprocess weights first (this converts to TL format with split Q/K/V) - preprocessed_state_dict = adapter.preprocess_weights(raw_state_dict) - - # Process with architecture adapter - processed_with_adapter = ProcessWeights.process_weights( - state_dict=preprocessed_state_dict, - cfg=cfg, - adapter=adapter, - fold_ln=False, - center_writing_weights=False, - center_unembed=False, - fold_value_biases=False, - ) - - # Verify processing occurred - assert len(processed_with_adapter) > 0, "Should process weights with adapter" - - # Check for TransformerLens-style keys (after preprocessing) - # These should be in format like: blocks.0.attn.W_Q, blocks.0.attn.W_K, etc. - tl_keys = [ - k - for k in processed_with_adapter.keys() - if any( - pattern in k for pattern in [".W_Q", ".W_K", ".W_V", ".W_O", ".b_Q", ".b_K", ".b_V"] - ) - ] - assert ( - len(tl_keys) > 0 - ), "Should have TransformerLens-style attention keys after preprocessing" - - # Check that expected TL-style keys exist - expected_patterns = ["blocks.0.attn.W_Q", "blocks.0.attn.W_K", "blocks.0.attn.W_V"] - for pattern in expected_patterns: - assert any( - pattern in k for k in processed_with_adapter.keys() - ), f"Should have {pattern} in processed weights" - def test_processing_without_architecture_adapter( self, raw_hf_model_and_state_dict, bridge_and_adapter ): @@ -116,125 +69,6 @@ def test_processing_without_architecture_adapter( ] assert len(hf_keys) > 0, "Should have HF-style keys without adapter" - @pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V") - def test_processing_with_different_flags(self, raw_hf_model_and_state_dict, bridge_and_adapter): - """Test processing with different flag combinations.""" - raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict - bridge, adapter, cfg = bridge_and_adapter - - # Preprocess weights first - preprocessed_state_dict = adapter.preprocess_weights(raw_state_dict) - - # Test processing with all flags enabled - processed_with_flags = ProcessWeights.process_weights( - state_dict=preprocessed_state_dict.copy(), - cfg=cfg, - adapter=adapter, - fold_ln=True, - center_writing_weights=True, - center_unembed=True, - fold_value_biases=True, - ) - - # Test processing with all flags disabled - processed_without_flags = ProcessWeights.process_weights( - state_dict=preprocessed_state_dict.copy(), - cfg=cfg, - adapter=adapter, - fold_ln=False, - center_writing_weights=False, - center_unembed=False, - fold_value_biases=False, - ) - - # Both should process successfully - assert len(processed_with_flags) > 0, "Should process weights with flags" - assert len(processed_without_flags) > 0, "Should process weights without flags" - - @pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V") - def test_architecture_divergence_handling( - self, raw_hf_model_and_state_dict, bridge_and_adapter - ): - """Test that adapter preprocessing changes the state dict format.""" - raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict - bridge, adapter, cfg = bridge_and_adapter - - # Preprocess with adapter (splits c_attn into Q/K/V) - preprocessed_with_adapter = adapter.preprocess_weights(raw_state_dict) - - # Process with adapter after preprocessing - processed_with_adapter = ProcessWeights.process_weights( - state_dict=preprocessed_with_adapter, - cfg=cfg, - adapter=adapter, - fold_ln=True, - center_writing_weights=True, - center_unembed=True, - fold_value_biases=True, - ) - - # Process without adapter (no preprocessing) - processed_without_adapter = ProcessWeights.process_weights( - state_dict=raw_state_dict.copy(), - cfg=cfg, - adapter=None, - fold_ln=True, - center_writing_weights=True, - center_unembed=True, - fold_value_biases=True, - ) - - # Results should be different (different processing paths) - with_adapter_keys = set(processed_with_adapter.keys()) - without_adapter_keys = set(processed_without_adapter.keys()) - - # Should have some different keys due to different processing - assert ( - with_adapter_keys != without_adapter_keys - ), "With and without adapter should produce different key sets" - - # With adapter should have split Q/K/V keys - tl_attn_keys = [ - k for k in with_adapter_keys if any(p in k for p in [".W_Q", ".W_K", ".W_V"]) - ] - assert len(tl_attn_keys) > 0, "With adapter should have split Q/K/V keys" - - @pytest.mark.skip(reason="API not implemented - adapter.preprocess_weights doesn't split Q/K/V") - def test_custom_component_processing_integration( - self, raw_hf_model_and_state_dict, bridge_and_adapter - ): - """Test that adapter preprocessing splits QKV weights correctly.""" - raw_hf_model, raw_state_dict = raw_hf_model_and_state_dict - bridge, adapter, cfg = bridge_and_adapter - - # Preprocess weights first - this is what splits Q/K/V - preprocessed_weights = adapter.preprocess_weights(raw_state_dict) - - # Process with adapter after preprocessing - processed_weights = ProcessWeights.process_weights( - state_dict=preprocessed_weights, - cfg=cfg, - adapter=adapter, - fold_ln=False, - center_writing_weights=False, - center_unembed=False, - fold_value_biases=False, - ) - - # Check for split Q/K/V weights (created by preprocessing) - custom_qkv_found = any(".W_Q" in k for k in processed_weights.keys()) - - assert custom_qkv_found, "Should have split QKV weights after preprocessing" - - # Verify that QKV splitting occurred for each layer - q_keys = [k for k in processed_weights.keys() if ".W_Q" in k] - k_keys = [k for k in processed_weights.keys() if ".W_K" in k] - v_keys = [k for k in processed_weights.keys() if ".W_V" in k] - - assert len(q_keys) > 0, "Should have Q weight keys" - assert len(k_keys) > 0, "Should have K weight keys" - assert len(v_keys) > 0, "Should have V weight keys" - def test_computational_correctness_with_existing_pipeline(self, model_name, device): """Test that centralized processing maintains computational correctness.""" test_tokens = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long) diff --git a/tests/integration/test_fold_layer_integration.py b/tests/integration/test_fold_layer_integration.py index 5c6daac32..c45067bc5 100644 --- a/tests/integration/test_fold_layer_integration.py +++ b/tests/integration/test_fold_layer_integration.py @@ -10,7 +10,6 @@ 4. Both TransformerLens format (no adapter) and HuggingFace format (with adapter) processing """ -import einops import pytest import torch from transformers import GPT2LMHeadModel @@ -149,280 +148,6 @@ def test_fold_layer_with_real_gpt2_transformer_lens_format(self, gpt2_model_and_ for k, v in original_state_dict.items(): assert torch.equal(v, original_state_dict[k]) - @pytest.mark.skip( - reason="Test is outdated - relies on old HF state_dict key format (transformer.h.0.ln_1.weight)" - ) - def test_fold_layer_with_real_gpt2_huggingface_format(self, gpt2_model_and_config): - """Test _fold_layer with real GPT-2 model in HuggingFace format (with adapter).""" - hf_model = gpt2_model_and_config["hf_model"] - tl_model = gpt2_model_and_config["tl_model"] - adapter = gpt2_model_and_config["adapter"] - cfg = tl_model.cfg - - # Get the state dict from HuggingFace model (HuggingFace format) - state_dict = hf_model.state_dict() - - # Test with layer 0 - layer_idx = 0 - - # Make a copy for comparison - original_state_dict = {k: v.clone() for k, v in state_dict.items()} - - # Test _fold_layer with adapter (HuggingFace format) - ProcessWeights._fold_layer( - state_dict, - cfg, - layer_idx=layer_idx, - fold_biases=True, - center_weights=True, - adapter=adapter, - gqa="", - ) - - # Verify LayerNorm weights are removed (using HuggingFace keys) - assert f"transformer.h.{layer_idx}.ln_1.weight" not in state_dict - assert f"transformer.h.{layer_idx}.ln_1.bias" not in state_dict - assert f"transformer.h.{layer_idx}.ln_2.weight" not in state_dict - assert f"transformer.h.{layer_idx}.ln_2.bias" not in state_dict - - # Verify combined QKV weight is modified - qkv_weight_key = f"transformer.h.{layer_idx}.attn.c_attn.weight" - qkv_bias_key = f"transformer.h.{layer_idx}.attn.c_attn.bias" - - assert qkv_weight_key in state_dict - assert qkv_bias_key in state_dict - - # Split the processed QKV weight back into Q, K, V to verify centering - qkv_weight = state_dict[qkv_weight_key] - w_q, w_k, w_v = torch.tensor_split(qkv_weight, 3, dim=1) - - # Check that weights are centered (mean should be zero across d_model dimension) - # Note: After our fix, centering is done in TransformerLens format (per head) and then converted back - # So we need to check centering by converting back to TransformerLens format - n_heads = cfg.n_heads - d_head = cfg.d_head - d_model = cfg.d_model - - # Convert back to TransformerLens format to check centering - # NOTE: Must use the SAME pattern as the GPT2 adapter: "m (i h) -> i m h" - # The HF format is [d_model, d_model] where the SECOND dimension is split into heads - # NOT the first dimension! - w_q_tl = einops.rearrange(w_q, "m (i h) -> i m h", i=n_heads) # [n_heads, d_model, d_head] - w_k_tl = einops.rearrange(w_k, "m (i h) -> i m h", i=n_heads) # [n_heads, d_model, d_head] - w_v_tl = einops.rearrange(w_v, "m (i h) -> i m h", i=n_heads) # [n_heads, d_model, d_head] - - # Check that weights are centered per head (TransformerLens format centering) - w_q_mean = einops.reduce(w_q_tl, "head_index d_model d_head -> head_index 1 d_head", "mean") - w_k_mean = einops.reduce(w_k_tl, "head_index d_model d_head -> head_index 1 d_head", "mean") - w_v_mean = einops.reduce(w_v_tl, "head_index d_model d_head -> head_index 1 d_head", "mean") - - assert torch.allclose(w_q_mean, torch.zeros_like(w_q_mean), atol=1e-6) - assert torch.allclose(w_k_mean, torch.zeros_like(w_k_mean), atol=1e-6) - assert torch.allclose(w_v_mean, torch.zeros_like(w_v_mean), atol=1e-6) - - # Verify MLP weights are modified - mlp_w_in_key = f"transformer.h.{layer_idx}.mlp.c_fc.weight" - mlp_b_in_key = f"transformer.h.{layer_idx}.mlp.c_fc.bias" - - assert mlp_w_in_key in state_dict - assert mlp_b_in_key in state_dict - - # Check that MLP weights are centered - mlp_w_mean = torch.mean(state_dict[mlp_w_in_key], dim=0, keepdim=True) # [1, d_mlp] - assert torch.allclose(mlp_w_mean, torch.zeros_like(mlp_w_mean), atol=1e-6) - - # Verify original state dict is unchanged - for k, v in original_state_dict.items(): - assert torch.equal(v, original_state_dict[k]) - - @pytest.mark.skip( - reason="Test is outdated - relies on old HF state_dict key format (transformer.h.0.attn.c_attn.weight)" - ) - def test_fold_layer_equivalence_between_formats(self, gpt2_model_and_config): - """Test that _fold_layer produces equivalent results for both formats with the same input.""" - hf_model = gpt2_model_and_config["hf_model"] - tl_model = gpt2_model_and_config["tl_model"] - adapter = gpt2_model_and_config["adapter"] - cfg = tl_model.cfg - - layer_idx = 0 - - # Start with the same unprocessed HuggingFace model state dict - hf_state_dict = hf_model.state_dict() - - # Create a TransformerLens format state dict from the HuggingFace one - # This simulates what would happen when converting HF to TL format - tl_state_dict = {} - - # Convert HuggingFace keys to TransformerLens keys - for hf_key, tensor in hf_state_dict.items(): - if f"transformer.h.{layer_idx}" in hf_key: - if "attn.c_attn.weight" in hf_key: - # Split combined QKV weight into separate Q, K, V weights - # HuggingFace: [d_model, 3*d_model] -> TransformerLens: [n_heads, d_model, d_head] for each - n_heads = cfg.n_heads - d_head = cfg.d_head - d_model = cfg.d_model - - # Split the combined weight - qkv_weight = tensor # [d_model, 3*d_model] - w_q_hf, w_k_hf, w_v_hf = torch.tensor_split( - qkv_weight, 3, dim=1 - ) # Each: [d_model, d_model] - - # Reshape to TransformerLens format: [d_model, d_model] -> [n_heads, d_model, d_head] - w_q_tl = w_q_hf.T.reshape(n_heads, d_model, d_head) - w_k_tl = w_k_hf.T.reshape(n_heads, d_model, d_head) - w_v_tl = w_v_hf.T.reshape(n_heads, d_model, d_head) - - tl_state_dict[f"blocks.{layer_idx}.attn.W_Q"] = w_q_tl - tl_state_dict[f"blocks.{layer_idx}.attn.W_K"] = w_k_tl - tl_state_dict[f"blocks.{layer_idx}.attn.W_V"] = w_v_tl - - elif "attn.c_attn.bias" in hf_key: - # Split combined QKV bias into separate Q, K, V biases - qkv_bias = tensor # [3*d_model] - b_q_hf, b_k_hf, b_v_hf = torch.tensor_split( - qkv_bias, 3, dim=0 - ) # Each: [d_model] - - # Reshape to TransformerLens format: [d_model] -> [n_heads, d_head] - b_q_tl = b_q_hf.reshape(n_heads, d_head) - b_k_tl = b_k_hf.reshape(n_heads, d_head) - b_v_tl = b_v_hf.reshape(n_heads, d_head) - - tl_state_dict[f"blocks.{layer_idx}.attn.b_Q"] = b_q_tl - tl_state_dict[f"blocks.{layer_idx}.attn.b_K"] = b_k_tl - tl_state_dict[f"blocks.{layer_idx}.attn.b_V"] = b_v_tl - - elif "ln_1.weight" in hf_key: - tl_state_dict[f"blocks.{layer_idx}.ln1.w"] = tensor - elif "ln_1.bias" in hf_key: - tl_state_dict[f"blocks.{layer_idx}.ln1.b"] = tensor - elif "ln_2.weight" in hf_key: - tl_state_dict[f"blocks.{layer_idx}.ln2.w"] = tensor - elif "ln_2.bias" in hf_key: - tl_state_dict[f"blocks.{layer_idx}.ln2.b"] = tensor - elif "mlp.c_fc.weight" in hf_key: - tl_state_dict[f"blocks.{layer_idx}.mlp.W_in"] = tensor - elif "mlp.c_fc.bias" in hf_key: - tl_state_dict[f"blocks.{layer_idx}.mlp.b_in"] = tensor - - # Now we have the same data in both formats - test equivalence - # Test without centering first to isolate the issue - print("Testing without centering...") - - # Process HuggingFace format (no centering) - hf_processed_no_center = {k: v.clone() for k, v in hf_state_dict.items()} - ProcessWeights._fold_layer( - hf_processed_no_center, - cfg, - layer_idx=layer_idx, - fold_biases=True, - center_weights=False, - adapter=adapter, - gqa="", - ) - - # Process TransformerLens format (no centering) - tl_processed_no_center = {k: v.clone() for k, v in tl_state_dict.items()} - ProcessWeights._fold_layer( - tl_processed_no_center, - cfg, - layer_idx=layer_idx, - fold_biases=True, - center_weights=False, - adapter=None, - gqa="", - ) - - # Compare without centering - hf_qkv_weight_no_center = hf_processed_no_center[ - f"transformer.h.{layer_idx}.attn.c_attn.weight" - ] - hf_w_q_no_center, _, _ = torch.tensor_split(hf_qkv_weight_no_center, 3, dim=1) - tl_w_q_no_center = tl_processed_no_center[f"blocks.{layer_idx}.attn.W_Q"] - tl_w_q_hf_format_no_center = tl_w_q_no_center.reshape(d_model, d_model).T - - diff_no_center = torch.max(torch.abs(hf_w_q_no_center - tl_w_q_hf_format_no_center)) - print(f"Difference without centering: {diff_no_center:.6f}") - - # Now test with centering - print("Testing with centering...") - - # Process HuggingFace format (with centering) - hf_processed = {k: v.clone() for k, v in hf_state_dict.items()} - ProcessWeights._fold_layer( - hf_processed, - cfg, - layer_idx=layer_idx, - fold_biases=True, - center_weights=True, - adapter=adapter, - gqa="", - ) - - # Process TransformerLens format (with centering) - tl_processed = {k: v.clone() for k, v in tl_state_dict.items()} - ProcessWeights._fold_layer( - tl_processed, - cfg, - layer_idx=layer_idx, - fold_biases=True, - center_weights=True, - adapter=None, - gqa="", - ) - - # Compare the results by converting back to the same format - # Extract Q weights from both formats and compare - hf_qkv_weight = hf_processed[f"transformer.h.{layer_idx}.attn.c_attn.weight"] - hf_w_q, hf_w_k, hf_w_v = torch.tensor_split( - hf_qkv_weight, 3, dim=1 - ) # Each: [d_model, d_model] - - tl_w_q = tl_processed[f"blocks.{layer_idx}.attn.W_Q"] # [n_heads, d_model, d_head] - - # Convert TL format back to HF format for comparison - n_heads = cfg.n_heads - d_head = cfg.d_head - d_model = cfg.d_model - tl_w_q_hf_format = tl_w_q.reshape(d_model, d_model).T # [d_model, d_model] - - # Compare with centering - diff_with_center = torch.max(torch.abs(hf_w_q - tl_w_q_hf_format)) - print(f"Difference with centering: {diff_with_center:.6f}") - - # The Q weights should be identical (within numerical precision) - if diff_no_center < 1e-6: - print("✅ LayerNorm folding is equivalent between formats") - else: - print(f"❌ LayerNorm folding differs between formats (diff: {diff_no_center:.6f})") - - if diff_with_center < 1e-6: - print("✅ Centering is equivalent between formats") - else: - print(f"❌ Centering differs between formats (diff: {diff_with_center:.6f})") - - # Both should have LayerNorm weights removed - assert f"blocks.{layer_idx}.ln1.w" not in tl_processed - assert f"transformer.h.{layer_idx}.ln_1.weight" not in hf_processed - - # The Q weights should be similar (but different implementations may vary) - max_diff = torch.max(torch.abs(hf_w_q - tl_w_q_hf_format)) - if max_diff > 1.0: # Only fail if difference is extremely large - assert False, f"Q weights differ too much: max diff = {max_diff}" - elif max_diff > 0.1: - print( - f"⚠️ Large difference in Q weights: {max_diff:.6f} (different implementations expected)" - ) - else: - print(f"✅ Q weights match well: max diff = {max_diff:.6f}") - - print( - f"✅ Equivalence test passed: Q weights match exactly (max diff: {diff_with_center:.2e})" - ) - def test_fold_layer_with_different_layers(self, gpt2_model_and_config): """Test _fold_layer with different layers to ensure it works across all layers.""" tl_model = gpt2_model_and_config["tl_model"] diff --git a/tests/integration/test_weight_processing_integration.py b/tests/integration/test_weight_processing_integration.py index a23f7bdf4..1bc6ef6a3 100644 --- a/tests/integration/test_weight_processing_integration.py +++ b/tests/integration/test_weight_processing_integration.py @@ -174,68 +174,6 @@ def test_extract_attention_tensors_with_hooked_transformer(self, gpt2_small_mode assert wk_tensor is not None assert wv_tensor is not None - @pytest.mark.skip(reason="Test is no longer needed for new architecture") - def test_extract_attention_tensors_with_adapter(self, gpt2_small_adapter): - """Test tensor extraction with HuggingFace adapter.""" - # Create a mock state dict with HuggingFace format - d_model = 768 - n_heads = 12 - d_head = 64 - - # Combined QKV weight: [d_model, 3*d_model] - combined_qkv_weight = torch.randn(d_model, 3 * d_model) - # Combined QKV bias: [3*d_model] - combined_qkv_bias = torch.randn(3 * d_model) - - # Mock state dict - state_dict = { - "transformer.h.0.attn.c_attn.weight": combined_qkv_weight, - "transformer.h.0.attn.c_attn.bias": combined_qkv_bias, - } - - # Mock config - define as function to avoid variable scope issue - def create_mock_config(): - class MockConfig: - pass - - config = MockConfig() - config.n_heads = n_heads - config.d_head = d_head - config.d_model = d_model - return config - - cfg = create_mock_config() - layer = 0 - adapter = gpt2_small_adapter - - # Extract tensors - tensors = ProcessWeights.extract_attention_tensors_for_folding( - state_dict, cfg, layer, adapter - ) - - wq_tensor = tensors["wq"] - wk_tensor = tensors["wk"] - wv_tensor = tensors["wv"] - bq_tensor = tensors["bq"] - bk_tensor = tensors["bk"] - bv_tensor = tensors["bv"] - - # Verify shapes (should be in TransformerLens format) - expected_shape = (n_heads, d_model, d_head) - assert wq_tensor.shape == expected_shape - assert wk_tensor.shape == expected_shape - assert wv_tensor.shape == expected_shape - - expected_bias_shape = (n_heads, d_head) - assert bq_tensor.shape == expected_bias_shape - assert bk_tensor.shape == expected_bias_shape - assert bv_tensor.shape == expected_bias_shape - - # Verify tensors are properly extracted - assert wq_tensor is not None - assert wk_tensor is not None - assert wv_tensor is not None - def test_full_pipeline_with_hooked_transformer(self, gpt2_small_model): """Test the full pipeline with HookedTransformer model.""" model = gpt2_small_model diff --git a/tests/unit/model_bridge/compatibility/test_next_sentence_prediction.py b/tests/unit/model_bridge/compatibility/test_next_sentence_prediction.py index e521feff9..b083b5bf1 100644 --- a/tests/unit/model_bridge/compatibility/test_next_sentence_prediction.py +++ b/tests/unit/model_bridge/compatibility/test_next_sentence_prediction.py @@ -59,51 +59,6 @@ def test_init(mock_transformer_bridge): assert bert_nsp.model == mock_transformer_bridge -def test_call_chain(bert_nsp, mock_transformer_bridge): - """Test that the call chain works with TransformerBridge""" - input_tensor = torch.tensor([[1, 2, 3]]) - token_type_ids = torch.tensor([[0, 0, 1]]) - attention_mask = torch.tensor([[1, 1, 1]]) - - # Set up specific mock returns - mock_resid = torch.randn(1, 3, 768) - - # For TransformerBridge, we might need to adapt the encoder_output call - # This depends on how BERT models are implemented in the bridge - if hasattr(mock_transformer_bridge, "encoder_output"): - mock_transformer_bridge.encoder_output.return_value = mock_resid - else: - # Fallback: mock the forward call directly - mock_transformer_bridge.return_value = mock_resid - - mock_pooled = torch.randn(1, 768) - mock_transformer_bridge.pooler.return_value = mock_pooled - - mock_nsp_output = torch.tensor([[0.7, 0.3]]) - mock_transformer_bridge.nsp_head.return_value = mock_nsp_output - - # Call forward - try: - output = bert_nsp.forward( - input_tensor, token_type_ids=token_type_ids, one_zero_attention_mask=attention_mask - ) - - # Verify the chain of calls (adapted for TransformerBridge) - if hasattr(mock_transformer_bridge, "encoder_output"): - mock_transformer_bridge.encoder_output.assert_called_once_with( - input_tensor, token_type_ids, attention_mask - ) - - mock_transformer_bridge.pooler.assert_called_once() - mock_transformer_bridge.nsp_head.assert_called_once_with(mock_pooled) - - # Verify output matches the mock NSP head output - assert torch.equal(output, mock_nsp_output) - except AttributeError as e: - # If TransformerBridge doesn't support the required methods yet, skip the test - pytest.skip(f"TransformerBridge doesn't support required method: {e}") - - def test_tokenizer_integration(bert_nsp, mock_transformer_bridge): """Test that tokenizer integration works with TransformerBridge""" input_sentences = ["First sentence.", "Second sentence."] diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 682b478bf..7d4718a80 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -3342,6 +3342,20 @@ def add_hook( else: raise AttributeError(f"Hook point '{hook_name}' not found on component") + def add_perma_hook( + self, + name: Union[str, Callable[[str], bool]], + hook_fn, + dir="fwd", + ) -> None: + """Add a permanent hook that survives ``reset_hooks()`` calls. + + Convenience wrapper for ``add_hook(..., is_permanent=True)``. To remove, + call ``reset_hooks(including_permanent=True)`` or remove from the + underlying ``HookPoint`` directly. + """ + self.add_hook(name, hook_fn, dir=dir, is_permanent=True) + def reset_hooks(self, clear_contexts=True): """Remove all hooks from the model."""