Skip to content

Commit

Permalink
Enable solution output for converted models
Browse files Browse the repository at this point in the history
Adding new variables and/or algebraic constraints broke solution output,
AMPL did not accept .sol with more items than in the NL. This was solved
with a ModelAdapter which reports the original model sizes to
mp::SolutionWriter
  • Loading branch information
glebbelov committed Apr 20, 2021
1 parent f312133 commit 80e3583
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 60 deletions.
40 changes: 20 additions & 20 deletions include/mp/convert/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
#define BACKEND_H_

#include "mp/clock.h"
#include "mp/convert/converter_query.h".h"
#include "mp/convert/converter_query.h"
#include "mp/convert/constraint_keeper.h"
#include "mp/convert/std_constr.h"
#include "mp/convert/std_obj.h"
#include "mp/convert/model.h"
#include "mp/convert/model_adapter.h"

namespace mp {

Expand All @@ -38,17 +39,17 @@ namespace mp {
template <class Impl>
class BasicBackend :
public BasicConstraintAdder,
private SolverImpl< BasicModel<> > // mp::Solver stuff, hidden
private SolverImpl< ModelAdapter< BasicModel<> > > // mp::Solver stuff, hidden
{
ConverterQuery *p_converter_query_object = nullptr;
using MPSolverBase = SolverImpl< BasicModel<> >;
using MPSolverBase = SolverImpl< ModelAdapter< BasicModel<> > >;
public:
using MPUtils = MPSolverBase; // Allow Converter access the SolverImpl
const MPUtils& GetMPUtils() const { return *this; }
MPUtils& GetMPUtils() { return *this; }
public:
BasicBackend() :
SolverImpl< BasicModel<> >(
MPSolverBase(
Impl::GetSolverInvocationName(),
Impl::GetAMPLSolverLongName(),
Impl::Date(), Impl::Flags())
Expand Down Expand Up @@ -191,13 +192,13 @@ class BasicBackend :
}


void SolveAndReport(Model &p, SolutionHandler &sh) {
void SolveAndReport() {
MP_DISPATCH( PrepareSolve() );
MP_DISPATCH( DoSolve() );
MP_DISPATCH( WrapupSolve() );

ObtainSolutionStatus();
ObtainAndReportSolution(p, sh);
ObtainAndReportSolution();
if (MP_DISPATCH( timing() ))
PrintTimingInfo();
}
Expand All @@ -216,7 +217,7 @@ class BasicBackend :
ConvertSolutionStatus(*MP_DISPATCH( interrupter() ), solve_code) );
}

void ObtainAndReportSolution(Model& p, SolutionHandler &sh) {
void ObtainAndReportSolution() {
fmt::MemoryWriter writer;
writer.write("{}: {}", MP_DISPATCH( long_name() ), solve_status);
if (solve_code < sol::INFEASIBLE) {
Expand All @@ -234,18 +235,9 @@ class BasicBackend :
MP_DISPATCH( DualSolution(dual_solution) );
}

auto toySuf =
GetCQ().AddIntSuffix("toy_var_suffix", suf::VAR | suf::OUTPUT | suf::OUTONLY);
/// TODO SuffixHandler should also allow GetValue()
for (int i = 0, n = MP_DISPATCH( NumberOfVariables() ); i < n; ++i) {
if (i % 2 == 0)
toySuf.SetValue(i, 900+i);
}
sh.HandleSolution(solve_code, writer.c_str(),
solution.empty() ? 0 : solution.data(),
dual_solution.empty() ? 0 : dual_solution.data(), obj_value);
HandleSolution(solve_code, writer.c_str(),
solution.empty() ? 0 : solution.data(),
dual_solution.empty() ? 0 : dual_solution.data(), obj_value);
}

void PrintTimingInfo() {
Expand Down Expand Up @@ -291,6 +283,12 @@ class BasicBackend :
using Solver::set_option_header;
using Solver::add_to_option_header;

protected:
void HandleSolution(int status, fmt::CStringRef msg,
const double *x, const double *y, double obj) {
GetCQ().HandleSolution(status, msg, x, y, obj);
}

///////////////////////////// OPTIONS /////////////////////////////////
/// TODOs
/// - hide all Solver stuff behind an abstract interface
Expand Down Expand Up @@ -364,11 +362,13 @@ class BasicBackend :
}

/// Adding solver options of types int/double/string/...
/// The type is deduced from the two last parameters min, max
/// (currently unused otherwise - TODO)
/// If min/max omitted, assume ValueType=std::string
/// Assumes existence of Impl::Get/SetSolverOption(KeyType, ValueType(&))
template <class KeyType, class ValueType=std::string>
void AddSolverOption(const char *name, const char *description,
KeyType k,
/// If min/max omitted, assume ValueType=std::string
ValueType ={}, ValueType ={}) {
AddOption(Solver::OptionPtr(
new ConcreteOptionWrapper<
Expand Down
2 changes: 1 addition & 1 deletion include/mp/convert/basic_constr.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace mp {
class BasicConstraint {
public:
static const char* GetConstraintName() { return "BasicConstraint"; }
void print(std::ostream& os) const { }
void print(std::ostream& ) const { }
static constexpr bool HasContext() { return false; }
void SetContext(Context ) const { }
Context GetContext() const { return Context::CTX_NONE; }
Expand Down
88 changes: 63 additions & 25 deletions include/mp/convert/basic_converters.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
#ifndef BASIC_CONVERTERS_H_
#define BASIC_CONVERTERS_H_

#include "mp/convert/model.h"
#include "mp/problem.h"
#include "mp/convert/model_adapter.h"
#include "mp/convert/backend.h"
#include "mp/solver.h"

Expand All @@ -37,34 +38,59 @@ template <class Impl, class Backend,
class Model = BasicProblem< > >
class BasicMPConverter :
public BasicConstraintConverter {
Model model_;

ModelAdapter<Model> model_adapter_;
Backend backend_;
/// This is to wrap some dependencies from MP
using SolverAdapter = SolverImpl<Model>;
std::unique_ptr<ConverterQuery> p_converter_query;
using SolverAdapter = SolverImpl< ModelAdapter<Model> >;

std::unique_ptr<ConverterQuery> p_converter_query_;
SolutionHandler* p_sol_h_;

public:
static const char* GetConverterName() { return "BasicMPConverter"; }
using Converter = Impl;
using ModelType = Model;
using OutputModelType = ModelAdapter<Model>;
using BackendType = Backend;

/// MP API requires a 'ProblemBuilder' type
using ProblemBuilder = Model;
using ProblemBuilder = OutputModelType;
using MPUtils = typename Backend::MPUtils;
const Model& GetModel() const { return model_; } // Can be used for NL file input
Model& GetModel() { return model_; } // Can be used for NL file input

/// The working model
const Model& GetModel() const { return GetOutputModel().GetModel(); }
/// The working model
Model& GetModel() { return GetOutputModel().GetModel(); }

/// Dirty: returning the output model
/// // Can be used for NL file input
const OutputModelType& GetInputModel() const { return GetOutputModel(); }
OutputModelType& GetInputModel() { return GetOutputModel(); }

const OutputModelType& GetOutputModel() const { return model_adapter_; } // TODO
OutputModelType& GetOutputModel() { return model_adapter_; }

const Backend& GetBackend() const { return backend_; }
Backend& GetBackend() { return backend_; }

const MPUtils& GetMPUtils() const { return GetBackend().GetMPUtils(); }
MPUtils& GetMPUtils() { return GetBackend().GetMPUtils(); }

const ConverterQuery& GetCQ() const {
assert(p_converter_query);
return *p_converter_query;
assert(p_converter_query_);
return *p_converter_query_;
}
ConverterQuery& GetCQ() {
assert(p_converter_query);
return *p_converter_query;
assert(p_converter_query_);
return *p_converter_query_;
}

const SolutionHandler& GetSolH() const { assert(p_sol_h_); return *p_sol_h_; }
SolutionHandler& GetSolH() { assert(p_sol_h_); return *p_sol_h_; }
void SetSolHandler(SolutionHandler& sh) { assert(&sh); p_sol_h_ = &sh; }
void RemoveSolHandler() { p_sol_h_=nullptr; }

public:

BasicMPConverter() {
Expand All @@ -73,7 +99,7 @@ class BasicMPConverter :
}

void InitConverterQueryObject() {
p_converter_query = MP_DISPATCH( MakeConverterQuery() );
p_converter_query_ = MP_DISPATCH( MakeConverterQuery() );
GetBackend().ProvideConverterQueryObject( &MP_DISPATCH( GetCQ() ) );
}

Expand All @@ -91,7 +117,7 @@ class BasicMPConverter :
NLReadResult ReadNLFile(const std::string& nl_filename, int nl_reader_flags) {
NLReadResult result;
result.handler_.reset(
new internal::SolverNLHandler<SolverAdapter>(GetModel(), GetMPUtils()));
new internal::SolverNLHandler<SolverAdapter>(GetInputModel(), GetMPUtils()));
internal::NLFileReader<> reader;
reader.Read(nl_filename, *result.handler_, nl_reader_flags);
return result;
Expand All @@ -106,18 +132,18 @@ class BasicMPConverter :
/// These guys used from outside to feed a model to be converted
/// and forwarded to a backend
void InputVariables(int n, const double* lb, const double* ub, const var::Type* ty) {
model_.AddVars(n, lb, ub, ty);
GetModel().AddVars(n, lb, ub, ty);
}
void InputObjective(obj::Type t,
int nnz, const double* c, const int* v, NumericExpr e=NumericExpr()) {
typename Model::LinearObjBuilder lob = model_.AddObj(t, e);
typename Model::LinearObjBuilder lob = GetModel().AddObj(t, e);
for (int i=0; i!=nnz; ++i) {
lob.AddTerm(c[i], v[i]);
}
}
void InputAlgebraicCon(int nnz, const double* c, const int* v,
double lb, double ub, NumericExpr e=NumericExpr()) {
typename Model::MutAlgebraicCon mac = model_.AddCon(lb, ub);
typename Model::MutAlgebraicCon mac = GetModel().AddCon(lb, ub);
typename Model::LinearConBuilder lcb = mac.set_linear_expr(nnz);
for (int i=0; i!=nnz; ++i)
lcb.AddTerm(v[i], c[i]);
Expand All @@ -132,30 +158,42 @@ class BasicMPConverter :
}

void Solve(SolutionHandler &sh) {
GetBackend().SolveAndReport(GetModel(), sh); // TODO no model any more
SetSolHandler(sh);
GetBackend().SolveAndReport();
RemoveSolHandler();
}


protected:
/// Convert the whole model, e.g., after reading from NL
void ConvertModel() {
MP_DISPATCH( PrepareConversion() );
MP_DISPATCH( ConvertStandardItems() );
MP_DISPATCH( ConvertExtraItems() );
}

void PrepareConversion() {
MP_DISPATCH( MemorizeModelSize() );
}

void MemorizeModelSize() {
GetOutputModel().set_num_vars(GetModel().num_vars());
GetOutputModel().set_num_alg_cons(GetModel().num_algebraic_cons());
}

void ConvertStandardItems() {
int num_common_exprs = model_.num_common_exprs();
int num_common_exprs = GetModel().num_common_exprs();
for (int i = 0; i < num_common_exprs; ++i)
MP_DISPATCH( Convert( model_.common_expr(i) ) );
if (int num_objs = model_.num_objs())
MP_DISPATCH( Convert( GetModel().common_expr(i) ) );
if (int num_objs = GetModel().num_objs())
for (int i = 0; i < num_objs; ++i)
MP_DISPATCH( Convert( model_.obj(i) ) );
if (int n_cons = model_.num_algebraic_cons())
MP_DISPATCH( Convert( GetModel().obj(i) ) );
if (int n_cons = GetModel().num_algebraic_cons())
for (int i = 0; i < n_cons; ++i)
MP_DISPATCH( Convert( model_.algebraic_con(i) ) );
if (int n_lcons = model_.num_logical_cons())
MP_DISPATCH( Convert( GetModel().algebraic_con(i) ) );
if (int n_lcons = GetModel().num_logical_cons())
for (int i = 0; i < n_lcons; ++i)
MP_DISPATCH( Convert( model_.logical_con(i) ) );
MP_DISPATCH( Convert( GetModel().logical_con(i) ) );
}

void ConvertExtraItems() { }
Expand Down
18 changes: 17 additions & 1 deletion include/mp/convert/converter_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,8 @@ class BasicMPFlatConverter
}

bool HasInitExpression(int var) const {
return var_info_.size()>var && nullptr!=var_info_[var].pInitExpr;
return int(var_info_.size())>var &&
nullptr!=var_info_[var].pInitExpr;
}

BasicConstraintKeeper* GetInitExpression(int var) {
Expand All @@ -920,6 +921,21 @@ class BasicMPFlatConverter
}


///////////////////////////////////////////////////////////////////////
//////////////////// SOLUTION REPORTING FROM BACKEND //////////////////
///////////////////////////////////////////////////////////////////////
public:
void HandleSolution(int status, fmt::CStringRef msg,
const double *x, const double *y, double obj) {
MP_DISPATCH( GetSolH() ).HandleSolution(status, msg, x, y, obj);
}

using OutputModelType = typename BaseConverter::OutputModelType;
typename OutputModelType::IntSuffixHandler
AddIntSuffix(fmt::StringRef name, int kind, int =0) {
return MP_DISPATCH( GetOutputModel() ).AddIntSuffix(name, kind);
}

///////////////////////////////////////////////////////////////////////
/////////////////////// OPTIONS /////////////////////////
///
Expand Down
13 changes: 9 additions & 4 deletions include/mp/convert/converter_flat_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ class FlatConverterQuery : public ConverterQuery {
const Converter& GetCvt() const { return cvt_; }
Converter& GetCvt() { return cvt_; }

using Model = typename Converter::ModelType;
using Model = typename Converter::OutputModelType;

const Model& GetPB() const { return GetCvt().GetModel(); }
Model& GetPB() { return GetCvt().GetModel(); }
const Model& GetOutputModel() const { return GetCvt().GetOutputModel(); }
Model& GetOutputModel() { return GetCvt().GetOutputModel(); }

IntSuffixHandler AddIntSuffix(fmt::StringRef name, int kind, int) override {
return GetPB().AddIntSuffix(name, kind);
return GetCvt().AddIntSuffix(name, kind);
}

void HandleSolution(int status, fmt::CStringRef msg,
const double *x, const double * y, double obj) override {
GetCvt().HandleSolution(status, msg, x, y, obj);
}

};
Expand Down
12 changes: 8 additions & 4 deletions include/mp/convert/converter_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*/

#include "mp/solver.h"
#include "mp/convert/model.h".h"
#include "mp/convert/model.h"

namespace mp {

Expand All @@ -20,9 +20,13 @@ class ConverterQuery {

using IntSuffixHandler = Model::SuffixHandler<int>;

// Adds an integer suffix.
// name: Suffix name that may not be null-terminated.
virtual IntSuffixHandler AddIntSuffix(fmt::StringRef name, int kind, int=0) { }
/// Adds an integer suffix.
/// name: Suffix name that may not be null-terminated.
/// TODO put values right there (1. vector, 2. sparse vector)
virtual IntSuffixHandler AddIntSuffix(fmt::StringRef name, int kind, int=0) = 0;

virtual void HandleSolution(int, fmt::CStringRef,
const double *, const double *, double) = 0;

};

Expand Down

0 comments on commit 80e3583

Please sign in to comment.