Skip to content

Commit

Permalink
[WGSL] Add support for constant matrices
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=263555
rdar://117371722

Reviewed by Mike Wyrzykowski.

Implement ConstantMatrix as one of the variants in ConstantValue and implement
all the constant functions that depend on it.

* Source/WebGPU/WGSL/ConstantFunctions.h:
(WGSL::zeroValue):
(WGSL::constantMatrix):
(WGSL::constantAdd):
(WGSL::constantMultiply):
(WGSL::constantDeterminant):
(WGSL::constantTranspose):
* Source/WebGPU/WGSL/ConstantValue.cpp:
(WGSL::ConstantValue::dump const):
* Source/WebGPU/WGSL/ConstantValue.h:
(WGSL::ConstantMatrix::ConstantMatrix):
(WGSL::ConstantValue::isMatrix const):
(WGSL::ConstantValue::toVector const):
* Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp:
(WGSL::Metal::FunctionDefinitionWriter::serializeConstant):
* Source/WebGPU/WGSL/tests/valid/constant-matrix.wgsl: Added.

Canonical link: https://commits.webkit.org/269741@main
  • Loading branch information
tadeuzagallo committed Oct 24, 2023
1 parent 29dde14 commit d041f3c
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 20 deletions.
175 changes: 160 additions & 15 deletions Source/WebGPU/WGSL/ConstantFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <bit>
#include <numbers>
#include <wtf/Assertions.h>
#include <wtf/DataLog.h>

