Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

It doesn't support the latest RTX 40-series card #15

Closed
hxssgaa opened this issue Oct 12, 2022 · 29 comments
Closed

It doesn't support the latest RTX 40-series card #15

hxssgaa opened this issue Oct 12, 2022 · 29 comments

Comments

@hxssgaa
Copy link

hxssgaa commented Oct 12, 2022

Hi, the FP8 should be supported for RTX 40-series as well since it's based on the AD102 architecture which has FP8 capabilities. However running TransformerEngine on the RTX 4090 results in an error: "AssertionError: Device compute capability 9.x required for FP8 execution.". Thus, unable to take advantage of FP8.

@WuNein
Copy link

WuNein commented Oct 21, 2022

my rtx4090 under nvidia docker has the same issue.

AssertionError: Device compute capability 9.x required for FP8 execution.
torch.cuda.current_device()
0

@audreyeternal
Copy link

I am not sure if remove this assertion will work? I think right now only GH100 has 9.x compute capability.

@WuNein
Copy link

WuNein commented Oct 28, 2022

I am not sure if remove this assertion will work? I think right now only GH100 has 9.x compute capability.

It didn't work, I tried.

@edward-io
Copy link

edward-io commented Dec 15, 2022

CUDA 12 released this week, does PyTorch need to be updated to CUDA 12 before this library can work with RTX GPUs?

https://developer.nvidia.com/blog/cuda-toolkit-12-0-released-for-general-availability/

Filed issue: pytorch/pytorch#90988

@edward-io
Copy link

edward-io commented Dec 16, 2022

I managed to workaround PyTorch while building TransformerEngine with CUDA 12 (using cuBLASLt v12.0), but encountered this error:

[2022-12-16 01:00:29][cublasLt][233734][Api][cublasLtMatmulAlgoGetHeuristic] Adesc=[type=R_8F_E4M3 rows=9216 cols=128 ld=9216] Bdesc=[type=R_8F_E4M3 rows=9216 cols=64 ld=9216] Cdesc=[type=R_32F rows=128 cols=64 ld=128] Ddesc=[type=R_32F rows=128 cols=64 ld=128] preference=[maxWavesCount=0.0 maxWorkspaceSizeinBytes=33554432] computeDesc=[computeType=COMPUTE_32F_FAST_TF32 scaleType=R_32F transa=OP_T epilogue=EPILOGUE_BIAS biasPointer=0x7f101aff7200 aScalePointer=0x7f101aed6804 bScalePointer=0x7f101aed6800 fastAccumulationFlag=1 biasDataType=14]
[2022-12-16 01:00:29][cublasLt][233734][Error][cublasLtMatmulAlgoGetHeuristic] Failed to query heuristics.

@ptrblck
Copy link

ptrblck commented Dec 16, 2022

AssertionError: Device compute capability 9.x required for FP8 execution.

As the error explains you would need a GPU with compute capability 9.0 while your 4090 uses compute capability 8.9.

@edward-io
Copy link

edward-io commented Dec 16, 2022

@ptrblck Thanks! Sorry, I'm a noob with CUDA. Why does it require compute capability 9.x? RTX 4090 has fp8 cores:

Ada’s new 4th Generation Tensor Cores are unbelievably fast, with an all new 8-Bit Floating Point (FP8) Tensor Engine, increasing throughput by up to 5X, to 1.32 Tensor-petaFLOPS on the GeForce RTX 4090.

and the CUDA 12.0 announcement says that it supports Lovelace architecture:

CUDA 12.0 exposes programmable functionality for many features of the NVIDIA Hopper and NVIDIA Ada Lovelace architectures: ...32x Ultra xMMA (including FP8 and FP16)

Might be out of scope of this repo, but what's the intended way to invoke FP8 ops if not using the same cuBLASLt functions that is done in this repo?

@ptrendx
Copy link
Member

ptrendx commented Jan 4, 2023

Hi All,

First of all, I'm really sorry for the prolonged silence on this issue - I did not want to communicate anything before getting a full alignment internally.
As noted in the RTX 4090 announcement and Ada whitepaper, Ada has FP8 TensorCore hardware. However, the software support for them is not currently available - e.g. there is no support for it exposed in cuBLASLt currently. The reason for it is that both the FP8 TC instruction as well as other features used in the fast FP8 GEMM kernels are different between Hopper and Ada (meaning a different set of kernels required for both architectures) and the Hopper support was prioritized. Once the FP8 support lands in CUDA and its libraries (tentatively scheduled for CUDA 12.1 in Q2), Transformer Engine will also fully support Ada.

