-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingstrategy: fairscale fsdp (removed)Fully Sharded Data ParallelFully Sharded Data Parallel
Milestone
Description
🐛 Bug
When using DDPFullyShardedStrategy (fairscale) we do not correctly enable the mixed precision flag for wrapped modules when precision=16:
reproduce:
import os
import torch
from fairscale.nn import wrap
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.strategies import DDPFullyShardedStrategy
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def configure_sharded_model(self) -> None:
self.layer = wrap(torch.nn.Linear(32, 2))
assert self.layer.mixed_precision # fails
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
strategy=DDPFullyShardedStrategy(),
precision=16,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()This is because the internal check looks only for the "mixed" flag: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/strategies/fully_sharded.py#L166
A separate issue is that passing precision="mixed" causes another error from the connector as it's not checked correctly.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstrategy: fairscale fsdp (removed)Fully Sharded Data ParallelFully Sharded Data Parallel