diff --git a/client.cpp b/client.cpp index a088d47..fffb38e 100755 --- a/client.cpp +++ b/client.cpp @@ -646,6 +646,32 @@ void client_group::run(void) event_base_dispatch(m_base); } +void client_group::interrupt(void) +{ + // Mark all clients as interrupted + set_all_clients_interrupted(); + // Break the event loop to stop processing + event_base_loopbreak(m_base); + // Set end time for all clients as close as possible to the loop break + finalize_all_clients(); +} + +void client_group::finalize_all_clients(void) +{ + for (std::vector::iterator i = m_clients.begin(); i != m_clients.end(); i++) { + client* c = *i; + c->set_end_time(); + } +} + +void client_group::set_all_clients_interrupted(void) +{ + for (std::vector::iterator i = m_clients.begin(); i != m_clients.end(); i++) { + client* c = *i; + c->get_stats()->set_interrupted(true); + } +} + unsigned long int client_group::get_total_bytes(void) { unsigned long int total_bytes = 0; diff --git a/client.h b/client.h index 6f599a4..5696cf4 100755 --- a/client.h +++ b/client.h @@ -210,6 +210,9 @@ class client_group { int create_clients(int count); int prepare(void); void run(void); + void interrupt(void); + void finalize_all_clients(void); + void set_all_clients_interrupted(void); void write_client_stats(const char *prefix); diff --git a/memtier_benchmark.cpp b/memtier_benchmark.cpp index f05eb2d..7ed7113 100755 --- a/memtier_benchmark.cpp +++ b/memtier_benchmark.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #ifdef USE_TLS #include @@ -65,6 +66,16 @@ static int log_level = 0; + +// Global flag for signal handling +static volatile sig_atomic_t g_interrupted = 0; + +// Signal handler for Ctrl+C +static void sigint_handler(int signum) +{ + (void)signum; // unused parameter + g_interrupted = 1; +} void benchmark_log_file_line(int level, const char *filename, unsigned int line, const char *fmt, ...) { if (level > log_level) @@ -1329,6 +1340,25 @@ run_stats run_benchmark(int run_id, benchmark_config* cfg, object_generator* obj active_threads = 0; sleep(1); + // Check for Ctrl+C interrupt + if (g_interrupted) { + // Calculate elapsed time before interrupting + unsigned long int elapsed_duration = 0; + unsigned int thread_counter = 0; + for (std::vector::iterator i = threads.begin(); i != threads.end(); i++) { + thread_counter++; + float factor = ((float)(thread_counter - 1) / thread_counter); + elapsed_duration = factor * elapsed_duration + (float)(*i)->m_cg->get_duration_usec() / thread_counter; + } + fprintf(stderr, "\n[RUN #%u] Interrupted by user (Ctrl+C) after %.1f secs, stopping threads...\n", + run_id, (float)elapsed_duration / 1000000); + // Interrupt all threads (marks clients as interrupted, breaks event loops, and finalizes stats) + for (std::vector::iterator i = threads.begin(); i != threads.end(); i++) { + (*i)->m_cg->interrupt(); + } + break; + } + unsigned long int total_ops = 0; unsigned long int total_bytes = 0; unsigned long int duration = 0; @@ -1496,6 +1526,9 @@ static void cleanup_openssl(void) int main(int argc, char *argv[]) { + // Install signal handler for Ctrl+C + signal(SIGINT, sigint_handler); + benchmark_config cfg = benchmark_config(); cfg.arbitrary_commands = new arbitrary_command_list(); diff --git a/run_stats.cpp b/run_stats.cpp index a5e38a0..1110d9a 100644 --- a/run_stats.cpp +++ b/run_stats.cpp @@ -112,6 +112,7 @@ inline timeval timeval_factorial_average(timeval a, timeval b, unsigned int weig run_stats::run_stats(benchmark_config *config) : m_config(config), + m_interrupted(false), m_totals(), m_cur_stats(0) { @@ -792,6 +793,11 @@ void run_stats::merge(const run_stats& other, int iteration) m_start_time = timeval_factorial_average( m_start_time, other.m_start_time, iteration ); m_end_time = timeval_factorial_average( m_end_time, other.m_end_time, iteration ); + // If any run was interrupted, mark the merged result as interrupted + if (other.m_interrupted) { + m_interrupted = true; + } + // aggregate the one_second_stats vectors. this is not efficient // but it's not really important (small numbers, not realtime) for (std::list::const_iterator other_i = other.m_stats.begin(); @@ -1221,6 +1227,7 @@ void run_stats::print_json(json_handler *jsonhandler, arbitrary_command_list& co jsonhandler->write_obj("Finish time","%lld", end_time_ms); jsonhandler->write_obj("Total duration","%lld", end_time_ms-start_time_ms); jsonhandler->write_obj("Time unit","\"%s\"","MILLISECONDS"); + jsonhandler->write_obj("Interrupted","\"%s\"", m_interrupted ? "true" : "false"); jsonhandler->close_nesting(); } std::vector timestamps = get_one_sec_cmd_stats_timestamp(); diff --git a/run_stats.h b/run_stats.h index 14796ea..2d137cf 100644 --- a/run_stats.h +++ b/run_stats.h @@ -93,6 +93,7 @@ class run_stats { struct timeval m_start_time; struct timeval m_end_time; + bool m_interrupted; totals m_totals; @@ -122,6 +123,8 @@ class run_stats { void setup_arbitrary_commands(size_t n_arbitrary_commands); void set_start_time(struct timeval* start_time); void set_end_time(struct timeval* end_time); + void set_interrupted(bool interrupted) { m_interrupted = interrupted; } + bool get_interrupted() const { return m_interrupted; } void update_get_op(struct timeval* ts, unsigned int bytes_rx, unsigned int bytes_tx, unsigned int latency, unsigned int hits, unsigned int misses); void update_set_op(struct timeval* ts, unsigned int bytes_rx, unsigned int bytes_tx, unsigned int latency); diff --git a/tests/tests_oss_simple_flow.py b/tests/tests_oss_simple_flow.py index c62ffec..bbb450f 100644 --- a/tests/tests_oss_simple_flow.py +++ b/tests/tests_oss_simple_flow.py @@ -1,5 +1,9 @@ import tempfile import json +import time +import signal +import subprocess +import os from include import * from mb import Benchmark, RunConfig @@ -907,3 +911,88 @@ def test_uri_invalid_database(env): # benchmark.run() should return False for invalid database number memtier_ok = benchmark.run() env.assertFalse(memtier_ok) + + +def test_interrupt_signal_handling(env): + """Test that Ctrl+C (SIGINT) properly stops the benchmark and outputs correct statistics""" + # Use a large number of requests so the test doesn't finish before we interrupt it + benchmark_specs = {"name": env.testName, "args": ['--requests=1000000', '--hide-histogram']} + addTLSArgs(benchmark_specs, env) + config = get_default_memtier_config(threads=4, clients=50, requests=1000000) + master_nodes_list = env.getMasterNodesList() + + add_required_env_arguments(benchmark_specs, config, env, master_nodes_list) + + # Create a temporary directory + test_dir = tempfile.mkdtemp() + config = RunConfig(test_dir, env.testName, config, {}) + ensure_clean_benchmark_folder(config.results_dir) + + benchmark = Benchmark.from_json(config, benchmark_specs) + + # Start the benchmark process manually so we can send SIGINT + import logging + logging.debug(' Command: %s', ' '.join(benchmark.args)) + + stderr_file = open(os.path.join(config.results_dir, 'mb.stderr'), 'wb') + process = subprocess.Popen( + stdin=None, stdout=subprocess.PIPE, stderr=stderr_file, + executable=benchmark.binary, args=benchmark.args) + + # Wait 3 seconds then send SIGINT + time.sleep(3) + process.send_signal(signal.SIGINT) + + # Wait for process to finish + _stdout, _ = process.communicate() + stderr_file.close() + + # Write stdout to file + benchmark.write_file('mb.stdout', _stdout) + + # Read stderr to check for interrupt message + with open(os.path.join(config.results_dir, 'mb.stderr'), 'r') as stderr: + stderr_content = stderr.read() + # Check that the interrupt message is present and shows elapsed time + env.assertTrue("Interrupted by user (Ctrl+C) after" in stderr_content) + env.assertTrue("secs, stopping threads..." in stderr_content) + + # Check JSON output + json_filename = '{0}/mb.json'.format(config.results_dir) + env.assertTrue(os.path.isfile(json_filename)) + + with open(json_filename) as results_json: + results_dict = json.load(results_json) + + # Check that Runtime section exists and has Interrupted flag + env.assertTrue("ALL STATS" in results_dict) + env.assertTrue("Runtime" in results_dict["ALL STATS"]) + runtime = results_dict["ALL STATS"]["Runtime"] + + # Verify interrupted flag is set to "true" + env.assertTrue("Interrupted" in runtime) + env.assertEqual(runtime["Interrupted"], "true") + + # Verify duration is reasonable (should be around 3 seconds, give or take) + env.assertTrue("Total duration" in runtime) + duration_ms = runtime["Total duration"] + env.assertTrue(duration_ms >= 2000) # At least 2 seconds + env.assertTrue(duration_ms <= 5000) # At most 5 seconds + + # Verify that throughput metrics are NOT zero + totals_metrics = results_dict["ALL STATS"]["Totals"] + + # Check ops/sec is not zero + env.assertTrue("Ops/sec" in totals_metrics) + total_ops_sec = totals_metrics["Ops/sec"] + env.assertTrue(total_ops_sec > 0) + + # Check latency metrics are not zero + env.assertTrue("Latency" in totals_metrics) + total_latency = totals_metrics["Latency"] + env.assertTrue(total_latency > 0) + + # Check that we actually processed some operations + env.assertTrue("Count" in totals_metrics) + total_count = totals_metrics["Count"] + env.assertTrue(total_count > 0)