namespace WGSL {

Expand Down Expand Up @@ -80,8 +81,12 @@ static ConstantValue zeroValue(const Type* type)
// yet have ConstantStruct
RELEASE_ASSERT_NOT_REACHED();
},
[&](const Types::Matrix&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
[&](const Types::Matrix& matrix) -> ConstantValue {
ConstantMatrix result(matrix.columns, matrix.rows);
auto value = zeroValue(matrix.element);
for (unsigned i = 0; i < result.elements.size(); ++i)
result.elements[i] = value;
return result;
},
[&](const Types::Reference&) -> ConstantValue {
RELEASE_ASSERT_NOT_REACHED();
Expand Down Expand Up @@ -259,10 +264,29 @@ static ConstantValue constantMatrix(const Type* resultType, const FixedVector<Co
if (arguments.isEmpty())
return zeroValue(resultType);

// FIXME: we don't support matrices yet
UNUSED_PARAM(columns);
UNUSED_PARAM(rows);
RELEASE_ASSERT_NOT_REACHED();

if (arguments.size() == 1) {
auto& arg = arguments[0];
ASSERT(arg.isMatrix());
// FIXME: we might need to convert the type of the result when we support f16
return arg;
}

if (arguments.size() == columns * rows)
return ConstantMatrix { columns, rows, arguments };

RELEASE_ASSERT(arguments.size() == columns);
ConstantMatrix result(columns, rows);
unsigned i = 0;
for (auto& arg : arguments) {
ASSERT(arg.isVector());
auto& vector = arg.toVector();
ASSERT(vector.elements.size() == rows);
for (auto& element : vector.elements)
result.elements[i++] = element;
}
ASSERT(i == columns * rows);
return result;
}

#define UNARY_OPERATION(name, constraint, fn) \
Expand Down Expand Up @@ -308,7 +332,15 @@ static ConstantValue constantAdd(const Type*, const FixedVector<ConstantValue>&
{
ASSERT(arguments.size() == 2);

// FIXME: handle constant matrices
if (auto* left = std::get_if<ConstantMatrix>(&arguments[0])) {
auto& right = std::get<ConstantMatrix>(arguments[1]);
ASSERT(left->columns == right.columns);
ASSERT(left->rows == right.rows);
ConstantMatrix result(left->columns, left->rows);
for (unsigned i = 0; i < result.elements.size(); ++i)
result.elements[i] = left->elements[i].toDouble() + right.elements[i].toDouble();
return result;
}

return constantBinaryOperation<Constraints::Number>(arguments, [&](auto left, auto right) {
return left + right;
Expand All @@ -328,11 +360,68 @@ static ConstantValue constantMinus(const Type*, const FixedVector<ConstantValue>
}


static ConstantValue constantMultiply(const Type*, const FixedVector<ConstantValue>& arguments)
static ConstantValue constantDot(const Type*, const FixedVector<ConstantValue>&);
static ConstantValue constantMultiply(const Type* resultType, const FixedVector<ConstantValue>& arguments)
{
ASSERT(arguments.size() == 2);

// FIXME: handle constant matrices
auto* leftMatrix = std::get_if<ConstantMatrix>(&arguments[0]);
auto* rightMatrix = std::get_if<ConstantMatrix>(&arguments[1]);
if (leftMatrix && rightMatrix) {
ASSERT(leftMatrix->columns == rightMatrix->rows);
ConstantMatrix result(rightMatrix->columns, leftMatrix->rows);
for (unsigned i = 0; i < rightMatrix->columns; ++i) {
for (unsigned j = 0; j < leftMatrix->rows; ++j) {
double value = 0;
for (unsigned k = 0; k < leftMatrix->columns; ++k)
value += leftMatrix->elements[k * leftMatrix->rows + j].toDouble() * rightMatrix->elements[i * rightMatrix->rows + k].toDouble();
result.elements[i * result.rows + j] = value;
}
}
return result;
}
if (leftMatrix || rightMatrix) {
if (auto* rightVector = std::get_if<ConstantVector>(&arguments[1])) {
auto columns = leftMatrix->columns;
auto rows = leftMatrix->rows;
ConstantVector result(rows);
ConstantVector leftVector(columns);
for (unsigned i = 0; i < rows; ++i) {
for (unsigned j = 0; j < columns; ++j)
leftVector.elements[j] = leftMatrix->elements[j * rows + i];
result.elements[i] = constantDot(resultType, { leftVector, *rightVector });
}
return result;
}

if (auto* leftVector = std::get_if<ConstantVector>(&arguments[0])) {
auto columns = rightMatrix->columns;
auto rows = rightMatrix->rows;
ConstantVector result(columns);
ConstantVector rightVector(rows);
for (unsigned i = 0; i < columns; ++i) {
for (unsigned j = 0; j < rows; ++j)
rightVector.elements[j] = rightMatrix->elements[i * rows + j];
result.elements[i] = constantDot(resultType, { *leftVector, rightVector });
}
return result;
}

const ConstantMatrix* matrix;
double scalar;
if (leftMatrix) {
matrix = leftMatrix;
scalar = arguments[1].toDouble();
} else {
matrix = rightMatrix;
scalar = arguments[0].toDouble();
}

ConstantMatrix result(matrix->columns, matrix->rows);
for (unsigned i = 0; i < result.elements.size(); ++i)
result.elements[i] = matrix->elements[i].toDouble() * scalar;
return result;
}

return constantBinaryOperation<Constraints::Number>(arguments, [&](auto left, auto right) {
return left * right;
Expand Down Expand Up @@ -532,10 +621,58 @@ static ConstantValue constantCross(const Type*, const FixedVector<ConstantValue>

UNARY_OPERATION(Degrees, Float, [&](float arg) { return arg * (180 / std::numbers::pi); })

static ConstantValue constantDeterminant(const Type*, const FixedVector<ConstantValue>&)
static ConstantValue constantDeterminant(const Type*, const FixedVector<ConstantValue>& arguments)
{
// FIXME: we don't support matrices yet
RELEASE_ASSERT_NOT_REACHED();
ASSERT(arguments.size() == 1);
auto& matrix = std::get<ConstantMatrix>(arguments[0]);
auto columns = matrix.columns;
auto solve2 = [&](
auto a, auto b,
auto c, auto d
) {
return a * d - b * c;
};

auto solve3 = [&](
auto a, auto b, auto c,
auto d, auto e, auto f,
auto g, auto h, auto i
) {
return a * e * i + b * f * g + c * d * h - c * e * g - b * d * i - a * f * h;
};

auto solve4 = [&](
auto a, auto b, auto c, auto d,
auto e, auto f, auto g, auto h,
auto i, auto j, auto k, auto l,
auto m, auto n, auto o, auto p
) {
return a * solve3(f, g, h, j, k, l, n, o, p) - b * solve3(e, g, h, i, k, l, m, o, p) + c * solve3(e, f, h, i, j, l, m, n, p) - d * solve3(e, f, g, i, j, k, m, n, o);
};

switch (columns) {
case 2:
return solve2(
matrix.elements[0].toDouble(), matrix.elements[2].toDouble(),
matrix.elements[1].toDouble(), matrix.elements[3].toDouble()
);

case 3:
return solve3(
matrix.elements[0].toDouble(), matrix.elements[3].toDouble(), matrix.elements[6].toDouble(),
matrix.elements[1].toDouble(), matrix.elements[4].toDouble(), matrix.elements[7].toDouble(),
matrix.elements[2].toDouble(), matrix.elements[5].toDouble(), matrix.elements[8].toDouble()
);
case 4:
return solve4(
matrix.elements[0].toDouble(), matrix.elements[4].toDouble(), matrix.elements[8].toDouble(), matrix.elements[12].toDouble(),
matrix.elements[1].toDouble(), matrix.elements[5].toDouble(), matrix.elements[9].toDouble(), matrix.elements[13].toDouble(),
matrix.elements[2].toDouble(), matrix.elements[6].toDouble(), matrix.elements[10].toDouble(), matrix.elements[14].toDouble(),
matrix.elements[3].toDouble(), matrix.elements[7].toDouble(), matrix.elements[11].toDouble(), matrix.elements[15].toDouble()
);
default:
RELEASE_ASSERT_NOT_REACHED();
}
}

static ConstantValue constantLength(const Type*, const FixedVector<ConstantValue>&);
Expand Down Expand Up @@ -799,10 +936,18 @@ TERNARY_OPERATION(Smoothstep, Float, [&](auto low, auto high, auto x) {
UNARY_OPERATION(Sqrt, Float, WRAP_STD(sqrt))
BINARY_OPERATION(Step, Float, [&](auto edge, auto x) { return edge <= x ? 1.0 : 0.0; })

static ConstantValue constantTranspose(const Type*, const FixedVector<ConstantValue>&)
static ConstantValue constantTranspose(const Type*, const FixedVector<ConstantValue>& arguments)
{
// FIXME: we don't support matrices yet
RELEASE_ASSERT_NOT_REACHED();
ASSERT(arguments.size() == 1);
auto& matrix = std::get<ConstantMatrix>(arguments[0]);
auto columns = matrix.columns;
auto rows = matrix.rows;
ConstantMatrix result(rows, columns);
for (unsigned j = 0; j < rows; ++j) {
for (unsigned i = 0; i < columns; ++i)
result.elements[j * columns + i] = matrix.elements[i * rows + j];
}
return result;
}

UNARY_OPERATION(Trunc, Float, WRAP_STD(trunc))
Expand Down
11 changes: 11 additions & 0 deletions Source/WebGPU/WGSL/ConstantValue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ void ConstantValue::dump(PrintStream& out) const
out.print(element);
}
out.print(")");
},
[&](const ConstantMatrix& m) {
out.print("mat", m.columns, "x", m.rows, "(");
bool first = true;
for (const auto& element : m.elements) {
if (!first)
out.print(", ");
first = false;
out.print(element);
}
out.print(")");
});
}

Expand Down
25 changes: 23 additions & 2 deletions Source/WebGPU/WGSL/ConstantValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,28 @@ struct ConstantVector {
FixedVector<ConstantValue> elements;
};

using BaseValue = std::variant<double, int64_t, bool, ConstantArray, ConstantVector>;
struct ConstantMatrix {
ConstantMatrix(uint32_t columns, uint32_t rows)
: columns(columns)
, rows(rows)
, elements(columns * rows)
{
}

ConstantMatrix(uint32_t columns, uint32_t rows, const FixedVector<ConstantValue>& elements)
: columns(columns)
, rows(rows)
, elements(elements)
{
RELEASE_ASSERT(elements.size() == columns * rows);
}

uint32_t columns;
uint32_t rows;
FixedVector<ConstantValue> elements;
};

using BaseValue = std::variant<double, int64_t, bool, ConstantArray, ConstantVector, ConstantMatrix>;
struct ConstantValue : BaseValue {
ConstantValue() = default;

Expand All @@ -78,6 +99,7 @@ struct ConstantValue : BaseValue {
bool isInt() const { return std::holds_alternative<int64_t>(*this); }
bool isNumber() const { return isInt() || std::holds_alternative<double>(*this); }
bool isVector() const { return std::holds_alternative<ConstantVector>(*this); }
bool isMatrix() const { return std::holds_alternative<ConstantMatrix>(*this); }

bool toBool() const { return std::get<bool>(*this); }
int64_t toInt() const
Expand All @@ -96,7 +118,6 @@ struct ConstantValue : BaseValue {
}
const ConstantVector& toVector() const
{
ASSERT(isNumber());
return std::get<ConstantVector>(*this);
}
};
Expand Down
16 changes: 13 additions & 3 deletions Source/WebGPU/WGSL/Metal/MetalFunctionWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1788,9 +1788,19 @@ void FunctionDefinitionWriter::serializeConstant(const Type* type, ConstantValue
}
m_stringBuilder.append("}");
},
[&](const Matrix&) {
// Not supported yet
RELEASE_ASSERT_NOT_REACHED();
[&](const Matrix& matrixType) {
auto& matrix = std::get<ConstantMatrix>(value);
m_stringBuilder.append("matrix<");
visit(matrixType.element);
m_stringBuilder.append(", ", matrixType.columns, ", ", matrixType.rows, ">(");
bool first = true;
for (auto& element : matrix.elements) {
if (!first)
m_stringBuilder.append(", ");
first = false;
serializeConstant(matrixType.element, element);
}
m_stringBuilder.append(")");
},
[&](const Struct&) {
// Not supported yet
Expand Down
61 changes: 61 additions & 0 deletions Source/WebGPU/WGSL/tests/valid/constant-matrix.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// RUN: %metal-compile main
// RUN: %metal main 2>&1 | %check


@compute @workgroup_size(1)
fn main()
{
// CHECK-L: 8680
let x = determinant(mat4x4(
15,2,34,4,
18,2,3,4,
1,72,32,4,
17,2,3,4,
));

// CHECK-L: 8680
let y = determinant(mat4x4(
vec4(15,2,34,4),
vec4(18,2,3,4),
vec4(1,72,32,4),
vec4(17,2,3,4),
));

const m2 = mat2x2(
1,2,
3,4,
);

// CHECK-L: 1., 3., 2., 4.
let tm2 = transpose(m2);

// CHECK-L: 1., 4., 7., 2., 5., 8., 3., 6., 9.
let tm3 = transpose(mat3x3(
1,2,3,
4,5,6,
7,8,9,
));

// CHECK-L: 2., 4., 6., 8.
let x0 = m2 * 2;

const m2x3 = mat2x3(
1,2,3,
4,5,6
);

// CHECK-L: 1., 4., 2., 5., 3., 6.
let x1 = transpose(m2x3);

// CHECK-L: 10., 14., 18.
let x2 = m2x3 * vec2(2);

// CHECK-L: 12., 30.
let x3 = vec3(2) * m2x3;

// CHECK-L: 32
let x4 = dot(vec3(1,2,3), vec3(4,5,6));

// CHECK-L: 95., 128., 68., 92., 41., 56., 14., 20.
let x5 = mat3x2(1,2,3,4,5,6) * mat4x3(12,11,10,9,8,7,6,5,4,3,2,1);
}

0 comments on commit d041f3c

Please sign in to comment.