Skip to content

Commit

Permalink
Include the training mode in the ModelSummary (#19468)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Feb 15, 2024
1 parent 1967547 commit 120c87f
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 31 deletions.
20 changes: 10 additions & 10 deletions docs/source-pytorch/debug/debugging_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ this generate a table like:

.. code-block:: text
| Name | Type | Params
----------------------------------
0 | net | Sequential | 132 K
1 | net.0 | Linear | 131 K
2 | net.1 | BatchNorm1d | 1.0 K
| Name | Type | Params | Mode
-------------------------------------------
0 | net | Sequential | 132 K | train
1 | net.0 | Linear | 131 K | train
2 | net.1 | BatchNorm1d | 1.0 K | train
To add the child modules to the summary add a :class:`~lightning.pytorch.callbacks.model_summary.ModelSummary`:

Expand Down Expand Up @@ -162,10 +162,10 @@ With the input array, the summary table will include the input and output layer

.. code-block:: text
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512]
| Name | Type | Params | Mode | In sizes | Out sizes
----------------------------------------------------------------------
0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | train | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | train | [10, 512] | [10, 512]
when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.
2 changes: 1 addition & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- The `ModelSummary` and `RichModelSummary` callbacks now display the training mode of each layer in the column "Mode" ([#19468](https://github.com/Lightning-AI/lightning/pull/19468))

-

Expand Down
1 change: 1 addition & 0 deletions src/lightning/pytorch/callbacks/rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def summarize(
table.add_column("Name", justify="left", no_wrap=True)
table.add_column("Type")
table.add_column("Params", justify="right")
table.add_column("Mode")

column_names = list(zip(*summary_data))[0]

Expand Down
31 changes: 21 additions & 10 deletions src/lightning/pytorch/utilities/model_summary/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def num_parameters(self) -> int:
"""Returns the number of parameters in this module."""
return sum(math.prod(p.shape) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters())

@property
def training(self) -> bool:
"""Returns whether the module is in training mode."""
return self._module.training


class ModelSummary:
"""Generates a summary of all layers in a :class:`~lightning.pytorch.core.LightningModule`.
Expand Down Expand Up @@ -178,21 +183,21 @@ class ModelSummary:
...
>>> model = LitModel()
>>> ModelSummary(model, max_depth=1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
------------------------------------------------------------
| Name | Type | Params | Mode | In sizes | Out sizes
--------------------------------------------------------------------
0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
--------------------------------------------------------------------
132 K Trainable params
0 Non-trainable params
132 K Total params
0.530 Total estimated model params size (MB)
>>> ModelSummary(model, max_depth=-1) # doctest: +NORMALIZE_WHITESPACE
| Name | Type | Params | In sizes | Out sizes
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512]
--------------------------------------------------------------
| Name | Type | Params | Mode | In sizes | Out sizes
----------------------------------------------------------------------
0 | net | Sequential | 132 K | train | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | train | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | train | [10, 512] | [10, 512]
----------------------------------------------------------------------
132 K Trainable params
0 Non-trainable params
132 K Total params
Expand Down Expand Up @@ -247,6 +252,10 @@ def out_sizes(self) -> List:
def param_nums(self) -> List[int]:
return [layer.num_parameters for layer in self._layer_summary.values()]

@property
def training_modes(self) -> List[bool]:
return [layer.training for layer in self._layer_summary.values()]

@property
def total_parameters(self) -> int:
return sum(p.numel() if not _is_lazy_weight_tensor(p) else 0 for p in self._model.parameters())
Expand Down Expand Up @@ -315,6 +324,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]:
("Name", self.layer_names),
("Type", self.layer_types),
("Params", list(map(get_human_readable_count, self.param_nums))),
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
]
if self._model.example_input_array is not None:
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))
Expand All @@ -333,6 +343,7 @@ def _add_leftover_params_to_summary(self, arrays: List[Tuple[str, List[str]]], t
layer_summaries["Name"].append(LEFTOVER_PARAMS_NAME)
layer_summaries["Type"].append(NOT_APPLICABLE)
layer_summaries["Params"].append(get_human_readable_count(total_leftover_params))
layer_summaries["Mode"].append(NOT_APPLICABLE)
if "In sizes" in layer_summaries:
layer_summaries["In sizes"].append(NOT_APPLICABLE)
if "Out sizes" in layer_summaries:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def _get_summary_data(self) -> List[Tuple[str, List[str]]]:
("Type", self.layer_types),
("Params", list(map(get_human_readable_count, self.param_nums))),
("Params per Device", list(map(get_human_readable_count, self.parameters_per_layer))),
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
]
if self._model.example_input_array is not None:
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))
Expand Down
12 changes: 8 additions & 4 deletions tests/tests_pytorch/callbacks/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
from typing import Any, List, Tuple

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelSummary
Expand Down Expand Up @@ -41,14 +41,15 @@ def test_model_summary_callback_with_enable_model_summary_true():
assert model_summary_callback._max_depth == 1


def test_custom_model_summary_callback_summarize(tmpdir):
def test_custom_model_summary_callback_summarize(tmp_path):
class CustomModelSummary(ModelSummary):
@staticmethod
def summarize(
summary_data: List[List[Union[str, List[str]]]],
summary_data: List[Tuple[str, List[str]]],
total_parameters: int,
trainable_parameters: int,
model_size: float,
**summarize_kwargs: Any,
) -> None:
assert summary_data[1][0] == "Name"
assert summary_data[1][1][0] == "layer"
Expand All @@ -60,7 +61,10 @@ def summarize(
assert total_parameters == 66
assert trainable_parameters == 66

assert summary_data[4][0] == "Mode"
assert summary_data[4][1][0] == "train"

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, callbacks=CustomModelSummary(), max_steps=1)
trainer = Trainer(default_root_dir=tmp_path, callbacks=CustomModelSummary(), max_steps=1)

trainer.fit(model)
2 changes: 1 addition & 1 deletion tests/tests_pytorch/callbacks/test_rich_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def example_input_array(self) -> Any:
assert mock_console.call_count == 2
# assert that the input summary data was converted correctly
args, _ = mock_table_add_row.call_args_list[0]
assert args[1:] == ("0", "layer", "Linear", "66 ", "[4, 32]", "[4, 2]")
assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "[4, 32]", "[4, 2]")
6 changes: 3 additions & 3 deletions tests/tests_pytorch/utilities/test_deepspeed_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@RunIf(min_cuda_gpus=2, deepspeed=True, standalone=True)
def test_deepspeed_summary(tmpdir):
def test_deepspeed_summary(tmp_path):
"""Test to ensure that the summary contains the correct values when stage 3 is enabled and that the trainer enables
the `DeepSpeedSummary` when DeepSpeed is used."""

Expand All @@ -37,12 +37,12 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -

# check the additional params per device
summary_data = model_summary._get_summary_data()
params_per_device = summary_data[-1][-1]
params_per_device = summary_data[4][-1]
assert int(params_per_device[0]) == (model_summary.total_parameters // 2)

trainer = Trainer(
strategy=DeepSpeedStrategy(stage=3),
default_root_dir=tmpdir,
default_root_dir=tmp_path,
accelerator="gpu",
fast_dev_run=True,
devices=2,
Expand Down
30 changes: 28 additions & 2 deletions tests/tests_pytorch/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,13 +327,13 @@ def test_empty_model_size(max_depth):
pytest.param("mps", marks=RunIf(mps=True)),
],
)
def test_model_size_precision(tmpdir, accelerator):
def test_model_size_precision(tmp_path, accelerator):
"""Test model size for half and full precision."""
model = PreCalculatedModel()

# fit model
trainer = Trainer(
default_root_dir=tmpdir, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32
default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32
)
trainer.fit(model)
summary = summarize(model)
Expand Down Expand Up @@ -430,3 +430,29 @@ def forward(self, x):
assert model.training
assert model.layer1.training
assert not model.layer2.training


def test_summary_training_mode():
"""Test that the model summary captures the training mode on all submodules."""
model = DeepNestedModel()
model.branch1[1][0].eval()
model.branch2.eval()

summary = summarize(model, max_depth=1)
summary_data = OrderedDict(summary._get_summary_data())
assert summary_data["Mode"] == [
"train", # branch1
"eval", # branch2
"train", # head
]

summary = summarize(model, max_depth=-1)
expected_eval = {"branch1.1.0", "branch2"}
for name, layer_summary in summary._layer_summary.items():
assert (name in expected_eval) == (not layer_summary.training)

# A model with params not belonging to a layer
model = NonLayerParamsModel()
model.layer.eval()
summary_data = OrderedDict(summarize(model)._get_summary_data())
assert summary_data["Mode"] == ["eval", "n/a"]

0 comments on commit 120c87f

Please sign in to comment.