Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise autocast usage error #93

Merged
merged 4 commits into from
Mar 13, 2023

Conversation

ksivaman
Copy link
Member

Without this change, the following use case errors out with a not very helpful message,

import torch
import transformer_engine.pytorch as te

model = te.Linear(512, 512)
inp = torch.rand((128, 512), device="cuda")
epochs = 5

def train():
    # Incorrect usage: model is being run with fp8 multiple times
    # under same autocast region with amax reduction turned on.
    with te.fp8_autocast(enabled=True):
        for _ in range(epochs):
            activation = model(inp)

train()
train() # Error!

This PR fixes this case and catches the error the first time train() is called in the above script. The "correct" usage is below:

import torch
import transformer_engine.pytorch as te

model = te.Linear(512, 512)
inp = torch.rand((128, 512), device="cuda")
epochs = 5

def train():
    for _ in range(epochs):
        with te.fp8_autocast(enabled=True):
            activation = model(inp)

train()
train()

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from ptrendx March 11, 2023 00:50
@ksivaman
Copy link
Member Author

/te-ci

@ksivaman ksivaman requested a review from timmoon10 March 13, 2023 16:48
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

For my own understanding, we expect this assert error to trigger if we perform multiple forward or backward passes within an FP8, or if we perform a partial forward or backward pass within an FP8 context (e.g. if we've frozen most of the model to finetune a specific section) (edit: partial forward or backward passes should run fine). It should run fine if the number of forward and backward passes don't match.

@ksivaman
Copy link
Member Author

We don't want to run any portion of the model twice under the same autocast call when using amax reduction with FP8 training

@ksivaman ksivaman merged commit 6605597 into NVIDIA:main Mar 13, 2023
nzmora-nvidia pushed a commit to nzmora-nvidia/TransformerEngine that referenced this pull request Mar 16, 2023
* catch incorrect usage of fp8_autocast

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* catch error on first time double execution

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
cyanguwa pushed a commit to cyanguwa/TransformerEngine that referenced this pull request Mar 31, 2023
* catch incorrect usage of fp8_autocast

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* catch error on first time double execution

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Charlene Yang <charleney@nvidia.com>
cyanguwa pushed a commit to cyanguwa/TransformerEngine that referenced this pull request Apr 1, 2023
* catch incorrect usage of fp8_autocast

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* catch error on first time double execution

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Charlene Yang <charleney@nvidia.com>
@erlebach
Copy link

The following code give the error:
AssertionError: Same module is being invoked more than once inside an `fp8_autocast` region when using FP8 with amax reduction. This behavior is currently unsupported. For more details and correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93.

Here is the Python code:

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda:0")
torch.manual_seed(12345)
my_linear = te.Linear(768, 768, bias=True).to(device)

inp = torch.rand((1024, 768)).to(device)
#inp = torch.rand((1024, 768)).cuda()

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out_fp8 = my_linear(inp)

loss_fp8 = out_fp8.mean()
loss_fp8.backward()  # This backward pass uses FP8, since out_fp8 was calculated inside fp8_autocast

out_fp32 = my_linear(inp)

The code runs fine without the last line. Since the last line runs outside the fp8_autocast why would this error occur?

Thanks for any insight. I am running with CUDA 11.8 for this library and CUDA 12.0 on the H100.

@ksivaman
Copy link
Member Author

ksivaman commented May 2, 2023

@erlebach This was a bug that was now been fixed in main (#187).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants