Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

spirv-val: Validate zero product workgroup size #4828

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions source/val/validate_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,19 +260,15 @@ spv_result_t ValidateDecorationTarget(ValidationState_t& _, SpvDecoration dec,
}
break;
case SpvDecorationBuiltIn:
if (target->opcode() != SpvOpVariable &&
!spvOpcodeIsConstant(target->opcode())) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "BuiltIns can only target variables, structure members or "
"constants";
}
if (_.HasCapability(SpvCapabilityShader) &&
inst->GetOperandAs<SpvBuiltIn>(2) == SpvBuiltInWorkgroupSize) {
if (target->opcode() != SpvOpVariable) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check could be greatly improved or just moved into the validate_builtins.cpp as well but for sake of not overloading this PR too much will leave for a post-PR fix

if (!spvOpcodeIsConstant(target->opcode())) {
return fail(0) << "must be a constant for WorkgroupSize";
return fail(0)
<< "BuiltIns can only target variables, structure members or "
"constants";
} else if (inst->GetOperandAs<SpvBuiltIn>(2) !=
SpvBuiltInWorkgroupSize) {
return fail(0) << "must be a variable";
}
} else if (target->opcode() != SpvOpVariable) {
return fail(0) << "must be a variable";
}
break;
case SpvDecorationNoPerspective:
Expand Down
43 changes: 29 additions & 14 deletions source/val/validate_builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3123,16 +3123,6 @@ spv_result_t BuiltInsValidator::ValidateI32Vec4InputAtDefinition(

spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtDefinition(
const Decoration& decoration, const Instruction& inst) {
if (spvIsVulkanEnv(_.context()->target_env)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dneto0

the type of the constant is not checked. So someone could feed a bad module and still get a buffer overrund, I think.

In Vulkan we have the VUID

The variable decorated with WorkgroupSize must be declared as a three-component vector of 32-bit integer values

I tried to write a Kernel shader but I can't think of how it could use a Builtin WorkgroupSize without an ivec3 so I changed this to a global SPIR-V check

if (spvIsVulkanEnv(_.context()->target_env) &&
!spvOpcodeIsConstant(inst.opcode())) {
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< _.VkErrorID(4426)
<< "Vulkan spec requires BuiltIn WorkgroupSize to be a "
"constant. "
<< GetIdDesc(inst) << " is not a constant.";
}

if (spv_result_t error = ValidateI32Vec(
decoration, inst, 3,
[this, &inst](const std::string& message) -> spv_result_t {
Expand All @@ -3145,7 +3135,29 @@ spv_result_t BuiltInsValidator::ValidateWorkgroupSizeAtDefinition(
})) {
return error;
}
}

if (!spvOpcodeIsConstant(inst.opcode())) {
if (spvIsVulkanEnv(_.context()->target_env)) {
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< _.VkErrorID(4426)
<< "Vulkan spec requires BuiltIn WorkgroupSize to be a "
"constant. "
<< GetIdDesc(inst) << " is not a constant.";
}
} else {
// can only validate product if static
uint64_t x_size, y_size, z_size;
// ValidateI32Vec above confirms there will be 3 words to read
bool static_x = _.GetConstantValUint64(inst.word(3), &x_size);
bool static_y = _.GetConstantValUint64(inst.word(4), &y_size);
bool static_z = _.GetConstantValUint64(inst.word(5), &z_size);
if (static_x && static_y && static_z &&
((x_size * y_size * z_size) == 0)) {
return _.diag(SPV_ERROR_INVALID_DATA, &inst)
<< "WorkgroupSize decorations must not have a static "
"product of zero.";
}
}

// Seed at reference checks with this built-in.
return ValidateWorkgroupSizeAtReference(decoration, inst, inst, inst);
Expand Down Expand Up @@ -4081,6 +4093,11 @@ spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinition(
const Decoration& decoration, const Instruction& inst) {
const SpvBuiltIn label = SpvBuiltIn(decoration.params()[0]);

// TODO - universal check needed before early return
if (label == SpvBuiltInWorkgroupSize) {
return ValidateWorkgroupSizeAtDefinition(decoration, inst);
}

if (!spvIsVulkanEnv(_.context()->target_env)) {
// Early return. All currently implemented rules are based on Vulkan spec.
//
Expand Down Expand Up @@ -4183,9 +4200,6 @@ spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinition(
case SpvBuiltInVertexIndex: {
return ValidateVertexIndexAtDefinition(decoration, inst);
}
case SpvBuiltInWorkgroupSize: {
return ValidateWorkgroupSizeAtDefinition(decoration, inst);
}
case SpvBuiltInVertexId: {
return ValidateVertexIdAtDefinition(decoration, inst);
}
Expand Down Expand Up @@ -4247,6 +4261,7 @@ spv_result_t BuiltInsValidator::ValidateSingleBuiltInAtDefinition(
case SpvBuiltInCullMaskKHR: {
return ValidateRayTracingBuiltinsAtDefinition(decoration, inst);
}
case SpvBuiltInWorkgroupSize: // validated above
case SpvBuiltInWorkDim:
case SpvBuiltInGlobalSize:
case SpvBuiltInEnqueuedWorkgroupSize:
Expand Down
6 changes: 2 additions & 4 deletions source/val/validate_instruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,8 @@ spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
}
_.set_addressing_model(inst->GetOperandAs<SpvAddressingModel>(0));
_.set_memory_model(inst->GetOperandAs<SpvMemoryModel>(1));
} else if (opcode == SpvOpExecutionMode) {
const uint32_t entry_point = inst->word(1);
_.RegisterExecutionModeForEntryPoint(entry_point,
SpvExecutionMode(inst->word(2)));
} else if (opcode == SpvOpExecutionMode || opcode == SpvOpExecutionModeId) {
_.RegisterExecutionModeForEntryPoint(inst);
} else if (opcode == SpvOpVariable) {
const auto storage_class = inst->GetOperandAs<SpvStorageClass>(2);
if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) {
Expand Down
28 changes: 28 additions & 0 deletions source/val/validate_mode_setting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,34 @@ spv_result_t ValidateEntryPoint(ValidationState_t& _, const Instruction* inst) {
}
}

const Instruction* local_size_inst =
_.GetLocalSizeInstruction(entry_point_id);
if (local_size_inst) {
const uint32_t x = local_size_inst->word(3);
const uint32_t y = local_size_inst->word(4);
const uint32_t z = local_size_inst->word(5);
if ((x * y * z) == 0) {
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
<< "Local Size execution mode must not have a product of zero.";
}
}

const Instruction* local_size_id_inst =
_.GetLocalSizeIdInstruction(entry_point_id);
if (local_size_id_inst) {
uint64_t x_size, y_size, z_size;
bool static_x =
_.GetConstantValUint64(local_size_id_inst->word(3), &x_size);
bool static_y =
_.GetConstantValUint64(local_size_id_inst->word(4), &y_size);
bool static_z =
_.GetConstantValUint64(local_size_id_inst->word(5), &z_size);
if (static_x && static_y && static_z && ((x_size * y_size * z_size) == 0)) {
return _.diag(SPV_ERROR_INVALID_DATA, local_size_inst)
<< "Local Size Id execution mode must not have a product of zero.";
}
}

return SPV_SUCCESS;
}

Expand Down
79 changes: 54 additions & 25 deletions source/val/validation_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,29 @@ class ValidationState_t {
/// instruction
bool in_block() const;

/// Description is unique and not shared by all entry points with same id
struct EntryPointDescription {
std::string name;
std::vector<uint32_t> interfaces;
};

/// All the meta data about a single entry point id
struct EntryPointData {
std::vector<EntryPointDescription> descriptions;
/// It is presumed that the same function could theoretically be used as
/// 'main' by multiple OpEntryPoint instructions.
std::set<SpvExecutionModel> execution_models;
std::set<SpvExecutionMode> execution_modes;
const Instruction* local_size = nullptr;
const Instruction* local_size_id = nullptr;
};

/// Registers |id| as an entry point with |execution_model| and |interfaces|.
void RegisterEntryPoint(const uint32_t id, SpvExecutionModel execution_model,
EntryPointDescription&& desc) {
entry_points_.push_back(id);
entry_point_to_execution_models_[id].insert(execution_model);
entry_point_descriptions_[id].emplace_back(desc);
entry_point_data_[id].execution_models.insert(execution_model);
entry_point_data_[id].descriptions.emplace_back(desc);
}

/// Returns a list of entry point function ids
Expand All @@ -233,38 +245,66 @@ class ValidationState_t {
}

/// Registers execution mode for the given entry point.
void RegisterExecutionModeForEntryPoint(uint32_t entry_point,
SpvExecutionMode execution_mode) {
entry_point_to_execution_modes_[entry_point].insert(execution_mode);
void RegisterExecutionModeForEntryPoint(const Instruction* inst) {
const uint32_t entry_point = inst->word(1);
const SpvExecutionMode mode = SpvExecutionMode(inst->word(2));
entry_point_data_[entry_point].execution_modes.insert(mode);

// Save for now since the IDs might have not been parsed yet
if (mode == SpvExecutionModeLocalSize) {
entry_point_data_[entry_point].local_size = inst;
} else if (mode == SpvExecutionModeLocalSizeId) {
entry_point_data_[entry_point].local_size_id = inst;
}
}

/// Returns the interface descriptions of a given entry point.
const std::vector<EntryPointDescription>& entry_point_descriptions(
uint32_t entry_point) {
return entry_point_descriptions_.at(entry_point);
return entry_point_data_[entry_point].descriptions;
}

/// Returns Execution Models for the given Entry Point.
/// Returns nullptr if none found (would trigger assertion).
const std::set<SpvExecutionModel>* GetExecutionModels(
uint32_t entry_point) const {
const auto it = entry_point_to_execution_models_.find(entry_point);
if (it == entry_point_to_execution_models_.end()) {
const auto it = entry_point_data_.find(entry_point);
if (it == entry_point_data_.end()) {
assert(0);
return nullptr;
}
return &it->second;
return &it->second.execution_models;
}

/// Returns Execution Modes for the given Entry Point.
/// Returns nullptr if none found.
const std::set<SpvExecutionMode>* GetExecutionModes(
uint32_t entry_point) const {
const auto it = entry_point_to_execution_modes_.find(entry_point);
if (it == entry_point_to_execution_modes_.end()) {
const auto it = entry_point_data_.find(entry_point);
if (it == entry_point_data_.end()) {
return nullptr;
}
return &it->second.execution_modes;
}

/// Returns the Local Size Execution Modes for the given Entry Point.
/// Returns nullptr if none found.
const Instruction* GetLocalSizeInstruction(uint32_t entry_point) const {
const auto it = entry_point_data_.find(entry_point);
if (it == entry_point_data_.end()) {
return nullptr;
}
return &it->second;
return it->second.local_size;
}

/// Returns the Local Size Id Execution Modes for the given Entry Point.
/// Returns nullptr if none found.
const Instruction* GetLocalSizeIdInstruction(uint32_t entry_point) const {
const auto it = entry_point_data_.find(entry_point);
if (it == entry_point_data_.end()) {
return nullptr;
}
return it->second.local_size_id;
}

/// Traverses call tree and computes function_to_entry_points_.
Expand Down Expand Up @@ -815,9 +855,8 @@ class ValidationState_t {
/// IDs that are entry points, ie, arguments to OpEntryPoint.
std::vector<uint32_t> entry_points_;

/// Maps an entry point id to its descriptions.
std::unordered_map<uint32_t, std::vector<EntryPointDescription>>
entry_point_descriptions_;
/// Maps an entry point id to all the information about it.
std::unordered_map<uint32_t, EntryPointData> entry_point_data_;

/// IDs that are entry points, ie, arguments to OpEntryPoint, and root a call
/// graph that recurses.
Expand Down Expand Up @@ -872,16 +911,6 @@ class ValidationState_t {
/// Maps function ids to function stat objects.
std::unordered_map<uint32_t, Function*> id_to_function_;

/// Mapping entry point -> execution models. It is presumed that the same
/// function could theoretically be used as 'main' by multiple OpEntryPoint
/// instructions.
std::unordered_map<uint32_t, std::set<SpvExecutionModel>>
entry_point_to_execution_models_;

/// Mapping entry point -> execution modes.
std::unordered_map<uint32_t, std::set<SpvExecutionMode>>
entry_point_to_execution_modes_;

/// Mapping function -> array of entry points inside this
/// module which can (indirectly) call the function.
std::unordered_map<uint32_t, std::vector<uint32_t>> function_to_entry_points_;
Expand Down
13 changes: 4 additions & 9 deletions test/val/val_annotation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,9 @@ OpDecorate %var BuiltIn )" +
%var = OpVariable %ptr Input
)";

// Workgroup are valid unless used with Vulkan env (VUID 04427)
CompileSuccessfully(text);
if (deco != "WorkgroupSize") {
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
} else {
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("must be a constant for WorkgroupSize"));
}
EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
}

TEST_P(BuiltInDecorations, IntegerType) {
Expand All @@ -424,7 +419,7 @@ OpDecorate %int BuiltIn )" +
)";

CompileSuccessfully(text);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("BuiltIns can only target variables, structure members "
"or constants"));
Expand All @@ -448,7 +443,7 @@ OpFunctionEnd
)";

CompileSuccessfully(text);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("BuiltIns can only target variables, structure members "
"or constants"));
Expand Down
4 changes: 2 additions & 2 deletions test/val/val_builtins_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2923,7 +2923,7 @@ OpDecorate %copy BuiltIn WorkgroupSize
generator.entry_points_.push_back(std::move(entry_point));

CompileSuccessfully(generator.Build(), SPV_ENV_VULKAN_1_0);
ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions(SPV_ENV_VULKAN_1_0));
ASSERT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_VULKAN_1_0));
EXPECT_THAT(getDiagnosticString(),
HasSubstr("BuiltIns can only target variables, structure "
"members or constants"));
Expand Down Expand Up @@ -3825,7 +3825,7 @@ OpDecorate %void BuiltIn Position
)";

CompileSuccessfully(text);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("BuiltIns can only target variables, structure members "
"or constants"));
Expand Down
Loading