Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Naohnakazawa committed Mar 4, 2024
2 parents 5aeb353 + 6845440 commit 3dac8ce
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 13 deletions.
14 changes: 14 additions & 0 deletions docs/tutorials/visualization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ Plotters have two sets of options that customize their behavior and figure conte
and ``figure_options``, which have figure-specific parameters that control aspects of the
figure itself, such as axis labels and series colors.

To see the residual plot, set ``plot_residuals=True`` in the analysis options:

.. jupyter-execute::

# Set to ``True`` analysis option for residual plot
rabi.analysis.set_options(plot_residuals=True)

# Run experiment
rabi_data = rabi.run().block_for_results()
rabi_data.figure(0)


This option works for experiments without subplots in their figures.

Here is a more complicated experiment in which we customize the figure of a DRAG
experiment before it's run, so that we don't need to regenerate the figure like in
the previous example. First, we run the experiment without customizing the options
Expand Down
3 changes: 3 additions & 0 deletions qiskit_experiments/curve_analysis/base_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def _default_options(cls) -> Options:
the analysis result.
plot_raw_data (bool): Set ``True`` to draw processed data points,
dataset without formatting, on canvas. This is ``False`` by default.
plot_residuals (bool): Set ``True`` to draw the residuals data for the
fitting model. This is ``False`` by default.
plot (bool): Set ``True`` to create figure for fit result or ``False`` to
not create a figure. This overrides the behavior of ``generate_figures``.
return_fit_parameters (bool): (Deprecated) Set ``True`` to return all fit model parameters
Expand Down Expand Up @@ -207,6 +209,7 @@ def _default_options(cls) -> Options:

options.plotter = CurvePlotter(MplDrawer())
options.plot_raw_data = False
options.plot_residuals = False
options.return_fit_parameters = True
options.return_data_points = False
options.data_processor = None
Expand Down
184 changes: 177 additions & 7 deletions qiskit_experiments/curve_analysis/curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Dict, List, Tuple, Union, Optional
from functools import partial

from copy import deepcopy
import lmfit
import numpy as np
import pandas as pd
Expand All @@ -31,6 +32,7 @@
)
from qiskit_experiments.framework.containers import FigureType, ArtifactData
from qiskit_experiments.data_processing.exceptions import DataProcessorError
from qiskit_experiments.visualization import PlotStyle

from .base_curve_analysis import BaseCurveAnalysis, DATA_ENTRY_PREFIX, PARAMS_ENTRY_PREFIX
from .curve_data import FitOptions, CurveFitResult
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(

self._models = models or []
self._name = name or self.__class__.__name__
self._plot_config_cache = {}

@property
def name(self) -> str:
Expand All @@ -148,6 +151,118 @@ def model_names(self) -> List[str]:
"""Return model names."""
return [getattr(m, "_name", f"model-{i}") for i, m in enumerate(self._models)]

def set_options(self, **fields):
"""Set the analysis options for :meth:`run` method.
Args:
fields: The fields to update the options
Raises:
KeyError: When removed option ``curve_fitter`` is set.
"""
if fields.get("plot_residuals") and not self.options.get("plot_residuals"):
# checking there are no subplots for the figure to prevent collision in subplot indices.
if self.plotter.options.get("subplots") != (1, 1):
warnings.warn(
"Residuals plotting is currently supported for analysis with 1 subplot.",
UserWarning,
stacklevel=2,
)
fields["plot_residuals"] = False
else:
self._add_residuals_plot_config()
if not fields.get("plot_residuals", True) and self.options.get("plot_residuals"):
self._remove_residuals_plot_config()

super().set_options(**fields)

def _add_residuals_plot_config(self):
"""Configure plotter options for residuals plot."""
# check we have model to fit into
residual_plot_y_axis_size = 3
if self.models:
# Cache figure options.
self._plot_config_cache["figure_options"] = {}
self._plot_config_cache["figure_options"]["ylabel"] = self.plotter.figure_options.get(
"ylabel"
)
self._plot_config_cache["figure_options"]["series_params"] = deepcopy(
self.plotter.figure_options.get("series_params")
)
self._plot_config_cache["figure_options"]["sharey"] = self.plotter.figure_options.get(
"sharey"
)

self.plotter.set_figure_options(
ylabel=[
self.plotter.figure_options.get("ylabel", ""),
"Residuals",
],
)

model_names = self.model_names()
series_params = self.plotter.figure_options["series_params"]
for model_name in model_names:
if series_params.get(model_name):
series_params[model_name]["canvas"] = 0
else:
series_params[model_name] = {"canvas": 0}
series_params[model_name + "_residuals"] = series_params[model_name].copy()
series_params[model_name + "_residuals"]["canvas"] = 1
self.plotter.set_figure_options(sharey=False, series_params=series_params)

# Cache plotter options.
self._plot_config_cache["plotter"] = {}
self._plot_config_cache["plotter"]["subplots"] = self.plotter.options.get("subplots")
self._plot_config_cache["plotter"]["style"] = deepcopy(
self.plotter.options.get("style", PlotStyle({}))
)

# removing the name from the plotter style, so it will not clash with the new name
previous_plotter_style = self._plot_config_cache["plotter"]["style"].copy()
previous_plotter_style.pop("style_name", "")

# creating new fig size based on previous size
new_figsize = self.plotter.drawer.options.get("figsize", (8, 5))
new_figsize = (new_figsize[0], new_figsize[1] + residual_plot_y_axis_size)

# Here add the configuration for the residuals plot:
self.plotter.set_options(
subplots=(2, 1),
style=PlotStyle.merge(
PlotStyle(
{
"figsize": new_figsize,
"textbox_rel_pos": (0.28, -0.10),
"sub_plot_heights_list": [7 / 10, 3 / 10],
"sub_plot_widths_list": [1],
"style_name": "residuals",
}
),
previous_plotter_style,
),
)

def _remove_residuals_plot_config(self):
"""set options for a single plot to its cached values."""
if self.models:
self.plotter.set_figure_options(
ylabel=self._plot_config_cache["figure_options"]["ylabel"],
sharey=self._plot_config_cache["figure_options"]["sharey"],
series_params=self._plot_config_cache["figure_options"]["series_params"],
)

# Here add the style_name so the plotter will know not to print the residual data.
self.plotter.set_options(
subplots=self._plot_config_cache["plotter"]["subplots"],
style=PlotStyle.merge(
self._plot_config_cache["plotter"]["style"],
PlotStyle({"style_name": "canceled_residuals"}),
),
)

self._plot_config_cache = {}

def _run_data_processing(
self,
raw_data: List[Dict],
Expand Down Expand Up @@ -335,8 +450,13 @@ def _run_curve_fit(
fit_options = [fit_options]

# Create convenient function to compute residual of the models.
partial_residuals = []
partial_weighted_residuals = []
valid_uncertainty = np.all(np.isfinite(curve_data.y_err))

# creating storage for residual plotting
if self.options.get("plot_residuals"):
residual_weights_list = []

for idx, sub_data in curve_data.iter_by_series_id():
if valid_uncertainty:
nonzero_yerr = np.where(
Expand All @@ -350,16 +470,23 @@ def _run_curve_fit(
# some yerr values might be very close to zero, yielding significant weights.
# With such outlier, the fit doesn't sense residual of other data points.
maximum_weight = np.percentile(raw_weights, 90)
weights = np.clip(raw_weights, 0.0, maximum_weight)
weights_list = np.clip(raw_weights, 0.0, maximum_weight)
else:
weights = None
model_residual = partial(
weights_list = None
model_weighted_residual = partial(
self._models[idx]._residual,
data=sub_data.y,
weights=weights,
weights=weights_list,
x=sub_data.x,
)
partial_residuals.append(model_residual)
partial_weighted_residuals.append(model_weighted_residual)

# adding weights to weights_list for residuals
if self.options.get("plot_residuals"):
if weights_list is None:
residual_weights_list.append(None)
else:
residual_weights_list.append(weights_list)

# Run fit for each configuration
res = None
Expand All @@ -379,7 +506,7 @@ def _run_curve_fit(
try:
with np.errstate(all="ignore"):
new = lmfit.minimize(
fcn=lambda x: np.concatenate([p(x) for p in partial_residuals]),
fcn=lambda x: np.concatenate([p(x) for p in partial_weighted_residuals]),
params=guess_params,
method=self.options.fit_method,
scale_covar=not valid_uncertainty,
Expand All @@ -396,11 +523,30 @@ def _run_curve_fit(
if new.success and res.redchi > new.redchi:
res = new

# if `plot_residuals` is ``False`` I would like the `residuals_model` be None to emphasize it
# wasn't calculated.
residuals_model = [] if self.options.get("plot_residuals") else None
if res and res.success and self.options.get("plot_residuals"):
for weights in residual_weights_list:
if weights is None:
residuals_model.append(res.residual)
else:
residuals_model.append(
[
weighted_res / np.abs(weight)
for weighted_res, weight in zip(res.residual, weights)
]
)

if residuals_model is not None:
residuals_model = np.array(residuals_model)

return convert_lmfit_result(
res,
self._models,
curve_data.x,
curve_data.y,
residuals_model,
)

def _create_figures(
Expand Down Expand Up @@ -449,6 +595,14 @@ def _create_figures(
y_interp_err=fit_stdev,
)

if self.options.get("plot_residuals"):
residuals_data = sub_data.filter(category="residuals")
self.plotter.set_series_data(
series_name=model_name,
x_residuals=residuals_data.x,
y_residuals=residuals_data.y,
)

return [self.plotter.figure()]

def _run_analysis(
Expand Down Expand Up @@ -526,6 +680,22 @@ def _run_analysis(
category="fitted",
analysis=self.name,
)

if self.options.get("plot_residuals"):
# need to add here the residuals plot.
xval_residual = sub_data.x
yval_residuals = unp.nominal_values(fit_data.residuals[series_id])

for xval, yval in zip(xval_residual, yval_residuals):
table.add_row(
xval=xval,
yval=yval,
series_name=model_names[series_id],
series_id=series_id,
category="residuals",
analysis=self.name,
)

result_data.extend(
self._create_analysis_results(
fit_data=fit_data,
Expand Down
6 changes: 6 additions & 0 deletions qiskit_experiments/curve_analysis/curve_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def __init__(
var_names: Optional[List[str]] = None,
x_data: Optional[np.ndarray] = None,
y_data: Optional[np.ndarray] = None,
weighted_residuals: Optional[np.ndarray] = None,
residuals: Optional[np.ndarray] = None,
covar: Optional[np.ndarray] = None,
):
"""Create new Qiskit curve analysis result object.
Expand All @@ -188,6 +190,8 @@ def __init__(
var_names: Name of variables, i.e. fixed parameters are excluded from the list.
x_data: X values used for the fitting.
y_data: Y values used for the fitting.
weighted_residuals: The residuals from the fitting after assigning weights for each ydata.
residuals: residuals of the fitted model.
covar: Covariance matrix of fitting variables.
"""
self.method = method
Expand All @@ -205,6 +209,8 @@ def __init__(
self.var_names = var_names
self.x_data = x_data
self.y_data = y_data
self.weighted_residuals = weighted_residuals
self.residuals = residuals
self.covar = covar

@property
Expand Down
4 changes: 4 additions & 0 deletions qiskit_experiments/curve_analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def convert_lmfit_result(
models: List[lmfit.Model],
xdata: np.ndarray,
ydata: np.ndarray,
residuals: np.ndarray,
) -> CurveFitResult:
"""A helper function to convert LMFIT ``MinimizerResult`` into :class:`.CurveFitResult`.
Expand All @@ -128,6 +129,7 @@ def convert_lmfit_result(
models: Model used for the fitting. Function description is extracted.
xdata: X values used for the fitting.
ydata: Y values used for the fitting.
residuals: The residuals of the ydata from the model.
Returns:
QiskitExperiments :class:`.CurveFitResult` object.
Expand Down Expand Up @@ -169,6 +171,8 @@ def convert_lmfit_result(
var_names=result.var_names,
x_data=xdata,
y_data=ydata,
weighted_residuals=result.residual,
residuals=residuals,
covar=covar,
)

Expand Down
Loading

0 comments on commit 3dac8ce

Please sign in to comment.