Feat/external architecture registration#1307
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a public mechanism for registering custom architecture adapters in TransformerLens without forking the project, plus accompanying tests and documentation.
Changes:
- New
ArchitectureAdapterFactory.register_adapter()classmethod for runtime registration, anddiscover_entry_points()that reads adapters from thetransformer_lens.architecturesentry-point group on first use. - New unit test module covering registration, discovery idempotency, and select error paths.
- New documentation page under Contributing describing both registration paths and an example external package layout.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| transformer_lens/factories/architecture_adapter_factory.py | Adds register_adapter and discover_entry_points classmethods; calls discovery at the start of select_architecture_adapter. |
| tests/unit/model_bridge/test_architecture_adapter_factory.py | New tests for SUPPORTED_ARCHITECTURES contents, runtime registration, error handling, and discovery idempotency. |
| docs/source/content/contributing.md | Adds the new doc page to the toctree. |
| docs/source/content/adapter_development/external-adapter-registration.md | New doc page explaining runtime and entry-point registration, with example package layout. |
Comments suppressed due to low confidence (1)
tests/unit/model_bridge/test_architecture_adapter_factory.py:58
- This test does not actually verify "overwrite" behavior. It registers
MockArchitectureAdaptertwice for the same key, so the post-conditioncls._adapters[key] is firstis trivially true regardless of whether the secondregister_adaptercall overwrote anything. To meaningfully test overwrite behavior, register a second, different adapter class and assert that the registry now points at the new class (and not atfirst).
def test_register_overwrites_existing(self):
key = "TestOverwriteForCausalLM"
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
first = ArchitectureAdapterFactory._adapters[key]
ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter)
assert ArchitectureAdapterFactory._adapters[key] is first
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| class TestRegisterAdapter: | ||
| """Verify runtime adapter registration.""" | ||
|
|
||
| def test_register_adds_to_adapters(self): | ||
| key = "TestMockForCausalLM" | ||
| ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) | ||
| assert key in ArchitectureAdapterFactory._adapters | ||
|
|
||
| def test_register_overwrites_existing(self): | ||
| key = "TestOverwriteForCausalLM" | ||
| ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) | ||
| first = ArchitectureAdapterFactory._adapters[key] | ||
| ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) | ||
| assert ArchitectureAdapterFactory._adapters[key] is first | ||
|
|
||
| def test_select_returns_registered_adapter(self): | ||
| key = "TestSelectForCausalLM" | ||
| ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) | ||
| cfg = _make_cfg(architecture=key) | ||
| adapter = ArchitectureAdapterFactory.select_architecture_adapter(cfg) | ||
| assert isinstance(adapter, MockArchitectureAdapter) | ||
|
|
||
|
|
||
| class TestSelectErrors: | ||
| """Verify error handling in select_architecture_adapter.""" | ||
|
|
||
| def test_unknown_architecture_raises(self): | ||
| cfg = _make_cfg(architecture="NonExistentForCausalLM") | ||
| with pytest.raises(ValueError, match="Unsupported architecture"): | ||
| ArchitectureAdapterFactory.select_architecture_adapter(cfg) | ||
|
|
||
| def test_none_architecture_raises(self): | ||
| cfg = _make_cfg(architecture=None) | ||
| with pytest.raises(ValueError, match="must have architecture set"): | ||
| ArchitectureAdapterFactory.select_architecture_adapter(cfg) | ||
|
|
||
|
|
||
| class TestDiscoverEntryPoints: | ||
| """Verify entry-point discovery behavior.""" | ||
|
|
||
| def test_discover_is_idempotent(self): | ||
| ArchitectureAdapterFactory._entry_points_discovered = False | ||
| ArchitectureAdapterFactory.discover_entry_points() | ||
| first_run = ArchitectureAdapterFactory._entry_points_discovered | ||
| ArchitectureAdapterFactory.discover_entry_points() | ||
| assert ArchitectureAdapterFactory._entry_points_discovered is first_run is True | ||
|
|
||
| def test_discover_does_not_remove_existing(self): | ||
| key = "TestPreserveForCausalLM" | ||
| ArchitectureAdapterFactory.register_adapter(key, MockArchitectureAdapter) | ||
| ArchitectureAdapterFactory._entry_points_discovered = False | ||
| ArchitectureAdapterFactory.discover_entry_points() | ||
| assert key in ArchitectureAdapterFactory._adapters |
| discovery of adapters from installed packages via entry points. | ||
| """ | ||
|
|
||
| _adapters = SUPPORTED_ARCHITECTURES |
There was a problem hiding this comment.
This is a great change, thanks for implementing @huseyincavusbi
| else: | ||
| for ep in eps: | ||
| try: | ||
| cls._adapters[ep.name] = ep.load() |
There was a problem hiding this comment.
We should add a guard here:
If an adapter already exists as part of the main TransformerLens package, we do not want it to be automatically overwritten by a custom adapter from another package. Let's check for the existence of ep.name in our _adapters list, and log a warning.
For the warning, I am thinking something along the lines of "Custom architecture adapter {ep.name} provided by {ep.dist.name} attempted to override a native adapter. If you'd like to use this custom adapter, register it explicitly with register_adapter".
| ... def __init__(self, cfg): | ||
| ... super().__init__(cfg) | ||
| >>> ArchitectureAdapterFactory.register_adapter("MyModelForCausalLM", MyAdapter) | ||
| >>> "MyModelForCausalLM" in ArchitectureAdapterFactory._adapters |
There was a problem hiding this comment.
This example references the private cls._adapters attribute. Better to demonstrate the public path via select_architecture_adapter.
| return TransformerBridgeConfig(**defaults) | ||
|
|
||
|
|
||
| class TestSupportedArchitectures: |
There was a problem hiding this comment.
There are no tests included for testing the entry point discovery system. Would it be possible to add some additional tests covering those features? Specifically we want to cover:
- Successfully entry point discovery
- Override prevention
- Failed to register bad adapter
|
|
||
| > **Important:** The architecture name you register (e.g. `"MyModelForCausalLM"`) must match the `architectures` field in the model's HuggingFace `config.json`. TransformerLens reads this field to look up the adapter. | ||
|
|
||
| ### 2. Entry-point registration (recommended for packages) |
There was a problem hiding this comment.
Add a Note at the end of this section documenting the warning behavior that occurs if you try to override a native adapter via entry-point
|
@huseyincavusbi just a couple small comments related to entry point behaviors, all together this is a great implementation, thank you for taking this on. Feel free to ping me if you have any questions |
|
Awesome work on this @huseyincavusbi! I am approving and merging now, this will be included in the next release |
Hi @jlarson4,
Description
Three changes to allow users to register custom architecture adapters without forking TransformerLens:
register_adapter()method — Classmethod onArchitectureAdapterFactorythat adds custom architectures to the registry at runtimediscover_entry_points()scans installed packages for adapters declared viatransformer_lens.architecturesentry group, auto-called on firstselect_architecture_adapter()Closes #1298
Type of change
Checklist