@oscarbg
Copy link

oscarbg commented Mar 1, 2023

@ptrendx with cuda 12.1 now released, Ada FP8 support is there? I mean seeing is q1 and you mention q2 estimated support for Ada FP8, has been delahed to cuda 12.2?

@ptrendx
Copy link
Member

ptrendx commented Mar 9, 2023

Hi @oscarbg. In order to support FP8 on Ada 2 things need to happen:

  • CUDA compiler and PTX for Ada needs to understand the casting instructions to and from FP8 -> this is done and if you look at the 12.1 toolkit, inside cuda_fp8.hpp you will see hardware acceleration for casts in Ada
  • cuBLAS needs to provide FP8 GEMMs on Ada -> this work is currently in progress and we are still targeting the original timeline.

@vgoklani
Copy link

vgoklani commented Mar 9, 2023

Thanks @ptrendx quick question, will Ada FP8 support include both E4M3 and E5M2 ?

@ptrendx
Copy link
Member

ptrendx commented Mar 9, 2023

Yes, it supports both types (including mixing them, as in performing matrix multiply where 1 input is e.g. E4M3 and the other is E5M2), same as Hopper.

@vgoklani
Copy link

vgoklani commented Mar 9, 2023

@ptrendx thank you!

@ptrendx
Copy link
Member

ptrendx commented Apr 20, 2023

Today a CUDA Toolkit 12.1 update 1 was released. It contains cuBLAS 12.1.3.1 enabling FP8 kernels for Ada. With this version of cuBLAS together with Transformer Engine 0.7 which added Ada to the compilation targets, FP8 computation is now supported on Ada.

Let us know if you encounter any issues with it.

@hxssgaa
Copy link
Author

hxssgaa commented Apr 20, 2023

When I try to install TransformerEngine again with a RTX 4090 card, I got the following error:

RuntimeError:
      The detected CUDA version (12.1) mismatches the version that was used to compile
      PyTorch (11.7). Please make sure to use the same CUDA versions.
      
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for flash-attn

Does the transformer engine rely on flash-attention? I checked from here, saying the support for flash-attn with FP8 is not ready yet.

Updates:

Hi, I have tested by installing latest Pytorch from master branch and installed both Flash-attn & transformer engine successfully.

I have also ran some basic test as mentioned in the home page of TransformerEngine and proved that it really works! I'll test more cases, if no problem, I will close the issue.

Thank you very much!

@hxssgaa
Copy link
Author

hxssgaa commented Apr 20, 2023

I ran some simple benchmark using basic MNIST example with optional FP8:

I found out with and without flag use-fp8, the training & inference time remains the same with a RTX 4090 card, I tried to increase the batch size to from 64 to 2048 to exhaust GPU compute resource, however the result remains the same, is it normal?

@overvalidated
Copy link

overvalidated commented Apr 20, 2023

@hxssgaa Did you built torch with CUDA 12.1 without issues? I've heard that it is still considered in development.

@FenghaoZhu
Copy link

I ran some simple benchmark using basic MNIST example with optional FP8:

I found out with and without flag use-fp8, the training & inference time remains the same with a RTX 4090 card, I tried to increase the batch size to from 64 to 2048 to exhaust GPU compute resource, however the result remains the same, is it normal?

FP8 is for transformer, not for simple task as mnist that does not use transfomer
Only when you train things like BERT will you see big improvement.

@hxssgaa
Copy link
Author

hxssgaa commented Apr 20, 2023

I ran some simple benchmark using basic MNIST example with optional FP8:
I found out with and without flag use-fp8, the training & inference time remains the same with a RTX 4090 card, I tried to increase the batch size to from 64 to 2048 to exhaust GPU compute resource, however the result remains the same, is it normal?

FP8 is for transformer, not for simple task as mnist that does not use transfomer Only when you train things like BERT will you see big improvement.

Yes, I built it successfully from master branch of PyTorch without any issues, although I'm not sure whether there would be performance issues since I haven't properly benchmarked it.

I will also do some benchmark about transformer later.

@hxssgaa
Copy link
Author

