From f61db0bcc6717b4b60a7e38e7c3c58ad02110aa8 Mon Sep 17 00:00:00 2001 From: Umar Arshad Date: Fri, 3 Jun 2016 21:24:24 -0400 Subject: [PATCH] Validator structured flow checks: back-edge, constructs Skip structured control flow chekcs for non-shader capability. Fix infinite loop in dominator algorithm when there's an unreachable block. --- source/val/BasicBlock.cpp | 52 ++- source/val/BasicBlock.h | 87 ++++- source/val/Construct.cpp | 54 +++- source/val/Construct.h | 102 +++++- source/val/Function.cpp | 73 ++++- source/val/Function.h | 29 +- source/val/ValidationState.h | 3 + source/validate.h | 25 +- source/validate_cfg.cpp | 260 +++++++++++---- test/Validate.CFG.cpp | 596 ++++++++++++++++++++++++++--------- test/ValidateFixtures.cpp | 1 + 11 files changed, 1007 insertions(+), 275 deletions(-) diff --git a/source/val/BasicBlock.cpp b/source/val/BasicBlock.cpp index 55be3b8df8..4325736222 100644 --- a/source/val/BasicBlock.cpp +++ b/source/val/BasicBlock.cpp @@ -35,25 +35,38 @@ namespace libspirv { BasicBlock::BasicBlock(uint32_t id) : id_(id), immediate_dominator_(nullptr), + immediate_post_dominator_(nullptr), predecessors_(), successors_(), + type_(0), reachable_(false) {} void BasicBlock::SetImmediateDominator(BasicBlock* dom_block) { immediate_dominator_ = dom_block; } +void BasicBlock::SetImmediatePostDominator(BasicBlock* pdom_block) { + immediate_post_dominator_ = pdom_block; +} + const BasicBlock* BasicBlock::GetImmediateDominator() const { return immediate_dominator_; } +const BasicBlock* BasicBlock::GetImmediatePostDominator() const { + return immediate_post_dominator_; +} + BasicBlock* BasicBlock::GetImmediateDominator() { return immediate_dominator_; } +BasicBlock* BasicBlock::GetImmediatePostDominator() { + return immediate_post_dominator_; +} -void BasicBlock::RegisterSuccessors(vector next_blocks) { +void BasicBlock::RegisterSuccessors(const vector& next_blocks) { for (auto& block : next_blocks) { block->predecessors_.push_back(this); successors_.push_back(block); - if (block->reachable_ == false) block->set_reachability(reachable_); + if (block->reachable_ == false) block->set_reachable(reachable_); } } @@ -63,24 +76,29 @@ void BasicBlock::RegisterBranchInstruction(SpvOp branch_instruction) { } BasicBlock::DominatorIterator::DominatorIterator() : current_(nullptr) {} -BasicBlock::DominatorIterator::DominatorIterator(const BasicBlock* block) - : current_(block) {} + +BasicBlock::DominatorIterator::DominatorIterator( + const BasicBlock* block, + std::function dominator_func) + : current_(block), dom_func_(dominator_func) {} BasicBlock::DominatorIterator& BasicBlock::DominatorIterator::operator++() { - if (current_ == current_->GetImmediateDominator()) { + if (current_ == dom_func_(current_)) { current_ = nullptr; } else { - current_ = current_->GetImmediateDominator(); + current_ = dom_func_(current_); } return *this; } const BasicBlock::DominatorIterator BasicBlock::dom_begin() const { - return DominatorIterator(this); + return DominatorIterator( + this, [](const BasicBlock* b) { return b->GetImmediateDominator(); }); } BasicBlock::DominatorIterator BasicBlock::dom_begin() { - return DominatorIterator(this); + return DominatorIterator( + this, [](const BasicBlock* b) { return b->GetImmediateDominator(); }); } const BasicBlock::DominatorIterator BasicBlock::dom_end() const { @@ -91,6 +109,24 @@ BasicBlock::DominatorIterator BasicBlock::dom_end() { return DominatorIterator(); } +const BasicBlock::DominatorIterator BasicBlock::pdom_begin() const { + return DominatorIterator( + this, [](const BasicBlock* b) { return b->GetImmediatePostDominator(); }); +} + +BasicBlock::DominatorIterator BasicBlock::pdom_begin() { + return DominatorIterator( + this, [](const BasicBlock* b) { return b->GetImmediatePostDominator(); }); +} + +const BasicBlock::DominatorIterator BasicBlock::pdom_end() const { + return DominatorIterator(); +} + +BasicBlock::DominatorIterator BasicBlock::pdom_end() { + return DominatorIterator(); +} + bool operator==(const BasicBlock::DominatorIterator& lhs, const BasicBlock::DominatorIterator& rhs) { return lhs.current_ == rhs.current_; diff --git a/source/val/BasicBlock.h b/source/val/BasicBlock.h index 0cdc459ad8..8818faa542 100644 --- a/source/val/BasicBlock.h +++ b/source/val/BasicBlock.h @@ -30,10 +30,24 @@ #include "spirv/1.1/spirv.h" #include + +#include +#include #include namespace libspirv { +enum BlockType : uint32_t { + kBlockTypeUndefined, + kBlockTypeHeader, + kBlockTypeLoop, + kBlockTypeMerge, + kBlockTypeBreak, + kBlockTypeContinue, + kBlockTypeReturn, + kBlockTypeCOUNT ///< Total number of block types. (must be the last element) +}; + // This class represents a basic block in a SPIR-V module class BasicBlock { public: @@ -61,27 +75,53 @@ class BasicBlock { /// Returns the successors of the BasicBlock std::vector* get_successors() { return &successors_; } - /// Returns true if the block should be reachable in the CFG + /// Returns true if the block is reachable in the CFG bool is_reachable() const { return reachable_; } - void set_reachability(bool reachability) { reachable_ = reachability; } + /// Returns true if BasicBlock is of the given type + bool is_type(BlockType type) const { + if (type == kBlockTypeUndefined) return type_.none(); + return type_.test(type); + } + + /// Sets the reachability of the basic block in the CFG + void set_reachable(bool reachability) { reachable_ = reachability; } + + /// Sets the type of the BasicBlock + void set_type(BlockType type) { + if (type == kBlockTypeUndefined) + type_.reset(); + else + type_.set(type); + } /// Sets the immedate dominator of this basic block /// /// @param[in] dom_block The dominator block void SetImmediateDominator(BasicBlock* dom_block); + /// Sets the immedate post dominator of this basic block + /// + /// @param[in] pdom_block The post dominator block + void SetImmediatePostDominator(BasicBlock* pdom_block); + /// Returns the immedate dominator of this basic block BasicBlock* GetImmediateDominator(); /// Returns the immedate dominator of this basic block const BasicBlock* GetImmediateDominator() const; + /// Returns the immedate post dominator of this basic block + BasicBlock* GetImmediatePostDominator(); + + /// Returns the immedate post dominator of this basic block + const BasicBlock* GetImmediatePostDominator() const; + /// Ends the block without a successor void RegisterBranchInstruction(SpvOp branch_instruction); /// Adds @p next BasicBlocks as successors of this BasicBlock - void RegisterSuccessors(std::vector next = {}); + void RegisterSuccessors(const std::vector& next = {}); /// Returns true if the id of the BasicBlock matches bool operator==(const BasicBlock& other) const { return other.id_ == id_; } @@ -91,7 +131,7 @@ class BasicBlock { /// @brief A BasicBlock dominator iterator class /// - /// This iterator will iterate over the dominators of the block + /// This iterator will iterate over the (post)dominators of the block class DominatorIterator : public std::iterator { public: @@ -104,8 +144,12 @@ class BasicBlock { /// @brief Constructs an iterator for the given block which points to /// @p block /// - /// @param block The block which is referenced by the iterator - explicit DominatorIterator(const BasicBlock* block); + /// @param block The block which is referenced by the iterator + /// @param dominator_func This function will be called to get the immediate + /// (post)dominator of the current block + DominatorIterator( + const BasicBlock* block, + std::function dominator_func); /// @brief Advances the iterator DominatorIterator& operator++(); @@ -118,16 +162,36 @@ class BasicBlock { private: const BasicBlock* current_; + std::function dom_func_; }; - /// Returns an iterator which points to the current block + /// Returns a dominator iterator which points to the current block const DominatorIterator dom_begin() const; + + /// Returns a dominator iterator which points to the current block DominatorIterator dom_begin(); - /// Returns an iterator which points to one element past the first block + /// Returns a dominator iterator which points to one element past the first + /// block const DominatorIterator dom_end() const; + + /// Returns a dominator iterator which points to one element past the first + /// block DominatorIterator dom_end(); + /// Returns a post dominator iterator which points to the current block + const DominatorIterator pdom_begin() const; + /// Returns a post dominator iterator which points to the current block + DominatorIterator pdom_begin(); + + /// Returns a post dominator iterator which points to one element past the + /// last block + const DominatorIterator pdom_end() const; + + /// Returns a post dominator iterator which points to one element past the + /// last block + DominatorIterator pdom_end(); + private: /// Id of the BasicBlock const uint32_t id_; @@ -135,12 +199,19 @@ class BasicBlock { /// Pointer to the immediate dominator of the BasicBlock BasicBlock* immediate_dominator_; + /// Pointer to the immediate dominator of the BasicBlock + BasicBlock* immediate_post_dominator_; + /// The set of predecessors of the BasicBlock std::vector predecessors_; /// The set of successors of the BasicBlock std::vector successors_; + /// The type of the block + std::bitset type_; + + /// True if the block is reachable in the CFG bool reachable_; }; diff --git a/source/val/Construct.cpp b/source/val/Construct.cpp index 91140bfcf9..6bd1b5d239 100644 --- a/source/val/Construct.cpp +++ b/source/val/Construct.cpp @@ -26,19 +26,51 @@ #include "val/Construct.h" +#include +#include + namespace libspirv { -Construct::Construct(BasicBlock* header_block, BasicBlock* merge_block, - BasicBlock* continue_block) - : header_block_(header_block), - merge_block_(merge_block), - continue_block_(continue_block) {} +Construct::Construct(ConstructType type, BasicBlock* entry, + BasicBlock* exit, std::vector constructs) + : type_(type), + corresponding_constructs_(constructs), + entry_block_(entry), + exit_block_(exit) {} + +ConstructType Construct::get_type() const { return type_; } + +const std::vector& Construct::get_corresponding_constructs() const { + return corresponding_constructs_; +} +std::vector& Construct::get_corresponding_constructs() { + return corresponding_constructs_; +} + +bool ValidateConstructSize(ConstructType type, size_t size) { + switch (type) { + case ConstructType::kSelection: return size == 0; + case ConstructType::kContinue: return size == 1; + case ConstructType::kLoop: return size == 1; + case ConstructType::kCase: return size >= 1; + default: assert(1 == 0 && "Type not defined"); + } + return false; +} + +void Construct::set_corresponding_constructs( + std::vector constructs) { + assert(ValidateConstructSize(type_, constructs.size())); + corresponding_constructs_ = constructs; +} + +const BasicBlock* Construct::get_entry() const { return entry_block_; } +BasicBlock* Construct::get_entry() { return entry_block_; } -const BasicBlock* Construct::get_header() const { return header_block_; } -const BasicBlock* Construct::get_merge() const { return merge_block_; } -const BasicBlock* Construct::get_continue() const { return continue_block_; } +const BasicBlock* Construct::get_exit() const { return exit_block_; } +BasicBlock* Construct::get_exit() { return exit_block_; } -BasicBlock* Construct::get_header() { return header_block_; } -BasicBlock* Construct::get_merge() { return merge_block_; } -BasicBlock* Construct::get_continue() { return continue_block_; } +void Construct::set_exit(BasicBlock* exit_block) { + exit_block_ = exit_block; } +} /// namespace libspirv diff --git a/source/val/Construct.h b/source/val/Construct.h index ef5fae43fd..b87c99afe7 100644 --- a/source/val/Construct.h +++ b/source/val/Construct.h @@ -28,29 +28,109 @@ #define LIBSPIRV_VAL_CONSTRUCT_H_ #include +#include namespace libspirv { +enum class ConstructType { + kNone, + /// The set of blocks dominated by a selection header, minus the set of blocks + /// dominated by the header's merge block + kSelection, + /// The set of blocks dominated by an OpLoopMerge's Continue Target and post + /// dominated by the corresponding back + kContinue, + /// The set of blocks dominated by a loop header, minus the set of blocks + /// dominated by the loop's merge block, minus the loop's corresponding + /// continue construct + kLoop, + /// The set of blocks dominated by an OpSwitch's Target or Default, minus the + /// set of blocks dominated by the OpSwitch's merge block (this construct is + /// only defined for those OpSwitch Target or Default that are not equal to + /// the OpSwitch's corresponding merge block) + kCase +}; + class BasicBlock; /// @brief This class tracks the CFG constructs as defined in the SPIR-V spec class Construct { public: - Construct(BasicBlock* header_block, BasicBlock* merge_block, - BasicBlock* continue_block = nullptr); + Construct(ConstructType type, BasicBlock* dominator, + BasicBlock* exit = nullptr, + std::vector constructs = {}); + + /// Returns the type of the construct + ConstructType get_type() const; + + const std::vector& get_corresponding_constructs() const; + std::vector& get_corresponding_constructs(); + void set_corresponding_constructs(std::vector constructs); + + /// Returns the dominator block of the construct. + /// + /// This is usually the header block or the first block of the construct. + const BasicBlock* get_entry() const; - const BasicBlock* get_header() const; - const BasicBlock* get_merge() const; - const BasicBlock* get_continue() const; + /// Returns the dominator block of the construct. + /// + /// This is usually the header block or the first block of the construct. + BasicBlock* get_entry(); - BasicBlock* get_header(); - BasicBlock* get_merge(); - BasicBlock* get_continue(); + /// Returns the exit block of the construct. + /// + /// For a continue construct it is the backedge block of the corresponding + /// loop construct. For the case construct it is the block that branches to + /// the OpSwitch merge block or other case blocks. Otherwise it is the merge + /// block of the corresponding header block + const BasicBlock* get_exit() const; + + /// Returns the exit block of the construct. + /// + /// For a continue construct it is the backedge block of the corresponding + /// loop construct. For the case construct it is the block that branches to + /// the OpSwitch merge block or other case blocks. Otherwise it is the merge + /// block of the corresponding header block + BasicBlock* get_exit(); + + /// Sets the exit block for this construct. This is useful for continue + /// constructs which do not know the back-edge block during construction + void set_exit(BasicBlock* exit_block); private: - BasicBlock* header_block_; ///< The header block of a loop or selection - BasicBlock* merge_block_; ///< The merge block of a loop or selection - BasicBlock* continue_block_; ///< The continue block of a loop block + /// The type of the construct + ConstructType type_; + + /// These are the constructs that are related to this construct. These + /// constructs can be the continue construct, for the corresponding loop + /// construct, the case construct that are part of the same OpSwitch + /// instruction + /// + /// Here is a table that describes what constructs are included in + /// @p corresponding_constructs_ + /// | this construct | corresponding construct | + /// |----------------|----------------------------------| + /// | loop | continue | + /// | continue | loop | + /// | case | other cases in the same OpSwitch | + /// + /// kContinue and kLoop constructs will always have corresponding + /// constructs even if they are represented by the same block + std::vector corresponding_constructs_; + + /// @brief Dominator block for the construct + /// + /// The dominator block for the construct. Depending on the construct this may + /// be a selection header, a continue target of a loop, a loop header or a + /// Target or Default block of a switch + BasicBlock* entry_block_; + + /// @brief Exiting block for the construct + /// + /// The exit block for the construct. This can be a merge block for the loop + /// and selection constructs, a back-edge block for a continue construct, or + /// the branching block for the case construct + BasicBlock* exit_block_; }; } /// namespace libspirv diff --git a/source/val/Function.cpp b/source/val/Function.cpp index 3756949f32..d2c89fbf50 100644 --- a/source/val/Function.cpp +++ b/source/val/Function.cpp @@ -29,13 +29,18 @@ #include #include +#include #include "val/BasicBlock.h" #include "val/Construct.h" #include "val/ValidationState.h" +using std::ignore; using std::list; +using std::make_pair; +using std::pair; using std::string; +using std::tie; using std::vector; namespace libspirv { @@ -66,6 +71,7 @@ Function::Function(uint32_t id, uint32_t result_type_id, declaration_type_(FunctionDecl::kFunctionDeclUnknown), blocks_(), current_block_(nullptr), + pseudo_exit_block_(kInvalidId), cfg_constructs_(), variable_ids_(), parameter_ids_() {} @@ -93,15 +99,33 @@ spv_result_t Function::RegisterLoopMerge(uint32_t merge_id, uint32_t continue_id) { RegisterBlock(merge_id, false); RegisterBlock(continue_id, false); - cfg_constructs_.emplace_back(get_current_block(), &blocks_.at(merge_id), - &blocks_.at(continue_id)); + BasicBlock& merge_block = blocks_.at(merge_id); + BasicBlock& continue_block = blocks_.at(continue_id); + assert(current_block_ && + "RegisterLoopMerge must be called when called within a block"); + + current_block_->set_type(kBlockTypeLoop); + merge_block.set_type(kBlockTypeMerge); + continue_block.set_type(kBlockTypeContinue); + cfg_constructs_.emplace_back(ConstructType::kLoop, current_block_, + &merge_block); + Construct& loop_construct = cfg_constructs_.back(); + cfg_constructs_.emplace_back(ConstructType::kContinue, &continue_block); + Construct& continue_construct = cfg_constructs_.back(); + continue_construct.set_corresponding_constructs({&loop_construct}); + loop_construct.set_corresponding_constructs({&continue_construct}); return SPV_SUCCESS; } spv_result_t Function::RegisterSelectionMerge(uint32_t merge_id) { RegisterBlock(merge_id, false); - cfg_constructs_.emplace_back(get_current_block(), &blocks_.at(merge_id)); + BasicBlock& merge_block = blocks_.at(merge_id); + current_block_->set_type(kBlockTypeHeader); + merge_block.set_type(kBlockTypeMerge); + + cfg_constructs_.emplace_back(ConstructType::kSelection, get_current_block(), + &merge_block); return SPV_SUCCESS; } @@ -152,7 +176,7 @@ spv_result_t Function::RegisterBlock(uint32_t id, bool is_definition) { undefined_blocks_.erase(id); current_block_ = &inserted_block->second; ordered_blocks_.push_back(current_block_); - if (IsFirstBlock(id)) current_block_->set_reachability(true); + if (IsFirstBlock(id)) current_block_->set_reachable(true); } else if (success) { // Block doesn't exsist but this is not a definition undefined_blocks_.insert(id); } @@ -182,6 +206,11 @@ void Function::RegisterBlockEnd(vector next_list, next_blocks.push_back(&inserted_block->second); } + if (branch_instruction == SpvOpReturn || + branch_instruction == SpvOpReturnValue) { + assert(next_blocks.empty()); + next_blocks.push_back(&pseudo_exit_block_); + } current_block_->RegisterBranchInstruction(branch_instruction); current_block_->RegisterSuccessors(next_blocks); current_block_ = nullptr; @@ -202,6 +231,11 @@ vector& Function::get_blocks() { return ordered_blocks_; } const BasicBlock* Function::get_current_block() const { return current_block_; } BasicBlock* Function::get_current_block() { return current_block_; } +BasicBlock* Function::get_pseudo_exit_block() { return &pseudo_exit_block_; } +const BasicBlock* Function::get_pseudo_exit_block() const { + return &pseudo_exit_block_; +} + const list& Function::get_constructs() const { return cfg_constructs_; } @@ -216,17 +250,32 @@ BasicBlock* Function::get_first_block() { return ordered_blocks_[0]; } -bool Function::IsMergeBlock(uint32_t merge_block_id) const { - const auto b = blocks_.find(merge_block_id); +bool Function::IsBlockType(uint32_t merge_block_id, BlockType type) const { + bool ret = false; + const BasicBlock* block; + tie(block, ignore) = GetBlock(merge_block_id); + if (block) { + ret = block->is_type(type); + } + return ret; +} + +pair Function::GetBlock(uint32_t id) const { + const auto b = blocks_.find(id); if (b != end(blocks_)) { - return cfg_constructs_.end() != - find_if(begin(cfg_constructs_), end(cfg_constructs_), - [&](const Construct& construct) { - return construct.get_merge() == &b->second; - }); + const BasicBlock* block = &(b->second); + bool defined = + undefined_blocks_.find(block->get_id()) == end(undefined_blocks_); + return make_pair(block, defined); } else { - return false; + return make_pair(nullptr, false); } } +pair Function::GetBlock(uint32_t id) { + const BasicBlock* out; + bool defined; + tie(out, defined) = const_cast(this)->GetBlock(id); + return make_pair(const_cast(out), defined); +} } /// namespace libspirv diff --git a/source/val/Function.h b/source/val/Function.h index c9f0746945..4344fe9d6e 100644 --- a/source/val/Function.h +++ b/source/val/Function.h @@ -28,9 +28,9 @@ #define LIBSPIRV_VAL_FUNCTION_H_ #include -#include -#include #include +#include +#include #include "spirv/1.1/spirv.h" #include "spirv-tools/libspirv.h" @@ -100,12 +100,19 @@ class Function { void RegisterBlockEnd(std::vector successors_list, SpvOp branch_instruction); - /// Returns true if the \p merge_block_id is a merge block - bool IsMergeBlock(uint32_t merge_block_id) const; - - /// Returns true if the \p id is the first block of this function + /// Returns true if the \p id block is the first block of this function bool IsFirstBlock(uint32_t id) const; + /// Returns true if the \p merge_block_id is a BlockType of \p type + bool IsBlockType(uint32_t merge_block_id, BlockType type) const; + + /// Returns a pair consisting of the BasicBlock with \p id and a bool + /// which is true if the block has been defined, and false if it is + /// declared but not defined. This function will return nullptr if the + /// \p id was not declared and not defined at the current point in the binary + std::pair GetBlock(uint32_t id) const; + std::pair GetBlock(uint32_t id); + /// Returns the first block of the current function const BasicBlock* get_first_block() const; @@ -142,6 +149,12 @@ class Function { /// Returns the block that is currently being parsed in the binary const BasicBlock* get_current_block() const; + /// Returns the psudo exit block + BasicBlock* get_pseudo_exit_block(); + + /// Returns the psudo exit block + const BasicBlock* get_pseudo_exit_block() const; + /// Prints a GraphViz digraph of the CFG of the current funciton void printDotGraph() const; @@ -179,6 +192,9 @@ class Function { /// The block that is currently being parsed BasicBlock* current_block_; + /// A pseudo exit block that is the successor to all return blocks + BasicBlock pseudo_exit_block_; + /// The constructs that are available in this function std::list cfg_constructs_; @@ -191,5 +207,4 @@ class Function { } /// namespace libspirv - #endif /// LIBSPIRV_VAL_FUNCTION_H_ diff --git a/source/val/ValidationState.h b/source/val/ValidationState.h index 9002634b06..9bfab09a80 100644 --- a/source/val/ValidationState.h +++ b/source/val/ValidationState.h @@ -42,6 +42,9 @@ namespace libspirv { +// Universal Limit of ResultID + 1 +static const uint32_t kInvalidId = 0x400000; + // Info about a result ID. typedef struct spv_id_info_t { /// Id value. diff --git a/source/validate.h b/source/validate.h index 6f2b89e440..74b350688a 100644 --- a/source/validate.h +++ b/source/validate.h @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -52,16 +53,25 @@ namespace libspirv { class ValidationState_t; -/// @brief Calculates dominator edges of a root basic block +/// A function that returns a vector of BasicBlocks given a BasicBlock. Used to +/// get the successor and predecessor nodes of a CFG block +using get_blocks_func = + std::function*(const BasicBlock*)>; + +/// @brief Calculates dominator edges for a set of blocks /// -/// This function calculates the dominator edges form a root BasicBlock. Uses -/// the dominator algorithm by Cooper et al. +/// This function calculates the dominator edges for a set of blocks in the CFG. +/// Uses the dominator algorithm by Cooper et al. /// -/// @param[in] first_block the root or entry BasicBlock of a function +/// @param[in] postorder A vector of blocks in post order traversal order +/// in a CFG +/// @param[in] predecessor_func Function used to get the predecessor nodes of a +/// block /// /// @return a set of dominator edges represented as a pair of blocks std::vector> CalculateDominators( - const BasicBlock& first_block); + const std::vector& postorder, + get_blocks_func predecessor_func); /// @brief Performs the Control Flow Graph checks /// @@ -76,8 +86,11 @@ spv_result_t PerformCfgChecks(ValidationState_t& _); /// provided by the @p dom_edges parameter /// /// @param[in,out] dom_edges The edges of the dominator tree +/// @param[in] set_func This function will be called to updated the Immediate +/// dominator void UpdateImmediateDominators( - std::vector>& dom_edges); + const std::vector>& dom_edges, + std::function set_func); /// @brief Prints all of the dominators of a BasicBlock /// diff --git a/source/validate_cfg.cpp b/source/validate_cfg.cpp index a1e86bba28..b687661d88 100644 --- a/source/validate_cfg.cpp +++ b/source/validate_cfg.cpp @@ -30,6 +30,8 @@ #include #include +#include +#include #include #include #include @@ -43,9 +45,13 @@ using std::find; using std::function; using std::get; +using std::ignore; using std::make_pair; using std::numeric_limits; using std::pair; +using std::set; +using std::string; +using std::tie; using std::transform; using std::unordered_map; using std::unordered_set; @@ -61,8 +67,6 @@ using bb_ptr = BasicBlock*; using cbb_ptr = const BasicBlock*; using bb_iter = vector::const_iterator; -using get_blocks_func = function*(const BasicBlock*)>; - struct block_info { cbb_ptr block; ///< pointer to the block bb_iter iter; ///< Iterator to the current child node being processed @@ -92,8 +96,8 @@ bool FindInWorkList(const vector& work_list, uint32_t id) { /// @param[in] entry The root BasicBlock of a CFG tree /// @param[in] successor_func A function which will return a pointer to the /// successor nodes -/// @param[in] preorder A function that will be called for every block in a CFG -/// following preorder traversal semantics +/// @param[in] preorder A function that will be called for every block in a +/// CFG following preorder traversal semantics /// @param[in] postorder A function that will be called for every block in a /// CFG following postorder traversal semantics /// @param[in] backedge A function that will be called when a backedge is @@ -143,45 +147,44 @@ const vector* successor(const BasicBlock* b) { return b->get_successors(); } +const vector* predecessor(const BasicBlock* b) { + return b->get_predecessors(); +} + } // namespace vector> CalculateDominators( - vector& postorder) { + const vector& postorder, get_blocks_func predecessor_func) { struct block_detail { size_t dominator; ///< The index of blocks's dominator in post order array size_t postorder_index; ///< The index of the block in the post order array }; - - const size_t undefined_dom = static_cast(postorder.size()); + const size_t undefined_dom = postorder.size(); unordered_map idoms; for (size_t i = 0; i < postorder.size(); i++) { idoms[postorder[i]] = {undefined_dom, i}; } - idoms[postorder.back()].dominator = idoms[postorder.back()].postorder_index; bool changed = true; while (changed) { changed = false; for (auto b = postorder.rbegin() + 1; b != postorder.rend(); b++) { - size_t& b_dom = idoms[*b].dominator; - const vector* predecessors = (*b)->get_predecessors(); - - // first processed predecessor + const vector* predecessors = predecessor_func(*b); + // first processed/reachable predecessor auto res = find_if(begin(*predecessors), end(*predecessors), [&idoms, undefined_dom](BasicBlock* pred) { - return idoms[pred].dominator != undefined_dom; + return idoms[pred].dominator != undefined_dom && + pred->is_reachable(); }); - assert(res != end(*predecessors)); + if (res == end(*predecessors)) continue; BasicBlock* idom = *res; size_t idom_idx = idoms[idom].postorder_index; // all other predecessors for (auto p : *predecessors) { - if (idom == p || p->is_reachable() == false) { - continue; - } + if (idom == p || p->is_reachable() == false) continue; if (idoms[p].dominator != undefined_dom) { size_t finger1 = idoms[p].postorder_index; size_t finger2 = idom_idx; @@ -196,8 +199,8 @@ vector> CalculateDominators( idom_idx = finger1; } } - if (b_dom != idom_idx) { - b_dom = idom_idx; + if (idoms[*b].dominator != idom_idx) { + idoms[*b].dominator = idom_idx; changed = true; } } @@ -213,13 +216,15 @@ vector> CalculateDominators( return out; } -void UpdateImmediateDominators(vector>& dom_edges) { +void UpdateImmediateDominators( + const vector>& dom_edges, + function set_func) { for (auto& edge : dom_edges) { - get<0>(edge)->SetImmediateDominator(get<1>(edge)); + set_func(get<0>(edge), get<1>(edge)); } } -void printDominatorList(BasicBlock& b) { +void printDominatorList(const BasicBlock& b) { std::cout << b.get_id() << " is dominated by: "; const BasicBlock* bb = &b; while (bb->GetImmediateDominator() != bb) { @@ -244,7 +249,7 @@ spv_result_t FirstBlockAssert(ValidationState_t& _, uint32_t target) { } spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) { - if (_.get_current_function().IsMergeBlock(merge_block)) { + if (_.get_current_function().IsBlockType(merge_block, kBlockTypeMerge)) { return _.diag(SPV_ERROR_INVALID_CFG) << "Block " << _.getIdName(merge_block) << " is already a merge block for another header"; @@ -252,21 +257,188 @@ spv_result_t MergeBlockAssert(ValidationState_t& _, uint32_t merge_block) { return SPV_SUCCESS; } +/// Update the continue construct's exit blocks once the backedge blocks are +/// identified in the CFG. +void UpdateContinueConstructExitBlocks( + Function& function, const vector>& back_edges) { + auto& constructs = function.get_constructs(); + // TODO(umar): Think of a faster way to do this + for (auto& edge : back_edges) { + uint32_t back_edge_block_id; + uint32_t loop_header_block_id; + tie(back_edge_block_id, loop_header_block_id) = edge; + + auto is_this_header = [=](Construct& c) { + return c.get_type() == ConstructType::kLoop && + c.get_entry()->get_id() == loop_header_block_id; + }; + + for (auto construct : constructs) { + if (is_this_header(construct)) { + Construct* continue_construct = + construct.get_corresponding_constructs().back(); + assert(continue_construct->get_type() == ConstructType::kContinue); + + BasicBlock* back_edge_block; + tie(back_edge_block, ignore) = function.GetBlock(back_edge_block_id); + continue_construct->set_exit(back_edge_block); + } + } + } +} + +/// Constructs an error message for construct validation errors +string ConstructErrorString(const Construct& construct, + const string& header_string, + const string& exit_string, + bool post_dominate = false) { + string construct_name; + string header_name; + string exit_name; + string dominate_text; + if (post_dominate) { + dominate_text = "is not post dominated by"; + } else { + dominate_text = "does not dominate"; + } + + switch (construct.get_type()) { + case ConstructType::kSelection: + construct_name = "selection"; + header_name = "selection header"; + exit_name = "merge block"; + break; + case ConstructType::kLoop: + construct_name = "loop"; + header_name = "loop header"; + exit_name = "merge block"; + break; + case ConstructType::kContinue: + construct_name = "continue"; + header_name = "continue target"; + exit_name = "back-edge block"; + break; + case ConstructType::kCase: + construct_name = "case"; + header_name = "case block"; + exit_name = "exit block"; // TODO(umar): there has to be a better name + break; + default: + assert(1 == 0 && "Not defined type"); + } + // TODO(umar): Add header block for continue constructs to error message + return "The " + construct_name + " construct with the " + header_name + " " + + header_string + " " + dominate_text + " the " + exit_name + " " + + exit_string; +} + +spv_result_t StructuredControlFlowChecks( + const ValidationState_t& _, const Function& function, + const vector>& back_edges) { + /// Check all backedges target only loop headers and have exactly one + /// back-edge branching to it + set loop_headers; + for (auto back_edge : back_edges) { + uint32_t back_edge_block; + uint32_t header_block; + tie(back_edge_block, header_block) = back_edge; + if (!function.IsBlockType(header_block, kBlockTypeLoop)) { + return _.diag(SPV_ERROR_INVALID_CFG) + << "Back-edges (" << _.getIdName(back_edge_block) << " -> " + << _.getIdName(header_block) + << ") can only be formed between a block and a loop header."; + } + bool success; + tie(ignore, success) = loop_headers.insert(header_block); + if (!success) { + // TODO(umar): List the back-edge blocks that are branching to loop + // header + return _.diag(SPV_ERROR_INVALID_CFG) + << "Loop header " << _.getIdName(header_block) + << " targeted by multiple back-edges"; + } + } + + // Check construct rules + for (const Construct& construct : function.get_constructs()) { + auto header = construct.get_entry(); + auto merge = construct.get_exit(); + + // if the merge block is reachable then it's dominated by the header + if (merge->is_reachable() && + find(merge->dom_begin(), merge->dom_end(), header) == + merge->dom_end()) { + return _.diag(SPV_ERROR_INVALID_CFG) + << ConstructErrorString(construct, _.getIdName(header->get_id()), + _.getIdName(merge->get_id())); + } + if (construct.get_type() == ConstructType::kContinue) { + if (find(header->pdom_begin(), header->pdom_end(), merge) == + merge->pdom_end()) { + return _.diag(SPV_ERROR_INVALID_CFG) + << ConstructErrorString(construct, _.getIdName(header->get_id()), + _.getIdName(merge->get_id()), true); + } + } + // TODO(umar): an OpSwitch block dominates all its defined case + // constructs + // TODO(umar): each case construct has at most one branch to another + // case construct + // TODO(umar): each case construct is branched to by at most one other + // case construct + // TODO(umar): if Target T1 branches to Target T2, or if Target T1 + // branches to the Default and the Default branches to Target T2, then + // T1 must immediately precede T2 in the list of the OpSwitch Target + // operands + } + return SPV_SUCCESS; +} + spv_result_t PerformCfgChecks(ValidationState_t& _) { for (auto& function : _.get_functions()) { + // Check all referenced blocks are defined within a function + if (function.get_undefined_block_count() != 0) { + string undef_blocks("{"); + for (auto undefined_block : function.get_undefined_blocks()) { + undef_blocks += _.getIdName(undefined_block) + " "; + } + return _.diag(SPV_ERROR_INVALID_CFG) + << "Block(s) " << undef_blocks << "\b}" + << " are referenced but not defined in function " + << _.getIdName(function.get_id()); + } + // Updates each blocks immediate dominators vector postorder; + vector postdom_postorder; vector> back_edges; if (auto* first_block = function.get_first_block()) { + /// calculate dominators DepthFirstTraversal(*first_block, successor, [](cbb_ptr) {}, [&](cbb_ptr b) { postorder.push_back(b); }, [&](cbb_ptr from, cbb_ptr to) { back_edges.emplace_back(from->get_id(), to->get_id()); }); - auto edges = libspirv::CalculateDominators(postorder); - libspirv::UpdateImmediateDominators(edges); + auto edges = libspirv::CalculateDominators(postorder, predecessor); + libspirv::UpdateImmediateDominators( + edges, [](bb_ptr block, bb_ptr dominator) { + block->SetImmediateDominator(dominator); + }); + + /// calculate post dominators + auto exit_block = function.get_pseudo_exit_block(); + DepthFirstTraversal(*exit_block, predecessor, [](cbb_ptr) {}, + [&](cbb_ptr b) { postdom_postorder.push_back(b); }, + [&](cbb_ptr, cbb_ptr) {}); + auto postdom_edges = + libspirv::CalculateDominators(postdom_postorder, successor); + libspirv::UpdateImmediateDominators( + postdom_edges, [](bb_ptr block, bb_ptr dominator) { + block->SetImmediatePostDominator(dominator); + }); } + UpdateContinueConstructExitBlocks(function, back_edges); // Check if the order of blocks in the binary appear before the blocks they // dominate @@ -284,41 +456,10 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) { } } - // Check all referenced blocks are defined within a function - if (function.get_undefined_block_count() != 0) { - std::stringstream ss; - ss << "{"; - for (auto undefined_block : function.get_undefined_blocks()) { - ss << _.getIdName(undefined_block) << " "; - } - return _.diag(SPV_ERROR_INVALID_CFG) - << "Block(s) " << ss.str() << "\b}" - << " are referenced but not defined in function " - << _.getIdName(function.get_id()); + /// Structured control flow checks are only required for shader capabilities + if (_.hasCapability(SpvCapabilityShader)) { + spvCheckReturn(StructuredControlFlowChecks(_, function, back_edges)); } - - // Check all headers dominate their merge blocks - for (Construct& construct : function.get_constructs()) { - auto header = construct.get_header(); - auto merge = construct.get_merge(); - // auto cont = construct.get_continue(); - - if (merge->is_reachable() && - find(merge->dom_begin(), merge->dom_end(), header) == - merge->dom_end()) { - return _.diag(SPV_ERROR_INVALID_CFG) - << "Header block " << _.getIdName(header->get_id()) - << " doesn't dominate its merge block " - << _.getIdName(merge->get_id()); - } - } - - // TODO(umar): All CFG back edges must branch to a loop header, with each - // loop header having exactly one back edge branching to it - - // TODO(umar): For a given loop, its back-edge block must post dominate the - // OpLoopMerge's Continue Target, and that Continue Target must dominate the - // back-edge block } return SPV_SUCCESS; } @@ -331,7 +472,6 @@ spv_result_t CfgPass(ValidationState_t& _, spvCheckReturn(_.get_current_function().RegisterBlock(inst->result_id)); break; case SpvOpLoopMerge: { - // TODO(umar): mark current block as a loop header uint32_t merge_block = inst->words[inst->operands[0].offset]; uint32_t continue_block = inst->words[inst->operands[1].offset]; CFG_ASSERT(MergeBlockAssert, merge_block); diff --git a/test/Validate.CFG.cpp b/test/Validate.CFG.cpp index 28229ccedd..5a713c8773 100644 --- a/test/Validate.CFG.cpp +++ b/test/Validate.CFG.cpp @@ -56,7 +56,7 @@ using ::testing::MatchesRegex; using libspirv::BasicBlock; using libspirv::ValidationState_t; -using ValidateCFG = spvtest::ValidateBase; +using ValidateCFG = spvtest::ValidateBase; using spvtest::ScopedContext; namespace { @@ -160,34 +160,52 @@ Block& operator>>(Block& lhs, Block& successor) { return lhs; } -string header = - "OpCapability Shader\n" - "OpMemoryModel Logical GLSL450\n"; +const char* header(SpvCapability cap) { + static const char* shader_header = + "OpCapability Shader\n" + "OpMemoryModel Logical GLSL450\n"; -string types_consts = - "%voidt = OpTypeVoid\n" - "%boolt = OpTypeBool\n" - "%intt = OpTypeInt 32 1\n" - "%one = OpConstant %intt 1\n" - "%two = OpConstant %intt 2\n" - "%ptrt = OpTypePointer Function %intt\n" - "%funct = OpTypeFunction %voidt\n"; + static const char* kernel_header = + "OpCapability Kernel\n" + "OpMemoryModel Logical OpenCL\n"; -TEST_F(ValidateCFG, Simple) { - Block first("first"); + return (cap == SpvCapabilityShader) ? shader_header : kernel_header; +} + +const char* types_consts() { + static const char* types = + "%voidt = OpTypeVoid\n" + "%boolt = OpTypeBool\n" + "%intt = OpTypeInt 32 1\n" + "%one = OpConstant %intt 1\n" + "%two = OpConstant %intt 2\n" + "%ptrt = OpTypePointer Function %intt\n" + "%funct = OpTypeFunction %voidt\n"; + + return types; +} + +INSTANTIATE_TEST_CASE_P(StructuredControlFlow, ValidateCFG, + ::testing::Values(SpvCapabilityShader, + SpvCapabilityKernel)); + +TEST_P(ValidateCFG, Simple) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); Block loop("loop", SpvOpBranchConditional); Block cont("cont"); Block merge("merge", SpvOpReturn); - loop.setBody( - "%cond = OpSLessThan %intt %one %two\n" - "OpLoopMerge %merge %cont None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) { + loop.setBody("OpLoopMerge %merge %cont None\n"); + } - string str = header + nameOps("loop", "first", "cont", "merge", - make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + nameOps("loop", "entry", "cont", "merge", + make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; - str += first >> loop; + str += entry >> loop; str += loop >> vector({cont, merge}); str += cont >> loop; str += merge; @@ -197,15 +215,15 @@ TEST_F(ValidateCFG, Simple) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, Variable) { +TEST_P(ValidateCFG, Variable) { Block entry("entry"); Block cont("cont"); Block exit("exit", SpvOpReturn); entry.setBody("%var = OpVariable %ptrt Function\n"); - string str = header + nameOps(make_pair("func", "Main")) + types_consts + - " %func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + nameOps(make_pair("func", "Main")) + + types_consts() + " %func = OpFunction %voidt None %funct\n"; str += entry >> cont; str += cont >> exit; str += exit; @@ -215,7 +233,7 @@ TEST_F(ValidateCFG, Variable) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, VariableNotInFirstBlockBad) { +TEST_P(ValidateCFG, VariableNotInFirstBlockBad) { Block entry("entry"); Block cont("cont"); Block exit("exit", SpvOpReturn); @@ -223,8 +241,8 @@ TEST_F(ValidateCFG, VariableNotInFirstBlockBad) { // This operation should only be performed in the entry block cont.setBody("%var = OpVariable %ptrt Function\n"); - string str = header + nameOps(make_pair("func", "Main")) + types_consts + - " %func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + nameOps(make_pair("func", "Main")) + + types_consts() + " %func = OpFunction %voidt None %funct\n"; str += entry >> cont; str += cont >> exit; @@ -239,18 +257,19 @@ TEST_F(ValidateCFG, VariableNotInFirstBlockBad) { "Variables can only be defined in the first block of a function")); } -TEST_F(ValidateCFG, BlockAppearsBeforeDominatorBad) { +TEST_P(ValidateCFG, BlockAppearsBeforeDominatorBad) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block cont("cont"); Block branch("branch", SpvOpBranchConditional); Block merge("merge", SpvOpReturn); - branch.setBody( - " %cond = OpSLessThan %intt %one %two\n" - "OpSelectionMerge %merge None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) branch.setBody("OpSelectionMerge %merge None\n"); - string str = header + nameOps("cont", "branch", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("cont", "branch", make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> branch; str += cont >> merge; // cont appears before its dominator @@ -265,20 +284,22 @@ TEST_F(ValidateCFG, BlockAppearsBeforeDominatorBad) { "before its dominator .\\[branch\\]")); } -TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) { +TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block loop("loop"); Block selection("selection", SpvOpBranchConditional); Block merge("merge", SpvOpReturn); - loop.setBody( - " %cond = OpSLessThan %intt %one %two\n" - " OpLoopMerge %merge %loop None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) loop.setBody(" OpLoopMerge %merge %loop None\n"); + // cannot share the same merge - selection.setBody("OpSelectionMerge %merge None\n"); + if (is_shader) selection.setBody("OpSelectionMerge %merge None\n"); - string str = header + nameOps("merge", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("merge", make_pair("func", "Main")) + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; str += loop >> selection; @@ -287,26 +308,32 @@ TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksBad) { str += "OpFunctionEnd\n"; CompileSuccessfully(str); - ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - MatchesRegex("Block .\\[merge\\] is already a merge block " - "for another header")); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("Block .\\[merge\\] is already a merge block " + "for another header")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } } -TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) { +TEST_P(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block loop("loop", SpvOpBranchConditional); Block selection("selection", SpvOpBranchConditional); Block merge("merge", SpvOpReturn); - selection.setBody( - " %cond = OpSLessThan %intt %one %two\n" - " OpSelectionMerge %merge None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) selection.setBody(" OpSelectionMerge %merge None\n"); + // cannot share the same merge - loop.setBody(" OpLoopMerge %merge %loop None\n"); + if (is_shader) loop.setBody(" OpLoopMerge %merge %loop None\n"); - string str = header + nameOps("merge", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("merge", make_pair("func", "Main")) + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> selection; str += selection >> vector({merge, loop}); @@ -315,18 +342,23 @@ TEST_F(ValidateCFG, MergeBlockTargetedByMultipleHeaderBlocksSelectionBad) { str += "OpFunctionEnd\n"; CompileSuccessfully(str); - ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); - EXPECT_THAT(getDiagnosticString(), - MatchesRegex("Block .\\[merge\\] is already a merge block " - "for another header")); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("Block .\\[merge\\] is already a merge block " + "for another header")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } } -TEST_F(ValidateCFG, BranchTargetFirstBlockBad) { +TEST_P(ValidateCFG, BranchTargetFirstBlockBad) { Block entry("entry"); Block bad("bad"); Block end("end", SpvOpReturn); - string str = header + nameOps("entry", "bad", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("entry", "bad", make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> bad; str += bad >> entry; // Cannot target entry block @@ -340,17 +372,17 @@ TEST_F(ValidateCFG, BranchTargetFirstBlockBad) { "is targeted by block .\\[bad\\]")); } -TEST_F(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { +TEST_P(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { Block entry("entry"); Block bad("bad", SpvOpBranchConditional); Block exit("exit", SpvOpReturn); - bad.setBody( - " %cond = OpSLessThan %intt %one %two\n" - " OpLoopMerge %entry %exit None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + bad.setBody(" OpLoopMerge %entry %exit None\n"); - string str = header + nameOps("entry", "bad", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("entry", "bad", make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> bad; str += bad >> vector({entry, exit}); // cannot target entry block @@ -364,19 +396,19 @@ TEST_F(ValidateCFG, BranchConditionalTrueTargetFirstBlockBad) { "is targeted by block .\\[bad\\]")); } -TEST_F(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) { +TEST_P(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) { Block entry("entry"); Block bad("bad", SpvOpBranchConditional); Block t("t"); Block merge("merge"); Block end("end", SpvOpReturn); - bad.setBody( - "%cond = OpSLessThan %intt %one %two\n" - "OpLoopMerge %merge %cont None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + bad.setBody("OpLoopMerge %merge %cont None\n"); - string str = header + nameOps("entry", "bad", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("entry", "bad", make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> bad; str += bad >> vector({t, entry}); @@ -391,7 +423,7 @@ TEST_F(ValidateCFG, BranchConditionalFalseTargetFirstBlockBad) { "is targeted by block .\\[bad\\]")); } -TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) { +TEST_P(ValidateCFG, SwitchTargetFirstBlockBad) { Block entry("entry"); Block bad("bad", SpvOpSwitch); Block block1("block1"); @@ -401,12 +433,12 @@ TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) { Block merge("merge"); Block end("end", SpvOpReturn); - bad.setBody( - "%cond = OpSLessThan %intt %one %two\n" - "OpSelectionMerge %merge None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + bad.setBody("OpSelectionMerge %merge None\n"); - string str = header + nameOps("entry", "bad", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("entry", "bad", make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> bad; str += bad >> vector({def, block1, block2, block3, entry}); @@ -425,21 +457,21 @@ TEST_F(ValidateCFG, SwitchTargetFirstBlockBad) { "is targeted by block .\\[bad\\]")); } -TEST_F(ValidateCFG, BranchToBlockInOtherFunctionBad) { +TEST_P(ValidateCFG, BranchToBlockInOtherFunctionBad) { Block entry("entry"); Block middle("middle", SpvOpBranchConditional); Block end("end", SpvOpReturn); - middle.setBody( - "%cond = OpSLessThan %intt %one %two\n" - "OpSelectionMerge %end None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + middle.setBody("OpSelectionMerge %end None\n"); Block entry2("entry2"); Block middle2("middle2"); Block end2("end2", SpvOpReturn); - string str = header + nameOps("middle2", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("middle2", make_pair("func", "Main")) + types_consts() + + "%func = OpFunction %voidt None %funct\n"; str += entry >> middle; str += middle >> vector({end, middle2}); @@ -460,7 +492,8 @@ TEST_F(ValidateCFG, BranchToBlockInOtherFunctionBad) { "defined in function .\\[Main\\]")); } -TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) { +TEST_P(ValidateCFG, HeaderDoesntDominatesMergeBad) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block head("head", SpvOpBranchConditional); Block f("f"); @@ -468,10 +501,11 @@ TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) { entry.setBody("%cond = OpSLessThan %intt %one %two\n"); - head.setBody("OpSelectionMerge %merge None\n"); + if (is_shader) head.setBody("OpSelectionMerge %merge None\n"); - string str = header + nameOps("head", "merge", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("head", "merge", make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> merge; str += head >> vector({merge, f}); @@ -479,26 +513,33 @@ TEST_F(ValidateCFG, HeaderDoesntDominatesMergeBad) { str += merge; CompileSuccessfully(str); - ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); - EXPECT_THAT( - getDiagnosticString(), - MatchesRegex("Header block .\\[head\\] doesn't dominate its merge block " - ".\\[merge\\]")); + + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("The selection construct with the selection header " + ".\\[head\\] does not dominate the merge block " + ".\\[merge\\]")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } } -TEST_F(ValidateCFG, UnreachableMerge) { +TEST_P(ValidateCFG, UnreachableMerge) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block branch("branch", SpvOpBranchConditional); Block t("t", SpvOpReturn); Block f("f", SpvOpReturn); Block merge("merge", SpvOpReturn); - branch.setBody( - " %cond = OpSLessThan %intt %one %two\n" - "OpSelectionMerge %merge None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) branch.setBody("OpSelectionMerge %merge None\n"); - string str = header + nameOps("branch", "merge", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("branch", "merge", make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> branch; str += branch >> vector({t, f}); @@ -511,19 +552,20 @@ TEST_F(ValidateCFG, UnreachableMerge) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) { +TEST_P(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block branch("branch", SpvOpBranchConditional); Block t("t", SpvOpReturn); Block f("f", SpvOpReturn); Block merge("merge", SpvOpUnreachable); - branch.setBody( - " %cond = OpSLessThan %intt %one %two\n" - "OpSelectionMerge %merge None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) branch.setBody("OpSelectionMerge %merge None\n"); - string str = header + nameOps("branch", "merge", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("branch", "merge", make_pair("func", "Main")) + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> branch; str += branch >> vector({t, f}); @@ -536,14 +578,14 @@ TEST_F(ValidateCFG, UnreachableMergeDefinedByOpUnreachable) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, UnreachableBlock) { +TEST_P(ValidateCFG, UnreachableBlock) { Block entry("entry"); Block unreachable("unreachable"); Block exit("exit", SpvOpReturn); - string str = header + + string str = header(GetParam()) + nameOps("unreachable", "exit", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> exit; str += unreachable >> exit; @@ -554,7 +596,8 @@ TEST_F(ValidateCFG, UnreachableBlock) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, UnreachableBranch) { +TEST_P(ValidateCFG, UnreachableBranch) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block unreachable("unreachable", SpvOpBranchConditional); Block unreachablechildt("unreachablechildt"); @@ -562,12 +605,11 @@ TEST_F(ValidateCFG, UnreachableBranch) { Block merge("merge"); Block exit("exit", SpvOpReturn); - unreachable.setBody( - " %cond = OpSLessThan %intt %one %two\n" - "OpSelectionMerge %merge None\n"); - string str = header + + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) unreachable.setBody("OpSelectionMerge %merge None\n"); + string str = header(GetParam()) + nameOps("unreachable", "exit", make_pair("func", "Main")) + - types_consts + "%func = OpFunction %voidt None %funct\n"; + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> exit; str += unreachable >> vector({unreachablechildt, unreachablechildf}); @@ -581,25 +623,25 @@ TEST_F(ValidateCFG, UnreachableBranch) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, EmptyFunction) { - string str = header + types_consts + +TEST_P(ValidateCFG, EmptyFunction) { + string str = header(GetParam()) + string(types_consts()) + "%func = OpFunction %voidt None %funct\n" + "OpFunctionEnd\n"; CompileSuccessfully(str); ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, SingleBlockLoop) { +TEST_P(ValidateCFG, SingleBlockLoop) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block loop("loop", SpvOpBranchConditional); Block exit("exit", SpvOpReturn); - loop.setBody( - "%cond = OpSLessThan %intt %one %two\n" - "OpLoopMerge %exit %loop None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) loop.setBody("OpLoopMerge %exit %loop None\n"); - string str = - header + types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + string(types_consts()) + + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop; str += loop >> vector({loop, exit}); @@ -610,7 +652,8 @@ TEST_F(ValidateCFG, SingleBlockLoop) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, NestedLoops) { +TEST_P(ValidateCFG, NestedLoops) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block loop1("loop1"); Block loop1_cont_break_block("loop1_cont_break_block", @@ -620,14 +663,14 @@ TEST_F(ValidateCFG, NestedLoops) { Block loop1_merge("loop1_merge"); Block exit("exit", SpvOpReturn); - loop1.setBody( - "%cond = OpSLessThan %intt %one %two\n" - "OpLoopMerge %loop1_merge %loop2 None\n"); - - loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) { + loop1.setBody("OpLoopMerge %loop1_merge %loop2 None\n"); + loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n"); + } - string str = - header + types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + nameOps("loop2", "loop2_merge") + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop1; str += loop1 >> loop1_cont_break_block; @@ -641,29 +684,33 @@ TEST_F(ValidateCFG, NestedLoops) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -TEST_F(ValidateCFG, NestedSelection) { +TEST_P(ValidateCFG, NestedSelection) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); const int N = 256; vector if_blocks; vector merge_blocks; Block inner("inner"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if_blocks.emplace_back("if0", SpvOpBranchConditional); - if_blocks[0].setBody( - "%cond = OpSLessThan %intt %one %two\n" - "OpSelectionMerge %if_merge0 None\n"); + + if (is_shader) if_blocks[0].setBody("OpSelectionMerge %if_merge0 None\n"); merge_blocks.emplace_back("if_merge0", SpvOpReturn); for (int i = 1; i < N; i++) { stringstream ss; ss << i; if_blocks.emplace_back("if" + ss.str(), SpvOpBranchConditional); - if_blocks[i].setBody("OpSelectionMerge %if_merge" + ss.str() + " None\n"); + if (is_shader) + if_blocks[i].setBody("OpSelectionMerge %if_merge" + ss.str() + " None\n"); merge_blocks.emplace_back("if_merge" + ss.str(), SpvOpBranch); } - string str = - header + types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + string(types_consts()) + + "%func = OpFunction %voidt None %funct\n"; + str += entry >> if_blocks[0]; for (int i = 0; i < N - 1; i++) { str += if_blocks[i] >> vector({if_blocks[i + 1], merge_blocks[i]}); } @@ -679,37 +726,282 @@ TEST_F(ValidateCFG, NestedSelection) { ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } -// TODO(umar): enable this test -TEST_F(ValidateCFG, DISABLED_BackEdgeBlockDoesntPostDominateContinueTargetBad) { +TEST_P(ValidateCFG, BackEdgeBlockDoesntPostDominateContinueTargetBad) { + bool is_shader = GetParam() == SpvCapabilityShader; Block entry("entry"); Block loop1("loop1", SpvOpBranchConditional); Block loop2("loop2", SpvOpBranchConditional); - Block loop2_merge("loop2_merge"); - Block loop1_merge("loop1_merge", SpvOpBranchConditional); + Block loop2_merge("loop2_merge", SpvOpBranchConditional); + Block be_block("be_block"); Block exit("exit", SpvOpReturn); - loop1.setBody( - "%cond = OpSLessThan %intt %one %two\n" - "OpLoopMerge %loop1_merge %loop2 None\n"); - - loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n"); + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) { + loop1.setBody("OpLoopMerge %exit %loop2_merge None\n"); + loop2.setBody("OpLoopMerge %loop2_merge %loop2 None\n"); + } - string str = - header + types_consts + "%func = OpFunction %voidt None %funct\n"; + string str = header(GetParam()) + + nameOps("loop1", "loop2", "be_block", "loop2_merge") + + types_consts() + "%func = OpFunction %voidt None %funct\n"; str += entry >> loop1; - str += loop1 >> vector({loop2, loop1_merge}); + str += loop1 >> vector({loop2, exit}); str += loop2 >> vector({loop2, loop2_merge}); - str += loop2_merge >> loop1_merge; - str += loop1_merge >> vector({loop1, exit}); + str += loop2_merge >> vector({be_block, exit}); + str += be_block >> loop1; str += exit; str += "OpFunctionEnd"; CompileSuccessfully(str); - ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + if (GetParam() == SpvCapabilityShader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("The continue construct with the continue target " + ".\\[loop2_merge\\] is not post dominated by the " + "back-edge block .\\[be_block\\]")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchingToNonLoopHeaderBlockBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block split("split", SpvOpBranchConditional); + Block t("t"); + Block f("f"); + Block exit("exit", SpvOpReturn); + + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) split.setBody("OpSelectionMerge %exit None\n"); + + string str = header(GetParam()) + nameOps("split", "f") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> split; + str += split >> vector({t, f}); + str += t >> exit; + str += f >> split; + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + MatchesRegex("Back-edges \\(.\\[f\\] -> .\\[split\\]\\) can only " + "be formed between a block and a loop header.")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchingToSameNonLoopHeaderBlockBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block split("split", SpvOpBranchConditional); + Block exit("exit", SpvOpReturn); + + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) split.setBody("OpSelectionMerge %exit None\n"); + + string str = header(GetParam()) + nameOps("split") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> split; + str += split >> vector({split, exit}); + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex( + "Back-edges \\(.\\[split\\] -> .\\[split\\]\\) can only be " + "formed between a block and a loop header.")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, MultipleBackEdgesToLoopHeaderBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block cont("cont", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n"); + + string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> vector({cont, merge}); + str += cont >> vector({loop, loop}); + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex( + "Loop header .\\[loop\\] targeted by multiple back-edges")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, ContinueTargetMustBePostDominatedByBackEdge) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block cheader("cheader", SpvOpBranchConditional); + Block be_block("be_block"); + Block merge("merge", SpvOpReturn); + Block exit("exit", SpvOpReturn); + + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) loop.setBody("OpLoopMerge %merge %cheader None\n"); + + string str = header(GetParam()) + nameOps("cheader", "be_block") + + types_consts() + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> vector({cheader, merge}); + str += cheader >> vector({exit, be_block}); + str += exit; // Branches out of a continue construct + str += be_block >> loop; + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("The continue construct with the continue target " + ".\\[cheader\\] is not post dominated by the " + "back-edge block .\\[be_block\\]")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchOutOfConstructToMergeBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block cont("cont", SpvOpBranchConditional); + Block merge("merge", SpvOpReturn); + + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n"); + + string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> vector({cont, merge}); + str += cont >> vector({loop, merge}); + str += merge; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("The continue construct with the continue target " + ".\\[loop\\] is not post dominated by the " + "back-edge block .\\[cont\\]")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_P(ValidateCFG, BranchOutOfConstructBad) { + bool is_shader = GetParam() == SpvCapabilityShader; + Block entry("entry"); + Block loop("loop", SpvOpBranchConditional); + Block cont("cont", SpvOpBranchConditional); + Block merge("merge"); + Block exit("exit", SpvOpReturn); + + entry.setBody("%cond = OpSLessThan %intt %one %two\n"); + if (is_shader) loop.setBody("OpLoopMerge %merge %loop None\n"); + + string str = header(GetParam()) + nameOps("cont", "loop") + types_consts() + + "%func = OpFunction %voidt None %funct\n"; + + str += entry >> loop; + str += loop >> vector({cont, merge}); + str += cont >> vector({loop, exit}); + str += merge >> exit; + str += exit; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + if (is_shader) { + ASSERT_EQ(SPV_ERROR_INVALID_CFG, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + MatchesRegex("The continue construct with the continue target " + ".\\[loop\\] is not post dominated by the " + "back-edge block .\\[cont\\]")); + } else { + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); + } +} + +TEST_F(ValidateCFG, OpSwitchToUnreachableBlock) { + Block entry("entry", SpvOpSwitch); + Block case0("case0"); + Block case1("case1"); + Block case2("case2"); + Block def("default", SpvOpUnreachable); + Block phi("phi", SpvOpReturn); + + string str = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %main "main" %id +OpExecutionMode %main LocalSize 1 1 1 +OpSource GLSL 430 +OpName %main "main" +OpDecorate %id BuiltIn GlobalInvocationId +%void = OpTypeVoid +%voidf = OpTypeFunction %void +%u32 = OpTypeInt 32 0 +%f32 = OpTypeFloat 32 +%uvec3 = OpTypeVector %u32 3 +%fvec3 = OpTypeVector %f32 3 +%uvec3ptr = OpTypePointer Input %uvec3 +%id = OpVariable %uvec3ptr Input +%one = OpConstant %u32 1 +%three = OpConstant %u32 3 +%main = OpFunction %void None %voidf +)"; + + entry.setBody( + "%idval = OpLoad %uvec3 %id\n" + "%x = OpCompositeExtract %u32 %idval 0\n" + "%selector = OpUMod %u32 %x %three\n" + "OpSelectionMerge %phi None\n"); + str += entry >> vector({def, case0, case1, case2}); + str += case1 >> phi; + str += def; + str += phi; + str += case0 >> phi; + str += case2 >> phi; + str += "OpFunctionEnd"; + + CompileSuccessfully(str); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); } /// TODO(umar): Switch instructions -/// TODO(umar): CFG branching outside of CFG construct /// TODO(umar): Nested CFG constructs -} +} /// namespace diff --git a/test/ValidateFixtures.cpp b/test/ValidateFixtures.cpp index 7467960cb7..409263617c 100644 --- a/test/ValidateFixtures.cpp +++ b/test/ValidateFixtures.cpp @@ -90,4 +90,5 @@ template class spvtest::ValidateBase< template class spvtest::ValidateBase< std::tuple, std::function>>>; +template class spvtest::ValidateBase; }