Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a DDP info message that was never shown ([#8111](https://github.com/PyTorchLightning/pytorch-lightning/pull/8111))


- Fixed a bug where an infinite recursion would be triggered when using the `BaseFinetuning` callback on a model that contains a `ModuleDict` ([#8170](https://github.com/PyTorchLightning/pytorch-lightning/pull/8170))

## [1.3.7] - 2021-06-22

- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union

import torch
from torch.nn import Module
from torch.nn import Module, ModuleDict
from torch.nn.modules.batchnorm import _BatchNorm
from torch.optim.optimizer import Optimizer

Expand Down Expand Up @@ -114,6 +114,9 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
Returns:
List of modules
"""
if isinstance(modules, ModuleDict):
modules = modules.values()

if isinstance(modules, Iterable):
_modules = []
for m in modules:
Expand Down
12 changes: 7 additions & 5 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,17 @@ class ConvBlockParam(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
self.module_dict = nn.ModuleDict({
"conv": nn.Conv2d(in_channels, out_channels, 3),
"act": nn.ReLU(),
})
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)

def forward(self, x):
x = self.conv(x)
x = self.act(x)
x = self.module_dict["conv"](x)
x = self.module_dict["act"](x)
return self.bn(x)

model = nn.Sequential(
Expand All @@ -353,7 +355,7 @@ def forward(self, x):
assert len(BaseFinetuning.flatten_modules(model)) == 10

BaseFinetuning.freeze(model.encoder, train_bn=True)
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
assert not model.encoder[0].module_dict["conv"].weight.requires_grad # Validate a leaf module parameter is frozen
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
assert model.encoder[0].bn.weight.requires_grad

Expand Down