Skip to content

Fully Sharded does not enabled mixed precision for wrapped modules #12964

@SeanNaren

Description

@SeanNaren

🐛 Bug

When using DDPFullyShardedStrategy (fairscale) we do not correctly enable the mixed precision flag for wrapped modules when precision=16:

reproduce:

import os

import torch
from fairscale.nn import wrap
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.strategies import DDPFullyShardedStrategy


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

    def configure_sharded_model(self) -> None:
        self.layer = wrap(torch.nn.Linear(32, 2))
        assert self.layer.mixed_precision # fails


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        strategy=DDPFullyShardedStrategy(),
        precision=16,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

This is because the internal check looks only for the "mixed" flag: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/strategies/fully_sharded.py#L166

A separate issue is that passing precision="mixed" causes another error from the connector as it's not checked correctly.

cc @SeanNaren @awaelchli @rohitgr7 @akihironitta

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions