Skip to content

Commit a6258fa

Browse files
committed
add v1 variants of scaled ext and trunc ops
Signed-off-by: Tim Gymnich <tim@gymni.ch>
1 parent 300ce14 commit a6258fa

File tree

4 files changed

+439
-11
lines changed

4 files changed

+439
-11
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def AMDGPU_ExtPackedFp8Op :
115115
def AMDGPU_ScaledExtPackedOp
116116
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
117117
Arguments<(
118-
ins AnyTypeOf<[VectorOfLengthAndType<[2, 3, 4], [F8E5M2, F8E4M3FN]>,
119-
VectorOfLengthAndType<[2, 3, 4, 5, 6, 7, 8],
118+
ins AnyTypeOf<[VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>,
119+
VectorOfLengthAndType<[1, 2, 3, 4, 5, 6, 7, 8],
120120
[F4E2M1FN]>]>:$source,
121121
F32:$scale,
122122
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index)>,
@@ -173,7 +173,7 @@ def AMDGPU_PackedTrunc2xFp8Op :
173173

174174
def AMDGPU_PackedScaledTruncOp
175175
: AMDGPU_Op<"packed_scaled_trunc", [Pure]>,
176-
Arguments<(ins VectorOfLengthAndType<[2], [F32, F16, BF16]>:$source,
176+
Arguments<(ins VectorOfLengthAndType<[1, 2], [F32, F16, BF16]>:$source,
177177
F32:$scale,
178178
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<7>]>:$index,
179179
Optional<AnyTypeOf<
@@ -184,7 +184,7 @@ def AMDGPU_PackedScaledTruncOp
184184
FixedVectorOfLengthAndType<[8], [F4E2M1FN]>]>:$res)> {
185185
let summary = "Round two floats into a packed vector of floats";
186186
let description = [{
187-
Scale and round the inputs `sourceA` and `sourceB` (which is undefined if not
187+
Scale and round the inputs `source` (which is undefined if not
188188
specified) into the low or high word (bottom two or top two) elements
189189
of the returned vector, keeping the other two elements of `existing`
190190
unchanged if present (or undefined if it was not passed in).

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,10 +1270,10 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
12701270
Value source = adaptor.getSource();
12711271
Value scale = adaptor.getScale();
12721272

1273-
VectorType sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1274-
Type sourceElemType = getElementTypeOrSelf(op.getSource());
1275-
VectorType destVecType = dyn_cast<VectorType>(op.getResult().getType());
1276-
Type destElemType = getElementTypeOrSelf(op.getResult());
1273+
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1274+
Type sourceElemType = sourceVecType.getElementType();
1275+
VectorType destVecType = cast<VectorType>(op.getResult().getType());
1276+
Type destElemType = destVecType.getElementType();
12771277

12781278
VectorType packedVecType;
12791279
if (isa<Float8E5M2Type, Float8E4M3FNType>(sourceElemType)) {
@@ -1287,8 +1287,7 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
12871287
}
12881288

12891289
// Extend to a packedVectorType
1290-
if (!sourceVecType ||
1291-
sourceVecType.getNumElements() < packedVecType.getNumElements()) {
1290+
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
12921291
Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
12931292
if (!sourceVecType) {
12941293
longVec = rewriter.create<LLVM::InsertElementOp>(
@@ -1352,7 +1351,8 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
13521351

13531352
Type resultType = op.getResult().getType();
13541353
Type resultElemType = getElementTypeOrSelf(resultType);
1355-
Type sourceElemType = getElementTypeOrSelf(op.getSource());
1354+
VectorType sourceVecType = cast<VectorType>(op.getSource().getType());
1355+
Type sourceElemType = sourceVecType.getElementType();
13561356

13571357
Type intResultType = isa<Float4E2M1FNType>(resultElemType) ? i32 : v2i16;
13581358

@@ -1364,6 +1364,14 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
13641364
else
13651365
existing = rewriter.create<LLVM::ZeroOp>(loc, intResultType);
13661366

1367+
if (sourceVecType.getNumElements() < 2) {
1368+
Value c0 = createI32Constant(rewriter, loc, 0);
1369+
Value elem0 = rewriter.create<LLVM::ExtractElementOp>(loc, source, c0);
1370+
VectorType v2 = VectorType::get(2, sourceElemType);
1371+
source = rewriter.create<LLVM::ZeroOp>(loc, v2);
1372+
source = rewriter.create<LLVM::InsertElementOp>(loc, source, elem0, c0);
1373+
}
1374+
13671375
Value sourceA, sourceB;
13681376
if (sourceElemType.isF32()) {
13691377
Value c0 = createI32Constant(rewriter, loc, 0);

mlir/test/Conversion/AMDGPUToROCDL/packed-ext.mlir

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,120 @@ func.func @scaled_ext_scalar_f4e2m1_bf16(%v: vector<2xf4E2M1FN>, %scale: f32) ->
373373
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<2xf4E2M1FN> to vector<2xbf16>
374374
func.return %ret : vector<2xbf16>
375375
}
376+
377+
// CHECK-LABEL: func.func @scaled_ext_one_f8e4m3_f32
378+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf8E4M3FN> to vector<1xi8>
379+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<4xi8>
380+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
381+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi8>
382+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<4xi8>
383+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<4xi8> to i32
384+
// CHECK: rocdl.cvt.scalef32.pk.f32.fp8 [[BITCAST]][false], %arg1 : vector<2xf32>
385+
func.func @scaled_ext_one_f8e4m3_f32(%v: vector<1xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
386+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf8E4M3FN> to vector<2xf32>
387+
func.return %ret : vector<2xf32>
388+
}
389+
390+
// CHECK-LABEL: func.func @scaled_ext_one_f8e4m3_f16
391+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf8E4M3FN> to vector<1xi8>
392+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<4xi8>
393+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
394+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi8>
395+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<4xi8>
396+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<4xi8> to i32
397+
// CHECK: rocdl.cvt.scalef32.pk.f16.fp8 [[BITCAST]][false], %arg1 : vector<2xf16>
398+
func.func @scaled_ext_one_f8e4m3_f16(%v: vector<1xf8E4M3FN>, %scale: f32) -> vector<2xf16> {
399+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf8E4M3FN> to vector<2xf16>
400+
func.return %ret : vector<2xf16>
401+
}
402+
403+
// CHECK-LABEL: func.func @scaled_ext_one_f8e4m3_bf16
404+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf8E4M3FN> to vector<1xi8>
405+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<4xi8>
406+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
407+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi8>
408+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<4xi8>
409+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<4xi8> to i32
410+
// CHECK: rocdl.cvt.scalef32.pk.bf16.fp8 [[BITCAST]][false], %arg1 : vector<2xbf16>
411+
func.func @scaled_ext_one_f8e4m3_bf16(%v: vector<1xf8E4M3FN>, %scale: f32) -> vector<2xbf16> {
412+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf8E4M3FN> to vector<2xbf16>
413+
func.return %ret : vector<2xbf16>
414+
}
415+
416+
// CHECK-LABEL: func.func @scaled_ext_one_f8e5m2_f32
417+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf8E5M2> to vector<1xi8>
418+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<4xi8>
419+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
420+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi8>
421+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<4xi8>
422+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<4xi8> to i32
423+
// CHECK: rocdl.cvt.scalef32.pk.f32.bf8 [[BITCAST]][false], %arg1 : vector<2xf32>
424+
func.func @scaled_ext_one_f8e5m2_f32(%v: vector<1xf8E5M2>, %scale: f32) -> vector<2xf32> {
425+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf8E5M2> to vector<2xf32>
426+
func.return %ret : vector<2xf32>
427+
}
428+
429+
// CHECK-LABEL: func.func @scaled_ext_one_f8e5m2_f16
430+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf8E5M2> to vector<1xi8>
431+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<4xi8>
432+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
433+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi8>
434+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<4xi8>
435+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<4xi8> to i32
436+
// CHECK: rocdl.cvt.scalef32.pk.f16.bf8 [[BITCAST]][false], %arg1 : vector<2xf16>
437+
func.func @scaled_ext_one_f8e5m2_f16(%v: vector<1xf8E5M2>, %scale: f32) -> vector<2xf16> {
438+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf8E5M2> to vector<2xf16>
439+
func.return %ret : vector<2xf16>
440+
}
441+
442+
// CHECK-LABEL: func.func @scaled_ext_one_f8e5m2_bf16
443+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf8E5M2> to vector<1xi8>
444+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<4xi8>
445+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
446+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi8>
447+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<4xi8>
448+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<4xi8> to i32
449+
// CHECK: rocdl.cvt.scalef32.pk.bf16.bf8 [[BITCAST]][false], %arg1 : vector<2xbf16>
450+
func.func @scaled_ext_one_f8e5m2_bf16(%v: vector<1xf8E5M2>, %scale: f32) -> vector<2xbf16> {
451+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf8E5M2> to vector<2xbf16>
452+
func.return %ret : vector<2xbf16>
453+
}
454+
455+
// CHECK-LABEL: func.func @scaled_ext_one_f4e2m1_f32
456+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf4E2M1FN> to vector<1xi4>
457+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<8xi4>
458+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
459+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi4>
460+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<8xi4>
461+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<8xi4> to i32
462+
// CHECK: rocdl.cvt.scalef32.pk.f32.fp4 [[BITCAST]][0], %arg1 : vector<2xf32>
463+
func.func @scaled_ext_one_f4e2m1_f32(%v: vector<1xf4E2M1FN>, %scale: f32) -> vector<2xf32> {
464+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf4E2M1FN> to vector<2xf32>
465+
func.return %ret : vector<2xf32>
466+
}
467+
468+
// CHECK-LABEL: func.func @scaled_ext_one_f4e2m1_f16
469+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf4E2M1FN> to vector<1xi4>
470+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<8xi4>
471+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
472+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi4>
473+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<8xi4>
474+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<8xi4> to i32
475+
// CHECK: rocdl.cvt.scalef32.pk.f16.fp4 [[BITCAST]][0], %arg1 : vector<2xf16>
476+
func.func @scaled_ext_one_f4e2m1_f16(%v: vector<1xf4E2M1FN>, %scale: f32) -> vector<2xf16> {
477+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf4E2M1FN> to vector<2xf16>
478+
func.return %ret : vector<2xf16>
479+
}
480+
481+
// CHECK-LABEL: func.func @scaled_ext_one_f4e2m1_bf16
482+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast %arg0 : vector<1xf4E2M1FN> to vector<1xi4>
483+
// CHECK-DAG: [[ZERO:%.+]] = llvm.mlir.zero : vector<8xi4>
484+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
485+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<1xi4>
486+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[ZERO]]{{\[}}[[C0]] : i32] : vector<8xi4>
487+
// CHECK: [[BITCAST:%.+]] = llvm.bitcast [[VEC_0]] : vector<8xi4> to i32
488+
// CHECK: rocdl.cvt.scalef32.pk.bf16.fp4 [[BITCAST]][0], %arg1 : vector<2xbf16>
489+
func.func @scaled_ext_one_f4e2m1_bf16(%v: vector<1xf4E2M1FN>, %scale: f32) -> vector<2xbf16> {
490+
%ret = amdgpu.scaled_ext_packed %v[0], %scale : vector<1xf4E2M1FN> to vector<2xbf16>
491+
func.return %ret : vector<2xbf16>
492+
}

0 commit comments

Comments
 (0)