Skip to content

Device error in ModelCheckpoint when checking monitored metric improvement after resuming training #14446

@mlgavilan

Description

@mlgavilan

🐛 Bug

A device error is raised in ModelCheckpoint callback when checking if the current epoch model is better and should be saved.

This error is happening if training on GPU and after resuming the training from a checkpoint. Under these conditions, current metric is on GPU and the best model metric inside the callback is on CPU.

This raises the error when the logged metric that is being monitored is a single element tensor with ndim=1, e.g torch.tensor([1.0]). If logged metric has ndim = 0, e.g torch.tensor(1.0), there is no error even thought the tensors are in different devices.

File "/home/miguel/envs/roof_elements/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 495, in check_monitor_top_k
    should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

To Reproduce

Run this code a first time and stop it after some epochs have been completed. Then uncomment the line in the main block to resume from the saved checkpoint and run again.

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum().unsqueeze(0)   # add dimension 0 to reproduce the issue
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum().unsqueeze(0)  # add dimension 0 to reproduce the issue
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run(resume_ckpt=None):

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        callbacks=ModelCheckpoint(monitor='valid_loss', mode='min', filename='bestLoss'),
        accelerator='gpu',
        logger=None,
        default_root_dir=os.getcwd(),
        max_epochs=1000)

    # fit
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path=resume_ckpt)


if __name__ == "__main__":
    resume_ckpt = None

    ## Uncomment next line in a second execution to resume from the saved checkpoint and reproduce the issue
    # resume_ckpt = os.path.join(os.getcwd(), 'checkpoints', 'bestLoss.ckpt')

    run(resume_ckpt=resume_ckpt)

Expected behavior

Check if the current model is to be saved without rising the Runtime error due to the metric tensors being in different devices

Environment

Details
    * CUDA:
            - GPU:
                    - Quadro T1000
            - available:         True
            - version:           10.2
    * Lightning:
            - pytorch-lightning: 1.7.3
            - torch:             1.12.1
            - torchmetrics:      0.9.3
    * Packages:
            - absl-py:           1.2.0
            - aiohttp:           3.8.1
            - aiosignal:         1.2.0
            - async-timeout:     4.0.2
            - attrs:             22.1.0
            - cachetools:        5.2.0
            - certifi:           2022.6.15
            - charset-normalizer: 2.1.1
            - frozenlist:        1.3.1
            - fsspec:            2022.7.1
            - google-auth:       2.11.0
            - google-auth-oauthlib: 0.4.6
            - grpcio:            1.48.0
            - idna:              3.3
            - importlib-metadata: 4.12.0
            - markdown:          3.4.1
            - markupsafe:        2.1.1
            - multidict:         6.0.2
            - numpy:             1.23.2
            - oauthlib:          3.2.0
            - packaging:         21.3
            - pip:               22.2.2
            - protobuf:          3.19.4
            - pyasn1:            0.4.8
            - pyasn1-modules:    0.2.8
            - pydeprecate:       0.3.2
            - pyparsing:         3.0.9
            - pytorch-lightning: 1.7.3
            - pyyaml:            6.0
            - requests:          2.28.1
            - requests-oauthlib: 1.3.1
            - rsa:               4.9
            - setuptools:        63.3.0
            - six:               1.16.0
            - tensorboard:       2.10.0
            - tensorboard-data-server: 0.6.1
            - tensorboard-plugin-wit: 1.8.1
            - torch:             1.12.1
            - torchmetrics:      0.9.3
            - tqdm:              4.64.0
            - typing-extensions: 4.3.0
            - urllib3:           1.26.12
            - werkzeug:          2.2.2
            - wheel:             0.37.1
            - yarl:              1.8.1
            - zipp:              3.8.1
    * System:
            - OS:                Linux
            - architecture:
                    - 64bit
                    - ELF
            - processor:         x86_64
            - python:            3.8.10
            - version:           #49~20.04.1-Ubuntu SMP Thu Aug 4 19:15:44 UTC 2022

Additional context

cc @carmocca @awaelchli @ninginthecloud @jjenniferdai @rohitgr7 @akihironitta

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingcallback: model checkpointplGeneric label for PyTorch Lightning package

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions