Skip to content

LightningLite not updating DeviceDtypeModuleMixin correctly #10556

@justusschock

Description

@justusschock

🐛 Bug

When using LightningLite and transferring the _LiteModule to cpu, attributes of DeviceDtypeModule are not updated.

To Reproduce

class SomeDummy(DeviceDtypeModuleMixin):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Linear(1,1)

class MyClass(LightningLite):
    def run(self):
        model = SomeDummy()
        model, optimiser = self.setup(model, torch.optim.Adam(model.parameters()))

        #  do some stuff
        
        # now clean up gpu memory for later stages
        model.cpu()
        assert str(model.module.device) == 'cpu'

MyClass(accelerator='gpu', devices=1).run()

Expected behavior

model.module.device should be Cpu

Additional context

Could probably be solved by using DeviceDtypeModuleMixin as base class for the _LiteModule since this is an issue with the to function only calling _apply on all child tensors instead of calling .to on every child module.

cc @carmocca @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingfabriclightning.fabric.Fabric

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions