-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
A device error is raised in ModelCheckpoint callback when checking if the current epoch model is better and should be saved.
This error is happening if training on GPU and after resuming the training from a checkpoint. Under these conditions, current metric is on GPU and the best model metric inside the callback is on CPU.
This raises the error when the logged metric that is being monitored is a single element tensor with ndim=1, e.g torch.tensor([1.0]). If logged metric has ndim = 0, e.g torch.tensor(1.0), there is no error even thought the tensors are in different devices.
File "/home/miguel/envs/roof_elements/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 495, in check_monitor_top_k
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
To Reproduce
Run this code a first time and stop it after some epochs have been completed. Then uncomment the line in the main block to resume from the saved checkpoint and run again.
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum().unsqueeze(0) # add dimension 0 to reproduce the issue
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum().unsqueeze(0) # add dimension 0 to reproduce the issue
self.log("valid_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run(resume_ckpt=None):
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
callbacks=ModelCheckpoint(monitor='valid_loss', mode='min', filename='bestLoss'),
accelerator='gpu',
logger=None,
default_root_dir=os.getcwd(),
max_epochs=1000)
# fit
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path=resume_ckpt)
if __name__ == "__main__":
resume_ckpt = None
## Uncomment next line in a second execution to resume from the saved checkpoint and reproduce the issue
# resume_ckpt = os.path.join(os.getcwd(), 'checkpoints', 'bestLoss.ckpt')
run(resume_ckpt=resume_ckpt)Expected behavior
Check if the current model is to be saved without rising the Runtime error due to the metric tensors being in different devices
Environment
Details
* CUDA:
- GPU:
- Quadro T1000
- available: True
- version: 10.2
* Lightning:
- pytorch-lightning: 1.7.3
- torch: 1.12.1
- torchmetrics: 0.9.3
* Packages:
- absl-py: 1.2.0
- aiohttp: 3.8.1
- aiosignal: 1.2.0
- async-timeout: 4.0.2
- attrs: 22.1.0
- cachetools: 5.2.0
- certifi: 2022.6.15
- charset-normalizer: 2.1.1
- frozenlist: 1.3.1
- fsspec: 2022.7.1
- google-auth: 2.11.0
- google-auth-oauthlib: 0.4.6
- grpcio: 1.48.0
- idna: 3.3
- importlib-metadata: 4.12.0
- markdown: 3.4.1
- markupsafe: 2.1.1
- multidict: 6.0.2
- numpy: 1.23.2
- oauthlib: 3.2.0
- packaging: 21.3
- pip: 22.2.2
- protobuf: 3.19.4
- pyasn1: 0.4.8
- pyasn1-modules: 0.2.8
- pydeprecate: 0.3.2
- pyparsing: 3.0.9
- pytorch-lightning: 1.7.3
- pyyaml: 6.0
- requests: 2.28.1
- requests-oauthlib: 1.3.1
- rsa: 4.9
- setuptools: 63.3.0
- six: 1.16.0
- tensorboard: 2.10.0
- tensorboard-data-server: 0.6.1
- tensorboard-plugin-wit: 1.8.1
- torch: 1.12.1
- torchmetrics: 0.9.3
- tqdm: 4.64.0
- typing-extensions: 4.3.0
- urllib3: 1.26.12
- werkzeug: 2.2.2
- wheel: 0.37.1
- yarl: 1.8.1
- zipp: 3.8.1
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.10
- version: #49~20.04.1-Ubuntu SMP Thu Aug 4 19:15:44 UTC 2022
Additional context
cc @carmocca @awaelchli @ninginthecloud @jjenniferdai @rohitgr7 @akihironitta