Left: LLaMA-2-7B on MetaMathQA (GSM8K accuracy vs. retention on 14 held-out benchmarks). Right: Mistral-7B on Magicoder-Evol-Instruct-110K (HumanEval pass@1 vs. the same retention suite). In both cases DISeL (red) sits on the base-model retention line while matching the strongest fine-tuning accuracy; every other adapter family either forgets visibly (LoRA / DoRA / Full FT) or loses accuracy to preserve retention (AdaLoRA).
DISeL (Dynamic Input-Sensitive LoRA) makes LoRA's correction input-dependent.
Standard LoRA learns a single low-rank correction
DISeL keeps the rank-$r$ factorisation but multiplies each rank-one component by an input-dependent sigmoid gate:
Two properties follow directly from the parameterisation:
-
The model starts at the pre-trained mapping. Initialising
$b_g$ to a small negative value (we use$-3$ , so$\sigma(-3) \approx 0.05$ ) closes every gate at step 0, and combined with LoRA's standard zero-init on$B$ this means$f(x)=0$ everywhere. Adaptation only kicks in when some gate learns to open. - Gates open only where opening reduces the fine-tuning loss. Each gate is an independent soft switch (sigmoid, not softmax), so multiple ranks can open at once and out-of-distribution tokens can keep them all closed. In practice the gates do exactly that — see the paper's interpretability section for histograms by input domain and module type.
The gate adds
This repository is the minimal reference implementation that accompanies the
paper. It ships as a pip install disel-able package built on top of
HuggingFace peft and
transformers; we plan to
upstream the method as a use_disel=True flag on LoraConfig once the API
shape settles (see "Upstreaming" below).
DISeL has only a handful of knobs beyond standard LoRA. The defaults below are the ones we use in the paper unless noted; the trade-off they control is the same in every case (lower gate values + smaller gate LR → less adaptation, more retention; higher → more adaptation, more forgetting).
The gate value at initialisation is
-
Default:
$b_g = -3$ ($\sigma(-3) \approx 0.05$ ). This is what we use for all LLaMA-2-7B and Mistral-7B experiments. -
More conservative starts (
b_g = -5or-7) help on small datasets, where you want to preserve pre-training more strongly and rely on the gate LR to selectively unlock the directions that matter. In the paper we use$b_g = -7$ for most RoBERTa-on-GLUE tasks (MNLI, SST-2, QNLI), and a less negative bias only on the smallest tasks (CoLA, MRPC) where the gates need to open faster.
We initialise nn.Linear default. Note that random (small but non-zero) weights
are what we want here, not zeros: the gate value at init is determined by
The package also exposes disel_gate_weight_init="zero" for ablations; it
zeros
Gate parameters
-
Paper recipe (LLaMA / Mistral): gate LR
$= 10^{-3}$ , which is$\mathbf{5\times}$ the LoRA learning rate of$2\times 10^{-4}$ . Exposed asgate_lr_multiplier=5.0(default) onbuild_optimizer. -
RoBERTa-on-GLUE: gate LR is
$10^{-4}$ on large tasks (MNLI, SST-2, QNLI) and$10^{-3}$ on small ones (CoLA, MRPC). The latter need a higher gate LR to overcome the conservative$b_g = -7$ init.
The gate learning rate is the main knob you have for the plasticity vs. retention trade-off: a larger gate LR encourages more gates to open and favours adaptation, a smaller gate LR keeps more directions closed and favours retention.
A subtle but important point: build_optimizer
puts the gate parameters in their own AdamW group with weight_decay=0.0;
if you write your own optimiser, do the same.
pip install -e .
# or, with the example training scripts:
pip install -e ".[examples]"Python ≥ 3.10, torch ≥ 2.1, transformers ≥ 4.40, peft ≥ 0.13, < 0.20.
import torch
from transformers import AutoModelForCausalLM
from peft import get_peft_model
import disel
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16,
)
config = disel.DiselConfig(
r=64,
lora_alpha=128,
target_modules="all-linear",
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
disel_gate_bias_init=-3.0, # gates start ~closed
disel_gate_normalize=False,
disel_gate_weight_init="random",
)
model = get_peft_model(model, config)
disel.enable_disel(model, config) # attach gates to every LoRA layer
model.print_trainable_parameters()
optimizer = disel.build_optimizer(
model, base_lr=2e-4, gate_lr=1e-3, weight_decay=0.01,
)
# Plug `optimizer` into HuggingFace Trainer or your loop. See examples/.enable_disel adds a lora_disel_gate ModuleDict to every PEFT LoraLayer
and registers a LoraVariant so PEFT's forward path routes through the DISeL
computation. The lora_ prefix on the ModuleDict name matches the convention
DoRA uses (lora_magnitude_vector) and is what triggers PEFT's state-dict
serialiser to include the gate parameters — so saving with the standard
model.save_pretrained(...) writes them into the same
adapter_model.safetensors as the LoRA matrices.
Saving is just the standard PEFT call:
model.save_pretrained("checkpoints/disel_run")Loading needs three steps in a specific order, because vanilla PEFT does not
know about DISeL: (1) PEFT rebuilds the LoRA layers, (2) we attach fresh
gates, (3) we re-apply the saved state dict to populate the gates. We expose
disel.from_pretrained to do all three in one call:
from transformers import AutoModelForCausalLM
import disel
base = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16,
)
model = disel.from_pretrained(base, "checkpoints/disel_run")
model.eval()If you prefer to assemble the pieces manually (e.g. you already called
PeftModel.from_pretrained from your own pipeline), use the lower-level
helper:
peft_model = PeftModel.from_pretrained(base, "checkpoints/disel_run")
disel.enable_disel(peft_model, config) # attach fresh gates
disel.load_gate_state_dict(peft_model, "checkpoints/disel_run") # fill themNaively calling PeftModel.from_pretrained without the second/third step
silently leaves the gates at their fresh init — covered by
tests/test_disel.py::test_save_load_round_trip, which asserts bit-exact
parameter round-trips and matching forward passes.
If you prefer to keep gate weights in their own file alongside the LoRA
checkpoint (the convention used in the original research repo), call
disel.save_gate_state_dict after the usual save:
model.save_pretrained("checkpoints/disel_run") # adapter_model.safetensors
disel.save_gate_state_dict(model, "checkpoints/disel_run") # gate_weights.safetensorsdisel.load_gate_state_dict (and therefore disel.from_pretrained) reads
either layout — the bundled adapter_model.safetensors first, falling back
to a standalone gate_weights.safetensors (or legacy gate_weights.pt).
The three example scripts mirror what a paper-style run looks like:
# 1. Train (LLaMA-2-7B on MetaMathQA, full paper recipe)
accelerate launch examples/train_metamath.py \
--output_dir runs/disel_llama_r64 \
--lora_rank 64 --learning_rate 2e-4 --gate_lr 1e-3 \
--num_train_epochs 3
# 2. Quick smoke check that load + inference work
python examples/eval_gsm8k.py \
--base meta-llama/Llama-2-7b-hf \
--adapter runs/disel_llama_r64 \
--num_problems 200
# 3. Full paper-style evaluation (target task + 14-benchmark retention)
# handled by lm-evaluation-harness; see examples/README.mdThe hero figure above covers the two large-scale settings: mathematical reasoning (LLaMA-2-7B / MetaMathQA, evaluated on GSM8K) and code generation (Mistral-7B / Magicoder, evaluated on HumanEval pass@1). In both, DISeL crosses the base-model retention line at the largest ranks while matching the strongest fine-tuning accuracy: LoRA and DoRA give up 2–5 points of retention for similar accuracy, AdaLoRA preserves retention but drops 10+ accuracy points, and Full FT moves several points to the left of the retention line.
The same picture also holds on a much smaller architecture (RoBERTa-base on five GLUE tasks), where the appropriate retention metric is masked-LM perplexity on three out-of-domain corpora rather than benchmark accuracy:
RoBERTa-base fine-tuned on five GLUE tasks (MNLI, SST-2, QNLI, CoLA, MRPC). DISeL keeps masked-LM perplexity near the pre-trained baseline (dashed line, ≈ 6.4) while matching the accuracy of AdaLoRA and Full FT; LoRA, DoRA, and especially Full FT shift one to two decades to the right on the perplexity axis.
| Feature | Status |
|---|---|
target_modules="all-linear" |
✅ via the underlying LoraConfig |
target_modules=[...] (explicit list) |
✅ |
| Saving / loading via PEFT | ✅ (gates are in adapter_layer_names) |
model.disable_adapter() |
✅ |
model.merge_and_unload() |
❌ NotImplementedError — the gate is input-dependent, so there is no fixed ΔW to fold into the base weight |
| Quantised backends (bnb 4-/8-bit) | Experimental — works at fp16/bf16 master weights, no special quant kernel |
Multi-adapter (add_adapter) |
✅ — call enable_disel(model, config, adapter_name=...) for each |
See examples/ for runnable scripts and examples/README.md for the exact
hyperparameters used in the paper tables. The two main recipes are
train_metamath.py (LLaMA-2-7B on MetaMathQA) and the matching
HumanEval/Magicoder recipe (coming soon).
disel/
├── __init__.py # public API
├── config.py # DiselConfig (subclass of LoraConfig)
├── layer.py # RankGate / LightRankGate nn.Module
├── variant.py # DiselLinearVariant (shaped like PEFT's DoraLinearVariant)
└── integration.py # enable_disel(...) and build_optimizer(...)
examples/
└── train_metamath.py
tests/
└── test_disel.py
The variant class in disel/variant.py is intentionally shaped to drop in to
src/peft/tuners/lora/variants.py
next to DoraLinearVariant. Doing so requires (a) a new flag on LoraConfig
(use_disel), (b) one extra branch in
Linear.resolve_lora_variant / Embedding.resolve_lora_variant, and (c) a
test-matrix entry in tests/test_custom_models.py. See
PR #1474 (DoRA),
PR #1838 (FourierFT), and
PR #1864 (HRA) for the
canonical contribution flow.
Coming soon.
Apache-2.0 — see LICENSE.