Skip to content

Commit d317ebf

Browse files
ethanwharrislexierule
authored andcommitted
Fix module dict in base finetuning (#8170)
* Fix module dict in base finetuning * Update CHANGELOG.md
1 parent 2856e62 commit d317ebf

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

pytorch_lightning/callbacks/finetuning.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
2121

2222
import torch
23-
from torch.nn import Module
23+
from torch.nn import Module, ModuleDict
2424
from torch.nn.modules.batchnorm import _BatchNorm
2525
from torch.optim.optimizer import Optimizer
2626

@@ -114,6 +114,9 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
114114
Returns:
115115
List of modules
116116
"""
117+
if isinstance(modules, ModuleDict):
118+
modules = modules.values()
119+
117120
if isinstance(modules, Iterable):
118121
_modules = []
119122
for m in modules:

tests/callbacks/test_finetuning_callback.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,15 +330,17 @@ class ConvBlockParam(nn.Module):
330330

331331
def __init__(self, in_channels, out_channels):
332332
super().__init__()
333-
self.conv = nn.Conv2d(in_channels, out_channels, 3)
334-
self.act = nn.ReLU()
333+
self.module_dict = nn.ModuleDict({
334+
"conv": nn.Conv2d(in_channels, out_channels, 3),
335+
"act": nn.ReLU(),
336+
})
335337
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
336338
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
337339
self.bn = nn.BatchNorm2d(out_channels)
338340

339341
def forward(self, x):
340-
x = self.conv(x)
341-
x = self.act(x)
342+
x = self.module_dict["conv"](x)
343+
x = self.module_dict["act"](x)
342344
return self.bn(x)
343345

344346
model = nn.Sequential(
@@ -352,7 +354,7 @@ def forward(self, x):
352354
assert len(BaseFinetuning.flatten_modules(model)) == 10
353355

354356
BaseFinetuning.freeze(model.encoder, train_bn=True)
355-
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
357+
assert not model.encoder[0].module_dict["conv"].weight.requires_grad # Validate a leaf module parameter is frozen
356358
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
357359
assert model.encoder[0].bn.weight.requires_grad
358360

0 commit comments

Comments
 (0)