Checkpoint Refactor#65
Conversation
|
This should be ready for review now. I had to mess around with my fork to change it to the public PR so hopefully nothing weird happened. |
|
Hey @akshaysubr , @NickGeneva and I talked and I will refactor some of the registry and entrypoints stuff to be a bit more in line with fsspec. You can look at other parts of the code but I wouldn't bother seeing too much until I make these changes. |
|
Linking with Modulus launch PR that should be merged at the same time, https://github.com/NVIDIA/modulus-launch/pull/42 |
|
Hey @akshaysubr and @NickGeneva, This is ready for another round of revisions. I think everything here is in a pretty good place however there is one key design decision that I will explain bellow. So for the entrypoint exposure I followed the example here to the letter, https://amir.rachum.com/python-entry-points/. In particular I follow their final revision and expose all our models through entrypoints and treat internal and external models the same. You can see where I get the list of models here, https://github.com/loliverhennigh/modulus/blob/fea-checkpoint_refactor/modulus/models/registry.py#L20. This goes somewhat counter to the way that fsspec does things. Instead of exposing their protocols through entrypoints and getting them from there they manually write them in a dict here https://github.com/fsspec/filesystem_spec/blob/master/fsspec/registry.py#L62C2-L62C2 . Functionally this leads to the same design and interface for users and they have good documentation on how to add protocols through entrypoints. I would say the advantage of this is they can easily add protocols from outside packages and you can see a few examples of this in the dict. If we did the same thing maybe we could bring in hugging face models in for example. I am open to setting all this up following the fsspec example but at this point there is no benefit to doing it I think. I would say we go with the design I have in this PR and then if we want we can switch to the fsspec setup at any time. This should be easy to do given how I structured the code and will basically just be adding a dict to the Also, just to make it a bit more concrete on how a user can expose a model to modulus via entrypoints I included a little example here https://github.com/loliverhennigh/modulus/blob/fea-checkpoint_refactor/examples/ExternalPackage/example/load_model.py. Ill delete this after but it was really good to wrap my head around this stuff. |
akshaysubr
left a comment
There was a problem hiding this comment.
@loliverhennigh This is great work! Left some comments about suggested changes. Would also be good to change the PR title and description to something more verbose and descriptive.
I think it is good to keep this example. Just need to ensure it gets run every so often to ensure that it doesn't go stale. |
|
Hey @akshaysubr and @NickGeneva , Ready for another round of reviews! I think I addressed everything. Could I get another look when you have time. Thanks! |
|
@loliverhennigh I tested this now and it works great! Really liking how this turned out! |
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
3 similar comments
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
|
/blossom-ci |
1 similar comment
|
/blossom-ci |
* fixed grid effect * blew up commit history * fixed commint * added modulus model version and checkpoing * refactored model registry * example registry * updated model registry * modified activations * save git hash changed to verbose * Fixed most issues * ~90% functionality implemented * added docstring about json input * added map location here * updated init method * black formated * removed example external package * Update test_from_torch.py * from torch model black --------- Co-authored-by: oliver <ohennigh@nvidia.com>
* fixed grid effect * blew up commit history * fixed commint * added modulus model version and checkpoing * refactored model registry * example registry * updated model registry * modified activations * save git hash changed to verbose * Fixed most issues * ~90% functionality implemented * added docstring about json input * added map location here * updated init method * black formated * removed example external package * Update test_from_torch.py * from torch model black --------- Co-authored-by: oliver <ohennigh@nvidia.com>
* fixed grid effect * blew up commit history * fixed commint * added modulus model version and checkpoing * refactored model registry * example registry * updated model registry * modified activations * save git hash changed to verbose * Fixed most issues * ~90% functionality implemented * added docstring about json input * added map location here * updated init method * black formated * removed example external package * Update test_from_torch.py * from torch model black --------- Co-authored-by: oliver <ohennigh@nvidia.com>
Addresses Peter Sharpe's CHANGES_REQUESTED review on PR NVIDIA#1576 in full, and subsumes the 3D-UFNO portion of the planned xFNO PR into this one. Net: -139 lines despite gaining trunkless mode, time-axis-extend, multi-channel output, coord features, multi-layer lift, and 2D/3D genericity. Theme 1 — Dimensional unification (Peter NVIDIA#37, NVIDIA#47, NVIDIA#52, NVIDIA#53): - ``DeepONet`` (formerly DeepONet + DeepONet3D) takes ``dimension: int`` (2 or 3) and dispatches via ``_DIM_DEFAULTS`` and per-dim conv/spectral primitives, mirroring the ``FNO`` pattern. - ``SpatialBranch`` (formerly SpatialBranch + SpatialBranch3D) takes ``dimension`` and uses an ``_DIM_LAYERS`` lookup for ``SpectralConv``/``Conv``/``BatchNorm``/``AdaptiveAvgPool``/ ``UNetAdapter`` and the permute helpers. - ``Conv{2,3}dFCLayer`` is selected via a one-line lookup (Peter NVIDIA#45). Theme 2 — Wrappers folded into DeepONet (Peter NVIDIA#54, NVIDIA#64, NVIDIA#65): - ``wrappers.py`` deleted (``DeepONetWrapper`` and ``DeepONet3DWrapper`` removed). Padding behaviour is now a constructor flag, ``auto_pad: bool = False``, and the model dispatches to ``_forward_packed`` / ``_forward_packed_trunkless`` accordingly. - 6-cell call matrix (trunked/trunkless × packed/core × spatial/mlp) is documented in the class docstring. - The previous private ``_temporal_projection`` attribute is exposed as a public ``has_temporal_projection`` property (Peter NVIDIA#55). Theme 3 — Deduplication (Peter NVIDIA#43, NVIDIA#44, NVIDIA#50, NVIDIA#51, NVIDIA#40, related Greptile): - ``TrunkNet`` and ``MLPBranch`` deleted — both duplicated ``physicsnemo.models.mlp.FullyConnected``; users now pass any ``nn.Module`` for the trunk / branches (DI-first API). - ``_SinActivation`` deleted; the activation is registered as ``"sin"`` in ``physicsnemo.nn.module.activations.ACT2FN`` (previous commit). All ``if activation_fn.lower() == "sin"`` special-cases removed. - ``DeepONet.from_config`` and the dict-config schema removed entirely; Hydra-style ``_target_`` instantiation supersedes it. - ``count_params`` collapsed from 4 duplicate copies to 1. Theme 4 — xFNO fold-in: - ``trunk: nn.Module | None = None`` enables trunkless mode (the 3D-UFNO use case from the planned xFNO PR). - ``out_channels: int = 1`` adds multi-channel output to every path. - ``time_modes: int | None`` enables xFNO-style time-axis-extend in trunkless packed mode: replicate-pads the last spatial axis to fit ``2 * time_modes`` and crops to the requested ``target_times``. - ``coord_features`` and ``lift_layers``/``lift_hidden_width`` parameters on ``SpatialBranch`` replace the deleted dict-driven "conv encoder" option. Theme 5 — Housekeeping (Peter NVIDIA#33, NVIDIA#34, NVIDIA#38, NVIDIA#41, NVIDIA#48, NVIDIA#57, NVIDIA#58, NVIDIA#59, Charlelie NVIDIA#26, Greptile NVIDIA#5, NVIDIA#6): - ``padding.py`` renamed to private ``_padding.py``; all functions carry ``jaxtyping.Shaped`` annotations. - All public forward methods carry ``jaxtyping.Float`` annotations and ``torch.compiler.is_compiling`` shape-validation guards. - ``Literal`` type aliases for ``decoder_type`` and other enums; case-insensitive validation against ``get_args`` (Greptile NVIDIA#15). - Modern type hints throughout (``dict[str, Any] | None``, no ``Dict``/``Optional``). - All public docstrings use ``r"""`` raw-string prefix, LaTeX math for tensor shapes, double backticks for inline code, and Examples sections. - ``Notes`` block in ``branches.py`` documents the ``num_unet_layers`` 8x memory/compute penalty (Peter NVIDIA#49). Theme 6 — Tests (Peter NVIDIA#60, NVIDIA#61, NVIDIA#62, NVIDIA#63, Charlelie NVIDIA#29): - ``_FIXTURE_REGISTRY`` drives all non-regression tests across 9 scenarios: u_deeponet 2D/3D, fourier_deeponet, mionet, temporal_projection, multi-channel packed 2D, xfno trunkless 3D (with and without time-axis-extend), and core 2D MLP-branch. - New 3D gradient-flow test and 3D ``torch.compile`` test. - ``fullgraph=True`` probe tests for 2D and 3D marked ``@pytest.mark.xfail(strict=False)`` to empirically answer Peter NVIDIA#63. - ``_load_golden`` uses ``pytest.skip`` for missing fixtures so CI passes pending cluster-side golden regeneration. - Test class structure mirrors MOD-008a/b/c: ``TestDeepONetConstructor``, ``TestDeepONetNonRegression``, ``TestDeepONetCheckpoint``, ``TestDeepONetGradientFlow``, ``TestDeepONetCompile``, ``TestDeepONetTimeAxisExtend``. CHANGELOG bullet rewritten to describe the actual shipped API (was stale, still described the old config-driven 8-variant model).
Description
This feature adds better saving and loading of models that store arguments in a json file. This allows loading the model without needing to keep track of model parameters. Along with this merge there are several related side features. from_torch allows users to bring in models from torch seamlessly assuming they have json serializable inputs. Also, we have exposed the models through endpoints following discussions about expanding usability of modulus.
Here is a list of high level features,