diff --git a/modelopt/torch/distill/mode.py b/modelopt/torch/distill/mode.py index b65a31cd5..b22520209 100644 --- a/modelopt/torch/distill/mode.py +++ b/modelopt/torch/distill/mode.py @@ -78,7 +78,7 @@ def convert(self) -> ConvertEntrypoint: @property def restore(self) -> RestoreEntrypoint: """The mode's entrypoint for restoring a model.""" - return _restore_kd_model + raise NotImplementedError(f"{self.name} mode does not support restore.") @property def update_for_new_mode(self) -> UpdateEntrypoint: @@ -86,9 +86,9 @@ def update_for_new_mode(self) -> UpdateEntrypoint: return _reset_kd_state_config @property - def update_for_save(self) -> UpdateEntrypoint: - """The mode's entrypoint for updating the models state before saving.""" - return _reset_kd_state_config + def save_mode_in_state(self) -> bool: + """Whether the mode should be saved into the modelopt state.""" + return False @DistillModeRegistry.register_mode @@ -121,7 +121,12 @@ def convert(self) -> ConvertEntrypoint: @property def restore(self) -> RestoreEntrypoint: """The mode's entrypoint for restoring a model.""" - return _restore_exported_student + raise NotImplementedError(f"{self.name} mode does not support restore.") + + @property + def save_mode_in_state(self) -> bool: + """Whether the mode should be saved into the modelopt state.""" + return False def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType: @@ -174,12 +179,6 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType return distillation_model, metadata -def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module: - """Function for restoring a previously convert model to a distillation meta-model.""" - # NOTE: DistillationModel will purposely remain unrestored - return model - - def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict): """Function for resetting the state's config.""" config.teacher_model = nn.Module @@ -206,16 +205,8 @@ def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertRet student_model, warn=True, msg=( - f"The student model is wrapped into {type(student_model).__name__}. Unwrapping and" - " exporting it ..." + f"The student model is wrapped into {type(student_model).__name__}. Unwrapping and exporting it ..." ), ) return student_model, {} - - -def _restore_exported_student( - model: nn.Module, config: ExportStudentConfig, metadata: MetadataDict -) -> nn.Module: - # NOTE: DistillationModel was unrestored so this does nothing - return model diff --git a/modelopt/torch/opt/conversion.py b/modelopt/torch/opt/conversion.py index 183514f9e..1de6143bd 100644 --- a/modelopt/torch/opt/conversion.py +++ b/modelopt/torch/opt/conversion.py @@ -476,11 +476,16 @@ def modelopt_state(model: nn.Module) -> dict[str, Any]: # update metadata of current mode as needed manager.update_last_state_before_save(model) + # filter out modes that should not be saved in the state + skip_idx = [] + for i, (m, _, _) in enumerate(manager.modes_with_states()): + if not m.save_mode_in_state: + skip_idx.append(i) + state_dict = [state for i, state in enumerate(manager.state_dict()) if i not in skip_idx] + # construct state dict and return it objs = { - "modelopt_state_dict": ( - manager.state_dict() - ), # empty state_dict is okay (saving regular models) + "modelopt_state_dict": state_dict, # empty state_dict is okay (saving regular models) "modelopt_version": __version__, } return objs diff --git a/modelopt/torch/opt/mode.py b/modelopt/torch/opt/mode.py index b9ce12686..d9ac1bf07 100644 --- a/modelopt/torch/opt/mode.py +++ b/modelopt/torch/opt/mode.py @@ -242,6 +242,17 @@ def require_model_like(self) -> bool: """ return False + @property + def save_mode_in_state(self) -> bool: + """Whether the mode should be saved into the modelopt state. + + This is useful if the mode is intended to be manually re-applied every time it's used. + + Returns: + True + """ + return True + def assert_compatibility_as_next_mode_of(self, other_mode: "ModeDescriptor | str") -> None: """Assert that this mode is compatible as a next mode of the other mode.""" if isinstance(other_mode, str): @@ -249,8 +260,7 @@ def assert_compatibility_as_next_mode_of(self, other_mode: "ModeDescriptor | str if other_mode.next_modes is not None: assert str(self) in other_mode.next_modes, ( - f"Cannot add {self} after {other_mode}! Next modes of {other_mode} are" - f" {other_mode.next_modes}." + f"Cannot add {self} after {other_mode}! Next modes of {other_mode} are {other_mode.next_modes}." ) if other_mode.next_prohibited_modes is not None: diff --git a/tests/unit/torch/distill/test_distill.py b/tests/unit/torch/distill/test_distill.py index f87367296..10241f076 100644 --- a/tests/unit/torch/distill/test_distill.py +++ b/tests/unit/torch/distill/test_distill.py @@ -147,16 +147,21 @@ def test_distillation_save_restore(distillation_model, tmp_path): new_student = tiny_mobilenet() distillation_model_new = mto.restore(new_student, tmp_path / "ckpt.pt") - # Ensure state config was reset + # Ensure state is not actually restored manager = mto.ModeloptStateManager(distillation_model_new) - cfg = manager._state[-1][1]["config"] - assert cfg["teacher_model"] == nn.Module - assert isinstance(next(iter(cfg["criterion"].values())), Loss) - assert cfg["loss_balancer"] is None - - # Should not have restored anything + assert not manager.has_state assert isinstance(distillation_model_new, type(new_student)) + # Subsequent convert should behave normally + config = { + "teacher_model": distillation_model.teacher_model, + "criterion": mtd.LogitsDistillationLoss(), + } + distillation_model_newer = mtd.convert(new_student, mode=[("kd_loss", config)]) + manager = mto.ModeloptStateManager(distillation_model_newer) + assert manager.has_state + assert isinstance(distillation_model_newer, mtd.DistillationModel) + def test_distillation_export(distillation_model, tmp_path): model_exported = mtd.export(distillation_model) diff --git a/tests/unit/torch/opt/plugins/test_hf_patching.py b/tests/unit/torch/opt/plugins/test_hf_patching.py index a0e8e227e..0b6427d18 100644 --- a/tests/unit/torch/opt/plugins/test_hf_patching.py +++ b/tests/unit/torch/opt/plugins/test_hf_patching.py @@ -53,5 +53,5 @@ def test_nested_model_save_restore(tmp_path, model_cls, teacher_model_type): model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model") tf_output_tester(model, model_test) - # since distill model contains loss function, we compare state of model manually - assert mto.modelopt_state(model.model) == mto.modelopt_state(model_test.model) + # KD state is not saved and it should be empty + assert not mto.ModeloptStateManager(model_test).has_state diff --git a/tests/unit/torch/opt/test_chaining.py b/tests/unit/torch/opt/test_chaining.py index 19623c6e7..bedbbfee0 100644 --- a/tests/unit/torch/opt/test_chaining.py +++ b/tests/unit/torch/opt/test_chaining.py @@ -88,7 +88,9 @@ def test_chained_save_restore(mode): # compare serialized version since some configs may be objected... manager = mto.ModeloptStateManager(model) manager2 = mto.ModeloptStateManager(model2) - assert torch.equal(_serialize(manager.state_dict()), _serialize(manager2.state_dict())) + # NOTE: KD modes are skipped during restore and thus won't exist + state_minus_kd = [s for s in manager.state_dict() if s[0] not in ("kd_loss", "export_student")] + assert torch.equal(_serialize(state_minus_kd), _serialize(manager2.state_dict())) # run comparison in eval mode since there might be model randomization in train mode model.eval()