From be21711c939d1e705edd68b6c6b79f9d3e134143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B8ren=20Fuglede=20J=C3=B8rgensen?= Date: Wed, 19 Apr 2023 11:45:35 +0200 Subject: [PATCH] Add writeBasis interface method to highspy --- highspy/highs_bindings.cpp | 1 + highspy/tests/test_highspy.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/highspy/highs_bindings.cpp b/highspy/highs_bindings.cpp index 4f33c46a7e..dd7c8bd509 100644 --- a/highspy/highs_bindings.cpp +++ b/highspy/highs_bindings.cpp @@ -735,6 +735,7 @@ PYBIND11_MODULE(highs_bindings, m) .def("passRowName", &Highs::passRowName) .def("readModel", &Highs::readModel) .def("readBasis", &Highs::readBasis) + .def("writeBasis", &Highs::writeBasis) .def("presolve", &Highs::presolve) .def("run", &Highs::run) .def("postsolve", &Highs::postsolve) diff --git a/highspy/tests/test_highspy.py b/highspy/tests/test_highspy.py index 80a4516dc1..2c46a0c717 100644 --- a/highspy/tests/test_highspy.py +++ b/highspy/tests/test_highspy.py @@ -1,3 +1,4 @@ +import tempfile import unittest import highspy import numpy as np @@ -404,3 +405,37 @@ def test_ranging(self): self.assertEqual(ranging.row_bound_up.value_[1], inf); self.assertEqual(ranging.row_bound_up.objective_[1], inf); + def test_write_basis_before_running(self): + h = self.get_basic_model() + with tempfile.NamedTemporaryFile() as f: + h.writeBasis(f.name) + contents = f.read() + self.assertEqual(contents, b'HiGHS v1\nNone\n') + + def test_write_basis_after_running(self): + h = self.get_basic_model() + h.run() + with tempfile.NamedTemporaryFile() as f: + h.writeBasis(f.name) + contents = f.read() + self.assertEqual( + contents, b'HiGHS v1\nValid\n# Columns 2\n1 1 \n# Rows 2\n0 0 \n' + ) + + def test_read_basis(self): + # Read basis from one run model into an unrun model + expected_status_before = highspy.HighsBasisStatus.kLower + expected_status_after = highspy.HighsBasisStatus.kBasic + + h1 = self.get_basic_model() + self.assertEqual(h1.getBasis().col_status[0], expected_status_before) + h1.run() + self.assertEqual(h1.getBasis().col_status[0], expected_status_after) + + h2 = self.get_basic_model() + self.assertEqual(h2.getBasis().col_status[0], expected_status_before) + + with tempfile.NamedTemporaryFile() as f: + h1.writeBasis(f.name) + h2.readBasis(f.name) + self.assertEqual(h2.getBasis().col_status[0], expected_status_after)