-
Couldn't load subscription status.
- Fork 3.6k
Closed
Copy link
Description
🐛 Bug
When using DDP and calling estimated_stepping_batches in configure_optimizers, an error is thrown.
It happens because there's an attempt to sync between the processes using the model's device, but the model hasn't been moved to a non-cpu device yet.
https://github.com/PyTorchLightning/pytorch-lightning/blob/7ee690758ccad7f702460d056f6369c1d4371a46/pytorch_lightning/utilities/data.py#L124
The error:
File "~/bug_report_model.py", line 36, in configure_optimizers
self.trainer.estimated_stepping_batches
File "~/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 2702, in estimated_stepping_batches
self.reset_train_dataloader()
File "~/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 1848, in reset_train_dataloader
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
File "~/pytorch-lightning/pytorch_lightning/utilities/data.py", line 124, in has_len_all_ranks
total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum")
File "~/pytorch-lightning/pytorch_lightning/strategies/ddp_spawn.py", line 224, in reduce
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
File "~/pytorch-lightning/pytorch_lightning/utilities/distributed.py", line 95, in sync_ddp_if_available
return sync_ddp(result, group=group, reduce_op=reduce_op)
File "~/pytorch-lightning/pytorch_lightning/utilities/distributed.py", line 129, in sync_ddp
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
File "~/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
work = group.allreduce([tensor], opts)
RuntimeError: Tensors must be CUDA and dense
To Reproduce
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
self.trainer.estimated_stepping_batches # Can be used here to define the LR scheduler
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run(ckpt_path=None):
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
num_sanity_val_steps=0,
max_epochs=1,
gpus=2,
logger=False,
enable_checkpointing=False
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()Environment
- PyTorch Lightning Version (e.g., 1.5.0): master
- PyTorch Version (e.g., 1.10): 1.10
- Python version (e.g., 3.9): 3.9
cc @justusschock @kaushikb11 @awaelchli @akihironitta @rohitgr7
austinmw, Yevgnen and sangrok-lee-1021
Metadata
Metadata
Assignees
Labels
strategy: ddpDistributedDataParallelDistributedDataParallel