hxssgaa commented Apr 20, 2023

I ran some simple benchmark using basic MNIST example with optional FP8:
I found out with and without flag use-fp8, the training & inference time remains the same with a RTX 4090 card, I tried to increase the batch size to from 64 to 2048 to exhaust GPU compute resource, however the result remains the same, is it normal?

FP8 is for transformer, not for simple task as mnist that does not use transfomer Only when you train things like BERT will you see big improvement.

Yes, I built it successfully from master branch of PyTorch without any issues, although I'm not sure whether there would be performance issues since I haven't properly benchmarked it.

I will also do some benchmark about transformer later.

I tried to do some simple benchmark about Transformer layer with and without FP8, the code I use is from here, I did some modification as some helper file is missing.

The code below is without FP8 and with attention layer from transformer engine, it achieves on average of 8.86 ms:

torch.manual_seed(1234)
import time
all_times = []
for _ in range(100):
    s = time.time()
    y = te_transformer(x, attention_mask=None)
    all_times.append(time.time() - s)
print(np.mean(all_times) * 1000)

The code below is with FP8 and with attention layer from transformer engine, it's actually showing slower 13.35 ms:

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
torch.manual_seed(1234)
all_times = []
for _ in range(100):
    s = time.time()
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        y = te_transformer(x, attention_mask=None)
    all_times.append(time.time() - s)

Not sure whether it's because the pytorch is from a master branch rather than a stable version even though I build it successfully. I will check with the issue again once Pytorch has release official CUDA 12.1 support

@ptrendx
Copy link
Member

ptrendx commented Apr 20, 2023

Hi, we are aware of the issue with the quickstart_utils.py not being served by the website properly and working on a fix. You can get it from here: https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart_utils.py

As for the benchmarking questions:

  • the mnist sample is not supposed to be used as a benchmark - it is a tiny network and the performance of it is dominated by the time which is needed to launch the kernels, not the GPU time itself
  • the benchmark script posted by @hxssgaa is flawed as it does not wait for the GPU work to finish and so it also measures just the CPU time needed to schedule the work (which is higher in fp8 due to additional kernels needed for casts for example). The proper benchmark should look like the one defined here: https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart_utils.py#L11 (it implements warmup iterations to measure just the steady state and uses CUDA events to accurately measure GPU time) or something like this:
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
torch.manual_seed(1234)

s = time.time()
for _ in range(100):
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        y = te_transformer(x, attention_mask=None)
torch.cuda.synchronize()
e = time.time()
mean_time = (e - s)/100

Please note that CUDA synchronization takes some small time so timing by events is preferable for the most accurate measurement, but with 100 iterations the effect of that should be really small.

@hxssgaa
Copy link
Author

hxssgaa commented Apr 20, 2023

Hi, we are aware of the issue with the quickstart_utils.py not being served by the website properly and working on a fix. You can get it from here: https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart_utils.py

As for the benchmarking questions:

  • the mnist sample is not supposed to be used as a benchmark - it is a tiny network and the performance of it is dominated by the time which is needed to launch the kernels, not the GPU time itself
  • the benchmark script posted by @hxssgaa is flawed as it does not wait for the GPU work to finish and so it also measures just the CPU time needed to schedule the work (which is higher in fp8 due to additional kernels needed for casts for example). The proper benchmark should look like the one defined here: https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart_utils.py#L11 (it implements warmup iterations to measure just the steady state and uses CUDA events to accurately measure GPU time) or something like this:
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
torch.manual_seed(1234)

s = time.time()
for _ in range(100):
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        y = te_transformer(x, attention_mask=None)
torch.cuda.synchronize()
e = time.time()
mean_time = (e - s)/100

Please note that CUDA synchronization takes some small time so timing by events is preferable for the most accurate measurement, but with 100 iterations the effect of that should be really small.

Thanks for the detailed clarification, and providing the script. I revised the script according to the provided script and rerun the script again and confirmed enabling FP8 reduces the transformer inference time from 91.50ms to 65.02ms on a RTX 4090, I used the below script:

utils.speedometer(
    fused_te_transformer,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
    fp8_autocast_kwargs = { "enabled": True, "fp8_recipe": fp8_recipe },
)

Note that I still encounter an error when using flash attention as mentioned here, but I think it's a Pytorch problem.

With this experiment, I can confirm the FP8 support on AD102 is indeed working, I will close the issue. Thanks for the team!

@hxssgaa hxssgaa closed this as completed Apr 20, 2023
@nbroad1881
Copy link

@ptrendx, do you know if this updated cublas will be in the pytorch ngc container v23.04? If I try to manually update it, I'm sure I'll mess it up 😅

@ptrendx
Copy link
Member

ptrendx commented Apr 22, 2023

@nbroad1881 Yes, 23.04 container has everything you need :-).

@vince62s
Copy link

in a conda env, I did:
conda install cuda 12.1 update 1
conda install pytorch nightly with cuda 12.1

Then I did the pip install of transformer engine but got this error below.

indeed cublas_v2.h is NOT in the include of the env but I still have this file in /miniconda3/pkgs/libcublas-dev-12.1.3.1-0/include

Am I missing something or do I need to change the cmake because of the conda env ?

 -- Unable to find cublas_v2.h in either "/home/vincent/miniconda3/envs/pytorch1.14/include" or "/home/vincent/miniconda3/math_libs/include"
  -- Found CUDAToolkit: /home/vincent/miniconda3/envs/pytorch1.14/include (found version "12.1.105")
  -- Performing Test CMAKE_HAVE_LIBC_PTHREAD
  -- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
  -- Looking for pthread_create in pthreads
  -- Looking for pthread_create in pthreads - not found
  -- Looking for pthread_create in pthread
  -- Looking for pthread_create in pthread - found
  -- Found Threads: TRUE
  -- cudnn found at /usr/lib/x86_64-linux-gnu/libcudnn.so.
  -- cudnn_adv_infer found at /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.
  -- cudnn_adv_train found at /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.
  -- cudnn_cnn_infer found at /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.
  -- cudnn_cnn_train found at /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.
  -- cudnn_ops_infer found at /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.
  -- cudnn_ops_train found at /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.
  -- Found CUDNN: /usr/include
  -- cuDNN: /usr/lib/x86_64-linux-gnu/libcudnn.so
  -- cuDNN: /usr/include
  -- Found Python: /home/vincent/miniconda3/envs/pytorch1.14/bin/python3.10 (found version "3.10.11") found components: Interpreter Development Development.Module Development.Embed
  -- Configuring done
  CMake Error at common/CMakeLists.txt:33 (target_link_libraries):
    Target "transformer_engine" links to:
  
      CUDA::cublas
  
    but the target was not found.  Possible reasons include:

@nbroad1881
Copy link

@vince62s, you should use the pytorch ngc container version 23.04

It makes the process much, much easier

@vince62s
Copy link

vince62s commented May 16, 2023

not really helping when you want a repo to rely on this.
@hxssgaa what is the clean path you used to do this ?

EDIT: this seems to work

  1. cuda 12.1.1 with conda install cuda -c nvidia/label/cuda-12.1.1
  2. pytorch nightly 2.1+cu12.1 with conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch-nightly -c nvidia
  3. install flash-attn from source with python setup.py install to avoid the pip install which downloads torch11.7
  4. git submodule init; git submodule update
  5. then from repo TE: pip install .

@AnubhabB
Copy link

AnubhabB commented Jun 9, 2023

Still can't get this to work.

  1. Cuda: 12.1.1 installed with cuda_12.1.1_530.30.02_linux.run
  2. Pytorch nightly + cuda 12.1
  3. Built TransformerEngine with pip install . after updating the submodules

Now when I try to run the following example from the documentation:

from transformer_engine.common.recipe import Format, DelayedScaling
import transformer_engine.pytorch as te
import torch

fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

torch.manual_seed(12345)

my_linear = te.Linear(768, 768, bias=True)

inp = torch.rand((1024, 768)).cuda()

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out_fp8 = my_linear(inp)

Error
/<...>/envs/ai/lib/python3.10/site-packages/transformer_engine/pytorch/fp8.py", line 299, in fp8_autocast assert fp8_available, reason_for_no_fp8 AssertionError: CublasLt version 12.1.3.x or higher required for FP8 execution on Ada.

I'm running this on a RTX 4090.

What am I missing?

@mgrankin
Copy link

mgrankin commented Jun 9, 2023

@AnubhabB
Consider using latest NVidia Pytorch containers. Example above works great on my 4090 with nvcr.io/nvidia/pytorch:23.05-py3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests