Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add callbacks to highspy, use std::function #1447

Merged
merged 14 commits into from
Oct 31, 2023
Merged
4 changes: 2 additions & 2 deletions check/TestCAPI.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ const HighsInt dev_run = 0;
const double double_equal_tolerance = 1e-5;

static void userCallback(const int callback_type, const char* message,
const struct HighsCallbackDataOut* data_out,
struct HighsCallbackDataIn* data_in,
const HighsCallbackDataOut* data_out,
HighsCallbackDataIn* data_in,
void* user_callback_data) {
// Extract the double value pointed to from void* user_callback_data
const double local_callback_data = user_callback_data == NULL ? -1 : *(double*)user_callback_data;
Expand Down
175 changes: 89 additions & 86 deletions check/TestCallbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "HCheckConfig.h"
#include "Highs.h"
#include "catch.hpp"
#include "lp_data/HighsCallback.h"

const bool dev_run = false;

Expand All @@ -26,98 +27,100 @@ using std::strncmp;
using std::strstr;

// Callback that saves message for comparison
static void myLogCallback(const int callback_type, const char* message,
HighsCallbackFunctionType myLogCallback =
[](int callback_type, const std::string& message,
const HighsCallbackDataOut* data_out, HighsCallbackDataIn* data_in,
void* user_callback_data) { strcpy(printed_log, message.c_str()); };

HighsCallbackFunctionType userInterruptCallback =
[](int callback_type, const std::string& message,
const HighsCallbackDataOut* data_out, HighsCallbackDataIn* data_in,
void* user_callback_data) {
// Extract local_callback_data from user_callback_data unless it
// is nullptr
if (callback_type == kCallbackMipImprovingSolution) {
// Use local_callback_data to maintain the objective value from
// the previous callback
assert(user_callback_data);
// Extract the double value pointed to from void* user_callback_data
const double local_callback_data = *(double*)user_callback_data;
if (dev_run)
printf(
"userCallback(type %2d; data %11.4g): %s with objective %g and "
"solution[0] = %g\n",
callback_type, local_callback_data, message.c_str(),
data_out->objective_function_value, data_out->mip_solution[0]);
REQUIRE(local_callback_data >= data_out->objective_function_value);
// Update the double value pointed to from void* user_callback_data
*(double*)user_callback_data = data_out->objective_function_value;
} else {
const int local_callback_data =
user_callback_data ? static_cast<int>(reinterpret_cast<intptr_t>(
user_callback_data))
: kUserCallbackNoData;
if (user_callback_data) {
REQUIRE(local_callback_data == kUserCallbackData);
} else {
REQUIRE(local_callback_data == kUserCallbackNoData);
}
if (callback_type == kCallbackLogging) {
if (dev_run)
printf("userInterruptCallback(type %2d; data %2d): %s",
callback_type, local_callback_data, message.c_str());
} else if (callback_type == kCallbackSimplexInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with iteration "
"count = "
"%d\n",
callback_type, local_callback_data, message.c_str(),
int(data_out->simplex_iteration_count));
data_in->user_interrupt = data_out->simplex_iteration_count >
adlittle_simplex_iteration_limit;
} else if (callback_type == kCallbackIpmInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with iteration "
"count = "
"%d\n",
callback_type, local_callback_data, message.c_str(),
int(data_out->ipm_iteration_count));
data_in->user_interrupt =
data_out->ipm_iteration_count > adlittle_ipm_iteration_limit;
} else if (callback_type == kCallbackMipInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with Bounds "
"(%11.4g, %11.4g); Gap = %11.4g; Objective = "
"%g\n",
callback_type, local_callback_data, message.c_str(),
data_out->mip_dual_bound, data_out->mip_primal_bound,
data_out->mip_gap, data_out->objective_function_value);
data_in->user_interrupt =
data_out->objective_function_value < egout_objective_target;
}
}
};

std::function<void(int, const std::string&, const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*)>
userDataCallback = [](int callback_type, const std::string& message,
const HighsCallbackDataOut* data_out,
HighsCallbackDataIn* data_in,
void* user_callback_data) {
strcpy(printed_log, message);
}

static void userInterruptCallback(const int callback_type, const char* message,
const HighsCallbackDataOut* data_out,
HighsCallbackDataIn* data_in,
void* user_callback_data) {
// Extract local_callback_data from user_callback_data unless it
// is nullptr
if (callback_type == kCallbackMipImprovingSolution) {
// Use local_callback_data to maintain the objective value from
// the previous callback
assert(user_callback_data);
// Extract the double value pointed to from void* user_callback_data
const double local_callback_data = *(double*)user_callback_data;
if (dev_run)
printf(
"userCallback(type %2d; data %11.4g): %s with objective %g and "
"solution[0] = %g\n",
callback_type, local_callback_data, message,
data_out->objective_function_value, data_out->mip_solution[0]);
REQUIRE(local_callback_data >= data_out->objective_function_value);
// Update the double value pointed to from void* user_callback_data
*(double*)user_callback_data = data_out->objective_function_value;
} else {
const int local_callback_data =
user_callback_data
? static_cast<int>(reinterpret_cast<intptr_t>(user_callback_data))
: kUserCallbackNoData;
if (user_callback_data) {
REQUIRE(local_callback_data == kUserCallbackData);
} else {
REQUIRE(local_callback_data == kUserCallbackNoData);
}
if (callback_type == kCallbackLogging) {
if (dev_run)
printf("userInterruptCallback(type %2d; data %2d): %s", callback_type,
local_callback_data, message);
} else if (callback_type == kCallbackSimplexInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with iteration "
"count = "
"%d\n",
callback_type, local_callback_data, message,
int(data_out->simplex_iteration_count));
data_in->user_interrupt =
data_out->simplex_iteration_count > adlittle_simplex_iteration_limit;
} else if (callback_type == kCallbackIpmInterrupt) {
assert(callback_type == kCallbackMipInterrupt ||
callback_type == kCallbackMipLogging ||
callback_type == kCallbackMipImprovingSolution);
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with iteration "
"count = "
"%d\n",
callback_type, local_callback_data, message,
int(data_out->ipm_iteration_count));
data_in->user_interrupt =
data_out->ipm_iteration_count > adlittle_ipm_iteration_limit;
} else if (callback_type == kCallbackMipInterrupt) {
if (dev_run)
printf(
"userInterruptCallback(type %2d; data %2d): %s with Bounds "
"(%11.4g, %11.4g); Gap = %11.4g; Objective = "
"%g\n",
callback_type, local_callback_data, message,
"userDataCallback: Node count = %" PRId64
"; Time = %6.2f; "
"Bounds (%11.4g, %11.4g); Gap = %11.4g; Objective = %11.4g: %s\n",
data_out->mip_node_count, data_out->running_time,
data_out->mip_dual_bound, data_out->mip_primal_bound,
data_out->mip_gap, data_out->objective_function_value);
data_in->user_interrupt =
data_out->objective_function_value < egout_objective_target;
}
}
}

static void userDataCallback(const int callback_type, const char* message,
const HighsCallbackDataOut* data_out,
HighsCallbackDataIn* data_in,
void* user_callback_data) {
assert(callback_type == kCallbackMipInterrupt ||
callback_type == kCallbackMipLogging ||
callback_type == kCallbackMipImprovingSolution);
if (dev_run)
printf("userDataCallback: Node count = %" PRId64
"; Time = %6.2f; "
"Bounds (%11.4g, %11.4g); Gap = %11.4g; Objective = %11.4g: %s\n",
data_out->mip_node_count, data_out->running_time,
data_out->mip_dual_bound, data_out->mip_primal_bound,
data_out->mip_gap, data_out->objective_function_value, message);
}
data_out->mip_gap, data_out->objective_function_value,
message.c_str());
};

TEST_CASE("my-callback-logging", "[highs-callback]") {
bool output_flag = true; // Still runs quietly
Expand Down
16 changes: 8 additions & 8 deletions check/pythontest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
execfile("../src/interfaces/highs_lp_solver.py")

cc = (1.0,-2.0)
cl = (0.0,0.0)
cu = (10.0,10.0)
ru = (2.0,1.0)
rl = (0.0,0.0)
astart = (0,2,4)
aindex = (0,1,0,1)
avalue = (1.0,2.0,1.0,3.0)
cc = (1.0, -2.0)
cl = (0.0, 0.0)
cu = (10.0, 10.0)
ru = (2.0, 1.0)
rl = (0.0, 0.0)
astart = (0, 2, 4)
aindex = (0, 1, 0, 1)
avalue = (1.0, 2.0, 1.0, 3.0)
call_highs(cc, cl, cu, rl, ru, astart, aindex, avalue)
25 changes: 15 additions & 10 deletions docs/src/callbacks.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Callbacks

The HiGHS callback allows user actions to be performed within HiGHS. There is one generic callback method that can be defined by a user, with specific callback scenarios communicated to the user via a parameter. Particular callbacks must be activated (and can be deactivated) as described below. The user callback can be given any name and, below, is called `userCallback`. Its definition is

```bash
The HiGHS callback allows user actions to be performed within HiGHS. There is
one generic callback method that can be defined by a user, with specific
callback scenarios communicated to the user via a parameter. Particular
callbacks must be activated (and can be deactivated) as described below. The
user callback can be given any name and, below, is called `userCallback`. Its
definition is

```cpp
void userCallback(const int callback_type,
const char* message,
const HighsCallbackDataOut* data_out,
Expand All @@ -22,12 +27,12 @@ The logging callback type is a cast of the relevant member of the C++ enum
`HighsCallbackType`, and is available in C as a constant.

The user's callback method is communicated to HiGHS via the method that in the HiGHS C++ class is
```bash
```cpp
HighsStatus setCallback(void (*userCallback)(const int, const char*, const HighsCallbackDataOut*,
HighsCallbackDataIn*, void*), void* user_callback_data);
```
and, in the HiGHS C API is
```bash
```cpp
HighsInt Highs_setCallback(
void* highs,
void (*userCallback)(const int, const char*,
Expand All @@ -37,19 +42,19 @@ HighsInt Highs_setCallback(
```
There current callback scenarios are set out below, and the particular callback is activated in C++ by calling

```bash
```cpp
HighsStatus startCallback(const int callback_type);
```
and, in C, by calling
```bash
```cpp
HighsInt Highs_startCallback(void* highs, const int callback_type);
```
, and de-activated in C++ by calling
```bash
```cpp
HighsStatus stopCallback(const int callback_type);
```
and, in C, by calling
```bash
```cpp
HighsInt Highs_stopCallback(void* highs, const int callback_type);
```

Expand Down Expand Up @@ -122,7 +127,7 @@ For each of the MIP callbacks, the following `HighsCallbackDataOut` struct membe
* `num_nodes`: the number of MIP nodes explored to date
* `primal_bound`: the primal bound
* `dual_bound`: the dual bound
* `mip_gap`: the (relative) difference between tht primal and dual bounds
* `mip_gap`: the (relative) difference between the primal and dual bounds



Loading