diff --git a/CHANGELOG.md b/CHANGELOG.md index 980d2a450f786..81fbefca31453 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -315,6 +315,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a DDP info message that was never shown ([#8111](https://github.com/PyTorchLightning/pytorch-lightning/pull/8111)) +- Fixed a bug where an infinite recursion would be triggered when using the `BaseFinetuning` callback on a model that contains a `ModuleDict` ([#8170](https://github.com/PyTorchLightning/pytorch-lightning/pull/8170)) + ## [1.3.7] - 2021-06-22 - Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975)) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index cac4e4c9c857e..fe7e5f7bc09eb 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch -from torch.nn import Module +from torch.nn import Module, ModuleDict from torch.nn.modules.batchnorm import _BatchNorm from torch.optim.optimizer import Optimizer @@ -114,6 +114,9 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - Returns: List of modules """ + if isinstance(modules, ModuleDict): + modules = modules.values() + if isinstance(modules, Iterable): _modules = [] for m in modules: diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index fe4cd5cce1ef8..7492bcac7804a 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -331,15 +331,17 @@ class ConvBlockParam(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, 3) - self.act = nn.ReLU() + self.module_dict = nn.ModuleDict({ + "conv": nn.Conv2d(in_channels, out_channels, 3), + "act": nn.ReLU(), + }) # add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float)) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): - x = self.conv(x) - x = self.act(x) + x = self.module_dict["conv"](x) + x = self.module_dict["act"](x) return self.bn(x) model = nn.Sequential( @@ -353,7 +355,7 @@ def forward(self, x): assert len(BaseFinetuning.flatten_modules(model)) == 10 BaseFinetuning.freeze(model.encoder, train_bn=True) - assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen + assert not model.encoder[0].module_dict["conv"].weight.requires_grad # Validate a leaf module parameter is frozen assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen assert model.encoder[0].bn.weight.requires_grad