Skip to content

Commit 7522097

Browse files
[CoopVec] Add Linear Algebra common header with tests (#7350) (#7388)
This PR introduces the linear algebra header file, and places it in a location that is by default included in all HLSL compilation. The builtins in the API aren't yet defined, and depend on the #7290 PR merging first. The tests that have been added have temporary diagnostic messages while 7290 is in progress. They will need to be updated. Open to feedback on better / suggested error messages, or whether there shouldn't be any sema-level validation for these errors. Fixes [#7304](#7304) Cherrypick of #7350 Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent e866b4b commit 7522097

12 files changed

+516
-0
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
// Header for linear algebra APIs.
2+
3+
#if __spirv__
4+
#error "Cooperative vectors not (yet) supported for SPIRV"
5+
#endif
6+
7+
#if ((__SHADER_TARGET_MAJOR > 6) || \
8+
(__SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR >= 9)) && \
9+
(__HLSL_VERSION >= 2021)
10+
11+
namespace dx {
12+
namespace linalg {
13+
14+
// NOTE: can't be an enum class because we get this error:
15+
// error: non-type template argument of type 'dx::linalg::DataType' is not
16+
// an integral constant expression
17+
//
18+
enum DataType {
19+
DATA_TYPE_SINT16 = 2, // ComponentType::I16
20+
DATA_TYPE_UINT16 = 3, // ComponentType::U16
21+
DATA_TYPE_SINT32 = 4, // ComponentType::I32
22+
DATA_TYPE_UINT32 = 5, // ComponentType::U32
23+
DATA_TYPE_FLOAT16 = 8, // ComponentType::F16
24+
DATA_TYPE_FLOAT32 = 9, // ComponentType::F32
25+
DATA_TYPE_SINT8_T4_PACKED = 17, // ComponentType::PackedS8x32
26+
DATA_TYPE_UINT8_T4_PACKED = 18, // ComponentType::PackedU8x32
27+
DATA_TYPE_UINT8 = 19, // ComponentType::U8
28+
DATA_TYPE_SINT8 = 20, // ComponentType::I8
29+
DATA_TYPE_FLOAT8_E4M3 = 21, // ComponentType::F8_E4M3
30+
// (1 sign, 4 exp, 3 mantissa bits)
31+
DATA_TYPE_FLOAT8_E5M2 = 22, // ComponentType::F8_E5M2
32+
// (1 sign, 5 exp, 2 mantissa bits)
33+
};
34+
35+
enum MatrixLayout {
36+
MATRIX_LAYOUT_ROW_MAJOR = 0,
37+
MATRIX_LAYOUT_COLUMN_MAJOR = 1,
38+
MATRIX_LAYOUT_MUL_OPTIMAL = 2,
39+
MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL = 3
40+
};
41+
42+
//
43+
// Helper for signedness
44+
//
45+
namespace details {
46+
template <typename T> bool IsUnsigned() { return false; }
47+
48+
#ifdef __HLSL_ENABLE_16_BIT
49+
template <> bool IsUnsigned<uint16_t>() { return true; }
50+
#endif
51+
52+
template <> bool IsUnsigned<uint32_t>() { return true; }
53+
template <> bool IsUnsigned<uint64_t>() { return true; }
54+
} // namespace details
55+
56+
//
57+
// (RW)MatrixRef
58+
//
59+
60+
template <typename BufferTy, DataType DT, uint M, uint K, MatrixLayout ML,
61+
bool Transpose>
62+
struct MatrixRefImpl {
63+
BufferTy Buffer;
64+
uint StartOffset;
65+
uint Stride;
66+
};
67+
68+
template <DataType DT, uint M, uint K, MatrixLayout ML, bool Transpose = false>
69+
using MatrixRef = MatrixRefImpl<ByteAddressBuffer, DT, M, K, ML, Transpose>;
70+
71+
template <DataType DT, uint M, uint K, MatrixLayout ML, bool Transpose = false>
72+
using RWMatrixRef = MatrixRefImpl<RWByteAddressBuffer, DT, M, K, ML, Transpose>;
73+
74+
//
75+
// (RW)VectorRef
76+
//
77+
78+
template <typename BufferTy, DataType DT> struct VectorRefImpl {
79+
BufferTy Buffer;
80+
uint StartOffset;
81+
};
82+
83+
template <DataType DT> using VectorRef = VectorRefImpl<ByteAddressBuffer, DT>;
84+
85+
template <DataType DT>
86+
using RWVectorRef = VectorRefImpl<RWByteAddressBuffer, DT>;
87+
88+
//
89+
// Vector
90+
//
91+
92+
template <typename T, int N, DataType DT> struct InterpretedVector {
93+
vector<T, N> Data;
94+
};
95+
96+
template <DataType DT, typename T, int N>
97+
InterpretedVector<T, N, DT> MakeInterpretedVector(vector<T, N> Vec) {
98+
InterpretedVector<T, N, DT> IV = {Vec};
99+
return IV;
100+
}
101+
102+
//
103+
// Mul
104+
//
105+
106+
template <typename OutputElTy, typename InputElTy, int InputElCount,
107+
typename MatrixBufferTy, DataType InputDT, DataType MatrixDT,
108+
uint MatrixM, uint MatrixK, MatrixLayout MatrixLayout,
109+
bool MatrixTranspose>
110+
vector<OutputElTy, MatrixM>
111+
Mul(MatrixRefImpl<MatrixBufferTy, MatrixDT, MatrixM, MatrixK, MatrixLayout,
112+
MatrixTranspose>
113+
Matrix,
114+
InterpretedVector<InputElTy, InputElCount, InputDT> InputVector) {
115+
116+
vector<OutputElTy, MatrixM> OutputVector;
117+
118+
__builtin_MatVecMul(
119+
/*out*/ OutputVector, details::IsUnsigned<OutputElTy>(), InputVector.Data,
120+
details::IsUnsigned<InputElTy>(), InputDT, Matrix.Buffer,
121+
Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, MatrixLayout,
122+
MatrixTranspose, Matrix.Stride);
123+
124+
return OutputVector;
125+
}
126+
127+
//
128+
// MulAdd
129+
//
130+
131+
template <typename OutputElTy, typename InputElTy, int InputElCount,
132+
typename MatrixBufferTy, DataType InputDT, DataType MatrixDT,
133+
uint MatrixM, uint MatrixK, MatrixLayout MatrixLayout,
134+
bool MatrixTranspose, typename BiasVectorBufferTy,
135+
DataType BiasVectorDT>
136+
vector<OutputElTy, MatrixM>
137+
MulAdd(MatrixRefImpl<MatrixBufferTy, MatrixDT, MatrixM, MatrixK, MatrixLayout,
138+
MatrixTranspose>
139+
Matrix,
140+
InterpretedVector<InputElTy, InputElCount, InputDT> InputVector,
141+
VectorRefImpl<BiasVectorBufferTy, BiasVectorDT> BiasVector) {
142+
143+
vector<OutputElTy, MatrixM> OutputVector;
144+
145+
__builtin_MatVecMulAdd(
146+
/*out*/ OutputVector, details::IsUnsigned<OutputElTy>(), InputVector.Data,
147+
details::IsUnsigned<InputElTy>(), InputDT, Matrix.Buffer,
148+
Matrix.StartOffset, MatrixDT, MatrixM, MatrixK, MatrixLayout,
149+
MatrixTranspose, Matrix.Stride, BiasVector.Buffer, BiasVector.StartOffset,
150+
BiasVectorDT);
151+
152+
return OutputVector;
153+
}
154+
155+
//
156+
// OuterProductAccumulate
157+
//
158+
159+
template <typename ElTy, int MatrixM, int MatrixN, DataType MatrixDT,
160+
MatrixLayout MatrixLayout>
161+
void OuterProductAccumulate(
162+
vector<ElTy, MatrixM> InputVector1, vector<ElTy, MatrixN> InputVector2,
163+
RWMatrixRef<MatrixDT, MatrixM, MatrixN, MatrixLayout, false> Matrix) {
164+
__builtin_OuterProductAccumulate(InputVector1, InputVector2, Matrix.Buffer,
165+
Matrix.StartOffset, MatrixDT, MatrixLayout,
166+
Matrix.Stride);
167+
}
168+
169+
//
170+
// VectorAccumulate
171+
//
172+
173+
template <typename ElTy, int ElCount>
174+
void VectorAccumulate(vector<ElTy, ElCount> InputVector,
175+
RWByteAddressBuffer Buffer, uint Offset) {
176+
__builtin_VectorAccumulate(InputVector, Buffer, Offset);
177+
}
178+
179+
} // namespace linalg
180+
} // namespace dx
181+
182+
#endif // SM 6.9 check and HV version check
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s | FileCheck %s
2+
3+
#include <dx/linalg.h>
4+
5+
ByteAddressBuffer Buf;
6+
7+
export float4 Test1(vector<float, 4> Input) {
8+
using namespace dx::linalg;
9+
10+
MatrixRef<DATA_TYPE_FLOAT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> Matrix = {
11+
Buf, 0, 0};
12+
13+
// CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMul.v4f32.v4f32(i32 305, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle %{{.+}}, i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, i1 false)
14+
return Mul<float>(
15+
Matrix, MakeInterpretedVector<DATA_TYPE_FLOAT16>(Input));
16+
}
17+
18+
export vector<float, 8> Test2(vector<uint8_t4_packed, 6> Input) {
19+
using namespace dx::linalg;
20+
21+
MatrixRef<DATA_TYPE_UINT8, 8, 6 * 4, MATRIX_LAYOUT_MUL_OPTIMAL> Matrix = {
22+
Buf, 0, 0};
23+
24+
// note the stride argument is dropped.
25+
// CHECK: %{{.+}} = call <8 x float> @dx.op.matVecMul.v8f32.v6f32(i32 305, <6 x float> %{{.+}}, i1 false, i32 18, %dx.types.Handle %{{.+}}, i32 0, i32 19, i32 8, i32 24, i32 2, i1 false, i32 0, i1 false)
26+
return Mul<float>(Matrix,
27+
MakeInterpretedVector<DATA_TYPE_UINT8_T4_PACKED>(Input));
28+
}
29+
30+
// test that "stride" isn't ignored in non-optimal layouts
31+
export vector<float, 8> Test3(vector<uint8_t4_packed, 6> Input) {
32+
using namespace dx::linalg;
33+
34+
MatrixRef<DATA_TYPE_UINT8, 8, 6 * 4, MATRIX_LAYOUT_ROW_MAJOR> Matrix = {
35+
Buf, 0, 6 * 4 * 8};
36+
37+
// CHECK: %{{.+}} = call <8 x float> @dx.op.matVecMul.v8f32.v6f32(i32 305, <6 x float> %{{.+}}, i1 false, i32 18, %dx.types.Handle %{{.+}}, i32 0, i32 19, i32 8, i32 24, i32 0, i1 false, i32 192, i1 false)
38+
return Mul<float>(Matrix,
39+
MakeInterpretedVector<DATA_TYPE_UINT8_T4_PACKED>(Input));
40+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s
2+
3+
#include <dx/linalg.h>
4+
5+
ByteAddressBuffer Buf;
6+
7+
export float4 Test1(float4 input) {
8+
using namespace dx::linalg;
9+
10+
MatrixRef<DATA_TYPE_FLOAT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL> matrix = {Buf,
11+
0, 0};
12+
VectorRef<DATA_TYPE_FLOAT16> biasVector = {Buf, 256};
13+
14+
InterpretedVector<float, 4, DATA_TYPE_FLOAT16> theVector = {input};
15+
16+
// CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 false, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false)
17+
return MulAdd<float>(
18+
matrix, theVector,
19+
biasVector);
20+
}
21+
22+
export float4 Test2(float4 input) {
23+
using namespace dx::linalg;
24+
25+
MatrixRef<DATA_TYPE_FLOAT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> matrix = {
26+
Buf, 0, 0};
27+
VectorRef<DATA_TYPE_FLOAT16> biasVector = {Buf, 256};
28+
29+
InterpretedVector<float, 4, DATA_TYPE_FLOAT16> theVector = {input};
30+
31+
// CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false)
32+
return MulAdd<float>(
33+
matrix, theVector,
34+
biasVector);
35+
}
36+
37+
export float4 Test3(float4 input) {
38+
using namespace dx::linalg;
39+
40+
MatrixRef<DATA_TYPE_FLOAT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> matrix = {
41+
Buf, 0, 0};
42+
VectorRef<DATA_TYPE_FLOAT16> biasVector = {Buf, 256};
43+
44+
// CHECK: %{{.+}} = call <4 x float> @dx.op.matVecMulAdd.v4f32.v4f32(i32 306, <4 x float> %{{.+}}, i1 false, i32 8, %dx.types.Handle [[RES:%.+]], i32 0, i32 8, i32 4, i32 4, i32 2, i1 true, i32 0, %dx.types.Handle [[RES]], i32 256, i32 8, i1 false)
45+
return MulAdd<float>(
46+
matrix, MakeInterpretedVector<DATA_TYPE_FLOAT16>(input),
47+
biasVector);
48+
}
49+
50+
namespace ProposalExample {
51+
52+
ByteAddressBuffer model;
53+
54+
vector<float, 3> ApplyNeuralMaterial(vector<half, 8> inputVector) {
55+
using namespace dx::linalg;
56+
57+
MatrixRef<DATA_TYPE_FLOAT8_E4M3, 32, 8, MATRIX_LAYOUT_MUL_OPTIMAL> matrix0 = {
58+
model, 0, 0};
59+
60+
VectorRef<DATA_TYPE_FLOAT16> biasVector0 = {model, 1024};
61+
62+
MatrixRef<DATA_TYPE_FLOAT8_E4M3, 32, 32, MATRIX_LAYOUT_MUL_OPTIMAL> matrix1 =
63+
{model, 2048, 0};
64+
65+
VectorRef<DATA_TYPE_FLOAT16> biasVector1 = {model, 3072};
66+
67+
MatrixRef<DATA_TYPE_FLOAT8_E4M3, 3, 32, MATRIX_LAYOUT_MUL_OPTIMAL> matrix2 = {
68+
model, 4096, 0};
69+
70+
VectorRef<DATA_TYPE_FLOAT16> biasVector2 = {model, 5120};
71+
72+
vector<half, 32> layer0 = MulAdd<half>(
73+
matrix0, MakeInterpretedVector<DATA_TYPE_FLOAT8_E4M3>(inputVector),
74+
biasVector0);
75+
layer0 = max(layer0, 0);
76+
77+
vector<half, 32> layer1 = MulAdd<half>(
78+
matrix1, MakeInterpretedVector<DATA_TYPE_FLOAT8_E4M3>(layer0),
79+
biasVector1);
80+
layer1 = max(layer1, 0);
81+
82+
vector<float, 3> output = MulAdd<float>(
83+
matrix2, MakeInterpretedVector<DATA_TYPE_FLOAT8_E4M3>(layer1),
84+
biasVector2);
85+
output = exp(output);
86+
87+
return output;
88+
}
89+
90+
} // namespace ProposalExample
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: %dxc -I %hlsl_headers -T lib_6_9 -enable-16bit-types %s | FileCheck %s
2+
3+
#include <dx/linalg.h>
4+
5+
RWByteAddressBuffer RWBuf;
6+
7+
export void Test4(vector<half, 128> Input1, vector<half, 64> Input2) {
8+
using namespace dx::linalg;
9+
10+
RWMatrixRef<DATA_TYPE_FLOAT16, 128, 64, MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL>
11+
matrix = {RWBuf, 0, 0};
12+
13+
// CHECK: call void @dx.op.outerProductAccumulate.v128f16.v64f16(i32 307, <128 x half> %{{.+}}, <64 x half> %{{.+}}, %dx.types.Handle %{{.+}}, i32 0, i32 8, i32 3, i32 0)
14+
15+
OuterProductAccumulate(Input1, Input2, matrix);
16+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s | FileCheck %s
2+
3+
#include <dx/linalg.h>
4+
5+
RWByteAddressBuffer RWBuf;
6+
7+
export void Test5(vector<half, 128> Input) {
8+
using namespace dx::linalg;
9+
10+
RWBuf.Store<vector<half, 128> >(0, Input);
11+
12+
// CHECK: call void @dx.op.vectorAccumulate.v128f32(i32 308, <128 x float> %{{.*}}, %dx.types.Handle %{{.*}}, i32 0)
13+
VectorAccumulate(Input, RWBuf, 0);
14+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify
2+
3+
#include <dx/linalg.h>
4+
ByteAddressBuffer Buf;
5+
6+
export float4 Test1(vector<float, 4> Input) {
7+
using namespace dx::linalg;
8+
9+
MatrixRef<DATA_TYPE_UINT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> Matrix = {
10+
Buf, 0, 0};
11+
12+
// expected-error@+3{{no matching function for call to 'MakeInterpretedVector'}}
13+
// expected-note@dx/linalg.h:97{{candidate template ignored: invalid explicitly-specified argument for template parameter 'DT'}}
14+
return Mul<float>(
15+
Matrix, MakeInterpretedVector<2>(Input));
16+
}
17+
18+
enum DataType {
19+
DATA_TYPE_InvalidType = 40
20+
};
21+
22+
export float4 Test2(vector<float, 4> Input) {
23+
using namespace dx::linalg;
24+
25+
MatrixRef<DATA_TYPE_UINT16, 4, 4, MATRIX_LAYOUT_MUL_OPTIMAL, true> Matrix = {
26+
Buf, 0, 0};
27+
28+
// expected-error@+3{{no matching function for call to 'MakeInterpretedVector'}}
29+
// expected-note@dx/linalg.h:97{{candidate template ignored: invalid explicitly-specified argument for template parameter 'DT'}}
30+
return Mul<float>(
31+
Matrix, MakeInterpretedVector<DATA_TYPE_InvalidType>(Input));
32+
}
33+
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: %dxc -I %hlsl_headers -T lib_6_9 %s -verify
2+
3+
#include <dx/linalg.h>
4+
5+
ByteAddressBuffer Buf;
6+
7+
vector<float, 128> MixUpVectorAndMatrixArguments(vector<float, 128> Input) {
8+
using namespace dx::linalg;
9+
10+
MatrixRef<DATA_TYPE_FLOAT16, 128, 128, MATRIX_LAYOUT_MUL_OPTIMAL> Matrix = {
11+
Buf, 0, 0};
12+
13+
// expected-error@+2{{no matching function for call to 'Mul'}}
14+
// expected-note@dx/linalg.h:111{{candidate template ignored: could not match 'MatrixRefImpl' against 'InterpretedVector'}}
15+
return Mul<float>(MakeInterpretedVector<DATA_TYPE_FLOAT16>(Input), Matrix);
16+
}

0 commit comments

Comments
 (0)