Skip to content

Commit

Permalink
Make sure Kernel can have dynamic workgroupsize
Browse files Browse the repository at this point in the history
  • Loading branch information
sjfricke committed Jun 27, 2022
1 parent 44b6b19 commit 7a7d3a4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
26 changes: 13 additions & 13 deletions source/val/validate_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,21 +267,21 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, SpvDecoration dec,
"constants";
}
if (inst->GetOperandAs<SpvBuiltIn>(2) == SpvBuiltInWorkgroupSize) {
if (_.HasCapability(SpvCapabilityShader) &&
!spvOpcodeIsConstant(target->opcode())) {
if (spvOpcodeIsConstant(target->opcode())) {
// can only validate product if static
uint64_t x_size, y_size, z_size;
bool static_x = _.GetConstantValUint64(target->word(3), &x_size);
bool static_y = _.GetConstantValUint64(target->word(4), &y_size);
bool static_z = _.GetConstantValUint64(target->word(5), &z_size);
if (static_x && static_y && static_z &&
((x_size * y_size * z_size) == 0)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "WorkgroupSize decorations must not have a static "
"product of zero.";
}
} else if (_.HasCapability(SpvCapabilityShader)) {
return fail(0) << "must be a constant for WorkgroupSize";
}

uint64_t x_size, y_size, z_size;
bool static_x = _.GetConstantValUint64(target->word(3), &x_size);
bool static_y = _.GetConstantValUint64(target->word(4), &y_size);
bool static_z = _.GetConstantValUint64(target->word(5), &z_size);
if (static_x && static_y && static_z &&
((x_size * y_size * z_size) == 0)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "WorkgroupSize decorations must not have a product of "
"zero.";
}
} else if (target->opcode() != SpvOpVariable) {
return fail(0) << "must be a variable";
}
Expand Down
29 changes: 25 additions & 4 deletions test/val/val_modes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ OpDecorate %int3_1 BuiltIn WorkgroupSize
EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("WorkgroupSize decorations must not have a product of zero."));
HasSubstr(
"WorkgroupSize decorations must not have a static product of zero."));
}

TEST_F(ValidateMode, GLComputeZeroSpecWorkgroupSize) {
Expand All @@ -126,10 +127,11 @@ OpDecorate %int3_1 BuiltIn WorkgroupSize
EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("WorkgroupSize decorations must not have a product of zero."));
HasSubstr(
"WorkgroupSize decorations must not have a static product of zero."));
}

TEST_F(ValidateMode, KernelZeroWorkgroupSize) {
TEST_F(ValidateMode, KernelZeroWorkgroupSizeConstant) {
const std::string spirv = R"(
OpCapability Addresses
OpCapability Linkage
Expand All @@ -148,7 +150,26 @@ OpDecorate %int3_1 BuiltIn WorkgroupSize
EXPECT_THAT(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("WorkgroupSize decorations must not have a product of zero."));
HasSubstr(
"WorkgroupSize decorations must not have a static product of zero."));
}

TEST_F(ValidateMode, KernelZeroWorkgroupSizeVariable) {
const std::string spirv = R"(
OpCapability Addresses
OpCapability Linkage
OpCapability Kernel
OpMemoryModel Physical32 OpenCL
OpEntryPoint Kernel %main "main"
OpDecorate %var BuiltIn WorkgroupSize
%int = OpTypeInt 32 0
%int3 = OpTypeVector %int 3
%ptr = OpTypePointer Input %int3
%var = OpVariable %ptr Input
)" + kVoidFunction;

CompileSuccessfully(spirv);
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}

TEST_F(ValidateMode, GLComputeVulkanLocalSize) {
Expand Down

0 comments on commit 7a7d3a4

Please sign in to comment.