Skip to content

Can't use estimated_stepping_batches in configure_optimizers with DDP #12317

@eladsegal

Description

@eladsegal

🐛 Bug

When using DDP and calling estimated_stepping_batches in configure_optimizers, an error is thrown.
It happens because there's an attempt to sync between the processes using the model's device, but the model hasn't been moved to a non-cpu device yet.
https://github.com/PyTorchLightning/pytorch-lightning/blob/7ee690758ccad7f702460d056f6369c1d4371a46/pytorch_lightning/utilities/data.py#L124
The error:

  File "~/bug_report_model.py", line 36, in configure_optimizers
    self.trainer.estimated_stepping_batches
  File "~/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 2702, in estimated_stepping_batches
    self.reset_train_dataloader()
  File "~/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1848, in reset_train_dataloader
    if has_len_all_ranks(self.train_dataloader, self.strategy, module)
  File "~/pytorch-lightning/pytorch_lightning/utilities/data.py", line 124, in has_len_all_ranks
    total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum")
  File "~/pytorch-lightning/pytorch_lightning/strategies/ddp_spawn.py", line 224, in reduce
    tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
  File "~/pytorch-lightning/pytorch_lightning/utilities/distributed.py", line 95, in sync_ddp_if_available
    return sync_ddp(result, group=group, reduce_op=reduce_op)
  File "~/pytorch-lightning/pytorch_lightning/utilities/distributed.py", line 129, in sync_ddp
    torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
  File "~/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: Tensors must be CUDA and dense

To Reproduce

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        self.trainer.estimated_stepping_batches  # Can be used here to define the LR scheduler 
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run(ckpt_path=None):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        num_sanity_val_steps=0,
        max_epochs=1,
        gpus=2,
        logger=False,
        enable_checkpointing=False
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Environment

  • PyTorch Lightning Version (e.g., 1.5.0): master
  • PyTorch Version (e.g., 1.10): 1.10
  • Python version (e.g., 3.9): 3.9

cc @justusschock @kaushikb11 @awaelchli @akihironitta @rohitgr7

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions