Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer test cannot load from checkpoint when training on multiple GPUs #5144

Closed
wjaskowski opened this issue Dec 15, 2020 · 8 comments 路 Fixed by #5155
Closed

Trainer test cannot load from checkpoint when training on multiple GPUs #5144

wjaskowski opened this issue Dec 15, 2020 · 8 comments 路 Fixed by #5155
Assignees
Labels
bug Something isn't working help wanted Open to be worked on waiting on author Waiting on user action, correction, or update

Comments

@wjaskowski
Copy link

馃悰 Bug

The Trainer.test() looks for epoch=X-v0.ckpt when only epoch=X.ckpt exists, thus the result is:

Traceback (most recent call last):
  File "/home/wojciech/tmp/pytorch-lightining/main.py", line 16, in <module>
    result = trainer.test()
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in test
    results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 751, in __test_using_best_weights
    ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py", line 31, in load
    with fs.open(path_or_url, "rb") as f:
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/fsspec/spec.py", line 897, in open
    f = self._open(
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/fsspec/implementations/local.py", line 115, in _open
    return LocalFileOpener(path, mode, fs=self, **kwargs)
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/fsspec/implementations/local.py", line 197, in __init__
    self._open()
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/fsspec/implementations/local.py", line 202, in _open
    self.f = open(self.path, mode=self.mode)
FileNotFoundError: [Errno 2] No such file or directory: '/home/wojciech/tmp/pytorch-lightining/lightning_logs/version_10/checkpoints/epoch=0-v0.ckpt'

To Reproduce

Execute several times on >1 gpu machine:

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import NeptuneLogger

import os
from typing import Any, Optional

import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms

import torch
from torch.nn import functional as F
from pytorch_lightning.core.lightning import LightningModule
from torch.optim import Adam


class MNISTModule(LightningModule):

    def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('test_loss', loss, on_step=False, on_epoch=True)
        return loss

    def configure_optimizers(self):
        opt = Adam(self.parameters(), lr=1e-3)
        return opt



# noinspection PyAttributeOutsideInit
class MNISTDataModule(LightningDataModule):

    def __init__(self):
        super().__init__()
        self.train_dims = None
        self.vocab_size = 0

    def prepare_data(self):
        # called only on 1 GPU
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage: Optional[str] = None):
        # called on every GPU
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        self.train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
        self.test = MNIST(os.getcwd(), train=False, download=False, transform=transform)

        self.train, self.val = torch.utils.data.random_split(self.train, (50000, 10000))

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64, shuffle=True, drop_last=True, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=512, drop_last=False)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=512, drop_last=False)


if __name__ == '__main__':
    dm = MNISTDataModule()
    model = MNISTModule()

    params = dict(param1='a', param2=1)
    trainer = Trainer(gpus=2, max_epochs=1, accelerator='ddp')
    trainer.fit(model, datamodule=dm)

    result = trainer.test()
    print(result)

Expected behavior

No exception.

Environment

  • CUDA:
    - GPU:
    - GeForce GTX TITAN X
    - GeForce GTX TITAN X
    - GeForce GTX TITAN X
    - GeForce GTX TITAN X
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.19.4
    - pyTorch_debug: False
    - pyTorch_version: 1.7.1
    - pytorch-lightning: 1.1.0 [Also 1.0.8]
    - tqdm: 4.54.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.5
    - version: Support of different batch types聽#113-Ubuntu SMP Thu Jul 9 23:41:39 UTC 2020
### Additional context

This happens only when `gpus=2` and acceleration=`ddp`. There must be some race condition since this problem occurs every now and then only. 
@wjaskowski wjaskowski added bug Something isn't working help wanted Open to be worked on labels Dec 15, 2020
@awaelchli awaelchli self-assigned this Dec 15, 2020
@awaelchli
Copy link
Member

@wjaskowski I believe I found the fix for it. Thanks for including a reproducible script, it helped alot.
If you find the time, would you mind checking if the fix on my branch works for you? (see linked PR)

@wjaskowski
Copy link
Author

I tried the branch but the problem seems to be still there:

Traceback (most recent call last):
  File "/home/wojciech/tmp/pytorch-lightining/bug3.py", line 113, in <module>
    result = trainer.test()
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 754, in test
    results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 784, in __test_using_best_weights
    ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py", line 31, in load
    with fs.open(path_or_url, "rb") as f:
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/fsspec/spec.py", line 897,
in open
    f = self._open(
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/fsspec/implementations/local.py", line 115, in _open
    return LocalFileOpener(path, mode, fs=self, **kwargs)
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/fsspec/implementations/local.py", line 197, in __init__
    self._open()
  File "/home/wojciech/miniconda3/envs/ml/lib/python3.8/site-packages/fsspec/implementations/local.py", line 202, in _open
    self.f = open(self.path, mode=self.mode)
FileNotFoundError: [Errno 2] No such file or directory: '/home/wojciech/tmp/pytorch-lightining/lightning_logs/version_45/checkpoints/epoch=0-step=259-v0.ckpt'

@awaelchli
Copy link
Member

this is surprising. with your exact script I can get the error on master within 2-3 trials, but on the bugfix branch (bugfix/ddp-ckpt) I ran it probably 20+ times and it never occurs.

@wjaskowski
Copy link
Author

I will carefully give it a try once again when I will regain access to a machine with >1 GPUs.

@edenlightning
Copy link
Contributor

@wjaskowski any update?

@edenlightning edenlightning added the waiting on author Waiting on user action, correction, or update label Jan 8, 2021
@edenlightning
Copy link
Contributor

Feel free to reopen if needed!

@sustcsonglin
Copy link

I have the same issues...

@blacksnail789521
Copy link

Same issue in 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants