diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index 7f0205b5..3fa3f319 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: true matrix: - platform: [windows-latest, ubuntu-latest, macos-13, macos-14] + platform: [windows-2022, ubuntu-latest, macos-13, macos-14] env: CIBW_SKIP: 'pp*' CIBW_ARCHS: 'auto64' @@ -34,9 +34,9 @@ jobs: - name: Install OMP (MacOS Intel) if: matrix.platform == 'macos-13' run: | - brew install llvm libomp - echo "export CC=/usr/local/opt/llvm/bin/clang" >> ~/.bashrc - echo "export CXX=/usr/local/opt/llvm/bin/clang++" >> ~/.bashrc + brew install llvm@20 libomp + echo "export CC=/usr/local/opt/llvm@20/bin/clang" >> ~/.bashrc + echo "export CXX=/usr/local/opt/llvm@20/bin/clang++" >> ~/.bashrc echo "export CFLAGS=\"$CFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export CXXFLAGS=\"$CXXFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp\"" >> ~/.bashrc @@ -44,9 +44,9 @@ jobs: - name: Install OMP (MacOS M1) if: matrix.platform == 'macos-14' run: | - brew install llvm libomp - echo "export CC=/opt/homebrew/opt/llvm/bin/clang" >> ~/.bashrc - echo "export CXX=/opt/homebrew/opt/llvm/bin/clang++" >> ~/.bashrc + brew install llvm@20 libomp + echo "export CC=/opt/homebrew/opt/llvm@20/bin/clang" >> ~/.bashrc + echo "export CXX=/opt/homebrew/opt/llvm@20/bin/clang++" >> ~/.bashrc echo "export CFLAGS=\"$CFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc echo "export CXXFLAGS=\"$CXXFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/opt/homebrew/opt/libomp/lib -L/opt/homebrew/opt/libomp/lib -lomp\"" >> ~/.bashrc diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index a2de2b66..6936d3ca 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -19,7 +19,7 @@ jobs: strategy: fail-fast: false matrix: - platform: [windows-latest, ubuntu-latest, macos-13, macos-14] + platform: [windows-2022, ubuntu-latest, macos-13, macos-14] version: ["3.10", "3.13"] defaults: run: @@ -38,9 +38,9 @@ jobs: - name: Install OMP (MacOS Intel) if: matrix.platform == 'macos-13' run: | - brew install llvm libomp - echo "export CC=/usr/local/opt/llvm/bin/clang" >> ~/.bashrc - echo "export CXX=/usr/local/opt/llvm/bin/clang++" >> ~/.bashrc + brew install llvm@20 libomp + echo "export CC=/usr/local/opt/llvm@20/bin/clang" >> ~/.bashrc + echo "export CXX=/usr/local/opt/llvm@20/bin/clang++" >> ~/.bashrc echo "export CFLAGS=\"$CFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export CXXFLAGS=\"$CXXFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp\"" >> ~/.bashrc @@ -48,9 +48,9 @@ jobs: - name: Install OMP (MacOS M1) if: matrix.platform == 'macos-14' run: | - brew install llvm libomp - echo "export CC=/opt/homebrew/opt/llvm/bin/clang" >> ~/.bashrc - echo "export CXX=/opt/homebrew/opt/llvm/bin/clang++" >> ~/.bashrc + brew install llvm@20 libomp + echo "export CC=/opt/homebrew/opt/llvm@20/bin/clang" >> ~/.bashrc + echo "export CXX=/opt/homebrew/opt/llvm@20/bin/clang++" >> ~/.bashrc echo "export CFLAGS=\"$CFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc echo "export CXXFLAGS=\"$CXXFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/opt/homebrew/opt/libomp/lib -L/opt/homebrew/opt/libomp/lib -lomp\"" >> ~/.bashrc diff --git a/pyproject.toml b/pyproject.toml index 35febb57..d52f72a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,4 +38,5 @@ mark-parentheses = false [tool.ruff.lint.pydocstyle] convention = "numpy" - +[tool.ruff.lint.isort] +known-first-party = ["ratapi.rat_core"] diff --git a/ratapi/controls.py b/ratapi/controls.py index c5408411..ea62f5d2 100644 --- a/ratapi/controls.py +++ b/ratapi/controls.py @@ -233,19 +233,16 @@ def delete_IPC(self): os.remove(self._IPCFilePath) return None - def save(self, path: Union[str, Path], filename: str = "controls"): + def save(self, filepath: Union[str, Path] = "./controls.json"): """Save a controls object to a JSON file. Parameters ---------- - path : str or Path - The directory in which the controls object will be written. - filename : str - The name for the JSON file containing the controls object. - + filepath : str or Path + The path to where the controls file will be written. """ - file = Path(path, f"{filename.removesuffix('.json')}.json") - file.write_text(self.model_dump_json()) + filepath = Path(filepath).with_suffix(".json") + filepath.write_text(self.model_dump_json()) @classmethod def load(cls, path: Union[str, Path]) -> "Controls": diff --git a/ratapi/examples/domains/domains_XY_model.py b/ratapi/examples/domains/domains_XY_model.py index 8aeb8c77..00567666 100644 --- a/ratapi/examples/domains/domains_XY_model.py +++ b/ratapi/examples/domains/domains_XY_model.py @@ -1,8 +1,9 @@ """Custom model file for the domains custom XY example.""" -import math +from math import sqrt import numpy as np +from scipy.special import erf def domains_XY_model(params, bulk_in, bulk_out, contrast, domain): @@ -19,13 +20,13 @@ def domains_XY_model(params, bulk_in, bulk_out, contrast, domain): z = np.arange(0, 141) # Make the volume fraction distribution for our Silicon substrate - [vfSilicon, siSurf] = makeLayer(z, -25, 50, 1, subRough, subRough) + [vfSilicon, siSurf] = make_layer(z, -25, 50, 1, subRough, subRough) # ... and the Oxide ... - [vfOxide, oxSurface] = makeLayer(z, siSurf, oxideThick, 1, subRough, subRough) + [vfOxide, oxSurface] = make_layer(z, siSurf, oxideThick, 1, subRough, subRough) # ... and also our layer. - [vfLayer, laySurface] = makeLayer(z, oxSurface, layerThick, 1, subRough, layerRough) + [vfLayer, laySurface] = make_layer(z, oxSurface, layerThick, 1, subRough, layerRough) # Everything that is not already occupied will be filled will water totalVF = vfSilicon + vfOxide + vfLayer @@ -53,7 +54,7 @@ def domains_XY_model(params, bulk_in, bulk_out, contrast, domain): return SLD, subRough -def makeLayer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): +def make_layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): """Produce a layer, with a defined thickness, height and roughness. Each side of the layer has its own roughness value. @@ -63,12 +64,9 @@ def makeLayer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): right = prevLaySurf + thickness # Make our heaviside - a = (z - left) / ((2**0.5) * Sigma_L) - b = (z - right) / ((2**0.5) * Sigma_R) + erf_left = erf((z - left) / (sqrt(2) * Sigma_L)) + erf_right = erf((z - right) / (sqrt(2) * Sigma_R)) - erf_a = np.array([math.erf(value) for value in a]) - erf_b = np.array([math.erf(value) for value in b]) - - VF = np.array((height / 2) * (erf_a - erf_b)) + VF = np.array((0.5 * height) * (erf_left - erf_right)) return VF, right diff --git a/ratapi/examples/normal_reflectivity/custom_XY_DSPC.py b/ratapi/examples/normal_reflectivity/custom_XY_DSPC.py index dc1d1013..93e25b08 100644 --- a/ratapi/examples/normal_reflectivity/custom_XY_DSPC.py +++ b/ratapi/examples/normal_reflectivity/custom_XY_DSPC.py @@ -1,8 +1,9 @@ """A custom XY model for a supported DSPC bilayer.""" -import math +from math import sqrt import numpy as np +from scipy.special import erf def custom_XY_DSPC(params, bulk_in, bulk_out, contrast): @@ -51,10 +52,10 @@ def custom_XY_DSPC(params, bulk_in, bulk_out, contrast): z = np.arange(0, 141) # Make our Silicon substrate - vfSilicon, siSurf = layer(z, -25, 50, 1, subRough, subRough) + vfSilicon, siSurf = make_layer(z, -25, 50, 1, subRough, subRough) # Add the Oxide - vfOxide, oxSurface = layer(z, siSurf, oxideThick, 1, subRough, subRough) + vfOxide, oxSurface = make_layer(z, siSurf, oxideThick, 1, subRough, subRough) # We fill in the water at the end, but there may be a hydration layer between the bilayer and the oxide, # so we start the bilayer stack an appropriate distance away @@ -65,15 +66,15 @@ def custom_XY_DSPC(params, bulk_in, bulk_out, contrast): headThick = vHead / lipidAPM # ... and make a box for the volume fraction (1 for now, we correct for coverage later) - vfHeadL, headLSurface = layer(z, watSurface, headThick, 1, bilayerRough, bilayerRough) + vfHeadL, headLSurface = make_layer(z, watSurface, headThick, 1, bilayerRough, bilayerRough) # ... also do the same for the tails # We'll make both together, so the thickness will be twice the volume tailsThick = (2 * vTail) / lipidAPM - vfTails, tailsSurf = layer(z, headLSurface, tailsThick, 1, bilayerRough, bilayerRough) + vfTails, tailsSurf = make_layer(z, headLSurface, tailsThick, 1, bilayerRough, bilayerRough) # Finally the upper head ... - vfHeadR, headSurface = layer(z, tailsSurf, headThick, 1, bilayerRough, bilayerRough) + vfHeadR, headSurface = make_layer(z, tailsSurf, headThick, 1, bilayerRough, bilayerRough) # Making the model # We've created the volume fraction profiles corresponding to each of the groups. @@ -114,12 +115,12 @@ def custom_XY_DSPC(params, bulk_in, bulk_out, contrast): totSLD = sldSilicon + sldOxide + sldHeadL + sldTails + sldHeadR + sldWat # Make the SLD array for output - SLD = [[a, b] for (a, b) in zip(z, totSLD)] + SLD = np.column_stack((z, totSLD)) return SLD, subRough -def layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): +def make_layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): """Produce a layer, with a defined thickness, height and roughness. Each side of the layer has its own roughness value. @@ -129,12 +130,9 @@ def layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): right = prevLaySurf + thickness # Make our heaviside - a = (z - left) / ((2**0.5) * Sigma_L) - b = (z - right) / ((2**0.5) * Sigma_R) + erf_left = erf((z - left) / (sqrt(2) * Sigma_L)) + erf_right = erf((z - right) / (sqrt(2) * Sigma_R)) - erf_a = np.array([math.erf(value) for value in a]) - erf_b = np.array([math.erf(value) for value in b]) - - VF = np.array((height / 2) * (erf_a - erf_b)) + VF = np.array((0.5 * height) * (erf_left - erf_right)) return VF, right diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index a6f2f557..c2823b8f 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -94,7 +94,7 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool def plot_ref_sld_helper( data: PlotEventData, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.pyplot.figure, delay: bool = True, confidence_intervals: Union[dict, None] = None, linear_x: bool = False, @@ -112,8 +112,8 @@ def plot_ref_sld_helper( data : PlotEventData The plot event data that contains all the information to generate the ref and sld plots - fig : matplotlib.pyplot.figure, optional - The figure class that has two subplots + fig : matplotlib.pyplot.figure + The figure object that has two subplots delay : bool, default: True Controls whether to delay 0.005s after plot is created confidence_intervals : dict or None, default None @@ -134,19 +134,13 @@ def plot_ref_sld_helper( animated : bool, default: False Controls whether the animated property of foreground plot elements should be set. - Returns - ------- - fig : matplotlib.pyplot.figure - The figure class that has two subplots - """ preserve_zoom = False - if fig is None: - fig = plt.subplots(1, 2)[0] - elif len(fig.axes) != 2: + if len(fig.axes) != 2: fig.clf() fig.subplots(1, 2) + fig.subplots_adjust(wspace=0.3) ref_plot: plt.Axes = fig.axes[0] @@ -233,13 +227,12 @@ def plot_ref_sld_helper( if delay: plt.pause(0.005) - return fig - def plot_ref_sld( project: ratapi.Project, results: Union[ratapi.outputs.Results, ratapi.outputs.BayesResults], block: bool = False, + fig: Optional[matplotlib.pyplot.figure] = None, return_fig: bool = False, bayes: Literal[65, 95, None] = None, linear_x: bool = False, @@ -259,6 +252,8 @@ def plot_ref_sld( The result from the calculation block : bool, default: False Indicates the plot should block until it is closed + fig : matplotlib.pyplot.figure, optional + The figure object that has two subplots return_fig : bool, default False If True, return the figure instead of displaying it. bayes : 65, 95 or None, default None @@ -336,11 +331,15 @@ def plot_ref_sld( else: confidence_intervals = None - figure = plt.subplots(1, 2)[0] + if fig is None: + fig = plt.subplots(1, 2)[0] + elif len(fig.axes) != 2: + fig.clf() + fig.subplots(1, 2) plot_ref_sld_helper( data, - figure, + fig, confidence_intervals=confidence_intervals, linear_x=linear_x, q4=q4, @@ -351,7 +350,7 @@ def plot_ref_sld( ) if return_fig: - return figure + return fig plt.show(block=block) @@ -486,7 +485,7 @@ def update_plot(self, data): """ if self.figure is not None: self.figure.clf() - self.figure = ratapi.plotting.plot_ref_sld_helper( + plot_ref_sld_helper( data, self.figure, linear_x=self.linear_x, @@ -520,7 +519,7 @@ def update_foreground(self, data): """ self.set_animated(True) self.figure.canvas.restore_region(self.bg) - plot_data = ratapi.plotting._extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value) + plot_data = _extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value) offset = 2 if self.show_error_bar else 1 for i in range( @@ -649,9 +648,11 @@ def plot_corner( params: Union[list[Union[int, str]], None] = None, smooth: bool = True, block: bool = False, + fig: Optional[matplotlib.pyplot.figure] = None, return_fig: bool = False, hist_kwargs: Union[dict, None] = None, hist2d_kwargs: Union[dict, None] = None, + progress_callback: Union[Callable[[int, int], None], None] = None, ): """Create a corner plot from a Bayesian analysis. @@ -666,6 +667,8 @@ def plot_corner( Whether to apply Gaussian smoothing to the corner plot. block : bool, default False Whether Python should block until the plot is closed. + fig : matplotlib.pyplot.figure, optional + The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. hist_kwargs : dict @@ -674,6 +677,9 @@ def plot_corner( hist2d_kwargs : dict Extra keyword arguments to pass to the 2d histograms. Default is {'density': True, 'bins': 25} + progress_callback: Union[Callable[[int, int], None], None] + Callback function for providing progress during plot creation + First argument is current completed sub plot and second is total number of sub plots Returns ------- @@ -695,24 +701,32 @@ def plot_corner( hist2d_kwargs = {} num_params = len(params) + total_count = num_params + (num_params**2 - num_params) // 2 + + if fig is None: + fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10), subplot_kw={"visible": False}) + else: + fig.clf() + axes = fig.subplots(num_params, num_params, subplot_kw={"visible": False}) - fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10)) # i is row, j is column - for i, row_param in enumerate(params): - for j, col_param in enumerate(params): - current_axes: Axes = axes[i][j] + current_count = 0 + for i in range(num_params): + for j in range(i + 1): + row_param = params[i] + col_param = params[j] + current_axes: Axes = axes if isinstance(axes, matplotlib.axes.Axes) else axes[i][j] current_axes.tick_params(which="both", labelsize="medium") current_axes.xaxis.offsetText.set_fontsize("small") current_axes.yaxis.offsetText.set_fontsize("small") - + current_axes.set_visible(True) if i == j: # diagonal: histograms plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs) elif i > j: # lower triangle: 2d histograms plot_contour( results, x_param=col_param, y_param=row_param, smooth=smooth, axes=current_axes, **hist2d_kwargs ) - elif i < j: # upper triangle: no plot - current_axes.set_visible(False) + # remove label if on inside of corner plot if j != 0: current_axes.get_yaxis().set_visible(False) @@ -725,6 +739,9 @@ def plot_corner( current_axes.yaxis.offset_text_position = "center" current_axes.set_ylabel("") current_axes.set_xlabel("") + if progress_callback is not None: + current_count += 1 + progress_callback(current_count, total_count) if return_fig: return fig plt.show(block=block) @@ -956,7 +973,9 @@ def plot_contour( plt.show(block=block) -def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.figure.Figure: +def panel_plot_helper( + plot_func: Callable, indices: list[int], fig: Optional[matplotlib.pyplot.figure] = None +) -> matplotlib.figure.Figure: """Generate a panel-based plot from a single plot function. Parameters @@ -965,6 +984,8 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig A function which plots one parameter on an Axes object, given its index. indices : list[int] The list of indices to pass into ``plot_func``. + fig : matplotlib.pyplot.figure, optional + The figure object to use for plot. Returns ------- @@ -974,10 +995,18 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig """ nplots = len(indices) nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots)) - fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0] + + if fig is None: + fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0] + else: + fig.clf() + fig.subplots(nrows, ncols) axs = fig.get_axes() for plot_num, index in enumerate(indices): + axs[plot_num].tick_params(which="both", labelsize="medium") + axs[plot_num].xaxis.offsetText.set_fontsize("small") + axs[plot_num].yaxis.offsetText.set_fontsize("small") plot_func(axs[plot_num], index) # blank unused plots @@ -998,6 +1027,7 @@ def plot_hists( dict[Literal["normal", "lognor", "kernel", None]], Literal["normal", "lognor", "kernel", None] ] = None, block: bool = False, + fig: Optional[matplotlib.pyplot.figure] = None, return_fig: bool = False, **hist_settings, ): @@ -1031,6 +1061,8 @@ def plot_hists( e.g. to apply 'normal' to all unset parameters, set `estimated_density = {'default': 'normal'}`. block : bool, default False Whether Python should block until the plot is closed. + fig : matplotlib.pyplot.figure, optional + The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. hist_settings : @@ -1090,6 +1122,7 @@ def validate_dens_type(dens_type: Union[str, None], param: str): **hist_settings, ), params, + fig, ) if return_fig: return fig @@ -1102,6 +1135,7 @@ def plot_chain( params: Union[list[Union[int, str]], None] = None, maxpoints: int = 15000, block: bool = False, + fig: Optional[matplotlib.pyplot.figure] = None, return_fig: bool = False, ): """Plot the MCMC chain for each parameter of a Bayesian analysis. @@ -1117,6 +1151,8 @@ def plot_chain( The maximum number of points to plot for each parameter. block : bool, default False Whether Python should block until the plot is closed. + fig : matplotlib.pyplot.figure, optional + The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. @@ -1127,7 +1163,7 @@ def plot_chain( """ chain = results.chain - nsimulations, nplots = chain.shape + nsimulations, _ = chain.shape # skip is to evenly distribute points plotted # all points will be plotted if maxpoints < nsimulations skip = max(floor(nsimulations / maxpoints), 1) @@ -1142,9 +1178,9 @@ def plot_chain( def plot_one_chain(axes: Axes, i: int): axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip]) - axes.set_title(results.fitNames[i]) + axes.set_title(results.fitNames[i], fontsize="small") - fig = panel_plot_helper(plot_one_chain, params) + fig = panel_plot_helper(plot_one_chain, params, fig=fig) if return_fig: return fig plt.show(block=block) diff --git a/setup.py b/setup.py index b5871644..4c996362 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ class BuildExt(build_ext): """A custom build extension for adding compiler-specific options.""" c_opts = { - "msvc": ["/O2", "/EHsc"], + "msvc": ["/O2", "/EHsc", "/openmp"], "unix": ["-O2", "-fopenmp", "-std=c++11"], } l_opts = { diff --git a/tests/test_controls.py b/tests/test_controls.py index 72f0c745..61d68331 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -3,6 +3,7 @@ import contextlib import os import tempfile +from pathlib import Path from typing import Any, Union import pydantic @@ -45,6 +46,30 @@ def test_extra_property_error() -> None: controls.test = 1 +@pytest.mark.parametrize( + "inputs", + [ + {"parallel": Parallel.Contrasts, "resampleMinAngle": 0.66}, + {"procedure": "simplex"}, + {"procedure": "dream", "nSamples": 504, "nChains": 1200}, + {"procedure": "de", "crossoverProbability": 0.45, "strategy": Strategies.RandomEitherOrAlgorithm}, + {"procedure": "ns", "nMCMC": 4, "propScale": 0.6}, + ], +) +def test_save_load(inputs): + """Test that saving and loading controls returns the same controls.""" + + original_controls = Controls(**inputs) + with tempfile.TemporaryDirectory() as tmp: + # ignore relative path warnings + path = Path(tmp, "controls.json") + original_controls.save(path) + converted_controls = Controls.load(path) + + for field in Controls.model_fields: + assert getattr(converted_controls, field) == getattr(original_controls, field) + + class TestCalculate: """Tests the Calculate class.""" diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c42bfeea..222d5142 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -50,7 +50,8 @@ def fig(request) -> plt.figure: """Creates the fixture for the tests.""" plt.close("all") figure = plt.subplots(1, 2)[0] - return RATplot.plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data()) + RATplot.plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data()) + return figure @pytest.fixture @@ -68,7 +69,8 @@ def bayes_fig(request) -> plt.figure: for sld in dat.sldProfiles ], } - return RATplot.plot_ref_sld_helper(data=dat, fig=figure, confidence_intervals=confidence_intervals) + RATplot.plot_ref_sld_helper(data=dat, fig=figure, confidence_intervals=confidence_intervals) + return figure @pytest.mark.parametrize("fig", [False, True], indirect=True) @@ -120,8 +122,7 @@ def test_ref_sld_color_formatting(fig: plt.figure) -> None: assert sld_plot.get_lines()[i].get_color() == sld_plot.get_lines()[i + 1].get_color() -@pytest.mark.parametrize("bayes", [65, 95]) -def test_ref_sld_bayes(fig, bayes_fig, bayes): +def test_ref_sld_bayes(fig, bayes_fig): """Test that shading is correctly added to the figure when confidence intervals are supplied.""" # the shading is of type PolyCollection for axes in fig.axes: @@ -137,7 +138,7 @@ def test_sld_profile_function_call(mock: MagicMock) -> None: """Tests the makeSLDProfile function called with correct args. """ - RATplot.plot_ref_sld_helper(data()) + RATplot.plot_ref_sld_helper(data(), plt.subplots(1, 2)[0]) assert mock.call_count == 3 assert mock.call_args_list[0].args[0] == 2.07e-06 @@ -211,9 +212,9 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r def test_ref_sld_subplot_correction(): """Test that if an incorrect number of subplots is corrected in the figure helper.""" fig = plt.subplots(1, 3)[0] - ref_sld_fig = RATplot.plot_ref_sld_helper(data=data(), fig=fig) - assert ref_sld_fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2) - assert len(ref_sld_fig.axes) == 2 + RATplot.plot_ref_sld_helper(data=data(), fig=fig) + assert fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2) + assert len(fig.axes) == 2 @patch("ratapi.utils.plotting.plot_ref_sld_helper")