diff --git a/docs/notebooks/structural_reliability.ipynb b/docs/notebooks/structural_reliability.ipynb
index 1f0ec14..2a4b4af 100644
--- a/docs/notebooks/structural_reliability.ipynb
+++ b/docs/notebooks/structural_reliability.ipynb
@@ -12,32 +12,35 @@
},
{
"cell_type": "code",
+ "execution_count": 10,
"id": "cedd5ec9-31f7-4e7f-91be-f73b1d1d00f1",
"metadata": {
- "scrolled": true,
"ExecuteTime": {
"end_time": "2024-05-30T09:06:58.579066Z",
"start_time": "2024-05-30T09:06:58.552735Z"
- }
+ },
+ "scrolled": true
},
+ "outputs": [],
"source": [
"import pathlib\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import simdec as sd"
- ],
- "outputs": [],
- "execution_count": 1
+ ]
},
{
- "metadata": {},
"cell_type": "markdown",
- "source": "Let's first load the dataset. It's a CSV file, each row represent a simulation or sample. The first column is the output or quantity of interest and other columns are parameters' values.",
- "id": "8700ed278bb1c06d"
+ "id": "8700ed278bb1c06d",
+ "metadata": {},
+ "source": [
+ "Let's first load the dataset. It's a CSV file, each row represent a simulation or sample. The first column is the output or quantity of interest and other columns are parameters' values."
+ ]
},
{
"cell_type": "code",
+ "execution_count": 11,
"id": "0b21846d-edff-4e39-a423-b247f81c4520",
"metadata": {
"ExecuteTime": {
@@ -45,25 +48,9 @@
"start_time": "2024-05-30T09:06:58.579870Z"
}
},
- "source": [
- "fname = pathlib.Path(\"../../tests/data/stress.csv\")\n",
- "\n",
- "data = pd.read_csv(fname)\n",
- "output_name, *inputs_names = list(data.columns)\n",
- "inputs, output = data[inputs_names], data[output_name]\n",
- "inputs.head()"
- ],
"outputs": [
{
"data": {
- "text/plain": [
- " Kf sigma_res Rp0.2 R\n",
- "0 2.454866 -84.530638 297.406169 -0.834480\n",
- "1 2.774116 347.586947 379.499452 -0.131827\n",
- "2 2.504617 946.567040 940.477667 -0.039126\n",
- "3 2.466723 74.222224 406.622486 0.440311\n",
- "4 2.615602 -32.937734 979.498038 0.419690"
- ],
"text/html": [
"
\n",
"\n",
- "
\n",
+ "\n",
" \n",
" \n",
" | | \n",
" | \n",
- " N° | \n",
- " colour | \n",
- " std | \n",
- " min | \n",
- " mean | \n",
- " max | \n",
- " probability | \n",
+ " N° | \n",
+ " colour | \n",
+ " std | \n",
+ " min | \n",
+ " mean | \n",
+ " max | \n",
+ " probability | \n",
"
\n",
" \n",
" | sigma_res | \n",
@@ -478,119 +482,133 @@
"
\n",
" \n",
" \n",
- " | low | \n",
- " low | \n",
- " 9 | \n",
- " | \n",
- " 95.34 | \n",
- " 11.19 | \n",
- " 282.08 | \n",
- " 460.07 | \n",
- " 0.19 | \n",
+ " low | \n",
+ " low | \n",
+ " 9 | \n",
+ " | \n",
+ " 84.73 | \n",
+ " 11.74 | \n",
+ " 226.72 | \n",
+ " 397.62 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | medium | \n",
- " 8 | \n",
- " | \n",
- " 87.76 | \n",
- " 67.53 | \n",
- " 407.79 | \n",
- " 622.35 | \n",
- " 0.12 | \n",
+ " medium | \n",
+ " 8 | \n",
+ " | \n",
+ " 82.65 | \n",
+ " 11.19 | \n",
+ " 385.26 | \n",
+ " 619.78 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | high | \n",
- " 7 | \n",
- " | \n",
- " 108.03 | \n",
- " 237.13 | \n",
- " 541.32 | \n",
- " 819.41 | \n",
- " 0.26 | \n",
+ " high | \n",
+ " 7 | \n",
+ " | \n",
+ " 101.23 | \n",
+ " 384.75 | \n",
+ " 567.44 | \n",
+ " 817.84 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | medium | \n",
- " low | \n",
- " 6 | \n",
- " | \n",
- " 34.92 | \n",
- " 350.30 | \n",
- " 434.90 | \n",
- " 523.84 | \n",
- " 0.09 | \n",
+ " medium | \n",
+ " low | \n",
+ " 6 | \n",
+ " | \n",
+ " 43.37 | \n",
+ " 268.77 | \n",
+ " 376.10 | \n",
+ " 515.25 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | medium | \n",
- " 5 | \n",
- " | \n",
- " 44.39 | \n",
- " 398.42 | \n",
- " 485.72 | \n",
- " 650.98 | \n",
- " 0.06 | \n",
+ " medium | \n",
+ " 5 | \n",
+ " | \n",
+ " 63.56 | \n",
+ " 318.15 | \n",
+ " 485.57 | \n",
+ " 711.72 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | high | \n",
- " 4 | \n",
- " | \n",
- " 75.80 | \n",
- " 414.21 | \n",
- " 534.19 | \n",
- " 814.43 | \n",
- " 0.11 | \n",
+ " high | \n",
+ " 4 | \n",
+ " | \n",
+ " 106.22 | \n",
+ " 420.61 | \n",
+ " 580.03 | \n",
+ " 819.41 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | high | \n",
- " low | \n",
- " 3 | \n",
- " | \n",
- " 35.43 | \n",
- " 630.24 | \n",
- " 703.90 | \n",
- " 794.81 | \n",
- " 0.06 | \n",
+ " high | \n",
+ " low | \n",
+ " 3 | \n",
+ " | \n",
+ " 132.84 | \n",
+ " 383.11 | \n",
+ " 576.80 | \n",
+ " 794.81 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | medium | \n",
- " 2 | \n",
- " | \n",
- " 33.95 | \n",
- " 656.34 | \n",
- " 725.15 | \n",
- " 816.48 | \n",
- " 0.04 | \n",
+ " medium | \n",
+ " 2 | \n",
+ " | \n",
+ " 129.07 | \n",
+ " 410.77 | \n",
+ " 611.25 | \n",
+ " 824.87 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
- " | high | \n",
- " 1 | \n",
- " | \n",
- " 39.51 | \n",
- " 668.50 | \n",
- " 755.65 | \n",
- " 851.00 | \n",
- " 0.08 | \n",
+ " high | \n",
+ " 1 | \n",
+ " | \n",
+ " 127.28 | \n",
+ " 448.85 | \n",
+ " 643.91 | \n",
+ " 851.00 | \n",
+ " 0.11 | \n",
"
\n",
" \n",
"
\n"
+ ],
+ "text/plain": [
+ ""
]
},
- "execution_count": 9,
+ "execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
- "execution_count": 9
+ "source": [
+ "table, styler = sd.tableau(\n",
+ " statistic=res.statistic,\n",
+ " var_names=res.var_names,\n",
+ " states=res.states,\n",
+ " bins=res.bins,\n",
+ " palette=palette,\n",
+ ")\n",
+ "styler"
+ ]
},
{
- "metadata": {},
"cell_type": "markdown",
- "source": "Congratulations, now you know how to use SimDec to get more insights on your problem!",
- "id": "3cdf58c4bd3dbbca"
+ "id": "3cdf58c4bd3dbbca",
+ "metadata": {},
+ "source": [
+ "Congratulations, now you know how to use SimDec to get more insights on your problem!"
+ ]
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "Python 3",
"language": "python",
"name": "python3"
},
@@ -604,7 +622,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.1"
+ "version": "3.11.14"
}
},
"nbformat": 4,
diff --git a/panel/index.html b/panel/index.html
index 52dda01..9cc0ffc 100644
--- a/panel/index.html
+++ b/panel/index.html
@@ -203,34 +203,36 @@
-
+
+
-
-
+
-
+
-
-
+
-
+
-
+
diff --git a/panel/simdec_app.py b/panel/simdec_app.py
index cc98c62..93b7a12 100644
--- a/panel/simdec_app.py
+++ b/panel/simdec_app.py
@@ -170,8 +170,11 @@ def filtered_si(sensitivity_indices_table, input_names):
def explained_variance_80(sensitivity_indices_table):
- si = sensitivity_indices_table.value["Indices"]
- pos_80 = bisect.bisect_right(np.cumsum(si), 0.8)
+ df = sensitivity_indices_table.value
+ df = df[df["Inputs"] != "Sum of Indices"]
+ si = df["Indices"].values
+ target = 0.8 * np.sum(si)
+ pos_80 = bisect.bisect_right(np.cumsum(si), target)
# pos_80 = max(2, pos_80)
# pos_80 = min(len(si), pos_80)
diff --git a/src/simdec/__init__.py b/src/simdec/__init__.py
index 9394a8f..c91c1dc 100644
--- a/src/simdec/__init__.py
+++ b/src/simdec/__init__.py
@@ -2,6 +2,7 @@
from simdec.decomposition import *
from simdec.sensitivity_indices import *
from simdec.visualization import *
+from simdec.heterogeneity_indices import *
__all__ = [
"sensitivity_indices",
@@ -11,4 +12,5 @@
"two_output_visualization",
"tableau",
"palette",
+ "heterogeneity_indices",
]
diff --git a/src/simdec/decomposition.py b/src/simdec/decomposition.py
index 958d2aa..2969ffd 100644
--- a/src/simdec/decomposition.py
+++ b/src/simdec/decomposition.py
@@ -65,7 +65,7 @@ def __reduce__(self):
def decomposition(
inputs: pd.DataFrame,
- output: pd.DataFrame,
+ output: pd.DataFrame | np.ndarray,
*,
sensitivity_indices: np.ndarray,
dec_limit: float | None = None,
@@ -116,7 +116,11 @@ def decomposition(
inputs[cat_col] = codes
inputs = inputs.to_numpy()
- output = output.to_numpy().flatten()
+
+ if hasattr(output, "to_numpy"):
+ output = output.to_numpy().flatten()
+ else:
+ output = np.asarray(output).flatten()
# 1. variables for decomposition
var_order = np.argsort(sensitivity_indices)[::-1]
diff --git a/src/simdec/heterogeneity_indices.py b/src/simdec/heterogeneity_indices.py
new file mode 100644
index 0000000..5a9bb88
--- /dev/null
+++ b/src/simdec/heterogeneity_indices.py
@@ -0,0 +1,171 @@
+from .sensitivity_indices import sensitivity_indices
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+
+__all__ = ["heterogeneity_indices"]
+
+
+def heterogeneity_indices(
+ output: pd.Series,
+ inputs: pd.DataFrame,
+ split_variable: str | pd.Series,
+ n_subdivisions: int | None = None,
+ plot: bool = False,
+) -> pd.DataFrame:
+ """
+ Compute sensitivity-based heterogeneity across subdivisions of a variable.
+
+ Parameters
+ ----------
+ output : pd.Series
+ Model output vector.
+ inputs : pd.DataFrame
+ Input/feature matrix.
+ split_variable : str or pd.Series
+ Variable to split on. If string, must be a column in 'inputs'.
+ n_subdivisions : int, optional
+ Number of regions for continuous variables. Defaults to 4.
+ plot : bool, default False
+ If True, displays a stacked bar chart of regional sensitivities.
+
+ Returns
+ ----------
+ summary : pd.Dataframe
+ A summary of calculated heterogeneity indices.
+ """
+ y = pd.Series(output).reset_index(drop=True)
+ X = pd.DataFrame(inputs).reset_index(drop=True)
+
+ if isinstance(split_variable, str):
+ if split_variable not in X.columns:
+ raise ValueError(f"'{split_variable}' not found in inputs.")
+ z = X[split_variable].reset_index(drop=True)
+ split_name = split_variable
+ else:
+ z = pd.Series(split_variable).reset_index(drop=True)
+ split_name = getattr(split_variable, "name", "split_variable")
+
+ unique_vals = z.dropna().unique()
+ n_unique = len(unique_vals)
+
+ # Determine if variable is categorical/binary
+ is_categorical = (
+ pd.api.types.is_categorical_dtype(z)
+ or pd.api.types.is_object_dtype(z)
+ or pd.api.types.is_bool_dtype(z)
+ or n_unique <= 2
+ )
+
+ if is_categorical:
+ regions = z.astype("category")
+ else:
+ q = n_subdivisions if n_subdivisions is not None else 4
+ try:
+ regions = pd.qcut(z, q=q, duplicates="drop")
+ except ValueError as e:
+ raise ValueError(
+ f"Failed to bin '{split_name}' into {q} quantiles: {e}"
+ ) from e
+
+ regional_profiles = []
+ skipped = []
+
+ for region in regions.cat.categories:
+ mask = regions == region
+ n_in_region = mask.sum()
+
+ if n_in_region < 10:
+ # Need enough samples for meaningful sensitivity indices
+ skipped.append((region, n_in_region, "too few samples (< 10)"))
+ continue
+
+ X_sub = X.loc[mask]
+ y_sub = y.loc[mask]
+
+ # Skip if output has zero or near-zero variance in this region
+ if y_sub.var() < 1e-12:
+ skipped.append((region, n_in_region, "output variance ≈ 0"))
+ continue
+
+ try:
+ res = sensitivity_indices(inputs=X_sub, output=y_sub)
+ si_vals = np.asarray(res.si).ravel()
+
+ # Guard against NaN/Inf from degenerate sensitivity computation
+ if not np.all(np.isfinite(si_vals)):
+ skipped.append((region, n_in_region, "non-finite SI values"))
+ continue
+
+ si_region = pd.Series(si_vals, index=X.columns, name=region)
+ regional_profiles.append(si_region)
+
+ except Exception as e:
+ skipped.append((region, n_in_region, f"exception: {e}"))
+ continue
+
+ if skipped:
+ print(
+ f"[heterogeneity_indices] Skipped {len(skipped)} region(s) of '{split_name}':"
+ )
+ for reg, n, reason in skipped:
+ print(f" - region={reg!r}, n={n}, reason={reason}")
+
+ if len(regional_profiles) < 2:
+ total_regions = len(regions.cat.categories)
+ valid = len(regional_profiles)
+ raise ValueError(
+ f"Not enough valid subdivisions to compute heterogeneity: "
+ f"{valid}/{total_regions} regions passed all checks for '{split_name}'.\n"
+ f"Skipped regions:\n"
+ + "\n".join(f" {r!r}: n={n}, {reason}" for r, n, reason in skipped)
+ + "\n\nTry: (1) reducing n_subdivisions, "
+ "(2) using a different split_variable, or "
+ "(3) ensuring more samples per region."
+ )
+
+ regional_si = pd.concat(regional_profiles, axis=1)
+
+ res_global = sensitivity_indices(inputs=X, output=y)
+ overall_si = pd.Series(
+ np.asarray(res_global.si).ravel(),
+ index=X.columns,
+ name="Overall_SI",
+ )
+
+ # Heterogeneity = 2 × population std dev across regions
+ hetero_scores = 2 * regional_si.std(axis=1, ddof=0)
+ total_hetero = hetero_scores.mean()
+
+ hetero_col_name = f"Heterogeneity (across {split_name})"
+ summary = pd.DataFrame(
+ {"Overall_SI": overall_si, hetero_col_name: hetero_scores}
+ ).sort_values(by=hetero_col_name, ascending=False)
+ summary.loc["SUM / TOTAL"] = [overall_si.sum(), total_hetero]
+
+ if plot:
+ plot_order = summary.index[:-1]
+ data_to_plot = regional_si.loc[plot_order].T
+
+ cmap = plt.get_cmap("terrain")
+ colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(plot_order))]
+
+ _ = data_to_plot.plot(
+ kind="bar",
+ stacked=True,
+ figsize=(10, 6),
+ color=colors,
+ edgecolor="white",
+ width=0.8,
+ )
+
+ plt.title(f"Sensitivity Profiles across {split_name}", fontsize=14)
+ plt.ylabel("Variance Contribution", fontsize=12)
+ plt.xlabel(f"Regions of {split_name}", fontsize=12)
+ plt.legend(title="Input Variables", bbox_to_anchor=(1.05, 1), loc="upper left")
+ plt.xticks(rotation=45)
+ plt.grid(axis="y", linestyle="--", alpha=0.7)
+ plt.tight_layout()
+ plt.show()
+
+ return summary
diff --git a/src/simdec/sensitivity_indices.py b/src/simdec/sensitivity_indices.py
index ed9a07b..4e5e6ed 100644
--- a/src/simdec/sensitivity_indices.py
+++ b/src/simdec/sensitivity_indices.py
@@ -37,7 +37,9 @@ class SensitivityAnalysisResult:
def sensitivity_indices(
- inputs: pd.DataFrame | np.ndarray, output: pd.DataFrame | np.ndarray
+ inputs: pd.DataFrame | np.ndarray,
+ output: pd.DataFrame | np.ndarray,
+ print_indices: bool = False,
) -> SensitivityAnalysisResult:
"""Sensitivity indices.
@@ -50,6 +52,8 @@ def sensitivity_indices(
Input variables.
output : ndarray or DataFrame of shape (n_runs, 1)
Target variable.
+ print_indices : bool, default False
+ If True, displays computed indices.
Returns
-------
@@ -97,11 +101,18 @@ def sensitivity_indices(
"""
# Handle inputs conversion
if isinstance(inputs, pd.DataFrame):
- cat_columns = inputs.select_dtypes(["category", "O"]).columns
- inputs[cat_columns] = inputs[cat_columns].apply(
- lambda x: x.astype("category").cat.codes
- )
+ var_names = inputs.columns.tolist()
+ cat_cols = inputs.select_dtypes(["category", "O"]).columns
+ if not cat_cols.empty:
+ inputs = inputs.copy() # Avoid SettingWithCopyWarning
+ inputs[cat_cols] = inputs[cat_cols].apply(
+ lambda x: x.astype("category").cat.codes
+ )
inputs = inputs.to_numpy()
+ else:
+ inputs = np.asarray(inputs)
+ # Fallback names if it's just a numpy array
+ var_names = [f"x{i}" for i in range(inputs.shape[1])]
# Handle output conversion first, then flatten
if isinstance(output, (pd.DataFrame, pd.Series)):
@@ -181,4 +192,14 @@ def sensitivity_indices(
for k in range(n_factors):
si[k] = foe[k] + (soe[:, k].sum() / 2)
+ if print_indices:
+ df_foe = pd.DataFrame(foe, index=var_names, columns=["First-order effect"])
+ df_soe = pd.DataFrame(soe, index=var_names, columns=var_names)
+ df_si = pd.DataFrame(si, index=var_names, columns=["Combined effect"])
+
+ df_indices = pd.concat([df_foe, df_soe, df_si], axis=1)
+ print(f"{'-'*69}")
+ print(df_indices)
+ print(f"{'-'*69}")
+
return SensitivityAnalysisResult(si, foe, soe)
diff --git a/src/simdec/visualization.py b/src/simdec/visualization.py
index e77adf4..81459af 100644
--- a/src/simdec/visualization.py
+++ b/src/simdec/visualization.py
@@ -139,6 +139,8 @@ def visualization(
n_bins: str | int = "auto",
kind: Literal["histogram", "boxplot"] = "histogram",
ax=None,
+ print_legend: bool = False,
+ decomposition=None,
) -> plt.Axes:
"""Histogram plot of scenarios.
@@ -154,6 +156,10 @@ def visualization(
Histogram or Box Plot.
ax : Axes, optional
Matplotlib axis.
+ print_legend: Boolean, optional
+ Prints plot legend.
+ decomposition: Object, optional
+ Required for print_legend.
Returns
-------
@@ -186,6 +192,28 @@ def visualization(
)
else:
raise ValueError("'kind' can only be 'histogram' or 'boxplot'")
+
+ if print_legend:
+ from IPython.display import display
+
+ if decomposition is None:
+ import warnings
+
+ warnings.warn(
+ "print_legend=True requires the decomposition object. Table skipped."
+ )
+ else:
+ try:
+ _, styler = tableau(
+ var_names=decomposition.var_names,
+ statistic=decomposition.statistic,
+ states=decomposition.states,
+ bins=decomposition.bins,
+ palette=palette,
+ )
+ display(styler)
+ except ImportError:
+ pass
return ax
@@ -200,6 +228,8 @@ def two_output_visualization(
xlim: tuple[float, float] | None = None,
ylim: tuple[float, float] | None = None,
r_scatter: float = 1.0,
+ print_legend: bool = False,
+ decomposition=None,
) -> tuple[plt.Figure, np.ndarray]:
"""Two-output visualization.
@@ -229,6 +259,10 @@ def two_output_visualization(
Limits for the secondary output axis (scatter y / right histogram).
r_scatter : float, default 1.0
Fraction of data points shown in the scatter plot.
+ print_legend: Boolean, optional
+ Prints plot legend.
+ decomposition: Object, optional
+ Required for print_legend.
Returns
-------
@@ -286,6 +320,28 @@ def two_output_visualization(
axs[1, 1].axis("off")
fig.subplots_adjust(wspace=-0.015, hspace=0)
+
+ if print_legend:
+ from IPython.display import display
+
+ if decomposition is None:
+ import warnings
+
+ warnings.warn(
+ "print_legend=True requires the decomposition object. Table skipped."
+ )
+ else:
+ try:
+ _, styler = tableau(
+ var_names=decomposition.var_names,
+ statistic=decomposition.statistic,
+ states=decomposition.states,
+ bins=decomposition.bins,
+ palette=palette,
+ )
+ display(styler)
+ except ImportError:
+ pass
return fig, axs
diff --git a/tests/test_visualization.py b/tests/test_visualization.py
index a974ae0..cf35d5a 100644
--- a/tests/test_visualization.py
+++ b/tests/test_visualization.py
@@ -1,4 +1,6 @@
import pytest
+import pathlib
+import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import simdec as sd
@@ -62,3 +64,71 @@ def test_two_output_visualization_r_scatter():
bins=bins, bins2=bins2, palette=palette, r_scatter=0.5
)
assert isinstance(fig, plt.Figure)
+
+
+# Setup data path to match your decomposition tests
+path_data = pathlib.Path(__file__).parent / "data"
+
+
+@pytest.fixture
+def stress_results():
+ """Runs the actual decomposition to get a real result object."""
+ fname = path_data / "stress.csv"
+ data = pd.read_csv(fname)
+ output_name, *v_names = list(data.columns)
+ inputs, output = data[v_names], data[output_name]
+ si = np.array([0.04, 0.50, 0.11, 0.28])
+
+ res = sd.decomposition(
+ inputs=inputs, output=output, sensitivity_indices=si, dec_limit=1
+ )
+ return res
+
+
+def test_visualization_with_legend(stress_results):
+ """Verify visualization works with print_legend using live decomposition results."""
+ # Generate palette based on the live results
+ palette = sd.palette(stress_results.states)
+
+ # Test single visualization
+ ax = sd.visualization(
+ bins=stress_results.bins,
+ palette=palette,
+ print_legend=True,
+ decomposition=stress_results,
+ )
+
+ assert isinstance(ax, plt.Axes)
+ # Check that the columns were handled (RangeIndex is applied inside visualization)
+ assert isinstance(stress_results.bins.columns, pd.RangeIndex)
+
+
+def test_two_output_visualization_with_legend(stress_results):
+ """Verify two_output works with print_legend using live decomposition results."""
+ palette = sd.palette(stress_results.states)
+
+ # Using the same bins for both axes for testing purposes
+ fig, axs = sd.two_output_visualization(
+ bins=stress_results.bins,
+ bins2=stress_results.bins,
+ palette=palette,
+ print_legend=True,
+ decomposition=stress_results,
+ output_name="Primary",
+ output_name2="Secondary",
+ )
+
+ assert isinstance(fig, plt.Figure)
+ assert axs.shape == (2, 2)
+ assert axs[1, 0].get_xlabel() == "Primary"
+ assert axs[1, 0].get_ylabel() == "Secondary"
+
+
+def test_visualization_missing_decomposition_warning():
+ """Verify that omitting the decomposition object triggers a warning, not a crash."""
+ # Using small dummy data for a quick standalone check
+ bins = pd.DataFrame({"s1": [1, 2]})
+ pal = [[1, 0, 0, 1]]
+
+ with pytest.warns(UserWarning, match="requires the decomposition object"):
+ sd.visualization(bins=bins, palette=pal, print_legend=True, decomposition=None)