diff --git a/phlex/core/framework_graph.cpp b/phlex/core/framework_graph.cpp index 160e13c15..179269211 100644 --- a/phlex/core/framework_graph.cpp +++ b/phlex/core/framework_graph.cpp @@ -23,6 +23,9 @@ namespace phlex::experimental { layer_sentry::~layer_sentry() { + // To consider: We may want to skip the following logic if the framework prematurely + // needs to shut down. Keeping it enabled allows in-flight folds to + // complete. However, in some cases it may not be desirable to do this. auto flush_result = counters_.extract(store_->id()); auto flush_store = store_->make_flush(); if (not flush_result.empty()) { @@ -65,7 +68,16 @@ namespace phlex::experimental { eoms_.push(nullptr); } - framework_graph::~framework_graph() = default; + framework_graph::~framework_graph() + { + if (shutdown_on_error_) { + // When in an error state, we need to sanely pop the layer stack and wait for any tasks to finish. + while (!layers_.empty()) { + layers_.pop(); + } + graph_.wait_for_all(); + } + } std::size_t framework_graph::execution_counts(std::string const& node_name) const { @@ -82,10 +94,14 @@ namespace phlex::experimental { finalize(); run(); } catch (std::exception const& e) { + driver_.stop(); spdlog::error(e.what()); + shutdown_on_error_ = true; throw; } catch (...) { + driver_.stop(); spdlog::error("Unknown exception during graph execution"); + shutdown_on_error_ = true; throw; } diff --git a/phlex/core/framework_graph.hpp b/phlex/core/framework_graph.hpp index b129755d1..ef824e32f 100644 --- a/phlex/core/framework_graph.hpp +++ b/phlex/core/framework_graph.hpp @@ -186,7 +186,7 @@ namespace phlex::experimental { std::queue pending_stores_; flush_counters counters_; std::stack layers_; - bool shutdown_{false}; + bool shutdown_on_error_{false}; }; } diff --git a/phlex/core/fwd.hpp b/phlex/core/fwd.hpp index 394717b11..b4e191fda 100644 --- a/phlex/core/fwd.hpp +++ b/phlex/core/fwd.hpp @@ -2,7 +2,6 @@ #define PHLEX_CORE_FWD_HPP #include "phlex/model/fwd.hpp" -#include "phlex/utilities/async_driver.hpp" #include @@ -20,10 +19,6 @@ namespace phlex::experimental { using end_of_message_ptr = std::shared_ptr; } -namespace phlex { - using framework_driver = experimental::async_driver; -} - #endif // PHLEX_CORE_FWD_HPP // Local Variables: diff --git a/phlex/driver.hpp b/phlex/driver.hpp index 51625abf6..e190180fa 100644 --- a/phlex/driver.hpp +++ b/phlex/driver.hpp @@ -6,10 +6,15 @@ #include "phlex/configuration.hpp" #include "phlex/core/fwd.hpp" #include "phlex/model/product_store.hpp" +#include "phlex/utilities/async_driver.hpp" #include #include +namespace phlex { + using framework_driver = experimental::async_driver; +} + namespace phlex::experimental::detail { // See note below. diff --git a/phlex/utilities/async_driver.hpp b/phlex/utilities/async_driver.hpp index 1f3847a07..46482cb76 100644 --- a/phlex/utilities/async_driver.hpp +++ b/phlex/utilities/async_driver.hpp @@ -50,12 +50,23 @@ namespace phlex::experimental { return std::exchange(current_, std::nullopt); } + void stop() + { + // API that should only be called by the framework_graph + gear_ = states::park; + cv_.notify_one(); + } + void yield(RT rt) { std::unique_lock lock{mutex_}; current_ = std::make_optional(std::move(rt)); cv_.notify_one(); cv_.wait(lock); + if (gear_ == states::park) { + // Can only be in park at this point if the framework needs to prematurely shut down + throw std::runtime_error("Framework shutdown"); + } } private: diff --git a/test/framework_graph.cpp b/test/framework_graph.cpp index beb879ee8..5703240fb 100644 --- a/test/framework_graph.cpp +++ b/test/framework_graph.cpp @@ -1,4 +1,5 @@ #include "phlex/core/framework_graph.hpp" +#include "phlex/utilities/max_allowed_parallelism.hpp" #include "plugins/layer_generator.hpp" #include "catch2/catch_test_macros.hpp" @@ -39,3 +40,47 @@ TEST_CASE("Make progress with one thread", "[graph]") CHECK(g.execution_counts("provide_number") == 1000); CHECK(g.execution_counts("observe_number") == 1000); } + +TEST_CASE("Stop driver when workflow throws exception", "[graph]") +{ + experimental::layer_generator gen; + gen.add_layer("spill", {"job", 1000}); + + experimental::framework_graph g{driver_for_test(gen)}; + g.provide( + "throw_exception", + [](data_cell_index const&) -> unsigned int { + throw std::runtime_error("Error to stop driver"); + }, + concurrency::unlimited) + .output_product("number"_in("spill")); + + // Must have at least one downstream node that requires something of the + // provider...otherwise provider will not be executed. + g.observe( + "downstream_of_exception", [](unsigned int) {}, concurrency::unlimited) + .input_family("number"_in("spill")); + + CHECK_THROWS(g.execute()); + + // There are N + 1 potential existing threads for a framework job, where N corresponds + // to the number configured by the user, and 1 corresponds to the separate std::jthread + // created by the async_driver. Each "pull" from the async_driver happens in a + // serialized way. However, once an index has been pulled from the async_driver by the + // flow graph, that index is sent to downstream nodes for further processing. + // + // The first node that processes that index is a provider that immediately throws an + // exception. This places the framework graph in an error state, where the async_driver + // is short-circuited from doing further processing. + // + // We make the assumption that one of those threads will trigger the exception and the + // remaining threads must be permitted to complete. + CHECK(gen.emitted_cells("/job/spill") <= + static_cast(experimental::max_allowed_parallelism::active_value() + 1)); + + // A node has not "executed" until it has returned successfully. For that reason, + // neither the "throw_exception" provider nor the "downstream_of_exception" observer + // will have executed. + CHECK(g.execution_counts("throw_exception") == 0ull); + CHECK(g.execution_counts("downstream_of_exception") == 0ull); +}