Skip to content

Introduce arith.scaling_extf and arith.scaling_truncf #141965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jun 9, 2025

Conversation

umangyadav
Copy link
Contributor

@umangyadav umangyadav commented May 29, 2025

This PR adds arith.scaling_truncf and arith.scaling_extf operations which supports the block quantization following OCP MXFP specs listed here https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

OCP MXFP Spec comes with reference implementation here https://github.com/microsoft/microxcaling/tree/main

Interesting piece of reference code is this method _quantize_mx https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173.

Both arith.scaling_truncf and arith.scaling_extf are designed to be an elementwise operation. Please see description about them in ArithOps.td file for more details.

Internally,

arith.scaling_truncf does the arith.truncf(arith.divf(input/(2^scale))). scale should have necessary broadcast, clamping, normalization and NaN propagation done before callling into arith.scaling_truncf.

arith.scaling_extf does the arith.mulf(2^scale, input) after taking care of necessary data type conversions.

CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar @tgymnich

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some notes

return rewriter.notifyMatchFailure(
op, "scaling truncf is not using scale operand of type f8E8M0FNU");
}
auto scaleTy = scaleOperand.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type

} else if (inputETy.getIntOrFloatBitWidth() > 32) {
inputOperand = b.create<arith::TruncFOp>(f32Ty, inputOperand);
}
inputTy = inputOperand.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could update these to f32Type in the if statements above, but it doesn't matter

Value c127 = createConst(op->getLoc(), i32Ty, 127, rewriter);
Value cNeg127 = createConst(op->getLoc(), i32Ty, -127, rewriter);
Value scaleI8 = b.create<arith::BitcastOp>(i8Ty, scaleOperand);
Value scaleI32 = b.create<arith::ExtSIOp>(i32Ty, scaleI8);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an extui. But also, there's no need to go i32 here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first need to calculate unbiased scale value. I can do that while being in i8.

But then i also need to subtract emax (max exponent of largest normal number in resultant quantized dtype).
That subtraction could underflow or overflow and that needs to be checked and clamped later on. Therefore i require i32

Copy link
Contributor Author

@umangyadav umangyadav May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be an extui.

Thanks. Good catch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so, my bigger complaint is that you can simplify the generated code substantially if you just switch on what kind of type you're extending to

That is, f32 requires nothing - that's already a +- 127 situation

Types shorter than f32 will need the subtraction.

... Also, I'm doing to re-read the code but I'm not convinced this should be subtracting the max normalized exponent. Are we sure it isn't "clamp to the exponent range of the type"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Ah, we're subtracting the max exponent of the result type

Which can't lead to overflow

This could be substantially simplified if we just use usub_sat (which we'd need a MLIR Arith op for but that's fairly trivial)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... But also, the code you linked is for quantization

I think it's reasonable to assume that someone implementing quantization will already have done the scale-biasing thing and so we don't need to do it here

Unless we have evidence that the hardware implementations perform the subtraction described here? (We'll probably want to go find the AMD behavior)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and if you're doing usub_sat, you don't need to unbias the exponent.

But also, I'd make sure this is something that other implementors of scaling_truncf implement so we don't get conflicting lowerings

const llvm::fltSemantics &resultFltSemantics =
llvm::cast<FloatType>(resultETy).getFloatSemantics();
int maxExponent = APFloat::semanticsMaxExponent(resultFltSemantics);
Value cMaxNormalExponent =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skip all this if we're in f32 or higher?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrote using f32.

Value cmpCond = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, cI8Zero,
inputExponentU8);
Value inputTyZero = createFloatConst(op.getLoc(), inputTy, 0, rewriter);
Value flushedInput =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all seems overcomplicated?

This could just be extending the scale to f32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrote using f32. It does simplify things a bit. Thanks

Let's say originally input is shape <dim1 x dim2 x dim3 .. x dimN> then, given blockSize it can be reshaped to <dim1 x dim2 x ... (dimN/blockSize) x blockSize>.
Scales will be calculated on the block axis. Therefore scale will be of shape <dim1 x dim2 x dim3 ... (dimN/blockSize) x 1>.
Before calling into `arith.scaling_extf`, scales must be broadcasted appropariately to make it as same shape as input making `arith.scaling_extf` an elemenwise op.
In above example. scales should be broadcasted to shape of <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand from the description, it doesn't need to be broadcasted, you could use a non-broadcasted tensor of shape <dim1 x dim2 x dim3 x ... (dimN/blockSize) x blockSize>?

If that's the case, I don't think it's useful to explain all of these details, broadcasting is just a use-case. If I understood it correctly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the description needs to be updated - this arith op is set up to do things elementwise because arith ops in general are elementwise and the broadcast scale thing is a special case that gets pattern-matched in a future ArithToAMDGPU

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to rewrite documentation. Please check again and let me know if it is more clear now.

op, "scaling extf is not using scale operand of type f8E8M0FNU");
}
Type resultTy = op.getType();
// extf on scale will essentially create f32 number that is 2^scale and will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why f32? can't resultTy be any float type?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check if resultTy >= Float8E8M0FNU and >= inputType

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, scaled truncation from f32 to f32 is a really weird way to spell division,b ut we might want to verify it away

Copy link
Contributor Author

@umangyadav umangyadav May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why f32? can't resultTy be any float type

Changed comment to better reflect what it's doing.

should we check if resultTy >= Float8E8M0FNU and >= inputType

As part of verification, it checks that output dtype is of larger widhth compared to input.
https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1460

Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
// propagate rounding mode and fast math attributes
Value resultCast = b.create<arith::TruncFOp>(
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check resultTy <= f32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check resultTy <= f32?

Verify() checks that output width is smaller compared to input.

https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1587

there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter

No, other arith.truncf are mainly for scales dtype conversion which just operates on exponent and not really affected by rounding mode or fast math.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, verify checks that output width is smaller than input width. But I understand the output of this function is always f32. Then, I wonder if somebody can do input, scale -> f128, result -> f64. Then, it's true that output width < input width and we are still trying to truncate "result" which is f32 into f64. Not sure if I misunderstood something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
Then, Verify() checks is a strict check. Therefore output_bit_width < input_bit_width.
So this would never really be truncating to f32 resultTy in practice.

But I understand the output of this function is always f32

No, why do you think so ? Output dtype will be whatever user has specified.

Copy link
Contributor

@dhernandez0 dhernandez0 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, why do you think so ? Output dtype will be whatever user has specified.

I mean result of the function before truncation. result.dtype = f32, right?

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.

I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide.

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor nits! LGTM. I'll wait for @krzysz00.

PatternRewriter &rewriter) {
auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value);
if (auto shapedTy = dyn_cast<ShapedType>(type)) {
return rewriter.create<arith::ConstantOp>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update the attr here: attr = DenseElementsAttr::get(shapedTy, attr). It will return the right thing. (Both are fine to me).

// emax is calculated as exponent of the largest normal value in quantized type.
scale.normalize = arith.divf(scale.extf, emax)
scale.clamped = clamp(scale.normalize) // clamp underflows
input.flused = flush_denorms(input)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are some type conversions for input and scale that are not explained here. Not sure if we want all those details here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, That would be more details than necessary.

Value result = b.create<arith::DivFOp>(flushedInput, scaleF32);
// propagate rounding mode and fast math attributes
Value resultCast = b.create<arith::TruncFOp>(
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
Copy link
Contributor

@dhernandez0 dhernandez0 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, why do you think so ? Output dtype will be whatever user has specified.

I mean result of the function before truncation. result.dtype = f32, right?

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.

I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide.

Copy link
Contributor

@dhernandez0 dhernandez0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@krzysz00
Copy link
Contributor

krzysz00 commented Jun 3, 2025

Note: I'd rather we not land this just yet because I'm still waiting to find out if potential hardware-specific lowerings of arith.scaling_truncf will perform the exponent subtraction that this code does.

I have a suspicion that the answer is "no" - that that adjustment is part of the scale computation process, not the scale application process, and so the semantics of scaling_truncf shouldn't include it.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold for semantics questions, and @llvm/pr-subscribers-mlir-nvgpu for input on Nvidia semantics while I wait on AMD answers

…rm flushign on input should be carried out using specified fastMath flag. Scales are assumed to be normalized and clamped.
@umangyadav
Copy link
Contributor Author

Hold for semantics questions, and @llvm/pr-subscribers-mlir-nvgpu for input on Nvidia semantics while I wait on AMD answers

@krzysz00 @dhernandez0 @tgymnich @pashu123

Based on feedback i've changed semantics of arith.scaling_truncf to just to simply do following arith.truncf(arith.divf(input, 2^scale))

Can you do re-review please ?

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor notes, I think this is the correct operation semantics

let summary =
"Downcasts input floating point values using provided scales values following OCP MXFP Spec";
let description = [{
This operation quantizes input using the provided scale values. It expects
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove "quantizes" here, since this is just a truncation where you divide by the scale

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with downcast word

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@krzysz00 krzysz00 merged commit 7f08503 into llvm:main Jun 9, 2025
7 checks passed
rorth pushed a commit to rorth/llvm-project that referenced this pull request Jun 11, 2025
This PR adds `arith.scaling_truncf` and `arith.scaling_extf` operations
which supports the block quantization following OCP MXFP specs listed
here
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

OCP MXFP Spec comes with reference implementation here
https://github.com/microsoft/microxcaling/tree/main

Interesting piece of reference code is this method `_quantize_mx`
https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173.

Both `arith.scaling_truncf` and `arith.scaling_extf` are designed to be
an elementwise operation. Please see description about them in
`ArithOps.td` file for more details.
 
Internally, 

`arith.scaling_truncf` does the
`arith.truncf(arith.divf(input/(2^scale)))`. `scale` should have
necessary broadcast, clamping, normalization and NaN propagation done
before callling into `arith.scaling_truncf`.

`arith.scaling_extf` does the `arith.mulf(2^scale, input)` after taking
care of necessary data type conversions.


CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar
@tgymnich

---------

Co-authored-by: Prashant Kumar <pk5561@gmail.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
DhruvSrivastavaX pushed a commit to DhruvSrivastavaX/lldb-for-aix that referenced this pull request Jun 12, 2025
This PR adds `arith.scaling_truncf` and `arith.scaling_extf` operations
which supports the block quantization following OCP MXFP specs listed
here
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

OCP MXFP Spec comes with reference implementation here
https://github.com/microsoft/microxcaling/tree/main

Interesting piece of reference code is this method `_quantize_mx`
https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173.

Both `arith.scaling_truncf` and `arith.scaling_extf` are designed to be
an elementwise operation. Please see description about them in
`ArithOps.td` file for more details.
 
Internally, 

`arith.scaling_truncf` does the
`arith.truncf(arith.divf(input/(2^scale)))`. `scale` should have
necessary broadcast, clamping, normalization and NaN propagation done
before callling into `arith.scaling_truncf`.

`arith.scaling_extf` does the `arith.mulf(2^scale, input)` after taking
care of necessary data type conversions.


CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar
@tgymnich

---------

Co-authored-by: Prashant Kumar <pk5561@gmail.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
This PR adds `arith.scaling_truncf` and `arith.scaling_extf` operations
which supports the block quantization following OCP MXFP specs listed
here
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

OCP MXFP Spec comes with reference implementation here
https://github.com/microsoft/microxcaling/tree/main

Interesting piece of reference code is this method `_quantize_mx`
https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173.

Both `arith.scaling_truncf` and `arith.scaling_extf` are designed to be
an elementwise operation. Please see description about them in
`ArithOps.td` file for more details.
 
Internally, 

`arith.scaling_truncf` does the
`arith.truncf(arith.divf(input/(2^scale)))`. `scale` should have
necessary broadcast, clamping, normalization and NaN propagation done
before callling into `arith.scaling_truncf`.

`arith.scaling_extf` does the `arith.mulf(2^scale, input)` after taking
care of necessary data type conversions.


CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar
@tgymnich

---------

Co-authored-by: Prashant Kumar <pk5561@gmail.com>
Co-authored-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants