Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for FP8E4M3 and FP8E5M2 dtypes #254

Closed
IvanYashchuk opened this issue Apr 23, 2024 · 2 comments 路 Fixed by #445
Closed

Add support for FP8E4M3 and FP8E5M2 dtypes #254

IvanYashchuk opened this issue Apr 23, 2024 · 2 comments 路 Fixed by #445
Assignees
Labels
amp enhancement New feature or request

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Apr 23, 2024

馃殌 Feature

Currently supported dtypes are listed in this conversion dict:

_thunder_to_torch_dtype_map = {
bool: torch.bool,
int: torch.int32,
float: torch.float32,
complex: torch.complex64,
bool8_: torch.bool,
bool8: torch.bool,
uint8_: torch.uint8,
uint8: torch.uint8,
int8_: torch.int8,
int8: torch.int8,
int16_: torch.int16,
int16: torch.int16,
int32_: torch.int32,
int32: torch.int32,
int64_: torch.int64,
int64: torch.int64,
bfloat16_: torch.bfloat16,
bfloat16: torch.bfloat16,
float16_: torch.float16,
float16: torch.float16,
float32_: torch.float32,
float32: torch.float32,
float64_: torch.float64,
float64: torch.float64,
complex32_: torch.complex32,
complex32: torch.complex32,
complex64_: torch.complex64,
complex64: torch.complex64,
complex128_: torch.complex128,
complex128: torch.complex128,
}

PyTorch has recently added support for FP8E4M3 and FP8E5M2 dtypes and we need to add them to Thunder for native FP8 mixed precision training.

The first step is to be able to compile a function that does only dtype conversion.

cc @crcrpar

@riccardofelluga
Copy link
Collaborator

Sounds good!

How do we want to deal with the variants for each fp8 type present in torch? At the moment torch implements the following fp8 types:

# E4M3
torch.float8_e4m3fn
torch.float8_e4m3fnuz

# E5M2
torch.float8_e5m2
torch.float8_e5m2fnuz

Where the postfix FN stands for "only nan values and no infinite values" and the postfix UZ stands for "no negative zero". 1

Also for traces we use a shorthand for all the types like f32 for float32 however in this case, using the same notation as the one we currently use doesn't give enough information about the type f8 could be any of E4M3 or E5M2 variants.

@IvanYashchuk
Copy link
Collaborator Author

Thunder should recognize all FP8 variants that exist in PyTorch. As for printing them, I suggest using f8 with the full suffix, for example, f8_e5m2fnuz.

@riccardofelluga riccardofelluga linked a pull request May 22, 2024 that will close this issue
4 tasks
@t-vi t-vi closed this as completed in #445 May 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
amp enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants