Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the module to the precision dtype #12591

Closed
yuvalkirstain opened this issue Apr 3, 2022 · 9 comments 路 Fixed by #18113
Closed

Move the module to the precision dtype #12591

yuvalkirstain opened this issue Apr 3, 2022 · 9 comments 路 Fixed by #18113
Labels
feature Is an improvement or enhancement performance precision: amp Automatic Mixed Precision
Milestone

Comments

@yuvalkirstain
Copy link

yuvalkirstain commented Apr 3, 2022

馃悰 Bug

@SeanNaren
When I use bf-16 and check the dtype of the model, it seems like the model's precision is fp32 (and I do not see the memory gains I expect). On other frameworks that support bf-16 (like fairseq) the model's dtype is torch.bfloat16. Is there a simple example that "proves" that this feature reduces the memory consumption as it should? I suspect that there might be something wrong (but of course, I might be wrong).
Thank you!

To Reproduce

launch any job with precision=bf16 and compare with precision=32.

Expected behavior

This feature should save 30-50% memory but I do not see such gains in lightning.

Environment

  • CUDA:
    • GPU:
      • GeForce RTX 3090
    • available: True
    • version: 11.3
  • Packages:
    • numpy: 1.21.2
    • pyTorch_debug: False
    • pyTorch_version: 1.11.0
    • pytorch-lightning: 1.6.0dev
    • tqdm: 4.62.3
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.12
    • version: Load fix聽#74-Ubuntu SMP Tue Sep 17 17:06:04 UTC 2019

Additional context

BF-16 is a very important feature. It is usually more stable than fp16 and lightning should support it effectively (models that are pretrained with bf-16 should not be used with fp-16) :)

cc @Borda @tchaton @rohitgr7 @carmocca @justusschock @awaelchli @akihironitta

@yuvalkirstain yuvalkirstain added the needs triage Waiting to be triaged by maintainers label Apr 3, 2022
@SeanNaren
Copy link
Contributor

SeanNaren commented Apr 3, 2022

Thanks @yuvalkirstain!

I investigated this as it definitely seemed strange however I think there may be a misunderstanding of behaviour.

using AMP, the model remains in FP32 but operations that can happen in FP16 are auto-cast. This is the same with BF16 (see this example):

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        assert self.layer.weight.dtype == torch.float32 # the weights are in 32!
        x = self.layer(x) 
        assert x.dtype == torch.bfloat16 # output was bfloat16 as the operation is bf16 compatible
        return x

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = BoringModel()

    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        max_epochs=1,
        enable_model_summary=False,
        precision="bf16",
        gpus=1,
    )

    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

I then checked FairSeq, which converts the entire model into BF16 even in AMP mode: https://github.com/pytorch/fairseq/blob/main/fairseq/trainer.py#L105-L106

This might be why you see the memory improvements. If this is something you'd like to do, you can simply do the same in your __init__.

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2).bfloat16()

Let me know if this helps! We might need to improve our documentation to explain this case.

@yuvalkirstain
Copy link
Author

Thank you so much for the detailed answer! I will try it out and update :)

@SeanNaren
Copy link
Contributor

SeanNaren commented Apr 3, 2022

Nope thank you @yuvalkirstain let me know how it goes! You do also bring up a great point which is if we should support converting the pl_module internally for users. I think it is a potentially good idea, however, I wonder what the API for this would be.

@akihironitta akihironitta added precision: amp Automatic Mixed Precision and removed needs triage Waiting to be triaged by maintainers labels Apr 4, 2022
@carmocca
Copy link
Contributor

carmocca commented Apr 4, 2022

This would be just a flag on the precision plugin (name up for debate)

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin

Trainer(plugins=NativeMixedPrecisionPlugin(convert_modules=True))

Another question is whether this should be enabled by default.

As per the test added by @rohitgr7 in https://github.com/PyTorchLightning/pytorch-lightning/pull/12508/files#diff-3e387bfea0892d3f7033341769075f9e8b17dbe2fbe4a44a68e3d450bdba2e0eR1223, the layers are also moved using DeepSpeed

@carmocca carmocca added the feature Is an improvement or enhancement label Apr 4, 2022
@carmocca carmocca added this to the 1.7 milestone Apr 4, 2022
@yuvalkirstain
Copy link
Author

yuvalkirstain commented Apr 5, 2022

@SeanNaren Yes, doing so results in less memory on the GPU with identical results, thank you!

T5-XL (3B parameters) - inference on SQuAD dataset
8058MiB / 24268MiB (with model = model.bfloat16())
13196MiB / 24268MiB (without)

Regarding converting the pl_module internally for users, definitely, it makes more sense IMO that the trainer will take care of that rather than the model.

@SeanNaren
Copy link
Contributor

@yuvalkirstain I'm glad it worked! Hopefully we'll get the feature in soon :)

@rohitgr7
Copy link
Contributor

rohitgr7 commented Apr 5, 2022

Another question is whether this should be enabled by default.

I think yes because:

  1. less memory consumption
  2. other frameworks like deepspeed does it too

issues:

  1. BC

but if there are no other side effects, I think enabling it by default should be good.

@carmocca carmocca changed the title bf16 does not seems to work as expected Move the module to the precision dtype Apr 5, 2022
@SeanNaren
Copy link
Contributor

Another question is whether this should be enabled by default.

I think yes because:

  1. less memory consumption
  2. other frameworks like deepspeed does it too

issues:

  1. BC

but if there are no other side effects, I think enabling it by default should be good.

Just to be clear, we're talking about just BF16 precision? DeepSpeed's handling of AMP is very different and should be treated as such.

@carmocca
Copy link
Contributor

carmocca commented Apr 7, 2022

For DeepSpeed in particular, we let it make the choice, at the moment that's moving the module too

But I was thinking this should apply to all precision values.

@carmocca carmocca modified the milestones: pl:1.7, pl:future Jul 19, 2022
@carmocca carmocca added the priority: 1 Medium priority task label Jul 19, 2022
@carmocca carmocca removed the priority: 1 Medium priority task label Feb 6, 2023
@carmocca carmocca modified the milestones: future, 2.1 Aug 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement performance precision: amp Automatic Mixed Precision
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants