Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,32 @@ void client_group::run(void)
event_base_dispatch(m_base);
}

void client_group::interrupt(void)
{
// Mark all clients as interrupted
set_all_clients_interrupted();
// Break the event loop to stop processing
event_base_loopbreak(m_base);
// Set end time for all clients as close as possible to the loop break
finalize_all_clients();
}

void client_group::finalize_all_clients(void)
{
for (std::vector<client*>::iterator i = m_clients.begin(); i != m_clients.end(); i++) {
client* c = *i;
c->set_end_time();
}
}

void client_group::set_all_clients_interrupted(void)
{
for (std::vector<client*>::iterator i = m_clients.begin(); i != m_clients.end(); i++) {
client* c = *i;
c->get_stats()->set_interrupted(true);
}
}

unsigned long int client_group::get_total_bytes(void)
{
unsigned long int total_bytes = 0;
Expand Down
3 changes: 3 additions & 0 deletions client.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ class client_group {
int create_clients(int count);
int prepare(void);
void run(void);
void interrupt(void);
void finalize_all_clients(void);
void set_all_clients_interrupted(void);

void write_client_stats(const char *prefix);

Expand Down
33 changes: 33 additions & 0 deletions memtier_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <errno.h>
#include <sys/time.h>
#include <sys/resource.h>
#include <signal.h>

#ifdef USE_TLS
#include <openssl/crypto.h>
Expand Down Expand Up @@ -65,6 +66,16 @@


static int log_level = 0;

// Global flag for signal handling
static volatile sig_atomic_t g_interrupted = 0;

// Signal handler for Ctrl+C
static void sigint_handler(int signum)
{
(void)signum; // unused parameter
g_interrupted = 1;
}
void benchmark_log_file_line(int level, const char *filename, unsigned int line, const char *fmt, ...)
{
if (level > log_level)
Expand Down Expand Up @@ -1329,6 +1340,25 @@ run_stats run_benchmark(int run_id, benchmark_config* cfg, object_generator* obj
active_threads = 0;
sleep(1);

// Check for Ctrl+C interrupt
if (g_interrupted) {
// Calculate elapsed time before interrupting
unsigned long int elapsed_duration = 0;
unsigned int thread_counter = 0;
for (std::vector<cg_thread*>::iterator i = threads.begin(); i != threads.end(); i++) {
thread_counter++;
float factor = ((float)(thread_counter - 1) / thread_counter);
elapsed_duration = factor * elapsed_duration + (float)(*i)->m_cg->get_duration_usec() / thread_counter;
}
fprintf(stderr, "\n[RUN #%u] Interrupted by user (Ctrl+C) after %.1f secs, stopping threads...\n",
run_id, (float)elapsed_duration / 1000000);
// Interrupt all threads (marks clients as interrupted, breaks event loops, and finalizes stats)
for (std::vector<cg_thread*>::iterator i = threads.begin(); i != threads.end(); i++) {
(*i)->m_cg->interrupt();
}
break;
}

unsigned long int total_ops = 0;
unsigned long int total_bytes = 0;
unsigned long int duration = 0;
Expand Down Expand Up @@ -1496,6 +1526,9 @@ static void cleanup_openssl(void)

int main(int argc, char *argv[])
{
// Install signal handler for Ctrl+C
signal(SIGINT, sigint_handler);

benchmark_config cfg = benchmark_config();
cfg.arbitrary_commands = new arbitrary_command_list();

Expand Down
7 changes: 7 additions & 0 deletions run_stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ inline timeval timeval_factorial_average(timeval a, timeval b, unsigned int weig

run_stats::run_stats(benchmark_config *config) :
m_config(config),
m_interrupted(false),
m_totals(),
m_cur_stats(0)
{
Expand Down Expand Up @@ -792,6 +793,11 @@ void run_stats::merge(const run_stats& other, int iteration)
m_start_time = timeval_factorial_average( m_start_time, other.m_start_time, iteration );
m_end_time = timeval_factorial_average( m_end_time, other.m_end_time, iteration );

// If any run was interrupted, mark the merged result as interrupted
if (other.m_interrupted) {
m_interrupted = true;
}

// aggregate the one_second_stats vectors. this is not efficient
// but it's not really important (small numbers, not realtime)
for (std::list<one_second_stats>::const_iterator other_i = other.m_stats.begin();
Expand Down Expand Up @@ -1221,6 +1227,7 @@ void run_stats::print_json(json_handler *jsonhandler, arbitrary_command_list& co
jsonhandler->write_obj("Finish time","%lld", end_time_ms);
jsonhandler->write_obj("Total duration","%lld", end_time_ms-start_time_ms);
jsonhandler->write_obj("Time unit","\"%s\"","MILLISECONDS");
jsonhandler->write_obj("Interrupted","\"%s\"", m_interrupted ? "true" : "false");
jsonhandler->close_nesting();
}
std::vector<unsigned int> timestamps = get_one_sec_cmd_stats_timestamp();
Expand Down
3 changes: 3 additions & 0 deletions run_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class run_stats {

struct timeval m_start_time;
struct timeval m_end_time;
bool m_interrupted;

totals m_totals;

Expand Down Expand Up @@ -122,6 +123,8 @@ class run_stats {
void setup_arbitrary_commands(size_t n_arbitrary_commands);
void set_start_time(struct timeval* start_time);
void set_end_time(struct timeval* end_time);
void set_interrupted(bool interrupted) { m_interrupted = interrupted; }
bool get_interrupted() const { return m_interrupted; }

void update_get_op(struct timeval* ts, unsigned int bytes_rx, unsigned int bytes_tx, unsigned int latency, unsigned int hits, unsigned int misses);
void update_set_op(struct timeval* ts, unsigned int bytes_rx, unsigned int bytes_tx, unsigned int latency);
Expand Down
89 changes: 89 additions & 0 deletions tests/tests_oss_simple_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import tempfile
import json
import time
import signal
import subprocess
import os
from include import *
from mb import Benchmark, RunConfig

Expand Down Expand Up @@ -907,3 +911,88 @@ def test_uri_invalid_database(env):
# benchmark.run() should return False for invalid database number
memtier_ok = benchmark.run()
env.assertFalse(memtier_ok)


def test_interrupt_signal_handling(env):
"""Test that Ctrl+C (SIGINT) properly stops the benchmark and outputs correct statistics"""
# Use a large number of requests so the test doesn't finish before we interrupt it
benchmark_specs = {"name": env.testName, "args": ['--requests=1000000', '--hide-histogram']}
addTLSArgs(benchmark_specs, env)
config = get_default_memtier_config(threads=4, clients=50, requests=1000000)
master_nodes_list = env.getMasterNodesList()

add_required_env_arguments(benchmark_specs, config, env, master_nodes_list)

# Create a temporary directory
test_dir = tempfile.mkdtemp()
config = RunConfig(test_dir, env.testName, config, {})
ensure_clean_benchmark_folder(config.results_dir)

benchmark = Benchmark.from_json(config, benchmark_specs)

# Start the benchmark process manually so we can send SIGINT
import logging
logging.debug(' Command: %s', ' '.join(benchmark.args))

stderr_file = open(os.path.join(config.results_dir, 'mb.stderr'), 'wb')
process = subprocess.Popen(
stdin=None, stdout=subprocess.PIPE, stderr=stderr_file,
executable=benchmark.binary, args=benchmark.args)

# Wait 3 seconds then send SIGINT
time.sleep(3)
process.send_signal(signal.SIGINT)

# Wait for process to finish
_stdout, _ = process.communicate()
stderr_file.close()

# Write stdout to file
benchmark.write_file('mb.stdout', _stdout)

# Read stderr to check for interrupt message
with open(os.path.join(config.results_dir, 'mb.stderr'), 'r') as stderr:
stderr_content = stderr.read()
# Check that the interrupt message is present and shows elapsed time
env.assertTrue("Interrupted by user (Ctrl+C) after" in stderr_content)
env.assertTrue("secs, stopping threads..." in stderr_content)

# Check JSON output
json_filename = '{0}/mb.json'.format(config.results_dir)
env.assertTrue(os.path.isfile(json_filename))

with open(json_filename) as results_json:
results_dict = json.load(results_json)

# Check that Runtime section exists and has Interrupted flag
env.assertTrue("ALL STATS" in results_dict)
env.assertTrue("Runtime" in results_dict["ALL STATS"])
runtime = results_dict["ALL STATS"]["Runtime"]

# Verify interrupted flag is set to "true"
env.assertTrue("Interrupted" in runtime)
env.assertEqual(runtime["Interrupted"], "true")

# Verify duration is reasonable (should be around 3 seconds, give or take)
env.assertTrue("Total duration" in runtime)
duration_ms = runtime["Total duration"]
env.assertTrue(duration_ms >= 2000) # At least 2 seconds
env.assertTrue(duration_ms <= 5000) # At most 5 seconds

# Verify that throughput metrics are NOT zero
totals_metrics = results_dict["ALL STATS"]["Totals"]

# Check ops/sec is not zero
env.assertTrue("Ops/sec" in totals_metrics)
total_ops_sec = totals_metrics["Ops/sec"]
env.assertTrue(total_ops_sec > 0)

# Check latency metrics are not zero
env.assertTrue("Latency" in totals_metrics)
total_latency = totals_metrics["Latency"]
env.assertTrue(total_latency > 0)

# Check that we actually processed some operations
env.assertTrue("Count" in totals_metrics)
total_count = totals_metrics["Count"]
env.assertTrue(total_count > 0)
Loading