-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingfabriclightning.fabric.Fabriclightning.fabric.Fabric
Milestone
Description
🐛 Bug
When using LightningLite and transferring the _LiteModule to cpu, attributes of DeviceDtypeModule are not updated.
To Reproduce
class SomeDummy(DeviceDtypeModuleMixin):
def __init__(self):
super().__init__()
self.a = torch.nn.Linear(1,1)
class MyClass(LightningLite):
def run(self):
model = SomeDummy()
model, optimiser = self.setup(model, torch.optim.Adam(model.parameters()))
# do some stuff
# now clean up gpu memory for later stages
model.cpu()
assert str(model.module.device) == 'cpu'
MyClass(accelerator='gpu', devices=1).run()Expected behavior
model.module.device should be Cpu
Additional context
Could probably be solved by using DeviceDtypeModuleMixin as base class for the _LiteModule since this is an issue with the to function only calling _apply on all child tensors instead of calling .to on every child module.
awaelchli
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingfabriclightning.fabric.Fabriclightning.fabric.Fabric