Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions modelopt/torch/distill/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ def convert(self) -> ConvertEntrypoint:
@property
def restore(self) -> RestoreEntrypoint:
"""The mode's entrypoint for restoring a model."""
return _restore_kd_model
raise NotImplementedError(f"{self.name} mode does not support restore.")

@property
def update_for_new_mode(self) -> UpdateEntrypoint:
"""The mode's entrypoint for updating the models state for adding new mode."""
return _reset_kd_state_config

@property
def update_for_save(self) -> UpdateEntrypoint:
"""The mode's entrypoint for updating the models state before saving."""
return _reset_kd_state_config
def save_mode_in_state(self) -> bool:
"""Whether the mode should be saved into the modelopt state."""
return False


@DistillModeRegistry.register_mode
Expand Down Expand Up @@ -121,7 +121,12 @@ def convert(self) -> ConvertEntrypoint:
@property
def restore(self) -> RestoreEntrypoint:
"""The mode's entrypoint for restoring a model."""
return _restore_exported_student
raise NotImplementedError(f"{self.name} mode does not support restore.")

@property
def save_mode_in_state(self) -> bool:
"""Whether the mode should be saved into the modelopt state."""
return False


def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType:
Expand Down Expand Up @@ -174,12 +179,6 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType
return distillation_model, metadata


def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module:
"""Function for restoring a previously convert model to a distillation meta-model."""
# NOTE: DistillationModel will purposely remain unrestored
return model


def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict):
"""Function for resetting the state's config."""
config.teacher_model = nn.Module
Expand All @@ -206,16 +205,8 @@ def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertRet
student_model,
warn=True,
msg=(
f"The student model is wrapped into {type(student_model).__name__}. Unwrapping and"
" exporting it ..."
f"The student model is wrapped into {type(student_model).__name__}. Unwrapping and exporting it ..."
),
)

return student_model, {}


def _restore_exported_student(
model: nn.Module, config: ExportStudentConfig, metadata: MetadataDict
) -> nn.Module:
# NOTE: DistillationModel was unrestored so this does nothing
return model
11 changes: 8 additions & 3 deletions modelopt/torch/opt/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,11 +476,16 @@ def modelopt_state(model: nn.Module) -> dict[str, Any]:
# update metadata of current mode as needed
manager.update_last_state_before_save(model)

# filter out modes that should not be saved in the state
skip_idx = []
for i, (m, _, _) in enumerate(manager.modes_with_states()):
if not m.save_mode_in_state:
skip_idx.append(i)
state_dict = [state for i, state in enumerate(manager.state_dict()) if i not in skip_idx]

# construct state dict and return it
objs = {
"modelopt_state_dict": (
manager.state_dict()
), # empty state_dict is okay (saving regular models)
"modelopt_state_dict": state_dict, # empty state_dict is okay (saving regular models)
"modelopt_version": __version__,
}
return objs
Expand Down
14 changes: 12 additions & 2 deletions modelopt/torch/opt/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,25 @@ def require_model_like(self) -> bool:
"""
return False

@property
def save_mode_in_state(self) -> bool:
"""Whether the mode should be saved into the modelopt state.

This is useful if the mode is intended to be manually re-applied every time it's used.

Returns:
True
"""
return True

def assert_compatibility_as_next_mode_of(self, other_mode: "ModeDescriptor | str") -> None:
"""Assert that this mode is compatible as a next mode of the other mode."""
if isinstance(other_mode, str):
other_mode = _ModeRegistryCls.get_from_any(other_mode)

if other_mode.next_modes is not None:
assert str(self) in other_mode.next_modes, (
f"Cannot add {self} after {other_mode}! Next modes of {other_mode} are"
f" {other_mode.next_modes}."
f"Cannot add {self} after {other_mode}! Next modes of {other_mode} are {other_mode.next_modes}."
)

if other_mode.next_prohibited_modes is not None:
Expand Down
19 changes: 12 additions & 7 deletions tests/unit/torch/distill/test_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,21 @@ def test_distillation_save_restore(distillation_model, tmp_path):
new_student = tiny_mobilenet()
distillation_model_new = mto.restore(new_student, tmp_path / "ckpt.pt")

# Ensure state config was reset
# Ensure state is not actually restored
manager = mto.ModeloptStateManager(distillation_model_new)
cfg = manager._state[-1][1]["config"]
assert cfg["teacher_model"] == nn.Module
assert isinstance(next(iter(cfg["criterion"].values())), Loss)
assert cfg["loss_balancer"] is None

# Should not have restored anything
assert not manager.has_state
assert isinstance(distillation_model_new, type(new_student))

# Subsequent convert should behave normally
config = {
"teacher_model": distillation_model.teacher_model,
"criterion": mtd.LogitsDistillationLoss(),
}
distillation_model_newer = mtd.convert(new_student, mode=[("kd_loss", config)])
manager = mto.ModeloptStateManager(distillation_model_newer)
assert manager.has_state
assert isinstance(distillation_model_newer, mtd.DistillationModel)


def test_distillation_export(distillation_model, tmp_path):
model_exported = mtd.export(distillation_model)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/torch/opt/plugins/test_hf_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ def test_nested_model_save_restore(tmp_path, model_cls, teacher_model_type):
model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model")

tf_output_tester(model, model_test)
# since distill model contains loss function, we compare state of model manually
assert mto.modelopt_state(model.model) == mto.modelopt_state(model_test.model)
# KD state is not saved and it should be empty
assert not mto.ModeloptStateManager(model_test).has_state
4 changes: 3 additions & 1 deletion tests/unit/torch/opt/test_chaining.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def test_chained_save_restore(mode):
# compare serialized version since some configs may be objected...
manager = mto.ModeloptStateManager(model)
manager2 = mto.ModeloptStateManager(model2)
assert torch.equal(_serialize(manager.state_dict()), _serialize(manager2.state_dict()))
# NOTE: KD modes are skipped during restore and thus won't exist
state_minus_kd = [s for s in manager.state_dict() if s[0] not in ("kd_loss", "export_student")]
assert torch.equal(_serialize(state_minus_kd), _serialize(manager2.state_dict()))

# run comparison in eval mode since there might be model randomization in train mode
model.eval()
Expand Down