From 1ef95ae7b9e2c765abbbc6af6dff6c29c1fa4d13 Mon Sep 17 00:00:00 2001 From: Ashley Nelson Date: Tue, 27 Feb 2024 17:09:30 -0800 Subject: [PATCH] [Outlining] Fixes break reconstruction (#6352) Adds new visitBreakWithType and visitSwitchWithType functions to the IRBuilder API. These functions work around an assumption in IRBuilder that the module is being traversed in the fully nested format, i.e., that the destination scope of a break or switch has been visited before visiting the break or switch. Instead, the type of the destination scope is passed to IRBuilder. --- src/passes/Outlining.cpp | 14 ++++- src/wasm-ir-builder.h | 20 ++++++- src/wasm/wasm-ir-builder.cpp | 46 ++++++++++++++- test/lit/passes/outlining.wast | 100 +++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 6 deletions(-) diff --git a/src/passes/Outlining.cpp b/src/passes/Outlining.cpp index 345b3f9a2da..68c3d039707 100644 --- a/src/passes/Outlining.cpp +++ b/src/passes/Outlining.cpp @@ -132,7 +132,19 @@ struct ReconstructStringifyWalker : state == NotInSeq ? &existingBuilder : nullptr; if (builder) { - ASSERT_OK(builder->visit(curr)); + if (auto* expr = curr->dynCast()) { + Type type = expr->value ? expr->value->type : Type::none; + ASSERT_OK(builder->visitBreakWithType(expr, type)); + } else if (auto* expr = curr->dynCast()) { + Type type = expr->value ? expr->value->type : Type::none; + ASSERT_OK(builder->visitSwitchWithType(expr, type)); + } else { + // Assert ensures new unhandled branch instructions + // will quickly cause an error. Serves as a reminder to + // implement a new special-case visit*WithType. + assert(curr->is() || !Properties::isBranch(curr)); + ASSERT_OK(builder->visit(curr)); + } } DBG(printVisitExpression(curr)); diff --git a/src/wasm-ir-builder.h b/src/wasm-ir-builder.h index b37b352e3f6..d31a532a74c 100644 --- a/src/wasm-ir-builder.h +++ b/src/wasm-ir-builder.h @@ -224,10 +224,26 @@ class IRBuilder : public UnifiedExpressionVisitor> { [[nodiscard]] Result<> visitStructNew(StructNew*); [[nodiscard]] Result<> visitArrayNew(ArrayNew*); [[nodiscard]] Result<> visitArrayNewFixed(ArrayNewFixed*); + // Used to visit break exprs when traversing the module in the fully nested + // format. Break label destinations are assumed to have already been visited, + // with a corresponding push onto the scope stack. As a result, an error will + // return if a corresponding scope is not found for the break. [[nodiscard]] Result<> visitBreak(Break*, std::optional label = std::nullopt); + // Used to visit break nodes when traversing a single block without its + // context. The type indicates how many values the break carries to its + // destination. + [[nodiscard]] Result<> visitBreakWithType(Break*, Type); [[nodiscard]] Result<> + // Used to visit switch exprs when traversing the module in the fully nested + // format. Switch label destinations are assumed to have already been visited, + // with a corresponding push onto the scope stack. As a result, an error will + // return if a corresponding scope is not found for the switch. visitSwitch(Switch*, std::optional defaultLabel = std::nullopt); + // Used to visit switch nodes when traversing a single block without its + // context. The type indicates how many values the switch carries to its + // destination. + [[nodiscard]] Result<> visitSwitchWithType(Switch*, Type); [[nodiscard]] Result<> visitCall(Call*); [[nodiscard]] Result<> visitCallIndirect(CallIndirect*); [[nodiscard]] Result<> visitCallRef(CallRef*); @@ -535,8 +551,8 @@ class IRBuilder : public UnifiedExpressionVisitor> { [[nodiscard]] Result<> packageHoistedValue(const HoistedVal&, size_t sizeHint = 1); - [[nodiscard]] Result getBranchValue(Name labelName, - std::optional label); + [[nodiscard]] Result + getBranchValue(Expression* curr, Name labelName, std::optional label); void dump(); }; diff --git a/src/wasm/wasm-ir-builder.cpp b/src/wasm/wasm-ir-builder.cpp index b1bf8c8552c..8d87d0f3335 100644 --- a/src/wasm/wasm-ir-builder.cpp +++ b/src/wasm/wasm-ir-builder.cpp @@ -419,8 +419,14 @@ Result<> IRBuilder::visitArrayNewFixed(ArrayNewFixed* curr) { return Ok{}; } -Result IRBuilder::getBranchValue(Name labelName, +Result IRBuilder::getBranchValue(Expression* curr, + Name labelName, std::optional label) { + // As new branch instructions are added, one of the existing branch visit* + // functions is likely to be copied, along with its call to getBranchValue(). + // This assert serves as a reminder to also add an implementation of + // visit*WithType() for new branch instructions. + assert(curr->is() || curr->is()); if (!label) { auto index = getLabelIndex(labelName); CHECK_ERR(index); @@ -440,23 +446,57 @@ Result<> IRBuilder::visitBreak(Break* curr, std::optional label) { CHECK_ERR(cond); curr->condition = *cond; } - auto value = getBranchValue(curr->name, label); + auto value = getBranchValue(curr, curr->name, label); CHECK_ERR(value); curr->value = *value; return Ok{}; } +Result<> IRBuilder::visitBreakWithType(Break* curr, Type type) { + if (curr->condition) { + auto cond = pop(); + CHECK_ERR(cond); + curr->condition = *cond; + } + if (type == Type::none) { + curr->value = nullptr; + } else { + auto value = pop(type.size()); + CHECK_ERR(value) + curr->value = *value; + } + curr->finalize(); + push(curr); + return Ok{}; +} + Result<> IRBuilder::visitSwitch(Switch* curr, std::optional defaultLabel) { auto cond = pop(); CHECK_ERR(cond); curr->condition = *cond; - auto value = getBranchValue(curr->default_, defaultLabel); + auto value = getBranchValue(curr, curr->default_, defaultLabel); CHECK_ERR(value); curr->value = *value; return Ok{}; } +Result<> IRBuilder::visitSwitchWithType(Switch* curr, Type type) { + auto cond = pop(); + CHECK_ERR(cond); + curr->condition = *cond; + if (type == Type::none) { + curr->value = nullptr; + } else { + auto value = pop(type.size()); + CHECK_ERR(value) + curr->value = *value; + } + curr->finalize(); + push(curr); + return Ok{}; +} + Result<> IRBuilder::visitCall(Call* curr) { auto numArgs = wasm.getFunction(curr->target)->getNumParams(); curr->operands.resize(numArgs); diff --git a/test/lit/passes/outlining.wast b/test/lit/passes/outlining.wast index befce7513de..76f305db721 100644 --- a/test/lit/passes/outlining.wast +++ b/test/lit/passes/outlining.wast @@ -614,6 +614,106 @@ ) ) +;; Tests branch with condition is reconstructed without error. +(module + ;; CHECK: (type $0 (func)) + + ;; CHECK: (func $outline$ (type $0) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + + ;; CHECK: (func $a (type $0) + ;; CHECK-NEXT: (block $label1 + ;; CHECK-NEXT: (call $outline$) + ;; CHECK-NEXT: (loop $loop-in + ;; CHECK-NEXT: (br $label1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (call $outline$) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $a + (block $label1 + (drop + (i32.const 2) + ) + (drop + (i32.const 1) + ) + (loop + (br $label1) + ) + (drop + (i32.const 2) + ) + (drop + (i32.const 1) + ) + ) + ) +) + +;; Tests br_table instruction is reconstructed without error. +(module + ;; CHECK: (type $0 (func)) + + ;; CHECK: (type $1 (func (param i32) (result i32))) + + ;; CHECK: (func $outline$ (type $0) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + + ;; CHECK: (func $a (type $1) (param $0 i32) (result i32) + ;; CHECK-NEXT: (call $outline$) + ;; CHECK-NEXT: (block $block + ;; CHECK-NEXT: (block $block0 + ;; CHECK-NEXT: (br_table $block $block0 + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (i32.const 21) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (return + ;; CHECK-NEXT: (i32.const 20) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (call $outline$) + ;; CHECK-NEXT: (i32.const 22) + ;; CHECK-NEXT: ) + (func $a (param i32) (result i32) + (drop + (i32.const 2) + ) + (drop + (i32.const 1) + ) + (block + (block + (br_table 1 0 (local.get $0)) + (return (i32.const 21)) + ) + (return (i32.const 20)) + ) + (drop + (i32.const 2) + ) + (drop + (i32.const 1) + ) + (i32.const 22) + ) +) + ;; Tests return instructions are correctly filtered from being outlined. (module ;; CHECK: (type $0 (func (result i32)))