diff --git a/doc/release_notes.rst b/doc/release_notes.rst index e5b7033f..fa8e7184 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -52,6 +52,7 @@ Most users should keep calling ``model.solve(...)``. If you want more control, y **Bug Fixes** +* Setting pandas bounds with missing coords now broadcasts to target coords * SOS constraints on masked variables no longer cause solver-specific failures (Gurobi ``IndexError``, Xpress ``?404 Invalid column number``, LP parse errors, silent set corruption). ``Model.solve()`` and ``Model.to_file()`` now raise a clear ``NotImplementedError`` referring users to `#688 `__; pass ``reformulate_sos=True`` as a workaround. * ``Model.solve(..., reformulate_sos=True)`` now actually reformulates SOS constraints even when the solver supports them natively. Previously it was silently ignored with a warning. diff --git a/linopy/model.py b/linopy/model.py index 48a8200b..820825de 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -126,45 +126,89 @@ def _coords_to_dict( return result -def _validate_dataarray_bounds(arr: Any, coords: Any) -> Any: +def _sanitize_pandas(arr: pd.Series | pd.DataFrame) -> DataArray | None: """ - Validate and expand DataArray bounds against explicit coords. + Attempt to convert the pandas series or dataframe into a datarray with named coords. + """ + if isinstance(arr, pd.DataFrame): + # A pandas dataframe, possible with multi-level columns and multi-level index + # Unstack all layers of columns + while isinstance(arr, pd.DataFrame): + arr = arr.unstack() + if not isinstance(arr, pd.Series): + # This should not happen + logger.warning("Failed to unstack dataframe") + return None + + assert isinstance(arr, pd.Series) + # A pandas series, possible with a multi-level index + index = arr.index + + # We can only process pandas series/dataframes with named dimensions + if isinstance(index, pd.MultiIndex): + for name in index.names: + if name is None: + return None + else: + if index.name is None: + return None + return arr.to_xarray() + + +def _validate_dataarray_bounds( + arr: DataArray | pd.Series | pd.DataFrame | Any, coords: Any +) -> Any: + """ + Validate and expand DataArray (or pandas array with all named dimensions) against explicit coords. - If ``arr`` is not a DataArray, return it unchanged (``as_dataarray`` - will handle conversion). For DataArray inputs: + If ``arr`` is not a DataArray or pandas with all named dimensions, it will be returned unchanged. + If ``arr`` is a pandas series or dataframe, it will be converted to a DataArray. - Raises ``ValueError`` if the array has dimensions not in coords. - Raises ``ValueError`` if shared dimension coordinates don't match. - Expands missing dimensions via ``expand_dims``. """ - if not isinstance(arr, DataArray): + if not isinstance(arr, (DataArray, pd.Series, pd.DataFrame)): return arr + type_name = { + pd.Series: "Series", + pd.DataFrame: "DataFrame", + DataArray: "DataArray", + }[type(arr)] + + if isinstance(arr, (pd.Series, pd.DataFrame)): + xarr = _sanitize_pandas(arr) + if xarr is None: + return arr + else: + xarr = arr + expected = _coords_to_dict(coords) if not expected: - return arr + return xarr - extra = set(arr.dims) - set(expected) + extra = set(xarr.dims) - set(expected) if extra: - raise ValueError(f"DataArray has extra dimensions not in coords: {extra}") + raise ValueError(f"{type_name} has extra dimensions not in coords: {extra}") for dim, coord_values in expected.items(): - if dim not in arr.dims: + if dim not in xarr.dims: continue - if isinstance(arr.indexes.get(dim), pd.MultiIndex): + if isinstance(xarr.indexes.get(dim), pd.MultiIndex): continue expected_idx = ( coord_values if isinstance(coord_values, pd.Index) else pd.Index(coord_values) ) - actual_idx = arr.coords[dim].to_index() + actual_idx = xarr.coords[dim].to_index() if not actual_idx.equals(expected_idx): # Same values, different order → reindex to match expected order if len(actual_idx) == len(expected_idx) and set(actual_idx) == set( expected_idx ): - arr = arr.reindex({dim: expected_idx}) + xarr = xarr.reindex({dim: expected_idx}) else: raise ValueError( f"Coordinates for dimension '{dim}' do not match: " @@ -172,11 +216,11 @@ def _validate_dataarray_bounds(arr: Any, coords: Any) -> Any: ) # Expand missing dimensions - expand = {k: v for k, v in expected.items() if k not in arr.dims} + expand = {k: v for k, v in expected.items() if k not in xarr.dims} if expand: - arr = arr.expand_dims(expand) + xarr = xarr.expand_dims(expand) - return arr + return xarr class Model: diff --git a/test/test_variable.py b/test/test_variable.py index b14b746e..303224a0 100644 --- a/test/test_variable.py +++ b/test/test_variable.py @@ -433,18 +433,43 @@ def test_dataarray_extra_dims(self, model: "Model") -> None: model.add_variables(lower=lower, coords=self.DICT_COORDS, name="x") # -- Broadcasting missing dims ----------------------------------------- - - def test_dataarray_broadcast_missing_dim(self, model: "Model") -> None: + @pytest.mark.parametrize( + "bound", + [ + pytest.param( + DataArray([1, 2, 3], dims=["time"], coords={"time": range(3)}), + id="xr.DataArray", + ), + pytest.param( + pd.Series(index=pd.RangeIndex(3, name="time"), data=[1, 2, 3]), + id="pd.Series", + ), + pytest.param( + pd.DataFrame( + index=pd.RangeIndex(3, name="time"), + columns=pd.Index(["red"], name="colour"), + data=[1, 2, 3], + ), + id="pd.DataFrame", + ), + ], + ) + def test_broadcast_missing_dim( + self, model: "Model", bound: DataArray | pd.Series + ) -> None: time = pd.RangeIndex(3, name="time") space = pd.Index(["a", "b"], name="space") - lower = DataArray([1, 2, 3], dims=["time"], coords={"time": range(3)}) - var = model.add_variables(lower=lower, coords=[time, space], name="x") - assert set(var.data.dims) == {"time", "space"} - assert var.data.sizes == {"time": 3, "space": 2} + colour = pd.Index(["red"], name="colour") + + var = model.add_variables( + lower=-bound, upper=bound, coords=[time, space, colour], name="x" + ) + assert set(var.data.dims) == {"time", "space", "colour"} + assert var.data.sizes == {"time": 3, "space": 2, "colour": 1} # Verify broadcast filled with actual values, not NaN assert not var.data.lower.isnull().any() - assert (var.data.lower.sel(space="a") == [1, 2, 3]).all() - assert (var.data.lower.sel(space="b") == [1, 2, 3]).all() + assert (var.data.lower.sel(space="a", colour="red") == [-1, -2, -3]).all() + assert (var.data.lower.sel(space="b", colour="red") == [-1, -2, -3]).all() # -- Special coord formats ---------------------------------------------