-
Notifications
You must be signed in to change notification settings - Fork 190
[1/N] ModelOPT PEFT mode support for the megatron-lm #342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughIntroduces a new PEFT (LoRA) subsystem under modelopt.torch, including configs, conversion/adapter management, mode registry integration, LoRA base layer and Megatron-Core plugins, regex utilities, package initializers, tests (unit and GPU/Megatron), CODEOWNERS update, and a CHANGELOG entry. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Model as nn.Module
participant Convert as peft.convert.update_model
participant ConvUtil as peft.conversion.replace_lora_module
participant Adapter as peft.conversion.add_adapter
participant ModeReg as PEFTModeRegistry
User->>Convert: update_model(Model, PEFTConfig|dict)
Convert->>ModeReg: (indirect) ensure PEFT mode/config types
Convert->>ConvUtil: replace_lora_module(Model, version/config/registry)
Note right of ConvUtil: Replace eligible layers<br/>with LoRAModule instances
Convert->>Adapter: add_adapter(Model, PEFTConfig)
Note right of Adapter: Register/merge adapter(s)<br/>per layer patterns
Convert-->>User: Model (PEFT-enabled)
alt Toggle adapters
User->>peft.convert.enable_adapters: (Model, layers/adapters)
User->>peft.convert.disable_adapters: (Model, layers/adapters)
end
sequenceDiagram
autonumber
actor Trainer
participant Model as nn.Module (PEFT)
participant Freeze as peft.conversion.freeze_lora_weights
participant Unfreeze as peft.conversion.unfreeze_lora_weights
Trainer->>Freeze: layer_patterns / adapter_patterns
Note right of Freeze: Set requires_grad=False<br/>for matching LoRA params
Trainer->>Unfreeze: layer_patterns / adapter_patterns
Note right of Unfreeze: Set requires_grad=True<br/>for matching LoRA params
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (5)
🧰 Additional context used🧠 Learnings (1)📚 Learning: 2025-09-15T20:46:29.252ZApplied to files:
🧬 Code graph analysis (3)modelopt/torch/peft/convert.py (5)
tests/unit/torch/utils/test_regex.py (1)
modelopt/torch/peft/conversion.py (5)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #342 +/- ##
==========================================
- Coverage 73.79% 73.36% -0.43%
==========================================
Files 171 180 +9
Lines 17591 17919 +328
==========================================
+ Hits 12981 13147 +166
- Misses 4610 4772 +162 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
♻️ Duplicate comments (1)
modelopt/torch/peft/convert.py (1)
205-219: Another instance of private attribute accessSimilar to the previous comment, this function also directly accesses
_lora_adapters.
🧹 Nitpick comments (24)
modelopt/torch/peft/lora/__init__.py (1)
1-3: Prefer lazy submodule imports to reduce import-time cost and avoid cycles.Importing
layerandtp_layerat package import can be heavy and risks circulars. Expose them lazily via__getattr__.-"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning.""" - -from . import layer, tp_layer +"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning.""" + +from importlib import import_module as _import_module + +__all__ = ["layer", "tp_layer"] + +def __getattr__(name): + if name in __all__: + mod = _import_module(f".{name}", __name__) + globals()[name] = mod + return mod + raise AttributeError(f"module {__name__!r} has no attribute {name!r}")modelopt/torch/peft/__init__.py (1)
16-16: Docstring mismatch with package purpose.This is PEFT/LoRA, not distillation. Update the docstring.
-"""Distillation API subpackage for torch.""" +"""PEFT/LoRA API for torch."""modelopt/torch/peft/custom.py (2)
22-29: Run plugins deterministically and guard against concurrent mutation.Take a snapshot, sort by module/qualname, then invoke.
def register_custom_model_plugins_on_the_fly(model): """Registers custom PEFT/LoRA plugins on the fly. This is called before LoRAModule replacement to allow plugins to configure the model (e.g., for distributed checkpointing). """ - for callback in CUSTOM_MODEL_PLUGINS: - callback(model) + # Snapshot to avoid RuntimeError if mutated during iteration + callbacks = sorted( + tuple(CUSTOM_MODEL_PLUGINS), + key=lambda f: (getattr(f, "__module__", ""), getattr(f, "__qualname__", getattr(f, "__name__", ""))), + ) + for callback in callbacks: + callback(model)
16-23: Add light typing for clarity.Optional: add minimal typing to document the callback contract.
-"""Custom PEFT/LoRA plugins registry.""" +"""Custom PEFT/LoRA plugins registry.""" +from typing import Callable, Iterable # lightweight, no runtime depsmodelopt/torch/peft/plugins/megatron.py (1)
31-34: Optional: export public symbols for introspection.Expose
MEGATRON_AVAILABLEand the hook via__all__for discoverability.-__all__ = [] +__all__ = ["MEGATRON_AVAILABLE", "megatron_replace_lora_module_hook"]modelopt/torch/peft/mode.py (2)
1-13: Missing module docstringAdd a module-level docstring to document the purpose and functionality of this mode registry module.
+"""PEFT mode definitions and registry for parameter-efficient fine-tuning.""" + from modelopt.torch.opt.config import ModeloptBaseConfig
17-46: Missing class docstringsBoth mode descriptor classes lack docstrings explaining their purpose and usage.
@PEFTModeRegistry.register_mode class PEFTModeDescriptor(ModeDescriptor): + """Mode descriptor for PEFT/LoRA model conversion.""" + @property def name(self) -> str: + """Return the mode identifier string.""" return "peft"@PEFTModeRegistry.register_mode class ExportPEFTModeDescriptor(ModeDescriptor): + """Mode descriptor for exporting PEFT/LoRA models.""" @property def name(self) -> str: - """Returns the value (str representation) of the mode.""" + """Return the mode identifier string.""" return "export_peft"modelopt/torch/peft/config.py (3)
88-90: Incorrect error message for scale validationThe error message is missing "a" before "positive number".
if v <= 0: - raise ValueError("scale must be positive number") + raise ValueError("scale must be a positive number") return v
99-117: Pickling validation may be too restrictiveThe pickling requirement for initialization functions might be overly restrictive and prevent legitimate use cases like closures or partial functions. Consider documenting this requirement more prominently or providing alternative initialization strategies.
Consider adding a class method to create common initialization patterns that are guaranteed to be pickleable:
@classmethod def create_normal_init(cls, mean=0.0, std=0.02): """Create a pickleable normal initialization function.""" def normal_init(weight): return init.normal_(weight, mean=mean, std=std) return normal_init
162-170: Broad exception handling masks specific validation errorsCatching all exceptions and re-raising with a generic message loses valuable debugging information.
try: validated_cfg[key] = PEFTAttributeConfig(**value) - except Exception as e: - raise ValueError(f"Invalid adapter configuration for '{key}': {e}") + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid adapter configuration for '{key}': {e}") from emodelopt/torch/peft/lora/layer.py (2)
62-62: Typo in error messageThe error message has a double period.
- raise ValueError(f"adapter_name: {adapter_name} is already exist..") + raise ValueError(f"adapter_name: {adapter_name} already exists.")
193-230: Forward method has performance implicationsThe forward method iterates through all adapters on every forward pass, which could impact performance with many adapters. Consider caching active adapters.
Consider maintaining a list of active adapters to avoid checking the enable flag on every forward pass:
def _update_active_adapters(self): """Cache list of active adapters for efficient forward pass.""" self._active_adapters = [ (adapter["lora_a"], adapter["lora_b"], adapter["scale"]) for adapter in self._lora_adapters.values() if adapter["enable"] ] def forward(self, x: torch.Tensor, *args, **kwargs) -> Any: output = super().forward(x, *args, **kwargs) if isinstance(output, tuple): result = output[0] other_outputs = output[1:] else: result = output other_outputs = () # Use cached active adapters for lora_a, lora_b, scale in getattr(self, '_active_adapters', []): lora_a_output = lora_a(x) if isinstance(lora_a_output, tuple): lora_a_output = lora_a_output[0] lora_b_output = lora_b(lora_a_output) if isinstance(lora_b_output, tuple): lora_b_output = lora_b_output[0] result = result + scale * lora_b_output return (result, *other_outputs) if other_outputs else resultmodelopt/torch/peft/convert.py (3)
90-90: Unclear assertion messageThe assertion message uses a non-standard abbreviation "MO-PEFT" without explanation.
- assert is_peft_model(model), "It's not a MO-PEFT model" + assert is_peft_model(model), "Model has not been converted to PEFT/LoRA format"
101-102: Inconsistent error message formattingThe error message uses different formats for "pattern" vs "adapter pattern".
- pattern_type = "pattern" if allow_callable else "adapter pattern" - raise TypeError(f"Unsupported {pattern_type} type: {type(pattern)}") + pattern_type = "layer pattern" if allow_callable else "adapter pattern" + raise TypeError(f"Unsupported {pattern_type} type: {type(pattern).__name__}")
111-119: Direct access to private attribute _lora_adaptersThe function directly accesses the private
_lora_adaptersattribute, violating encapsulation. Consider adding a public method or property for adapter access.Add a public interface in LoRAModule:
# In LoRAModule class def get_adapter(self, adapter_name: str) -> dict[str, Any] | None: """Get adapter configuration by name.""" return self._lora_adapters.get(adapter_name) def set_adapter_state(self, adapter_name: str, enable: bool) -> None: """Set the enable state of an adapter.""" if adapter_name in self._lora_adapters: self._lora_adapters[adapter_name]["enable"] = enableThen update this function:
- for adapter_name, adapter_dict in module._lora_adapters.items(): + for adapter_name in module.adapter_names: if adapter_patterns is not None: if not matches_any_pattern( adapter_name, adapter_patterns, allow_callable=False ): continue - adapter_dict["enable"] = enable_state + module.set_adapter_state(adapter_name, enable_state)tests/gpu/torch/peft/test_megatron_peft.py (3)
93-121: Test helper lacks error handlingThe model provider function doesn't handle potential errors during model creation, which could make test failures harder to debug.
def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" - + try: if meta_device: with torch.device("meta"): gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=4, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, vocab_size=vocab_size, use_cpu_initialization=meta_device, ) else: gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=4, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, vocab_size=vocab_size, ).cuda() return gpt_model.eval() + except Exception as e: + pytest.fail(f"Failed to create GPT model: {e}")
134-139: Conditional assertion logic could be clearerThe conditional logic for checking output equality based on config type is not immediately clear. Consider adding a comment explaining why DEFAULT_LORA_CFG_TEST should produce identical outputs.
assert lora_output.shape == original_output.shape + # DEFAULT_LORA_CFG_TEST uses zero initialization for LoRA B, so initial output should match original if lora_config == DEFAULT_LORA_CFG_TEST: assert torch.allclose(lora_output, original_output, rtol=1e-5), ( f"{lora_output}, {original_output}" ) else: assert not torch.allclose(lora_output, original_output, rtol=1e-5)
161-174: Large number of commented test casesMost test parameterizations are commented out. This suggests either incomplete implementation or test instability.
If these tests are not ready, consider:
- Removing them entirely and tracking in an issue
- Using
pytest.skipwith a reason- Adding a TODO comment explaining why they're disabled
@pytest.mark.parametrize( "lora_config", [ DEFAULT_LORA_CFG_TEST, - # DEFAULT_LORA_CFG_RANDOM_INIT_TEST, - # SMALL_RANK_LORA_CFG, - # LARGE_SCALE_LORA_CFG, - # SELECTIVE_LAYER_LORA_CFG, + pytest.param(DEFAULT_LORA_CFG_RANDOM_INIT_TEST, marks=pytest.mark.skip(reason="Not yet stable")), + pytest.param(SMALL_RANK_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")), + pytest.param(LARGE_SCALE_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")), + pytest.param(SELECTIVE_LAYER_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")), ], )modelopt/torch/peft/conversion.py (3)
104-116: Return the possibly converted root module from replace_lora_module (and use it)Safer if a root replacement is ever registered; also removes confusion around local reassignment.
def replace_lora_module( model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry ): """Recursively replace the module with LoRA module.""" @@ - if type(model) in registry: - model = registry.convert(model) - _replace_lora_module(model, version=version, registry=registry) + if type(model) in registry: + model = registry.convert(model) + _replace_lora_module(model, version=version, registry=registry) + return modelAnd in convert_to_peft_model:
- replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config) + model = replace_lora_module( + model, version=ModeloptStateManager(model).state_version, config=config + )
41-42: Remove stale TODOThe replacement is already performed.
- # TODO: Replace to LoRA module
141-149: Docstring example uses an invalid signatureupdate_layer_lora takes a config object; adjust the example to avoid confusion.
- ... module.update_layer_lora("custom_adapter", rank=32) + ... from modelopt.torch.peft.config import PEFTAttributeConfig + ... module.update_layer_lora("custom_adapter", PEFTAttributeConfig(rank=32))modelopt/torch/peft/lora/tp_layer.py (3)
37-48: _get_init_methods is unused; either use or removePrefer using it to ensure defaults when attr_config initializers are None (e.g., after metadata restore).
Example use (apply similarly in both update_layer_lora methods):
- lora_a = ColumnParallelLinear( + lora_a_init, lora_b_init = self._get_init_methods(attr_config.lora_a_init, attr_config.lora_b_init) + lora_a = ColumnParallelLinear( self.input_size, attr_config.rank, config=self.config, bias=False, gather_output=True, - init_method=attr_config.lora_a_init, + init_method=lora_a_init, disable_grad_reduce=getattr(self.config, "sequence_parallel", False), ) @@ - lora_b = ColumnParallelLinear( + lora_b = ColumnParallelLinear( attr_config.rank, self.output_size, config=self.config, bias=False, gather_output=False, # Keep output distributed like base layer - init_method=attr_config.lora_a_init, + init_method=lora_b_init, )
80-87: Micro: combine device/dtype moves into a single .to callSlight cleanup; avoids two passes.
- if device is not None: - lora_a = lora_a.to(device) - lora_b = lora_b.to(device) - if dtype is not None: - lora_a = lora_a.to(dtype) - lora_b = lora_b.to(dtype) + if device is not None or dtype is not None: + lora_a = lora_a.to(device=device, dtype=dtype) + lora_b = lora_b.to(device=device, dtype=dtype)
26-28: Remove unused defaultsDEFAULT_LORA_RANK and DEFAULT_SCALE are unused.
-DEFAULT_LORA_RANK = 64 -DEFAULT_SCALE = 1.0
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
modelopt/torch/peft/__init__.py(1 hunks)modelopt/torch/peft/config.py(1 hunks)modelopt/torch/peft/conversion.py(1 hunks)modelopt/torch/peft/convert.py(1 hunks)modelopt/torch/peft/custom.py(1 hunks)modelopt/torch/peft/lora/__init__.py(1 hunks)modelopt/torch/peft/lora/layer.py(1 hunks)modelopt/torch/peft/lora/tp_layer.py(1 hunks)modelopt/torch/peft/mode.py(1 hunks)modelopt/torch/peft/plugins/__init__.py(1 hunks)modelopt/torch/peft/plugins/megatron.py(1 hunks)tests/gpu/torch/peft/test_megatron_peft.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
modelopt/torch/peft/plugins/megatron.py (1)
modelopt/torch/trace/symbols.py (1)
named_modules(444-447)
modelopt/torch/peft/lora/layer.py (3)
modelopt/torch/opt/dynamic.py (3)
DynamicModule(338-914)_DMRegistryCls(917-1124)config(1265-1278)modelopt/torch/peft/config.py (1)
PEFTAttributeConfig(40-117)modelopt/torch/peft/conversion.py (1)
peft_state(96-101)
modelopt/torch/peft/__init__.py (1)
modelopt/torch/peft/mode.py (2)
convert(31-32)convert(66-68)
tests/gpu/torch/peft/test_megatron_peft.py (6)
tests/_test_utils/import_helper.py (1)
skip_if_no_megatron(46-77)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
get_mcore_gpt_model(133-208)initialize_for_megatron(385-393)modelopt/torch/peft/config.py (2)
kaiming_init(30-32)zero_init(35-37)modelopt/torch/peft/lora/layer.py (1)
LoRAModule(20-230)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill(41-130)modelopt/torch/peft/convert.py (3)
update_model(39-63)disable_adapters(121-149)enable_adapters(152-180)
modelopt/torch/peft/convert.py (4)
modelopt/torch/opt/conversion.py (1)
apply_mode(342-429)modelopt/torch/peft/config.py (1)
PEFTConfig(124-170)modelopt/torch/peft/conversion.py (1)
add_adapter(163-192)modelopt/torch/peft/lora/layer.py (1)
LoRAModule(20-230)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
ModeloptBaseConfig(59-147)ModeloptField(50-53)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (7)
ApplyModeError(314-315)ModelLikeModule(318-330)ModeloptStateManager(63-311)init_modellike(326-330)state_version(135-137)is_converted(102-127)_last_metadata(220-222)modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)modelopt/torch/peft/config.py (1)
PEFTConfig(124-170)modelopt/torch/peft/lora/layer.py (4)
LoRAModule(20-230)set_from_peft_state(152-165)get_peft_state(91-131)update_layer_lora(72-89)modelopt/torch/peft/custom.py (1)
register_custom_model_plugins_on_the_fly(22-29)
modelopt/torch/peft/mode.py (4)
modelopt/torch/opt/config.py (1)
ModeloptBaseConfig(59-147)modelopt/torch/opt/mode.py (2)
ModeDescriptor(56-259)_ModeRegistryCls(267-344)modelopt/torch/peft/config.py (2)
PEFTConfig(124-170)ExportPEFTConfig(173-174)modelopt/torch/peft/conversion.py (5)
convert_to_peft_model(36-48)restore_peft_model(51-55)update_peft_metadata(91-93)export_peft_model(118-119)restore_export_peft_model(122-123)
modelopt/torch/peft/lora/tp_layer.py (4)
modelopt/torch/peft/config.py (1)
PEFTAttributeConfig(40-117)modelopt/torch/peft/lora/layer.py (4)
LoRAModule(20-230)_register_adapter(39-69)update_layer_lora(72-89)_setup(30-32)modelopt/torch/quantization/plugins/megatron.py (2)
_MegatronColumnParallelLinear(296-318)_MegatronRowParallelLinear(322-354)modelopt/torch/quantization/conversion.py (1)
register(325-366)
🪛 GitHub Actions: Code Quality
modelopt/torch/peft/mode.py
[error] 1-1: D100 Missing docstring in public module
[error] 22-22: D101 Missing docstring in public class
[error] 24-24: D102 Missing docstring in public method
[error] 28-28: D102 Missing docstring in public method
[error] 32-32: D102 Missing docstring in public method
[error] 36-36: D102 Missing docstring in public method
[error] 40-40: D102 Missing docstring in public method
[error] 44-44: D102 Missing docstring in public method
[error] 53-53: D101 Missing docstring in public class
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: wait-checks / wait
- GitHub Check: linux
🔇 Additional comments (2)
modelopt/torch/peft/plugins/__init__.py (1)
20-21: LGTM: guarded optional import.Import guarding with
suppress(ImportError)is appropriate; it triggers registration only when Megatron is present.modelopt/torch/peft/__init__.py (1)
18-25: Import-order OK — no action required. mode.py doesfrom .conversion import convert_to_peft_model, ..., export_peft_modeland its properties return those functions (modelopt/torch/peft/mode.py, ~lines 11–13, 31–33, 66–69).
|
Transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/peft/conversion.py (1)
28-34: Export all public API functions in__all__.Functions like
convert_to_peft_model,restore_peft_model, andadd_adapterare part of the public API (used by other modules in the package) but are not included in__all__. Add them to the export list or prefix truly internal helpers with an underscore.
🧹 Nitpick comments (1)
modelopt/torch/peft/conversion.py (1)
27-27: Clarify or remove TODO comment.The TODO mentions adding test cases, but per the PR summary and AI-generated summary, new tests already exist in
tests/gpu/torch/peft/test_megatron_peft.py. Either update this comment to specify which particular functions still need coverage, or remove it if all exported functions are already tested.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/peft/conversion.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/peft/conversion.py
🧬 Code graph analysis (1)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
ModelLikeModule(318-330)ModeloptStateManager(63-311)init_modellike(326-330)state_version(135-137)modelopt/torch/utils/network.py (1)
matches_pattern(92-156)modelopt/torch/peft/config.py (2)
PEFTAttributeConfig(32-110)PEFTConfig(117-191)modelopt/torch/peft/lora/layer.py (2)
LoRAModule(16-121)update_layer_lora(68-82)modelopt/torch/peft/custom.py (1)
register_custom_model_plugins_on_the_fly(22-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
modelopt/torch/peft/config.py (2)
56-71: Clarify json_schema_extra for callable fields.The
json_schema_extradeclares"type": "string"forlora_a_initandlora_b_init, but these fields expect callable objects (typeInitFn), not strings. This mismatch can mislead schema consumers, API clients, or documentation generators.Consider either:
- Removing
json_schema_extraand letting Pydantic infer the schema- Updating to reflect the actual type (e.g.,
"type": "callable"or omitting type and relying on examples only)
179-193: Consider explicit type validation for non-dict values.The validator converts
dictvalues toPEFTAttributeConfigbut silently accepts non-dict values without checking if they're validPEFTAttributeConfiginstances. While Pydantic's field-level validation should catch type mismatches before this validator runs, explicitly checking the type would make the validator more robust and self-documenting.Apply this diff for explicitness:
@field_validator("adapter_cfg") @classmethod def validate_adapter_cfg(cls, v): """Validate and convert adapter configurations.""" validated_cfg = {} for key, value in v.items(): if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig): # Convert dict to PEFTAttributeConfig to trigger validation try: validated_cfg[key] = PEFTAttributeConfig(**value) except Exception as e: raise ValueError(f"Invalid adapter configuration for '{key}': {e}") + elif isinstance(value, PEFTAttributeConfig): + validated_cfg[key] = value else: - validated_cfg[key] = value + raise ValueError( + f"Invalid adapter configuration for '{key}': " + f"expected dict or PEFTAttributeConfig, got {type(value).__name__}" + ) return validated_cfg
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/peft/config.py(1 hunks)modelopt/torch/peft/convert.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/peft/convert.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
ModeloptBaseConfig(59-147)ModeloptField(50-53)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (4)
modelopt/torch/peft/config.py (4)
1-30: LGTM!The imports, exports, and type alias are well-structured. The
InitFntype alias clearly defines initialization function signatures.
73-112: LGTM!The validators correctly enforce:
- Initialization functions must be in-place methods from
torch.nn.init(ending with_)- Rank must be positive
- Scale must be positive
Error messages are clear and actionable.
136-169: LGTM!The configuration fields are well-structured with appropriate defaults:
freeze_lora_weights=Falsecorrectly enables LoRA trainingadapter_cfg={"*": {"rank": 64}}uses the wildcard pattern to apply to all layers by default- Field descriptions are clear and helpful
196-197: LGTM!The
ExportPEFTConfigplaceholder class is appropriate for future extensibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
modelopt/torch/peft/config.py (1)
147-148: RemoveCallablefrom the key type union.The type alias allows
Callableas a key type, butvalidate_adapter_cfg(lines 213-225) doesn't validate callable keys, and the conversion logic uses string pattern matching withfnmatch. This misleads users into thinking callable keys are supported when they aren't.Apply this diff to align the type with actual supported behavior:
-PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict] +PEFTAdapterCfgType = dict[str, PEFTAttributeConfig | dict]Based on past review comments that flagged this issue as unresolved.
🧹 Nitpick comments (1)
modelopt/torch/peft/config.py (1)
168-173: Consider validating adapter_name format.The
adapter_namefield has no validation and could be an empty string or contain problematic characters. While this may be acceptable if downstream code handles it, consider adding a validator to ensure it's non-empty and contains only valid characters for identifiers.If validation is desired, add:
@field_validator("adapter_name") @classmethod def validate_adapter_name(cls, v): """Validate adapter name is non-empty.""" if not v or not v.strip(): raise ValueError("adapter_name must be a non-empty string") return v
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/peft/config.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
ModeloptBaseConfig(59-147)ModeloptField(50-53)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/torch/peft/config.py (4)
1-29: LGTM!The module structure, imports, and public exports are well-organized and appropriate for a configuration module.
30-52: LGTM!The
InitFieldtype and helper function provide a clean way to handle torch initializers in configuration with proper serialization and schema documentation.
55-144: LGTM!The
PEFTAttributeConfigclass is well-designed with comprehensive validation for all fields. The initializer parsing and validation logic properly enforces that only in-place torch.nn.init functions are used.
228-229: LGTM!The empty
ExportPEFTConfigclass serves as an acceptable placeholder for future export-related configuration extensions.
0e5b9d4 to
ca415dc
Compare
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
ca415dc to
fefbbe4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (6)
modelopt/torch/peft/convert.py (1)
66-67: Capture return value from apply_mode.
apply_modemay return a different module instance than the input. Not capturing the return value is unsafe and could lead to using a stale reference.Apply this diff:
# Check if model is already in PEFT mode by looking for LoRA modules if not is_peft_model(model): - apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) + model = apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry) else:This issue was flagged in a previous review but remains unaddressed.
tests/gpu/torch/peft/test_megatron_peft.py (2)
20-20: Add Apex/TE requirement to test skip guard.The tests use
transformer_impl="local"which requires Apex (line 160, 172). Without passingapex_or_te_required=Truetoskip_if_no_megatron(), environments lacking Apex will encounter assertion failures instead of gracefully skipping the tests.Apply this diff:
-skip_if_no_megatron() +skip_if_no_megatron(apex_or_te_required=True)This issue was flagged in a previous review but remains unaddressed.
498-507: Avoid mutating shared config fixture.Lines 500-501 mutate the
lora_configfixture directly, which can leak state across parametrized test runs. The_test_adapter_gradient_flow_freeze_lora_modelhelper correctly usescopy.deepcopy(line 448), but this helper does not.Apply this diff:
def _test_adapter_gradient_flow(lora_config, tmp_path, rank, size): hidden_size = 512 - lora_config["freeze_lora_weights"] = False - lora_config["freeze_base_model"] = False + local_cfg = copy.deepcopy(lora_config) + local_cfg["freeze_lora_weights"] = False + local_cfg["freeze_base_model"] = False initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1) model = _gpt_model_provider(tp_size=size, hidden_size=hidden_size) prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() - mtpeft.update_model(model, lora_config) + mtpeft.update_model(model, local_cfg)This issue was flagged in a previous review and marked as addressed, but the current code still has the problem.
modelopt/torch/peft/lora/plugins/megatron.py (1)
54-57: Fix typo in configuration attribute name.The attribute name
hetereogenous_dist_checkpointis misspelled and should beheterogeneous_dist_checkpoint. This typo was flagged in a previous review but is now within the scope of this PR.Apply this diff:
if hasattr(module, "config") and hasattr( - module.config, "hetereogenous_dist_checkpoint" + module.config, "heterogeneous_dist_checkpoint" ): - module.config.hetereogenous_dist_checkpoint = True + module.config.heterogeneous_dist_checkpoint = Truemodelopt/torch/peft/conversion.py (2)
28-34: Export core public API functions in allThe
__all__list is missing several public API functions that are used externally (per the AI summary and PR objectives):
convert_to_peft_model(core conversion entry point)restore_peft_model(core restore entry point)add_adapter(adapter management, used in mode integration)Apply this diff:
__all__ = [ + "add_adapter", + "convert_to_peft_model", "freeze_base_weights", "freeze_lora_weights", "replace_lora_module", + "restore_peft_model", "unfreeze_base_weights", "unfreeze_lora_weights", ]
133-137: Critical: enable check missing default TrueLine 133 checks
merged_setting_dict.get("enable")which returnsNone(falsy) when the "enable" key is absent. Since line 122 usesmodel_dump(exclude_unset=True), if no matching pattern explicitly setsenable, the key won't appear inmerged_setting_dict. This causes adapters to be skipped even thoughPEFTAttributeConfigdefaultsenable=True.Apply this diff:
- if merged_setting_dict is not None and merged_setting_dict.get("enable"): + if merged_setting_dict is not None and merged_setting_dict.get("enable", True): module.update_layer_lora( adapter_name, PEFTAttributeConfig(**merged_setting_dict),
🧹 Nitpick comments (1)
modelopt/torch/peft/config.py (1)
147-148: Add validation for callable adapter keys or document their usage.The type alias permits
Callablekeys, butvalidate_adapter_cfg(lines 211-225) doesn't validate them, and it's unclear what signature callables should have. Based onmatches_patternusage inconversion.py, callables should beCallable[[str], bool].Consider one of these approaches:
Option 1: Add validation for callable keys
@field_validator("adapter_cfg") @classmethod def validate_adapter_cfg(cls, v): """Validate and convert adapter configurations.""" validated_cfg = {} for key, value in v.items(): + # Validate callable keys have correct signature + if callable(key): + # Check it's a unary function + import inspect + sig = inspect.signature(key) + if len(sig.parameters) != 1: + raise ValueError( + f"Callable keys must accept a single string argument, got {sig}" + ) if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig): # Convert dict to PEFTAttributeConfig to trigger validation try: validated_cfg[key] = PEFTAttributeConfig(**value) except Exception as e: raise ValueError(f"Invalid adapter configuration for '{key}': {e}") else: validated_cfg[key] = value return validated_cfgOption 2: Remove callable support if not needed
# Type alias for adapter configuration -PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict] +PEFTAdapterCfgType = dict[str, PEFTAttributeConfig | dict]If callable keys are intentionally supported, add documentation explaining the expected signature:
Callable[[str], bool]where the argument is the module name.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (16)
.github/CODEOWNERS(1 hunks).vscode/settings.json(1 hunks)CHANGELOG.rst(1 hunks)modelopt/torch/__init__.py(1 hunks)modelopt/torch/peft/__init__.py(1 hunks)modelopt/torch/peft/config.py(1 hunks)modelopt/torch/peft/conversion.py(1 hunks)modelopt/torch/peft/convert.py(1 hunks)modelopt/torch/peft/custom.py(1 hunks)modelopt/torch/peft/lora/__init__.py(1 hunks)modelopt/torch/peft/lora/layer.py(1 hunks)modelopt/torch/peft/lora/plugins/__init__.py(1 hunks)modelopt/torch/peft/lora/plugins/megatron.py(1 hunks)modelopt/torch/peft/mode.py(1 hunks)modelopt/torch/utils/network.py(3 hunks)tests/gpu/torch/peft/test_megatron_peft.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
- modelopt/torch/peft/lora/init.py
- CHANGELOG.rst
- modelopt/torch/peft/mode.py
- modelopt/torch/peft/custom.py
- modelopt/torch/peft/init.py
- .vscode/settings.json
- modelopt/torch/utils/network.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.
Applied to files:
modelopt/torch/peft/conversion.py
🧬 Code graph analysis (6)
modelopt/torch/peft/lora/layer.py (3)
modelopt/torch/opt/dynamic.py (3)
DynamicModule(338-914)_DMRegistryCls(917-1124)config(1265-1278)modelopt/torch/peft/config.py (1)
PEFTAttributeConfig(55-144)modelopt/torch/peft/lora/plugins/megatron.py (4)
_setup(262-263)_setup(273-274)update_layer_lora(121-151)update_layer_lora(186-217)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
ModeloptBaseConfig(59-147)ModeloptField(50-53)
tests/gpu/torch/peft/test_megatron_peft.py (8)
tests/_test_utils/import_helper.py (1)
skip_if_no_megatron(46-77)tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
get_mcore_gpt_model(134-209)initialize_for_megatron(392-400)modelopt/torch/opt/plugins/mcore_dist_checkpointing.py (2)
restore_sharded_modelopt_state(207-250)save_sharded_modelopt_state(127-173)modelopt/torch/peft/lora/layer.py (1)
LoRAModule(16-121)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill(41-130)modelopt/torch/peft/lora/plugins/megatron.py (2)
sharded_state_dict(153-175)sharded_state_dict(219-241)modelopt/torch/peft/convert.py (3)
update_model(45-72)disable_adapters(116-144)enable_adapters(147-175)modelopt/torch/quantization/model_quant.py (1)
disable_quantizer(453-455)
modelopt/torch/peft/lora/plugins/megatron.py (3)
modelopt/torch/quantization/plugins/megatron.py (2)
_MegatronColumnParallelLinear(296-318)_MegatronRowParallelLinear(322-354)modelopt/torch/peft/config.py (1)
PEFTAttributeConfig(55-144)modelopt/torch/peft/lora/layer.py (4)
LoRAModule(16-121)_register_adapter(35-65)update_layer_lora(68-82)_setup(26-28)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
ModelLikeModule(318-330)ModeloptStateManager(63-311)init_modellike(326-330)state_version(135-137)modelopt/torch/utils/network.py (1)
matches_pattern(92-156)modelopt/torch/peft/config.py (2)
PEFTAttributeConfig(55-144)PEFTConfig(151-225)modelopt/torch/peft/lora/layer.py (2)
LoRAModule(16-121)update_layer_lora(68-82)modelopt/torch/peft/custom.py (1)
register_custom_model_plugins_on_the_fly(22-29)
modelopt/torch/peft/convert.py (5)
modelopt/torch/opt/conversion.py (1)
apply_mode(342-429)modelopt/torch/peft/config.py (1)
PEFTConfig(151-225)modelopt/torch/peft/conversion.py (1)
add_adapter(101-139)modelopt/torch/utils/network.py (1)
matches_pattern(92-156)modelopt/torch/peft/lora/layer.py (1)
LoRAModule(16-121)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (14)
modelopt/torch/peft/lora/layer.py (4)
26-28: LGTM!The initialization of the
_lora_adaptersdictionary is straightforward and correct.
35-65: LGTM!The adapter registration logic correctly:
- Adds modules with
add_module- Prevents duplicate adapters
- Stores all necessary metadata (lora_a, lora_b, rank, scale, enable)
67-82: LGTM!The abstract method declaration is correct, and the docstring accurately describes the
attr_configparameter as requested in past reviews.
84-121: LGTM!The forward pass correctly:
- Handles both single tensor and tuple outputs from the base layer
- Applies only enabled adapters
- Unpacks intermediate tuple outputs from parallel LoRA layers
- Scales and accumulates LoRA contributions
- Preserves additional outputs from the base layer
modelopt/torch/peft/lora/plugins/megatron.py (3)
71-110: LGTM!The device/dtype alignment logic correctly:
- Iterates through parameters and buffers to find the target device/dtype
- Moves both LoRA modules to match the parent module
- Delegates to the parent class's
_register_adaptermethod
121-151: LGTM!The LoRA adapter creation for ColumnParallelLinear is correct:
lora_ais a plainnn.Linear(not sharded)lora_bis aColumnParallelLinear(sharded at dim 0)- Proper initialization with
lora_a_initandlora_b_initThis design aligns with the discussion in past reviews and avoids quantization issues with very small ranks.
186-217: LGTM!The LoRA adapter creation for RowParallelLinear correctly mirrors the ColumnParallel design:
lora_ais aRowParallelLinear(sharded at dim 1)lora_bis a plainnn.Linear(not sharded)- Proper initialization
modelopt/torch/peft/convert.py (4)
75-87: LGTM!The check correctly identifies PEFT models by looking for any
LoRAModuleinstances.
90-113: LGTM!The adapter state management correctly:
- Validates the model is a PEFT model
- Filters by layer patterns using
matches_pattern- Filters by adapter patterns (string-only)
- Sets the enable flag appropriately
116-175: LGTM!Both
disable_adaptersandenable_adaptersare well-documented thin wrappers that delegate to_set_adapter_statewith appropriate parameters.
178-183: LGTM!The Megatron model detection correctly checks for the presence of Megatron-specific layer types.
tests/gpu/torch/peft/test_megatron_peft.py (1)
135-176: LGTM!The distributed checkpoint save/load helpers and model provider function are well-structured and correctly handle Megatron-specific initialization.
modelopt/torch/peft/config.py (2)
55-144: LGTM!
PEFTAttributeConfigis well-designed with comprehensive validation:
- Init method parsing from strings
- Validation that callables are from
torch.nn.initand in-place- Positive rank and scale checks
The validators provide clear error messages to guide users.
151-225: LGTM!
PEFTConfighas sensible defaults:
adapter_name="default"adapter_cfg={"*": {"rank": 64}}(wildcard matches all layers)freeze_base_model=True, freeze_lora_weights=False(standard LoRA fine-tuning)The validators ensure:
- Only "lora" adapter type is accepted
- Adapter configs are converted to
PEFTAttributeConfigwith clear error messages
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: jingyu-ml <108295447+jingyu-ml@users.noreply.github.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
|
Thanks everyone for the review. If there are any additional concerns, please feel free to leave comments in this PR or in the design doc. I’ll address them in the next PR. |
What does this PR do?
Type of change: new feature
Overview:
The PEFT module in TensorRT Model Optimizer provides an implementation of LoRA for parameter-efficient fine-tuning of large models. This module allows you to add trainable low-rank decomposition matrices to existing model layers, significantly reducing the number of trainable parameters while maintaining model performance. It's particularly optimized for Megatron-based models and supports multiple adapters that can be dynamically enabled, disabled, or switched during inference. And with MO, you will be able to do the QLoRA training easily in the future.
Usage
Advanced Usage - Quantization
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit