Skip to content

Conversation

@jingyu-ml
Copy link
Contributor

@jingyu-ml jingyu-ml commented Sep 18, 2025

What does this PR do?

Type of change: new feature

Overview:

The PEFT module in TensorRT Model Optimizer provides an implementation of LoRA for parameter-efficient fine-tuning of large models. This module allows you to add trainable low-rank decomposition matrices to existing model layers, significantly reducing the number of trainable parameters while maintaining model performance. It's particularly optimized for Megatron-based models and supports multiple adapters that can be dynamically enabled, disabled, or switched during inference. And with MO, you will be able to do the QLoRA training easily in the future.

Usage

import modelopt.torch.peft as mtpeft
import modelopt.torch.quantization as mtq
from modelopt.torch.peft.config import kaiming_init, zero_init

# Define LoRA configuration
lora_config = {
    "adapter_type": "lora",
    "adapter_name": "my_adapter",
    "adapter_cfg": {
        "*": {  # Apply to all layers
            "rank": 32,  # LoRA rank
            "scale": 1.0,  # Scaling factor
            "lora_a_init": kaiming_init,  # A matrix initialization
            "lora_b_init": zero_init,  # B matrix initialization
            "enable": True
        }
    }
}

# Apply LoRA to your model
mtpeft.update_model(model, lora_config)

# Use the model with LoRA adapter
output = model(input_data)

# Disable the adapter (use original model)
mtpeft.disable_adapters(model)
output_original = model(input_data)

# Re-enable the adapter
mtpeft.enable_adapters(model)
output_lora = model(input_data)

Advanced Usage - Quantization

# Add first adapter for task A
task_a_config = {
    "adapter_type": "lora",
    "adapter_name": "task_a",
    "adapter_cfg": {
        "*": {"rank": 16, "scale": 1.0, "enable": True}
    }
}
mtpeft.update_model(model, task_a_config)
mtq.quantize(model, FP8_CFG, forward_call)

# Switch between adapters
mtpeft.disable_adapters(model, adapters_to_disable=["task_a"])
mtpeft.enable_adapters(model, adapters_to_enable=["task_b"])
output_task_b = model(input_data)

mtpeft.disable_adapters(model, adapters_to_disable=["task_b"])
mtpeft.enable_adapters(model, adapters_to_enable=["task_a"])
output_task_a = model(input_data)

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?:No
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?:No

Additional Information

Summary by CodeRabbit

  • New Features
    • Added PEFT/LoRA support for Megatron-Core, including model conversion/update, multi-adapter management, and PEFT model detection.
    • Enable/disable and freeze/unfreeze LoRA adapters with per-layer/name pattern filters.
    • Introduced a pattern-matching utility for flexible layer/adapter selection.
    • Added plugin system with Megatron integration; exposed new PEFT submodule.
  • Documentation
    • Updated changelog highlighting LoRA mode and usage entrypoint.
  • Tests
    • Added GPU tests covering Megatron PEFT, quantization interactions, gradients, and save/restore.
    • Added unit tests for pattern matching utility.
  • Chores
    • Updated CODEOWNERS for the new PEFT module.

@jingyu-ml jingyu-ml self-assigned this Sep 18, 2025
@jingyu-ml jingyu-ml requested review from a team as code owners September 18, 2025 21:29
@jingyu-ml jingyu-ml marked this pull request as draft September 18, 2025 21:29
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 18, 2025

Walkthrough

Introduces a new PEFT (LoRA) subsystem under modelopt.torch, including configs, conversion/adapter management, mode registry integration, LoRA base layer and Megatron-Core plugins, regex utilities, package initializers, tests (unit and GPU/Megatron), CODEOWNERS update, and a CHANGELOG entry.

Changes

Cohort / File(s) Summary
PEFT core API and registry
modelopt/torch/peft/config.py, modelopt/torch/peft/conversion.py, modelopt/torch/peft/convert.py, modelopt/torch/peft/mode.py, modelopt/torch/peft/custom.py
Adds PEFT configs (LoRA), conversion and restore flows, adapter add/freeze controls, PEFT mode descriptors/registry, and on-the-fly custom plugin registration. Exports public APIs for update/convert/is_peft and gradient controls.
LoRA base and Megatron plugins
modelopt/torch/peft/lora/__init__.py, modelopt/torch/peft/lora/layer.py, modelopt/torch/peft/lora/plugins/__init__.py, modelopt/torch/peft/lora/plugins/megatron.py
Introduces LoRAModule base and registry, package initializers, and Megatron-Core plugin implementations (Column/RowParallel variants, optional quantized variants) plus a checkpointing hook. Conditional plugin import.
Package initializers
modelopt/torch/peft/__init__.py, modelopt/torch/__init__.py
Adds peft package facade re-exporting config/conversion/convert/mode and includes peft in top-level torch package exports.
Utils: regex
modelopt/torch/utils/regex.py, modelopt/torch/utils/__init__.py
Adds matches_pattern utility for flexible name matching and re-exports it from utils package.
Tests
tests/unit/torch/utils/test_regex.py, tests/gpu/torch/peft/test_megatron_peft.py
Adds unit tests for regex matching and extensive GPU tests for Megatron PEFT/LoRA, adapters, quantization interactions, and distributed checkpointing.
Repo meta
.github/CODEOWNERS, CHANGELOG.rst
Maps peft path to a CODEOWNERS group and documents new LoRA mode support in CHANGELOG.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Model as nn.Module
  participant Convert as peft.convert.update_model
  participant ConvUtil as peft.conversion.replace_lora_module
  participant Adapter as peft.conversion.add_adapter
  participant ModeReg as PEFTModeRegistry

  User->>Convert: update_model(Model, PEFTConfig|dict)
  Convert->>ModeReg: (indirect) ensure PEFT mode/config types
  Convert->>ConvUtil: replace_lora_module(Model, version/config/registry)
  Note right of ConvUtil: Replace eligible layers<br/>with LoRAModule instances
  Convert->>Adapter: add_adapter(Model, PEFTConfig)
  Note right of Adapter: Register/merge adapter(s)<br/>per layer patterns
  Convert-->>User: Model (PEFT-enabled)

  alt Toggle adapters
    User->>peft.convert.enable_adapters: (Model, layers/adapters)
    User->>peft.convert.disable_adapters: (Model, layers/adapters)
  end
Loading
sequenceDiagram
  autonumber
  actor Trainer
  participant Model as nn.Module (PEFT)
  participant Freeze as peft.conversion.freeze_lora_weights
  participant Unfreeze as peft.conversion.unfreeze_lora_weights

  Trainer->>Freeze: layer_patterns / adapter_patterns
  Note right of Freeze: Set requires_grad=False<br/>for matching LoRA params
  Trainer->>Unfreeze: layer_patterns / adapter_patterns
  Note right of Unfreeze: Set requires_grad=True<br/>for matching LoRA params
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60–90 minutes

Poem

In the gardens of PEFT I hop with glee,
LoRA vines twine round each TPU tree.
Megatron winds hum quantized tunes,
Adapters bloom beneath the moons.
Regex stars guide where names align—
I twitch my nose: convert, enable, fine!
Carrots committed. All tests shine. 🥕✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 76.32% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly states that the pull request introduces PEFT mode support for Megatron-LM within ModelOPT, directly reflecting the main feature implemented in the changeset. It is concise, specific, and aligns with the PR objectives without unnecessary detail.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch jingyux/megatron-lora

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5821576 and 1a86e30.

📒 Files selected for processing (5)
  • modelopt/torch/peft/conversion.py (1 hunks)
  • modelopt/torch/peft/convert.py (1 hunks)
  • modelopt/torch/utils/__init__.py (1 hunks)
  • modelopt/torch/utils/regex.py (1 hunks)
  • tests/unit/torch/utils/test_regex.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/peft/conversion.py
🧬 Code graph analysis (3)
modelopt/torch/peft/convert.py (5)
modelopt/torch/opt/conversion.py (1)
  • apply_mode (342-429)
modelopt/torch/peft/config.py (1)
  • PEFTConfig (151-225)
modelopt/torch/peft/conversion.py (1)
  • add_adapter (98-136)
modelopt/torch/utils/regex.py (1)
  • matches_pattern (11-75)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (16-121)
tests/unit/torch/utils/test_regex.py (1)
modelopt/torch/utils/regex.py (1)
  • matches_pattern (11-75)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
  • ModelLikeModule (318-330)
  • ModeloptStateManager (63-311)
  • init_modellike (326-330)
  • state_version (135-137)
modelopt/torch/utils/regex.py (1)
  • matches_pattern (11-75)
modelopt/torch/peft/config.py (2)
  • PEFTAttributeConfig (55-144)
  • PEFTConfig (151-225)
modelopt/torch/peft/lora/layer.py (2)
  • LoRAModule (16-121)
  • update_layer_lora (68-82)
modelopt/torch/peft/custom.py (1)
  • register_custom_model_plugins_on_the_fly (22-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 18, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@codecov
Copy link

codecov bot commented Sep 18, 2025

Codecov Report

❌ Patch coverage is 50.45593% with 163 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.36%. Comparing base (3a76d28) to head (1a86e30).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/peft/conversion.py 24.13% 66 Missing ⚠️
modelopt/torch/peft/config.py 59.72% 29 Missing ⚠️
modelopt/torch/peft/convert.py 39.13% 28 Missing ⚠️
modelopt/torch/peft/lora/layer.py 33.33% 28 Missing ⚠️
modelopt/torch/peft/mode.py 77.77% 10 Missing ⚠️
modelopt/torch/peft/custom.py 50.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #342      +/-   ##
==========================================
- Coverage   73.79%   73.36%   -0.43%     
==========================================
  Files         171      180       +9     
  Lines       17591    17919     +328     
==========================================
+ Hits        12981    13147     +166     
- Misses       4610     4772     +162     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jingyu-ml jingyu-ml changed the title Jingyux/megatron lora [1/N] ModelOPT PEFT mode support for the megatron-lm Sep 20, 2025
@jingyu-ml jingyu-ml marked this pull request as ready for review September 20, 2025 07:39
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

♻️ Duplicate comments (1)
modelopt/torch/peft/convert.py (1)

205-219: Another instance of private attribute access

Similar to the previous comment, this function also directly accesses _lora_adapters.

🧹 Nitpick comments (24)
modelopt/torch/peft/lora/__init__.py (1)

1-3: Prefer lazy submodule imports to reduce import-time cost and avoid cycles.

Importing layer and tp_layer at package import can be heavy and risks circulars. Expose them lazily via __getattr__.

-"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning."""
-
-from . import layer, tp_layer
+"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning."""
+
+from importlib import import_module as _import_module
+
+__all__ = ["layer", "tp_layer"]
+
+def __getattr__(name):
+    if name in __all__:
+        mod = _import_module(f".{name}", __name__)
+        globals()[name] = mod
+        return mod
+    raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
modelopt/torch/peft/__init__.py (1)

16-16: Docstring mismatch with package purpose.

This is PEFT/LoRA, not distillation. Update the docstring.

-"""Distillation API subpackage for torch."""
+"""PEFT/LoRA API for torch."""
modelopt/torch/peft/custom.py (2)

22-29: Run plugins deterministically and guard against concurrent mutation.

Take a snapshot, sort by module/qualname, then invoke.

 def register_custom_model_plugins_on_the_fly(model):
     """Registers custom PEFT/LoRA plugins on the fly.
 
     This is called before LoRAModule replacement to allow plugins
     to configure the model (e.g., for distributed checkpointing).
     """
-    for callback in CUSTOM_MODEL_PLUGINS:
-        callback(model)
+    # Snapshot to avoid RuntimeError if mutated during iteration
+    callbacks = sorted(
+        tuple(CUSTOM_MODEL_PLUGINS),
+        key=lambda f: (getattr(f, "__module__", ""), getattr(f, "__qualname__", getattr(f, "__name__", ""))),
+    )
+    for callback in callbacks:
+        callback(model)

16-23: Add light typing for clarity.

Optional: add minimal typing to document the callback contract.

-"""Custom PEFT/LoRA plugins registry."""
+"""Custom PEFT/LoRA plugins registry."""
+from typing import Callable, Iterable  # lightweight, no runtime deps
modelopt/torch/peft/plugins/megatron.py (1)

31-34: Optional: export public symbols for introspection.

Expose MEGATRON_AVAILABLE and the hook via __all__ for discoverability.

-__all__ = []
+__all__ = ["MEGATRON_AVAILABLE", "megatron_replace_lora_module_hook"]
modelopt/torch/peft/mode.py (2)

1-13: Missing module docstring

Add a module-level docstring to document the purpose and functionality of this mode registry module.

+"""PEFT mode definitions and registry for parameter-efficient fine-tuning."""
+
 from modelopt.torch.opt.config import ModeloptBaseConfig

17-46: Missing class docstrings

Both mode descriptor classes lack docstrings explaining their purpose and usage.

 @PEFTModeRegistry.register_mode
 class PEFTModeDescriptor(ModeDescriptor):
+    """Mode descriptor for PEFT/LoRA model conversion."""
+
     @property
     def name(self) -> str:
+        """Return the mode identifier string."""
         return "peft"
 @PEFTModeRegistry.register_mode
 class ExportPEFTModeDescriptor(ModeDescriptor):
+    """Mode descriptor for exporting PEFT/LoRA models."""
 
     @property
     def name(self) -> str:
-        """Returns the value (str representation) of the mode."""
+        """Return the mode identifier string."""
         return "export_peft"
modelopt/torch/peft/config.py (3)

88-90: Incorrect error message for scale validation

The error message is missing "a" before "positive number".

         if v <= 0:
-            raise ValueError("scale must be positive number")
+            raise ValueError("scale must be a positive number")
         return v

99-117: Pickling validation may be too restrictive

The pickling requirement for initialization functions might be overly restrictive and prevent legitimate use cases like closures or partial functions. Consider documenting this requirement more prominently or providing alternative initialization strategies.

Consider adding a class method to create common initialization patterns that are guaranteed to be pickleable:

@classmethod
def create_normal_init(cls, mean=0.0, std=0.02):
    """Create a pickleable normal initialization function."""
    def normal_init(weight):
        return init.normal_(weight, mean=mean, std=std)
    return normal_init

162-170: Broad exception handling masks specific validation errors

Catching all exceptions and re-raising with a generic message loses valuable debugging information.

                 try:
                     validated_cfg[key] = PEFTAttributeConfig(**value)
-                except Exception as e:
-                    raise ValueError(f"Invalid adapter configuration for '{key}': {e}")
+                except (TypeError, ValueError) as e:
+                    raise ValueError(f"Invalid adapter configuration for '{key}': {e}") from e
modelopt/torch/peft/lora/layer.py (2)

62-62: Typo in error message

The error message has a double period.

-            raise ValueError(f"adapter_name: {adapter_name} is already exist..")
+            raise ValueError(f"adapter_name: {adapter_name} already exists.")

193-230: Forward method has performance implications

The forward method iterates through all adapters on every forward pass, which could impact performance with many adapters. Consider caching active adapters.

Consider maintaining a list of active adapters to avoid checking the enable flag on every forward pass:

def _update_active_adapters(self):
    """Cache list of active adapters for efficient forward pass."""
    self._active_adapters = [
        (adapter["lora_a"], adapter["lora_b"], adapter["scale"])
        for adapter in self._lora_adapters.values()
        if adapter["enable"]
    ]

def forward(self, x: torch.Tensor, *args, **kwargs) -> Any:
    output = super().forward(x, *args, **kwargs)
    
    if isinstance(output, tuple):
        result = output[0]
        other_outputs = output[1:]
    else:
        result = output
        other_outputs = ()
    
    # Use cached active adapters
    for lora_a, lora_b, scale in getattr(self, '_active_adapters', []):
        lora_a_output = lora_a(x)
        if isinstance(lora_a_output, tuple):
            lora_a_output = lora_a_output[0]
        lora_b_output = lora_b(lora_a_output)
        if isinstance(lora_b_output, tuple):
            lora_b_output = lora_b_output[0]
        result = result + scale * lora_b_output
    
    return (result, *other_outputs) if other_outputs else result
modelopt/torch/peft/convert.py (3)

90-90: Unclear assertion message

The assertion message uses a non-standard abbreviation "MO-PEFT" without explanation.

-    assert is_peft_model(model), "It's not a MO-PEFT model"
+    assert is_peft_model(model), "Model has not been converted to PEFT/LoRA format"

101-102: Inconsistent error message formatting

The error message uses different formats for "pattern" vs "adapter pattern".

-                pattern_type = "pattern" if allow_callable else "adapter pattern"
-                raise TypeError(f"Unsupported {pattern_type} type: {type(pattern)}")
+                pattern_type = "layer pattern" if allow_callable else "adapter pattern"
+                raise TypeError(f"Unsupported {pattern_type} type: {type(pattern).__name__}")

111-119: Direct access to private attribute _lora_adapters

The function directly accesses the private _lora_adapters attribute, violating encapsulation. Consider adding a public method or property for adapter access.

Add a public interface in LoRAModule:

# In LoRAModule class
def get_adapter(self, adapter_name: str) -> dict[str, Any] | None:
    """Get adapter configuration by name."""
    return self._lora_adapters.get(adapter_name)

def set_adapter_state(self, adapter_name: str, enable: bool) -> None:
    """Set the enable state of an adapter."""
    if adapter_name in self._lora_adapters:
        self._lora_adapters[adapter_name]["enable"] = enable

Then update this function:

-            for adapter_name, adapter_dict in module._lora_adapters.items():
+            for adapter_name in module.adapter_names:
                 if adapter_patterns is not None:
                     if not matches_any_pattern(
                         adapter_name, adapter_patterns, allow_callable=False
                     ):
                         continue
 
-                adapter_dict["enable"] = enable_state
+                module.set_adapter_state(adapter_name, enable_state)
tests/gpu/torch/peft/test_megatron_peft.py (3)

93-121: Test helper lacks error handling

The model provider function doesn't handle potential errors during model creation, which could make test failures harder to debug.

 def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
     """Build the model."""
-
+    try:
         if meta_device:
             with torch.device("meta"):
                 gpt_model = get_mcore_gpt_model(
                     tensor_model_parallel_size=tp_size,
                     num_layers=4,
                     ffn_hidden_size=None,
                     num_attention_heads=4,
                     activation_func="squared_relu",
                     transformer_impl="local",
                     hidden_size=hidden_size,
                     vocab_size=vocab_size,
                     use_cpu_initialization=meta_device,
                 )
         else:
             gpt_model = get_mcore_gpt_model(
                 tensor_model_parallel_size=tp_size,
                 num_layers=4,
                 ffn_hidden_size=None,
                 num_attention_heads=4,
                 activation_func="squared_relu",
                 transformer_impl="local",
                 hidden_size=hidden_size,
                 vocab_size=vocab_size,
             ).cuda()
         return gpt_model.eval()
+    except Exception as e:
+        pytest.fail(f"Failed to create GPT model: {e}")

134-139: Conditional assertion logic could be clearer

The conditional logic for checking output equality based on config type is not immediately clear. Consider adding a comment explaining why DEFAULT_LORA_CFG_TEST should produce identical outputs.

     assert lora_output.shape == original_output.shape
+    # DEFAULT_LORA_CFG_TEST uses zero initialization for LoRA B, so initial output should match original
     if lora_config == DEFAULT_LORA_CFG_TEST:
         assert torch.allclose(lora_output, original_output, rtol=1e-5), (
             f"{lora_output}, {original_output}"
         )
     else:
         assert not torch.allclose(lora_output, original_output, rtol=1e-5)

161-174: Large number of commented test cases

Most test parameterizations are commented out. This suggests either incomplete implementation or test instability.

If these tests are not ready, consider:

  1. Removing them entirely and tracking in an issue
  2. Using pytest.skip with a reason
  3. Adding a TODO comment explaining why they're disabled
 @pytest.mark.parametrize(
     "lora_config",
     [
         DEFAULT_LORA_CFG_TEST,
-        # DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
-        # SMALL_RANK_LORA_CFG,
-        # LARGE_SCALE_LORA_CFG,
-        # SELECTIVE_LAYER_LORA_CFG,
+        pytest.param(DEFAULT_LORA_CFG_RANDOM_INIT_TEST, marks=pytest.mark.skip(reason="Not yet stable")),
+        pytest.param(SMALL_RANK_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")),
+        pytest.param(LARGE_SCALE_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")),
+        pytest.param(SELECTIVE_LAYER_LORA_CFG, marks=pytest.mark.skip(reason="Not yet stable")),
     ],
 )
modelopt/torch/peft/conversion.py (3)

104-116: Return the possibly converted root module from replace_lora_module (and use it)

Safer if a root replacement is ever registered; also removes confusion around local reassignment.

 def replace_lora_module(
     model: nn.Module, version=None, config: PEFTConfig = None, registry=LoRAModuleRegistry
 ):
     """Recursively replace the module with LoRA module."""
@@
-    if type(model) in registry:
-        model = registry.convert(model)
-    _replace_lora_module(model, version=version, registry=registry)
+    if type(model) in registry:
+        model = registry.convert(model)
+    _replace_lora_module(model, version=version, registry=registry)
+    return model

And in convert_to_peft_model:

-    replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config)
+    model = replace_lora_module(
+        model, version=ModeloptStateManager(model).state_version, config=config
+    )

41-42: Remove stale TODO

The replacement is already performed.

-    # TODO: Replace to LoRA module

141-149: Docstring example uses an invalid signature

update_layer_lora takes a config object; adjust the example to avoid confusion.

-        ...         module.update_layer_lora("custom_adapter", rank=32)
+        ...         from modelopt.torch.peft.config import PEFTAttributeConfig
+        ...         module.update_layer_lora("custom_adapter", PEFTAttributeConfig(rank=32))
modelopt/torch/peft/lora/tp_layer.py (3)

37-48: _get_init_methods is unused; either use or remove

Prefer using it to ensure defaults when attr_config initializers are None (e.g., after metadata restore).

Example use (apply similarly in both update_layer_lora methods):

-        lora_a = ColumnParallelLinear(
+        lora_a_init, lora_b_init = self._get_init_methods(attr_config.lora_a_init, attr_config.lora_b_init)
+        lora_a = ColumnParallelLinear(
             self.input_size,
             attr_config.rank,
             config=self.config,
             bias=False,
             gather_output=True,
-            init_method=attr_config.lora_a_init,
+            init_method=lora_a_init,
             disable_grad_reduce=getattr(self.config, "sequence_parallel", False),
         )
@@
-        lora_b = ColumnParallelLinear(
+        lora_b = ColumnParallelLinear(
             attr_config.rank,
             self.output_size,
             config=self.config,
             bias=False,
             gather_output=False,  # Keep output distributed like base layer
-            init_method=attr_config.lora_a_init,
+            init_method=lora_b_init,
         )

80-87: Micro: combine device/dtype moves into a single .to call

Slight cleanup; avoids two passes.

-        if device is not None:
-            lora_a = lora_a.to(device)
-            lora_b = lora_b.to(device)
-        if dtype is not None:
-            lora_a = lora_a.to(dtype)
-            lora_b = lora_b.to(dtype)
+        if device is not None or dtype is not None:
+            lora_a = lora_a.to(device=device, dtype=dtype)
+            lora_b = lora_b.to(device=device, dtype=dtype)

26-28: Remove unused defaults

DEFAULT_LORA_RANK and DEFAULT_SCALE are unused.

-DEFAULT_LORA_RANK = 64
-DEFAULT_SCALE = 1.0
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b895dc5 and d9a79c1.

📒 Files selected for processing (12)
  • modelopt/torch/peft/__init__.py (1 hunks)
  • modelopt/torch/peft/config.py (1 hunks)
  • modelopt/torch/peft/conversion.py (1 hunks)
  • modelopt/torch/peft/convert.py (1 hunks)
  • modelopt/torch/peft/custom.py (1 hunks)
  • modelopt/torch/peft/lora/__init__.py (1 hunks)
  • modelopt/torch/peft/lora/layer.py (1 hunks)
  • modelopt/torch/peft/lora/tp_layer.py (1 hunks)
  • modelopt/torch/peft/mode.py (1 hunks)
  • modelopt/torch/peft/plugins/__init__.py (1 hunks)
  • modelopt/torch/peft/plugins/megatron.py (1 hunks)
  • tests/gpu/torch/peft/test_megatron_peft.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (9)
modelopt/torch/peft/plugins/megatron.py (1)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/peft/lora/layer.py (3)
modelopt/torch/opt/dynamic.py (3)
  • DynamicModule (338-914)
  • _DMRegistryCls (917-1124)
  • config (1265-1278)
modelopt/torch/peft/config.py (1)
  • PEFTAttributeConfig (40-117)
modelopt/torch/peft/conversion.py (1)
  • peft_state (96-101)
modelopt/torch/peft/__init__.py (1)
modelopt/torch/peft/mode.py (2)
  • convert (31-32)
  • convert (66-68)
tests/gpu/torch/peft/test_megatron_peft.py (6)
tests/_test_utils/import_helper.py (1)
  • skip_if_no_megatron (46-77)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • get_mcore_gpt_model (133-208)
  • initialize_for_megatron (385-393)
modelopt/torch/peft/config.py (2)
  • kaiming_init (30-32)
  • zero_init (35-37)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (20-230)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/peft/convert.py (3)
  • update_model (39-63)
  • disable_adapters (121-149)
  • enable_adapters (152-180)
modelopt/torch/peft/convert.py (4)
modelopt/torch/opt/conversion.py (1)
  • apply_mode (342-429)
modelopt/torch/peft/config.py (1)
  • PEFTConfig (124-170)
modelopt/torch/peft/conversion.py (1)
  • add_adapter (163-192)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (20-230)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
  • ModeloptBaseConfig (59-147)
  • ModeloptField (50-53)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (7)
  • ApplyModeError (314-315)
  • ModelLikeModule (318-330)
  • ModeloptStateManager (63-311)
  • init_modellike (326-330)
  • state_version (135-137)
  • is_converted (102-127)
  • _last_metadata (220-222)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/peft/config.py (1)
  • PEFTConfig (124-170)
modelopt/torch/peft/lora/layer.py (4)
  • LoRAModule (20-230)
  • set_from_peft_state (152-165)
  • get_peft_state (91-131)
  • update_layer_lora (72-89)
modelopt/torch/peft/custom.py (1)
  • register_custom_model_plugins_on_the_fly (22-29)
modelopt/torch/peft/mode.py (4)
modelopt/torch/opt/config.py (1)
  • ModeloptBaseConfig (59-147)
modelopt/torch/opt/mode.py (2)
  • ModeDescriptor (56-259)
  • _ModeRegistryCls (267-344)
modelopt/torch/peft/config.py (2)
  • PEFTConfig (124-170)
  • ExportPEFTConfig (173-174)
modelopt/torch/peft/conversion.py (5)
  • convert_to_peft_model (36-48)
  • restore_peft_model (51-55)
  • update_peft_metadata (91-93)
  • export_peft_model (118-119)
  • restore_export_peft_model (122-123)
modelopt/torch/peft/lora/tp_layer.py (4)
modelopt/torch/peft/config.py (1)
  • PEFTAttributeConfig (40-117)
modelopt/torch/peft/lora/layer.py (4)
  • LoRAModule (20-230)
  • _register_adapter (39-69)
  • update_layer_lora (72-89)
  • _setup (30-32)
modelopt/torch/quantization/plugins/megatron.py (2)
  • _MegatronColumnParallelLinear (296-318)
  • _MegatronRowParallelLinear (322-354)
modelopt/torch/quantization/conversion.py (1)
  • register (325-366)
🪛 GitHub Actions: Code Quality
modelopt/torch/peft/mode.py

[error] 1-1: D100 Missing docstring in public module


[error] 22-22: D101 Missing docstring in public class


[error] 24-24: D102 Missing docstring in public method


[error] 28-28: D102 Missing docstring in public method


[error] 32-32: D102 Missing docstring in public method


[error] 36-36: D102 Missing docstring in public method


[error] 40-40: D102 Missing docstring in public method


[error] 44-44: D102 Missing docstring in public method


[error] 53-53: D101 Missing docstring in public class

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
🔇 Additional comments (2)
modelopt/torch/peft/plugins/__init__.py (1)

20-21: LGTM: guarded optional import.

Import guarding with suppress(ImportError) is appropriate; it triggers registration only when Megatron is present.

modelopt/torch/peft/__init__.py (1)

18-25: Import-order OK — no action required. mode.py does from .conversion import convert_to_peft_model, ..., export_peft_model and its properties return those functions (modelopt/torch/peft/mode.py, ~lines 11–13, 31–33, 66–69).

@realAsma
Copy link
Contributor

@jingyu-ml

Transformers peft already seems to have a backend for Megatron Core - https://huggingface.co/docs/peft/v0.17.0/en/package_reference/lora#peft.LoraConfig.megatron_core. Have we tested peft MCore backend? Does it work? What does this PR add in addition to peft's MCore backend?

@cjluo-nv cjluo-nv requested a review from meenchen September 22, 2025 18:45
@jingyu-ml jingyu-ml requested review from meenchen and mxinO October 7, 2025 23:55
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/peft/conversion.py (1)

28-34: Export all public API functions in __all__.

Functions like convert_to_peft_model, restore_peft_model, and add_adapter are part of the public API (used by other modules in the package) but are not included in __all__. Add them to the export list or prefix truly internal helpers with an underscore.

🧹 Nitpick comments (1)
modelopt/torch/peft/conversion.py (1)

27-27: Clarify or remove TODO comment.

The TODO mentions adding test cases, but per the PR summary and AI-generated summary, new tests already exist in tests/gpu/torch/peft/test_megatron_peft.py. Either update this comment to specify which particular functions still need coverage, or remove it if all exported functions are already tested.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 14151fd and 7a14112.

📒 Files selected for processing (1)
  • modelopt/torch/peft/conversion.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/peft/conversion.py
🧬 Code graph analysis (1)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
  • ModelLikeModule (318-330)
  • ModeloptStateManager (63-311)
  • init_modellike (326-330)
  • state_version (135-137)
modelopt/torch/utils/network.py (1)
  • matches_pattern (92-156)
modelopt/torch/peft/config.py (2)
  • PEFTAttributeConfig (32-110)
  • PEFTConfig (117-191)
modelopt/torch/peft/lora/layer.py (2)
  • LoRAModule (16-121)
  • update_layer_lora (68-82)
modelopt/torch/peft/custom.py (1)
  • register_custom_model_plugins_on_the_fly (22-29)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
modelopt/torch/peft/config.py (2)

56-71: Clarify json_schema_extra for callable fields.

The json_schema_extra declares "type": "string" for lora_a_init and lora_b_init, but these fields expect callable objects (type InitFn), not strings. This mismatch can mislead schema consumers, API clients, or documentation generators.

Consider either:

  • Removing json_schema_extra and letting Pydantic infer the schema
  • Updating to reflect the actual type (e.g., "type": "callable" or omitting type and relying on examples only)

179-193: Consider explicit type validation for non-dict values.

The validator converts dict values to PEFTAttributeConfig but silently accepts non-dict values without checking if they're valid PEFTAttributeConfig instances. While Pydantic's field-level validation should catch type mismatches before this validator runs, explicitly checking the type would make the validator more robust and self-documenting.

Apply this diff for explicitness:

     @field_validator("adapter_cfg")
     @classmethod
     def validate_adapter_cfg(cls, v):
         """Validate and convert adapter configurations."""
         validated_cfg = {}
         for key, value in v.items():
             if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig):
                 # Convert dict to PEFTAttributeConfig to trigger validation
                 try:
                     validated_cfg[key] = PEFTAttributeConfig(**value)
                 except Exception as e:
                     raise ValueError(f"Invalid adapter configuration for '{key}': {e}")
+            elif isinstance(value, PEFTAttributeConfig):
+                validated_cfg[key] = value
             else:
-                validated_cfg[key] = value
+                raise ValueError(
+                    f"Invalid adapter configuration for '{key}': "
+                    f"expected dict or PEFTAttributeConfig, got {type(value).__name__}"
+                )
         return validated_cfg
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 17f56f5 and 4c20df7.

📒 Files selected for processing (2)
  • modelopt/torch/peft/config.py (1 hunks)
  • modelopt/torch/peft/convert.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/peft/convert.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
  • ModeloptBaseConfig (59-147)
  • ModeloptField (50-53)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (4)
modelopt/torch/peft/config.py (4)

1-30: LGTM!

The imports, exports, and type alias are well-structured. The InitFn type alias clearly defines initialization function signatures.


73-112: LGTM!

The validators correctly enforce:

  • Initialization functions must be in-place methods from torch.nn.init (ending with _)
  • Rank must be positive
  • Scale must be positive

Error messages are clear and actionable.


136-169: LGTM!

The configuration fields are well-structured with appropriate defaults:

  • freeze_lora_weights=False correctly enables LoRA training
  • adapter_cfg={"*": {"rank": 64}} uses the wildcard pattern to apply to all layers by default
  • Field descriptions are clear and helpful

196-197: LGTM!

The ExportPEFTConfig placeholder class is appropriate for future extensibility.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
modelopt/torch/peft/config.py (1)

147-148: Remove Callable from the key type union.

The type alias allows Callable as a key type, but validate_adapter_cfg (lines 213-225) doesn't validate callable keys, and the conversion logic uses string pattern matching with fnmatch. This misleads users into thinking callable keys are supported when they aren't.

Apply this diff to align the type with actual supported behavior:

-PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict]
+PEFTAdapterCfgType = dict[str, PEFTAttributeConfig | dict]

Based on past review comments that flagged this issue as unresolved.

🧹 Nitpick comments (1)
modelopt/torch/peft/config.py (1)

168-173: Consider validating adapter_name format.

The adapter_name field has no validation and could be an empty string or contain problematic characters. While this may be acceptable if downstream code handles it, consider adding a validator to ensure it's non-empty and contains only valid characters for identifiers.

If validation is desired, add:

@field_validator("adapter_name")
@classmethod
def validate_adapter_name(cls, v):
    """Validate adapter name is non-empty."""
    if not v or not v.strip():
        raise ValueError("adapter_name must be a non-empty string")
    return v
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4c20df7 and acf107c.

📒 Files selected for processing (1)
  • modelopt/torch/peft/config.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
  • ModeloptBaseConfig (59-147)
  • ModeloptField (50-53)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/torch/peft/config.py (4)

1-29: LGTM!

The module structure, imports, and public exports are well-organized and appropriate for a configuration module.


30-52: LGTM!

The InitField type and helper function provide a clean way to handle torch initializers in configuration with proper serialization and schema documentation.


55-144: LGTM!

The PEFTAttributeConfig class is well-designed with comprehensive validation for all fields. The initializer parsing and validation logic properly enforces that only in-place torch.nn.init functions are used.


228-229: LGTM!

The empty ExportPEFTConfig class serves as an acceptable placeholder for future export-related configuration extensions.

@jingyu-ml jingyu-ml force-pushed the jingyux/megatron-lora branch from 0e5b9d4 to ca415dc Compare October 8, 2025 04:11
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml force-pushed the jingyux/megatron-lora branch from ca415dc to fefbbe4 Compare October 8, 2025 04:12
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
modelopt/torch/peft/convert.py (1)

66-67: Capture return value from apply_mode.

apply_mode may return a different module instance than the input. Not capturing the return value is unsafe and could lead to using a stale reference.

Apply this diff:

     # Check if model is already in PEFT mode by looking for LoRA modules
     if not is_peft_model(model):
-        apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
+        model = apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
     else:

This issue was flagged in a previous review but remains unaddressed.

tests/gpu/torch/peft/test_megatron_peft.py (2)

20-20: Add Apex/TE requirement to test skip guard.

The tests use transformer_impl="local" which requires Apex (line 160, 172). Without passing apex_or_te_required=True to skip_if_no_megatron(), environments lacking Apex will encounter assertion failures instead of gracefully skipping the tests.

Apply this diff:

-skip_if_no_megatron()
+skip_if_no_megatron(apex_or_te_required=True)

This issue was flagged in a previous review but remains unaddressed.


498-507: Avoid mutating shared config fixture.

Lines 500-501 mutate the lora_config fixture directly, which can leak state across parametrized test runs. The _test_adapter_gradient_flow_freeze_lora_model helper correctly uses copy.deepcopy (line 448), but this helper does not.

Apply this diff:

 def _test_adapter_gradient_flow(lora_config, tmp_path, rank, size):
     hidden_size = 512
-    lora_config["freeze_lora_weights"] = False
-    lora_config["freeze_base_model"] = False
+    local_cfg = copy.deepcopy(lora_config)
+    local_cfg["freeze_lora_weights"] = False
+    local_cfg["freeze_base_model"] = False

     initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1)
     model = _gpt_model_provider(tp_size=size, hidden_size=hidden_size)
     prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()

-    mtpeft.update_model(model, lora_config)
+    mtpeft.update_model(model, local_cfg)

This issue was flagged in a previous review and marked as addressed, but the current code still has the problem.

modelopt/torch/peft/lora/plugins/megatron.py (1)

54-57: Fix typo in configuration attribute name.

The attribute name hetereogenous_dist_checkpoint is misspelled and should be heterogeneous_dist_checkpoint. This typo was flagged in a previous review but is now within the scope of this PR.

Apply this diff:

             if hasattr(module, "config") and hasattr(
-                module.config, "hetereogenous_dist_checkpoint"
+                module.config, "heterogeneous_dist_checkpoint"
             ):
-                module.config.hetereogenous_dist_checkpoint = True
+                module.config.heterogeneous_dist_checkpoint = True
modelopt/torch/peft/conversion.py (2)

28-34: Export core public API functions in all

The __all__ list is missing several public API functions that are used externally (per the AI summary and PR objectives):

  • convert_to_peft_model (core conversion entry point)
  • restore_peft_model (core restore entry point)
  • add_adapter (adapter management, used in mode integration)

Apply this diff:

 __all__ = [
+    "add_adapter",
+    "convert_to_peft_model",
     "freeze_base_weights",
     "freeze_lora_weights",
     "replace_lora_module",
+    "restore_peft_model",
     "unfreeze_base_weights",
     "unfreeze_lora_weights",
 ]

133-137: Critical: enable check missing default True

Line 133 checks merged_setting_dict.get("enable") which returns None (falsy) when the "enable" key is absent. Since line 122 uses model_dump(exclude_unset=True), if no matching pattern explicitly sets enable, the key won't appear in merged_setting_dict. This causes adapters to be skipped even though PEFTAttributeConfig defaults enable=True.

Apply this diff:

-            if merged_setting_dict is not None and merged_setting_dict.get("enable"):
+            if merged_setting_dict is not None and merged_setting_dict.get("enable", True):
                 module.update_layer_lora(
                     adapter_name,
                     PEFTAttributeConfig(**merged_setting_dict),
🧹 Nitpick comments (1)
modelopt/torch/peft/config.py (1)

147-148: Add validation for callable adapter keys or document their usage.

The type alias permits Callable keys, but validate_adapter_cfg (lines 211-225) doesn't validate them, and it's unclear what signature callables should have. Based on matches_pattern usage in conversion.py, callables should be Callable[[str], bool].

Consider one of these approaches:

Option 1: Add validation for callable keys

     @field_validator("adapter_cfg")
     @classmethod
     def validate_adapter_cfg(cls, v):
         """Validate and convert adapter configurations."""
         validated_cfg = {}
         for key, value in v.items():
+            # Validate callable keys have correct signature
+            if callable(key):
+                # Check it's a unary function
+                import inspect
+                sig = inspect.signature(key)
+                if len(sig.parameters) != 1:
+                    raise ValueError(
+                        f"Callable keys must accept a single string argument, got {sig}"
+                    )
             if isinstance(value, dict) and not isinstance(value, PEFTAttributeConfig):
                 # Convert dict to PEFTAttributeConfig to trigger validation
                 try:
                     validated_cfg[key] = PEFTAttributeConfig(**value)
                 except Exception as e:
                     raise ValueError(f"Invalid adapter configuration for '{key}': {e}")
             else:
                 validated_cfg[key] = value
         return validated_cfg

Option 2: Remove callable support if not needed

 # Type alias for adapter configuration
-PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict]
+PEFTAdapterCfgType = dict[str, PEFTAttributeConfig | dict]

If callable keys are intentionally supported, add documentation explaining the expected signature: Callable[[str], bool] where the argument is the module name.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between acf107c and fefbbe4.

📒 Files selected for processing (16)
  • .github/CODEOWNERS (1 hunks)
  • .vscode/settings.json (1 hunks)
  • CHANGELOG.rst (1 hunks)
  • modelopt/torch/__init__.py (1 hunks)
  • modelopt/torch/peft/__init__.py (1 hunks)
  • modelopt/torch/peft/config.py (1 hunks)
  • modelopt/torch/peft/conversion.py (1 hunks)
  • modelopt/torch/peft/convert.py (1 hunks)
  • modelopt/torch/peft/custom.py (1 hunks)
  • modelopt/torch/peft/lora/__init__.py (1 hunks)
  • modelopt/torch/peft/lora/layer.py (1 hunks)
  • modelopt/torch/peft/lora/plugins/__init__.py (1 hunks)
  • modelopt/torch/peft/lora/plugins/megatron.py (1 hunks)
  • modelopt/torch/peft/mode.py (1 hunks)
  • modelopt/torch/utils/network.py (3 hunks)
  • tests/gpu/torch/peft/test_megatron_peft.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (7)
  • modelopt/torch/peft/lora/init.py
  • CHANGELOG.rst
  • modelopt/torch/peft/mode.py
  • modelopt/torch/peft/custom.py
  • modelopt/torch/peft/init.py
  • .vscode/settings.json
  • modelopt/torch/utils/network.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-15T20:46:29.252Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:170-189
Timestamp: 2025-09-15T20:46:29.252Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the restore_from_modelopt_state function can accept modelopt_state["modelopt_state_dict"] directly without needing to wrap it in a full dict structure or include modelopt_version.

Applied to files:

  • modelopt/torch/peft/conversion.py
🧬 Code graph analysis (6)
modelopt/torch/peft/lora/layer.py (3)
modelopt/torch/opt/dynamic.py (3)
  • DynamicModule (338-914)
  • _DMRegistryCls (917-1124)
  • config (1265-1278)
modelopt/torch/peft/config.py (1)
  • PEFTAttributeConfig (55-144)
modelopt/torch/peft/lora/plugins/megatron.py (4)
  • _setup (262-263)
  • _setup (273-274)
  • update_layer_lora (121-151)
  • update_layer_lora (186-217)
modelopt/torch/peft/config.py (1)
modelopt/torch/opt/config.py (2)
  • ModeloptBaseConfig (59-147)
  • ModeloptField (50-53)
tests/gpu/torch/peft/test_megatron_peft.py (8)
tests/_test_utils/import_helper.py (1)
  • skip_if_no_megatron (46-77)
tests/_test_utils/torch_dist/plugins/megatron_common.py (2)
  • get_mcore_gpt_model (134-209)
  • initialize_for_megatron (392-400)
modelopt/torch/opt/plugins/mcore_dist_checkpointing.py (2)
  • restore_sharded_modelopt_state (207-250)
  • save_sharded_modelopt_state (127-173)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (16-121)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/peft/lora/plugins/megatron.py (2)
  • sharded_state_dict (153-175)
  • sharded_state_dict (219-241)
modelopt/torch/peft/convert.py (3)
  • update_model (45-72)
  • disable_adapters (116-144)
  • enable_adapters (147-175)
modelopt/torch/quantization/model_quant.py (1)
  • disable_quantizer (453-455)
modelopt/torch/peft/lora/plugins/megatron.py (3)
modelopt/torch/quantization/plugins/megatron.py (2)
  • _MegatronColumnParallelLinear (296-318)
  • _MegatronRowParallelLinear (322-354)
modelopt/torch/peft/config.py (1)
  • PEFTAttributeConfig (55-144)
modelopt/torch/peft/lora/layer.py (4)
  • LoRAModule (16-121)
  • _register_adapter (35-65)
  • update_layer_lora (68-82)
  • _setup (26-28)
modelopt/torch/peft/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
  • ModelLikeModule (318-330)
  • ModeloptStateManager (63-311)
  • init_modellike (326-330)
  • state_version (135-137)
modelopt/torch/utils/network.py (1)
  • matches_pattern (92-156)
modelopt/torch/peft/config.py (2)
  • PEFTAttributeConfig (55-144)
  • PEFTConfig (151-225)
modelopt/torch/peft/lora/layer.py (2)
  • LoRAModule (16-121)
  • update_layer_lora (68-82)
modelopt/torch/peft/custom.py (1)
  • register_custom_model_plugins_on_the_fly (22-29)
modelopt/torch/peft/convert.py (5)
modelopt/torch/opt/conversion.py (1)
  • apply_mode (342-429)
modelopt/torch/peft/config.py (1)
  • PEFTConfig (151-225)
modelopt/torch/peft/conversion.py (1)
  • add_adapter (101-139)
modelopt/torch/utils/network.py (1)
  • matches_pattern (92-156)
modelopt/torch/peft/lora/layer.py (1)
  • LoRAModule (16-121)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (14)
modelopt/torch/peft/lora/layer.py (4)

26-28: LGTM!

The initialization of the _lora_adapters dictionary is straightforward and correct.


35-65: LGTM!

The adapter registration logic correctly:

  • Adds modules with add_module
  • Prevents duplicate adapters
  • Stores all necessary metadata (lora_a, lora_b, rank, scale, enable)

67-82: LGTM!

The abstract method declaration is correct, and the docstring accurately describes the attr_config parameter as requested in past reviews.


84-121: LGTM!

The forward pass correctly:

  • Handles both single tensor and tuple outputs from the base layer
  • Applies only enabled adapters
  • Unpacks intermediate tuple outputs from parallel LoRA layers
  • Scales and accumulates LoRA contributions
  • Preserves additional outputs from the base layer
modelopt/torch/peft/lora/plugins/megatron.py (3)

71-110: LGTM!

The device/dtype alignment logic correctly:

  • Iterates through parameters and buffers to find the target device/dtype
  • Moves both LoRA modules to match the parent module
  • Delegates to the parent class's _register_adapter method

121-151: LGTM!

The LoRA adapter creation for ColumnParallelLinear is correct:

  • lora_a is a plain nn.Linear (not sharded)
  • lora_b is a ColumnParallelLinear (sharded at dim 0)
  • Proper initialization with lora_a_init and lora_b_init

This design aligns with the discussion in past reviews and avoids quantization issues with very small ranks.


186-217: LGTM!

The LoRA adapter creation for RowParallelLinear correctly mirrors the ColumnParallel design:

  • lora_a is a RowParallelLinear (sharded at dim 1)
  • lora_b is a plain nn.Linear (not sharded)
  • Proper initialization
modelopt/torch/peft/convert.py (4)

75-87: LGTM!

The check correctly identifies PEFT models by looking for any LoRAModule instances.


90-113: LGTM!

The adapter state management correctly:

  • Validates the model is a PEFT model
  • Filters by layer patterns using matches_pattern
  • Filters by adapter patterns (string-only)
  • Sets the enable flag appropriately

116-175: LGTM!

Both disable_adapters and enable_adapters are well-documented thin wrappers that delegate to _set_adapter_state with appropriate parameters.


178-183: LGTM!

The Megatron model detection correctly checks for the presence of Megatron-specific layer types.

tests/gpu/torch/peft/test_megatron_peft.py (1)

135-176: LGTM!

The distributed checkpoint save/load helpers and model provider function are well-structured and correctly handle Megatron-specific initialization.

modelopt/torch/peft/config.py (2)

55-144: LGTM!

PEFTAttributeConfig is well-designed with comprehensive validation:

  • Init method parsing from strings
  • Validation that callables are from torch.nn.init and in-place
  • Positive rank and scale checks

The validators provide clear error messages to guide users.


151-225: LGTM!

PEFTConfig has sensible defaults:

  • adapter_name="default"
  • adapter_cfg={"*": {"rank": 64}} (wildcard matches all layers)
  • freeze_base_model=True, freeze_lora_weights=False (standard LoRA fine-tuning)

The validators ensure:

  • Only "lora" adapter type is accepted
  • Adapter configs are converted to PEFTAttributeConfig with clear error messages

@jingyu-ml jingyu-ml enabled auto-merge (squash) October 8, 2025 04:33
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
jingyu-ml and others added 2 commits October 8, 2025 05:05
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: jingyu-ml <108295447+jingyu-ml@users.noreply.github.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml
Copy link
Contributor Author

Thanks everyone for the review. If there are any additional concerns, please feel free to leave comments in this PR or in the design doc. I’ll address them in the next PR.

@jingyu-ml jingyu-ml enabled auto-merge (squash) October 8, 2025 23:57
@jingyu-ml jingyu-ml merged commit 39f81e5 into main Oct 9, 2025
27 checks passed
@jingyu-ml jingyu-ml deleted the jingyux/megatron-lora branch October 9, 2025 01:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants