diff --git a/qiskit/tools/visualization/bloch.py b/qiskit/tools/visualization/_bloch.py similarity index 98% rename from qiskit/tools/visualization/bloch.py rename to qiskit/tools/visualization/_bloch.py index 1f83cfd9509..cf9aa10d530 100644 --- a/qiskit/tools/visualization/bloch.py +++ b/qiskit/tools/visualization/_bloch.py @@ -126,7 +126,13 @@ def __init__(self, fig=None, axes=None, view=None, figsize=None, background=False): # Figure and axes + self._ext_fig = False + if fig is not None: + self._ext_fig = True self.fig = fig + self._ext_axes = False + if axes is not None: + self._ext_axes = True self.axes = axes # Background axes, default = False self.background = background @@ -368,9 +374,9 @@ def make_sphere(self): """ Plots Bloch sphere and data sets. """ - self.render(self.fig, self.axes) + self.render() - def render(self, fig=None, axes=None, title=''): + def render(self, title=''): """ Render the Bloch sphere and its data sets in on given figure and axes. """ @@ -380,10 +386,10 @@ def render(self, fig=None, axes=None, title=''): self._rendered = True # Figure instance for Bloch sphere plot - if not fig: + if not self._ext_fig: self.fig = plt.figure(figsize=self.figsize) - if not axes: + if not self._ext_axes: self.axes = Axes3D(self.fig, azim=self.view[0], elev=self.view[1]) if self.background: @@ -580,7 +586,7 @@ def show(self, title=''): """ Display Bloch sphere and corresponding data sets. """ - self.render(self.fig, self.axes, title=title) + self.render(title=title) if self.fig: plt.show(self.fig) @@ -597,7 +603,7 @@ def save(self, name=None, output='png', dirc=None): Directory for output images. Defaults to current working directory. """ - self.render(self.fig, self.axes) + self.render() if dirc: if not os.path.isdir(os.getcwd() + "/" + str(dirc)): os.makedirs(os.getcwd() + "/" + str(dirc)) diff --git a/qiskit/tools/visualization/_state_visualization.py b/qiskit/tools/visualization/_state_visualization.py index 1ddcb867dc7..adc99076a6f 100644 --- a/qiskit/tools/visualization/_state_visualization.py +++ b/qiskit/tools/visualization/_state_visualization.py @@ -20,7 +20,7 @@ import numpy as np from qiskit.tools.qi.pauli import pauli_group, pauli_singles from qiskit.tools.visualization import VisualizationError -from qiskit.tools.visualization.bloch import Bloch +from qiskit.tools.visualization._bloch import Bloch class Arrow3D(FancyArrowPatch):