Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))


- Fixed `materialize_module` setting a module's child recursively ([#12870](https://github.com/PyTorchLightning/pytorch-lightning/pull/12870))


- Fixed the number of references of `LightningModule` so it can be deleted ([#12897](https://github.com/PyTorchLightning/pytorch-lightning/pull/12897))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)):
materialize_module(child)
else:
setattr(child, name, materialize_fn())
setattr(root_module, name, materialize_fn())
return root_module


Expand Down
18 changes: 16 additions & 2 deletions tests/utilities/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from torch import nn

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf


Expand All @@ -24,7 +26,7 @@ def __init__(self, num_layers: int):
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)] + [nn.Dropout(), nn.LayerNorm(1)])


class BoringModel(LightningModule):
class SimpleBoringModel(LightningModule):
def __init__(self, num_layers: int):
super().__init__()
self.save_hyperparameters()
Expand All @@ -48,7 +50,7 @@ def test_init_meta_context():
assert not is_on_meta_device(mlp)
assert not is_on_meta_device(nn.Module())

model = BoringModel(4)
model = SimpleBoringModel(4)
assert model.layer[0].weight.device.type == "meta"
materialize_module(model)
assert model.layer[0].weight.device.type == "cpu"
Expand All @@ -68,3 +70,15 @@ def test_init_meta_context():

m = nn.Linear(in_features=1, out_features=1)
assert m.weight.device.type == "cpu"


@RunIf(min_torch="1.10.0", standalone=True)
def test_materialize_module_recursive_child():
"""Test materialize_module doesn't set a child recursively to a model instantiated within init_meta_context."""
with init_meta_context():
model = BoringModel()

materialize_module(model)

with pytest.raises(AttributeError, match="'Linear' object has no attribute 'layer'"):
model.layer.layer