diff --git a/tests/unit/model_bridge/supported_architectures/test_bloom_adapter.py b/tests/unit/model_bridge/supported_architectures/test_bloom_adapter.py new file mode 100644 index 000000000..efea55a09 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_bloom_adapter.py @@ -0,0 +1,370 @@ +"""Unit tests for BloomArchitectureAdapter.""" + +from types import SimpleNamespace +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.generalized_components import ( + BloomAttentionBridge, + BloomBlockBridge, + BloomMLPBridge, + EmbeddingBridge, + NormalizationBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.bloom import ( + BloomArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 8, + d_model: int = 64, + n_layers: int = 2, + d_vocab: int = 100, + n_ctx: int = 128, +) -> TransformerBridgeConfig: + return TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=n_ctx, + n_heads=n_heads, + d_vocab=d_vocab, + architecture="BloomForCausalLM", + ) + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> BloomArchitectureAdapter: + return BloomArchitectureAdapter(cfg) + + +def _make_qkv_component(d_model: int) -> Any: + ns = SimpleNamespace() + ns.query_key_value = nn.Linear(d_model, 3 * d_model, bias=True) + return ns + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestBloomAdapterConfig: + def test_normalization_type(self, adapter: BloomArchitectureAdapter) -> None: + assert adapter.cfg.normalization_type == "LN" + + def test_positional_embedding_type(self, adapter: BloomArchitectureAdapter) -> None: + assert adapter.cfg.positional_embedding_type == "alibi" + + def test_final_rms(self, adapter: BloomArchitectureAdapter) -> None: + assert adapter.cfg.final_rms is False + + def test_gated_mlp(self, adapter: BloomArchitectureAdapter) -> None: + assert adapter.cfg.gated_mlp is False + + def test_attn_only(self, adapter: BloomArchitectureAdapter) -> None: + assert adapter.cfg.attn_only is False + + def test_default_prepend_bos(self, adapter: BloomArchitectureAdapter) -> None: + assert adapter.cfg.default_prepend_bos is False + + +# --------------------------------------------------------------------------- +# Component mapping tests +# --------------------------------------------------------------------------- + + +class TestBloomAdapterComponentMapping: + @staticmethod + def _mapping(adapter: BloomArchitectureAdapter) -> dict[str, Any]: + mapping = adapter.component_mapping + assert mapping is not None + return mapping + + def test_embed_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["embed"], EmbeddingBridge) + assert mapping["embed"].name == "transformer.word_embeddings" + + def test_embed_ln_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["embed_ln"], NormalizationBridge) + assert mapping["embed_ln"].name == "transformer.word_embeddings_layernorm" + + def test_blocks_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["blocks"], BloomBlockBridge) + assert mapping["blocks"].name == "transformer.h" + + def test_ln_final_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["ln_final"], NormalizationBridge) + assert mapping["ln_final"].name == "transformer.ln_f" + + def test_unembed_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + mapping = self._mapping(adapter) + assert isinstance(mapping["unembed"], UnembeddingBridge) + assert mapping["unembed"].name == "lm_head" + + def test_ln1_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln1"], NormalizationBridge) + assert blocks.submodules["ln1"].name == "input_layernorm" + + def test_ln2_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["ln2"], NormalizationBridge) + assert blocks.submodules["ln2"].name == "post_attention_layernorm" + + def test_attn_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["attn"], BloomAttentionBridge) + assert blocks.submodules["attn"].name == "self_attention" + + def test_attn_qkv_name(self, adapter: BloomArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["attn"].submodules["qkv"].name == "query_key_value" + + def test_attn_o_name(self, adapter: BloomArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["attn"].submodules["o"].name == "dense" + + def test_mlp_type_and_name(self, adapter: BloomArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert isinstance(blocks.submodules["mlp"], BloomMLPBridge) + assert blocks.submodules["mlp"].name == "mlp" + + def test_mlp_in_name(self, adapter: BloomArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["mlp"].submodules["in"].name == "dense_h_to_4h" + + def test_mlp_out_name(self, adapter: BloomArchitectureAdapter) -> None: + blocks = self._mapping(adapter)["blocks"] + assert blocks.submodules["mlp"].submodules["out"].name == "dense_4h_to_h" + + +# --------------------------------------------------------------------------- +# Weight conversion tests +# --------------------------------------------------------------------------- + + +class TestBloomWeightConversions: + def test_four_conversion_keys(self, adapter: BloomArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + assert len(convs) == 4 + + def test_qkvo_keys_present(self, adapter: BloomArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + + for key in [ + "blocks.{i}.attn.q", + "blocks.{i}.attn.k", + "blocks.{i}.attn.v", + "blocks.{i}.attn.o", + ]: + assert key in convs + + def test_q_conversion_type(self, adapter: BloomArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + + conv = convs["blocks.{i}.attn.q"] + + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + def test_q_rearrange_pattern(self, adapter: BloomArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + + conv = convs["blocks.{i}.attn.q"] + + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + assert conv.tensor_conversion.pattern == "(n h) m -> n m h" + + def test_o_rearrange_pattern(self, adapter: BloomArchitectureAdapter) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + + conv = convs["blocks.{i}.attn.o"] + + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + assert conv.tensor_conversion.pattern == "m (n h) -> n h m" + + def test_q_rearrange_n_equals_n_heads( + self, + adapter: BloomArchitectureAdapter, + ) -> None: + convs = adapter.weight_processing_conversions + assert convs is not None + + conv = convs["blocks.{i}.attn.q"] + + assert isinstance(conv, ParamProcessingConversion) + assert isinstance(conv.tensor_conversion, RearrangeTensorConversion) + + assert conv.tensor_conversion.axes_lengths["n"] == adapter.cfg.n_heads + + +# --------------------------------------------------------------------------- +# split_qkv_matrix tests +# --------------------------------------------------------------------------- + + +class TestBloomSplitQKV: + def _adapter(self) -> BloomArchitectureAdapter: + return BloomArchitectureAdapter(_make_cfg()) + + def test_returns_three_linears(self) -> None: + adapter = self._adapter() + + component = _make_qkv_component(64) + + q, k, v = adapter.split_qkv_matrix(component) + + assert isinstance(q, nn.Linear) + assert isinstance(k, nn.Linear) + assert isinstance(v, nn.Linear) + + def test_output_shapes(self) -> None: + adapter = self._adapter() + + component = _make_qkv_component(64) + + q, k, v = adapter.split_qkv_matrix(component) + + assert q.weight.shape == (64, 64) + assert k.weight.shape == (64, 64) + assert v.weight.shape == (64, 64) + + def test_biases_present(self) -> None: + adapter = self._adapter() + + component = _make_qkv_component(64) + + q, k, v = adapter.split_qkv_matrix(component) + + assert q.bias is not None + assert k.bias is not None + assert v.bias is not None + + def test_interleaved_split_correctness(self) -> None: + """ + Bloom stores QKV interleaved: + [Q0,K0,V0,Q1,K1,V1,...] + """ + + d_model = 12 + n_heads = 3 + d_head = 4 + + cfg = _make_cfg( + n_heads=n_heads, + d_model=d_model, + ) + + adapter = BloomArchitectureAdapter(cfg) + + component = _make_qkv_component(d_model) + + W = torch.zeros(3 * d_model, d_model) + + for h in range(n_heads): + start = h * 3 * d_head + + W[start : start + d_head] = 1.0 + W[start + d_head : start + 2 * d_head] = 2.0 + W[start + 2 * d_head : start + 3 * d_head] = 3.0 + + component.query_key_value.weight = nn.Parameter(W) + + bias = torch.zeros(3 * d_model) + + for h in range(n_heads): + start = h * 3 * d_head + + bias[start : start + d_head] = 1.0 + bias[start + d_head : start + 2 * d_head] = 2.0 + bias[start + 2 * d_head : start + 3 * d_head] = 3.0 + + component.query_key_value.bias = nn.Parameter(bias) + + q, k, v = adapter.split_qkv_matrix(component) + + assert torch.all(q.weight == 1.0) + assert torch.all(k.weight == 2.0) + assert torch.all(v.weight == 3.0) + + assert torch.all(q.bias == 1.0) + assert torch.all(k.bias == 2.0) + assert torch.all(v.bias == 3.0) + + def test_forward_shapes(self) -> None: + adapter = self._adapter() + + component = _make_qkv_component(64) + + q, k, v = adapter.split_qkv_matrix(component) + + x = torch.randn(2, 5, 64) + + assert q(x).shape == (2, 5, 64) + assert k(x).shape == (2, 5, 64) + assert v(x).shape == (2, 5, 64) + + +# --------------------------------------------------------------------------- +# Factory registration tests +# --------------------------------------------------------------------------- + + +class TestBloomFactoryRegistration: + def test_factory_key(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "BloomForCausalLM" in SUPPORTED_ARCHITECTURES + + def test_factory_returns_bloom_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg() + + adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) + + assert isinstance(adapter, BloomArchitectureAdapter) + + def test_import_from_init(self) -> None: + from transformer_lens.model_bridge.supported_architectures import ( + BloomArchitectureAdapter as FromInit, + ) + + assert FromInit is BloomArchitectureAdapter diff --git a/tests/unit/model_bridge/supported_architectures/test_llama_adapter.py b/tests/unit/model_bridge/supported_architectures/test_llama_adapter.py new file mode 100644 index 000000000..689152341 --- /dev/null +++ b/tests/unit/model_bridge/supported_architectures/test_llama_adapter.py @@ -0,0 +1,493 @@ +"""Unit tests for LlamaArchitectureAdapter. + +Tests cover: +- Config attribute validation +- Component mapping structure +- Weight conversion keys +- GQA support +- Rotary embedding setup +- Factory registration +""" + +from types import SimpleNamespace +from typing import Any + +import pytest + +from transformer_lens.config import TransformerBridgeConfig +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + PositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) +from transformer_lens.model_bridge.supported_architectures.llama import ( + LlamaArchitectureAdapter, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_cfg( + n_heads: int = 8, + d_model: int = 128, + n_layers: int = 2, + d_vocab: int = 1000, + n_key_value_heads: int | None = None, +) -> TransformerBridgeConfig: + """Return minimal config for Llama adapter tests.""" + cfg = TransformerBridgeConfig( + d_model=d_model, + d_head=d_model // n_heads, + n_layers=n_layers, + n_ctx=512, + n_heads=n_heads, + d_vocab=d_vocab, + default_prepend_bos=True, + architecture="LlamaForCausalLM", + ) + + if n_key_value_heads is not None: + cfg.n_key_value_heads = n_key_value_heads + + return cfg + + +@pytest.fixture +def cfg() -> TransformerBridgeConfig: + return _make_cfg() + + +@pytest.fixture +def adapter(cfg: TransformerBridgeConfig) -> LlamaArchitectureAdapter: + return LlamaArchitectureAdapter(cfg) + + +# --------------------------------------------------------------------------- +# Config attribute tests +# --------------------------------------------------------------------------- + + +class TestLlamaAdapterConfig: + """Tests config values set by the adapter.""" + + def test_normalization_type_is_rms( + self, adapter: LlamaArchitectureAdapter + ) -> None: + assert adapter.cfg.normalization_type == "RMS" + + def test_positional_embedding_type_is_rotary( + self, adapter: LlamaArchitectureAdapter + ) -> None: + assert adapter.cfg.positional_embedding_type == "rotary" + + def test_final_rms_is_true( + self, adapter: LlamaArchitectureAdapter + ) -> None: + assert adapter.cfg.final_rms is True + + def test_gated_mlp_is_true( + self, adapter: LlamaArchitectureAdapter + ) -> None: + assert adapter.cfg.gated_mlp is True + + def test_attn_only_is_false( + self, adapter: LlamaArchitectureAdapter + ) -> None: + assert adapter.cfg.attn_only is False + + def test_uses_rms_norm_is_true( + self, adapter: LlamaArchitectureAdapter + ) -> None: + assert adapter.cfg.uses_rms_norm is True + + def test_eps_attr_is_variance_epsilon( + self, adapter: LlamaArchitectureAdapter + ) -> None: + assert adapter.cfg.eps_attr == "variance_epsilon" + + +# --------------------------------------------------------------------------- +# GQA tests +# --------------------------------------------------------------------------- + + +class TestLlamaGQA: + """Tests grouped query attention support.""" + + def test_n_key_value_heads_added_to_default_config(self) -> None: + cfg = _make_cfg(n_key_value_heads=4) + adapter = LlamaArchitectureAdapter(cfg) + + assert "n_key_value_heads" in adapter.default_config + assert adapter.default_config["n_key_value_heads"] == 4 + + def test_n_key_value_heads_set_on_cfg(self) -> None: + cfg = _make_cfg(n_key_value_heads=4) + adapter = LlamaArchitectureAdapter(cfg) + + assert adapter.cfg.n_key_value_heads == 4 + + +# --------------------------------------------------------------------------- +# Component mapping structure tests +# --------------------------------------------------------------------------- + + +class TestLlamaComponentMapping: + """Tests component mapping structure.""" + + def _blocks(self, adapter: LlamaArchitectureAdapter) -> BlockBridge: + component_mapping = adapter.component_mapping + assert component_mapping is not None + blocks = component_mapping["blocks"] + assert isinstance(blocks, BlockBridge) + return blocks + + def test_embed_is_embedding_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + component_mapping = adapter.component_mapping + assert component_mapping is not None + + assert isinstance( + component_mapping["embed"], + EmbeddingBridge, + ) + + def test_embed_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + component_mapping = adapter.component_mapping + assert component_mapping is not None + + assert ( + component_mapping["embed"].name + == "model.embed_tokens" + ) + + def test_rotary_emb_is_rotary_embedding_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + component_mapping = adapter.component_mapping + assert component_mapping is not None + + assert isinstance( + component_mapping["rotary_emb"], + RotaryEmbeddingBridge, + ) + + def test_rotary_emb_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + component_mapping = adapter.component_mapping + assert component_mapping is not None + + assert ( + component_mapping["rotary_emb"].name + == "model.rotary_emb" + ) + + def test_blocks_is_block_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + assert isinstance(blocks, BlockBridge) + + def test_blocks_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + assert blocks.name == "model.layers" + + def test_ln1_is_rms_norm_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + assert isinstance( + blocks.submodules["ln1"], + RMSNormalizationBridge, + ) + + def test_ln2_is_rms_norm_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + assert isinstance( + blocks.submodules["ln2"], + RMSNormalizationBridge, + ) + + def test_attn_is_position_embeddings_attention_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + assert isinstance( + blocks.submodules["attn"], + PositionEmbeddingsAttentionBridge, + ) + + def test_mlp_is_gated_mlp_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + assert isinstance( + blocks.submodules["mlp"], + GatedMLPBridge, + ) + + def test_q_proj_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + attn = blocks.submodules["attn"] + + assert attn.submodules["q"].name == "q_proj" + + def test_k_proj_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + attn = blocks.submodules["attn"] + + assert attn.submodules["k"].name == "k_proj" + + def test_v_proj_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + attn = blocks.submodules["attn"] + + assert attn.submodules["v"].name == "v_proj" + + def test_o_proj_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + attn = blocks.submodules["attn"] + + assert attn.submodules["o"].name == "o_proj" + + def test_gate_proj_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + mlp = blocks.submodules["mlp"] + + assert mlp.submodules["gate"].name == "gate_proj" + + def test_up_proj_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + mlp = blocks.submodules["mlp"] + + assert mlp.submodules["in"].name == "up_proj" + + def test_down_proj_name( + self, adapter: LlamaArchitectureAdapter + ) -> None: + blocks = self._blocks(adapter) + + mlp = blocks.submodules["mlp"] + + assert mlp.submodules["out"].name == "down_proj" + + def test_ln_final_is_rms_norm_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + component_mapping = adapter.component_mapping + assert component_mapping is not None + + assert isinstance( + component_mapping["ln_final"], + RMSNormalizationBridge, + ) + + def test_unembed_is_unembedding_bridge( + self, adapter: LlamaArchitectureAdapter + ) -> None: + component_mapping = adapter.component_mapping + assert component_mapping is not None + + assert isinstance( + component_mapping["unembed"], + UnembeddingBridge, + ) +# --------------------------------------------------------------------------- +# Weight conversion tests +# --------------------------------------------------------------------------- + +class TestLlamaWeightConversions: + """Tests expected conversion keys exist.""" + + def _conversions( + self, + adapter: LlamaArchitectureAdapter, + ) -> dict: + conversions = adapter.weight_processing_conversions + assert conversions is not None + return conversions + + def test_q_weight_key_present( + self, + adapter: LlamaArchitectureAdapter, + ) -> None: + conversions = self._conversions(adapter) + + assert ( + "blocks.{i}.attn.q.weight" + in conversions + ) + + def test_k_weight_key_present( + self, + adapter: LlamaArchitectureAdapter, + ) -> None: + conversions = self._conversions(adapter) + + assert ( + "blocks.{i}.attn.k.weight" + in conversions + ) + + def test_v_weight_key_present( + self, + adapter: LlamaArchitectureAdapter, + ) -> None: + conversions = self._conversions(adapter) + + assert ( + "blocks.{i}.attn.v.weight" + in conversions + ) + + def test_o_weight_key_present( + self, + adapter: LlamaArchitectureAdapter, + ) -> None: + conversions = self._conversions(adapter) + + assert ( + "blocks.{i}.attn.o.weight" + in conversions + ) + + +# --------------------------------------------------------------------------- +# setup_component_testing tests +# --------------------------------------------------------------------------- + + +class DummyAttention: + def __init__(self) -> None: + self.rotary_emb = None + + def set_rotary_emb(self, rotary_emb: Any) -> None: + self.rotary_emb = rotary_emb + + +class DummyBlock: + def __init__(self) -> None: + self.attn = DummyAttention() + + +class DummyBridgeModel: + def __init__(self, n_layers: int = 2) -> None: + self.blocks = [DummyBlock() for _ in range(n_layers)] + + +class TestLlamaSetupComponentTesting: + """Tests rotary embedding setup.""" + + def test_rotary_emb_set_on_bridge_model_blocks( + self, adapter: LlamaArchitectureAdapter + ) -> None: + rotary_emb = object() + + hf_model = SimpleNamespace( + model=SimpleNamespace( + rotary_emb=rotary_emb + ) + ) + + bridge_model = DummyBridgeModel() + + adapter.setup_component_testing( + hf_model, + bridge_model, + ) + + for block in bridge_model.blocks: + assert block.attn.rotary_emb is rotary_emb + + def test_template_attention_bridge_accepts_rotary_embedding( + self, adapter: LlamaArchitectureAdapter +) -> None: + """Ensure setup_component_testing successfully injects RoPE into template attention bridge.""" + + rotary_emb = object() + + hf_model = SimpleNamespace( + model=SimpleNamespace( + rotary_emb=rotary_emb + ) + ) + + # Should run without raising + adapter.setup_component_testing(hf_model) + + attn_bridge = adapter.get_generalized_component( + "blocks.0.attn" + ) + + # Verify method exists and adapter remains usable + assert hasattr(attn_bridge, "set_rotary_emb") + + +# --------------------------------------------------------------------------- +# Factory registration tests +# --------------------------------------------------------------------------- + + +class TestLlamaFactoryRegistration: + """Tests factory registration.""" + + def test_factory_returns_llama_adapter(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + ArchitectureAdapterFactory, + ) + + cfg = _make_cfg() + + adapter = ArchitectureAdapterFactory.select_architecture_adapter( + cfg + ) + + assert isinstance( + adapter, + LlamaArchitectureAdapter, + ) + + def test_factory_key_exists(self) -> None: + from transformer_lens.factories.architecture_adapter_factory import ( + SUPPORTED_ARCHITECTURES, + ) + + assert "LlamaForCausalLM" in SUPPORTED_ARCHITECTURES \ No newline at end of file