Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Release Notes
Upcoming Version
----------------

* Blacklist highspy 1.14.0 which produces wrong results due to broken presolve and crashes on Windows (`HiGHS#2964 <https://github.com/ERGO-Code/HiGHS/issues/2964>`_).
* Add ``Model.copy()`` (default deep copy) with ``deep`` and ``include_solution`` options; support Python ``copy.copy`` and ``copy.deepcopy`` protocols via ``__copy__`` and ``__deepcopy__``.
* Harmonize coordinate alignment for operations with subset/superset objects:
- Multiplication and division fill missing coords with 0 (variable doesn't participate)
Expand Down
67 changes: 46 additions & 21 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from itertools import product, zip_longest
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, cast, overload
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -536,7 +536,7 @@ def _multiply_by_linear_expression(
# merge on factor dimension only returns v1 * v2 + c1 * c2
ds = other.data[["coeffs", "vars"]].sel(_term=0).broadcast_like(self.data)
ds = assign_multiindex_safe(ds, const=other.const)
res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression) # type: ignore
res = merge([self, ds], dim=FACTOR_DIM, cls=QuadraticExpression)
# deal with cross terms c1 * v2 + c2 * v1
if self.has_constant:
res = res + self.const * other.reset_const()
Expand Down Expand Up @@ -741,7 +741,7 @@ def add(
self, QuadraticExpression
):
other = other.to_quadexpr()
return merge([self, other], cls=self.__class__, join=join) # type: ignore[list-item]
return merge([self, other], cls=self.__class__, join=join)

def sub(
self: GenericExpression,
Expand Down Expand Up @@ -2332,18 +2332,39 @@ def as_expression(
return LinearExpression(obj, model)


Mergeable: TypeAlias = BaseExpression | variables.Variable | Dataset


@overload
def merge(
exprs: Sequence[Mergeable] | Mergeable,
*add_exprs: Mergeable,
dim: str = ...,
cls: type[GenericExpression],
join: str | None = ...,
**kwargs: Any,
) -> GenericExpression: ...


@overload
def merge(
exprs: Sequence[Mergeable] | Mergeable,
*add_exprs: Mergeable,
dim: str = ...,
cls: None = ...,
join: str | None = ...,
**kwargs: Any,
) -> BaseExpression: ...


def merge(
exprs: Sequence[
LinearExpression | QuadraticExpression | variables.Variable | Dataset
],
*add_exprs: tuple[
LinearExpression | QuadraticExpression | variables.Variable | Dataset
],
exprs: Sequence[Mergeable] | Mergeable,
*add_exprs: Mergeable,
dim: str = TERM_DIM,
cls: type[GenericExpression] = None, # type: ignore
cls: type[BaseExpression] | None = None,
join: str | None = None,
**kwargs: Any,
) -> GenericExpression:
) -> BaseExpression:
"""
Merge multiple expression together.

Expand Down Expand Up @@ -2374,34 +2395,38 @@ def merge(
-------
res : linopy.LinearExpression or linopy.QuadraticExpression
"""
if not isinstance(exprs, list) and len(add_exprs):
if not isinstance(exprs, Sequence):
warn(
"Passing a tuple to the merge function is deprecated. Please pass a list of objects to be merged",
DeprecationWarning,
)
exprs = [exprs] + list(add_exprs) # type: ignore
exprs = [exprs] + list(add_exprs)

has_quad_expression = any(type(e) is QuadraticExpression for e in exprs)
has_linear_expression = any(type(e) is LinearExpression for e in exprs)
has_quad_expression = any(isinstance(e, QuadraticExpression) for e in exprs)
has_linear_expression = any(isinstance(e, LinearExpression) for e in exprs)
if cls is None:
cls = QuadraticExpression if has_quad_expression else LinearExpression

if cls is QuadraticExpression and dim == TERM_DIM and has_linear_expression:
if (
issubclass(cls, QuadraticExpression)
and dim == TERM_DIM
and has_linear_expression
):
raise ValueError(
"Cannot merge linear and quadratic expressions along term dimension."
"Convert to QuadraticExpression first."
)

if has_quad_expression and cls is not QuadraticExpression:
if has_quad_expression and not issubclass(cls, QuadraticExpression):
raise ValueError("Cannot merge linear expressions to QuadraticExpression")

linopy_types = (variables.Variable, LinearExpression, QuadraticExpression)
linopy_types = (variables.Variable, BaseExpression)

model = exprs[0].model

if join is not None:
override = join == "override"
elif cls in linopy_types and dim in HELPER_DIMS:
elif issubclass(cls, linopy_types) and dim in HELPER_DIMS:
coord_dims = [
{k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} for e in exprs
]
Expand All @@ -2417,9 +2442,9 @@ def merge(
"coords": "minimal",
"compat": "override",
}
if cls == LinearExpression:
if issubclass(cls, LinearExpression):
kwargs["fill_value"] = FILL_VALUE
elif cls == variables.Variable:
elif issubclass(cls, variables.Variable):
kwargs["fill_value"] = variables.FILL_VALUE

if join is not None:
Expand Down
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"bottleneck",
"toolz",
"numexpr",
"xarray>=2024.2.0",
"xarray>=2024.2.0,<2026.4",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we merge this, we should remove this line again in #647

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please

"dask>=0.18.0",
"polars>=1.31.1",
"tqdm",
Expand Down Expand Up @@ -82,13 +82,13 @@ benchmarks = [
]
solvers = [
"gurobipy",
"highspy>=1.5.0; python_version < '3.12'",
"highspy>=1.7.1; python_version >= '3.12'",
"cplex; platform_system != 'Darwin' and python_version < '3.12'",
"highspy>=1.5.0,!=1.14.0; python_version < '3.12'",
"highspy>=1.7.1,!=1.14.0; python_version >= '3.12'",
"cplex; platform_system != 'Darwin'",
"mosek",
"mindoptpy; python_version < '3.12'",
"mindoptpy",
"coptpy!=7.2.1",
"xpress; platform_system != 'Darwin' and python_version < '3.11'",
"xpress; platform_system != 'Darwin'",
"pyscipopt; platform_system != 'Darwin'",
"knitro>=15.1.0",
# "cupdlpx>=0.1.2", pip package currently unstable
Expand Down
16 changes: 8 additions & 8 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2023,26 +2023,26 @@ def test_same_shape_add_join_override(self, a: Variable, c: Variable) -> None:

class TestMerge:
def test_merge_join_parameter(self, a: Variable, b: Variable) -> None:
result: LinearExpression = merge(
[a.to_linexpr(), b.to_linexpr()], join="inner"
result = merge(
[a.to_linexpr(), b.to_linexpr()], cls=LinearExpression, join="inner"
)
assert list(result.data.indexes["i"]) == [1, 2]

def test_merge_outer_join(self, a: Variable, b: Variable) -> None:
result: LinearExpression = merge(
[a.to_linexpr(), b.to_linexpr()], join="outer"
result = merge(
[a.to_linexpr(), b.to_linexpr()], cls=LinearExpression, join="outer"
)
assert set(result.coords["i"].values) == {0, 1, 2, 3}

def test_merge_join_left(self, a: Variable, b: Variable) -> None:
result: LinearExpression = merge(
[a.to_linexpr(), b.to_linexpr()], join="left"
result = merge(
[a.to_linexpr(), b.to_linexpr()], cls=LinearExpression, join="left"
)
assert list(result.data.indexes["i"]) == [0, 1, 2]

def test_merge_join_right(self, a: Variable, b: Variable) -> None:
result: LinearExpression = merge(
[a.to_linexpr(), b.to_linexpr()], join="right"
result = merge(
[a.to_linexpr(), b.to_linexpr()], cls=LinearExpression, join="right"
)
assert list(result.data.indexes["i"]) == [1, 2, 3]

Expand Down
6 changes: 3 additions & 3 deletions test/test_quadratic_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def test_quadratic_expression_wrong_multiplication(x: Variable, y: Variable) ->
def merge_raise_deprecation_warning(x: Variable, y: Variable) -> None:
expr: QuadraticExpression = x * y # type: ignore
with pytest.warns(DeprecationWarning):
merge(expr, expr) # type: ignore
merge(expr, expr)


def test_merge_linear_expression_and_quadratic_expression(
Expand All @@ -238,11 +238,11 @@ def test_merge_linear_expression_and_quadratic_expression(
with pytest.raises(ValueError):
merge([linexpr, quadexpr], cls=QuadraticExpression)

new_quad_ex = merge([linexpr.to_quadexpr(), quadexpr]) # type: ignore
new_quad_ex = merge([linexpr.to_quadexpr(), quadexpr])
assert isinstance(new_quad_ex, QuadraticExpression)

with pytest.warns(DeprecationWarning):
merge(quadexpr, quadexpr, cls=QuadraticExpression) # type: ignore
merge(quadexpr, quadexpr, cls=QuadraticExpression)

quadexpr_2 = linexpr.to_quadexpr()
merged_expr = merge([quadexpr_2, quadexpr], cls=QuadraticExpression)
Expand Down
Loading