Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelCheckpointing behaviour changed from previous versions (self.best_model_path holds rank 1 values) #14302

Closed
nithinraok opened this issue Aug 18, 2022 · 3 comments
Assignees
Labels
bug Something isn't working checkpointing Related to checkpointing waiting on author Waiting on user action, correction, or update
Milestone

Comments

@nithinraok
Copy link
Contributor

nithinraok commented Aug 18, 2022

馃悰 Bug

self.best_k_models stores checkpoints from two devices with different values and self.best_model_path points to cuda:1 while checkpoints saved were from cuda:0

This behavior has changed from 1.6.5 to 1.7.x

To Reproduce

import os

import pytorch_lightning as pl
from omegaconf import OmegaConf
from torch import Tensor, nn, optim, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from nemo.utils.exp_manager import exp_manager

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard by defaulti
        self.log("loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log("val_loss", loss)
        return loss 

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def save_to(self, save_path):
        pass


dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset, batch_size = 256)
dev_loader = utils.data.DataLoader(dataset, batch_size = 256)

# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

trainer = pl.Trainer(
    limit_train_batches=100,
    max_epochs=2,
    devices=2,
    strategy='ddp',
    accelerator='gpu',
    enable_checkpointing=False,
    logger=False,
)

# kwargs = {
#     'dirpath': '/data/local/checkpoints', 
#     'filename': 'test_ptl--{val_loss:.4f}-{epoch}', 
#     'monitor': 'val_loss', 'verbose': True, 
#     'save_last': True, 
#     'save_top_k': 4, 
#     'save_weights_only': False, 
#     'mode': 'min', 
#     'every_n_epochs': 1
#     }
# from pytorch_lightning.callbacks import ModelCheckpoint
# checkpoint_callback = ModelCheckpoint(**kwargs)
# trainer.callbacks.append(checkpoint_callback)

manager = {
    'exp_dir': '/data/recognition/tarred',
    'name': 'test_ptl',
    'create_tensorboard_logger': True,
    'create_checkpoint_callback': True,
    'use_datetime_version': False,
    'create_wandb_logger': False,
    'resume_if_exists': True,
    'resume_ignore_no_checkpoint': True,
    'checkpoint_callback_params': {'save_top_k': 4, 'save_best_model': True},
    'version': 'local'
}

cfg = OmegaConf.create(manager)
log_dir = exp_manager(trainer, cfg)

trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=dev_loader)
# kprint(trainer.callbacks[-1].best_model_path)

Expected behavior

Run code without issues

Environment

1.7.x
a.txt

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @otaj

@nithinraok nithinraok added the needs triage Waiting to be triaged by maintainers label Aug 18, 2022
@awaelchli awaelchli added bug Something isn't working checkpointing Related to checkpointing and removed needs triage Waiting to be triaged by maintainers labels Aug 19, 2022
@awaelchli awaelchli added this to the pl:1.7.x milestone Aug 19, 2022
@awaelchli
Copy link
Member

awaelchli commented Aug 19, 2022

Hi @nithinraok
So what was missing in your bugreport here is a description of what the error or misbehavior is that you see. I can reproduce the following error, so I'm assuming this is what you were seeing:

pytorch_lightning.utilities.exceptions.DeadlockDetectedException: DeadLock detected from rank: 1 
 Traceback (most recent call last):
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 652, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 741, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 1181, in _run
    results = self._run_stage()
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 1267, in _run_stage
    return self._run_train()
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 1298, in _run_train
    self.fit_loop.run()
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/loops/loop.py", line 207, in run
    output = self.on_run_end()
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/loops/fit_loop.py", line 329, in on_run_end
    self.trainer._call_callback_hooks("on_train_end")
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 1606, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/Users/adrian/miniconda3/envs/lightning/lib/python3.10/site-packages/nemo/utils/exp_manager.py", line 822, in on_train_end
    trainer._checkpoint_connector.restore(self.best_model_path)
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 136, in restore
    self.resume_start(checkpoint_path)
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 85, in resume_start
    self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 89, in _load_and_validate_checkpoint
    loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path)
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/strategies/strategy.py", line 345, in load_checkpoint
    return self.checkpoint_io.load_checkpoint(checkpoint_path)
  File "/Users/adrian/repositories/lightning/src/pytorch_lightning/plugins/io/torch_plugin.py", line 83, in load_checkpoint
    raise FileNotFoundError(f"Checkpoint at {path} not found. Aborting training.")
FileNotFoundError: Checkpoint at /Users/adrian/repositories/lightning/examples/pl_bug_report/data/recognition/tarred/test_ptl/local/checkpoints/test_ptl--val_loss=0.0500-epoch=1.ckpt not found. Aborting training.


This is because in the exp manager, Nemo reloads the checkpoint like so:

trainer._checkpoint_connector.restore(self.best_model_path)

This code executes on each rank, but only rank 0 saved best_model_path while rank 1 has a different value for best_model_path and can't find it on disk.

This can be fixed by broadcasting the value to all ranks from rank 0:

best_model_path = trainer.strategy.broadcast(self.best_model_path)
trainer._checkpoint_connector.restore(best_model_path)

This small fix can be added in Nemo to unblock you.

By git bisect between 1.6.5 and master, I found that the change was introduced with #13364. cc @carmocca

This means that a second way to solve this would be to set sync_dist=True in the self.log call for the metric monitored by the checkpoint callback, e.g. in this specific example it would be self.log("val_loss", val_loss, sync_dist=True)

@awaelchli
Copy link
Member

awaelchli commented Aug 19, 2022

I also found that in nemo there is usage of deprecated apis in Lightning, for example, trainer.training_type_plugin. This should be replaced with trainer.strategy. Just FYI because this will be removed in the next release.

@awaelchli awaelchli added the waiting on author Waiting on user action, correction, or update label Aug 19, 2022
@awaelchli awaelchli self-assigned this Aug 19, 2022
@nithinraok
Copy link
Contributor Author

Thank you @awaelchli , the suggested solution works. Yes we need to change plugin to strategy as well before next release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing waiting on author Waiting on user action, correction, or update
Projects
None yet
Development

No branches or pull requests

2 participants