In [None]:
"""
Utility functions for plotting Bloch spheres on matplotlib axes.
Extracted and refactored from QuTiP's Bloch sphere implementation.
"""

import numpy as np
from numpy import cos, sin, outer, ones
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
from packaging.version import parse as parse_version
import matplotlib


class Arrow3D(FancyArrowPatch):
    """3D arrow patch for matplotlib."""
    
    def __init__(self, xs, ys, zs, *args, **kwargs):
        FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        FancyArrowPatch.draw(self, renderer)

    def do_3d_projection(self, renderer=None):
        # only called by matplotlib >= 3.5
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)


def state_to_cartesian(state):
    """
    Convert quantum state to Cartesian coordinates on Bloch sphere.
    
    Parameters
    ----------
    state : Qobj or array-like
        Quantum state object or expectation values [x, y, z]
        
    Returns
    -------
    list : [x, y, z] coordinates
    """
    # If it's already Cartesian coordinates
    if hasattr(state, '__len__') and len(state) == 3:
        return list(state)
    
    # If it's a QuTiP Qobj, compute expectation values
    # This would require QuTiP imports - placeholder for now
    try:
        from qutip import expect, sigmax, sigmay, sigmaz
        return [expect(sigmax(), state),
                expect(sigmay(), state), 
                expect(sigmaz(), state)]
    except ImportError:
        raise ValueError("Either provide [x,y,z] coordinates or install QuTiP for state conversion")


def setup_bloch_axes(ax, view=None, background=False):
    """
    Setup 3D axes for Bloch sphere plotting.
    
    Parameters
    ----------
    ax : Axes3D
        3D matplotlib axes
    view : list, optional
        [azimuth, elevation] viewing angles. Default [-60, 30]
    background : bool, optional
        Whether to show background grid. Default False
    """
    if view is None:
        view = [-60, 30]
    
    ax.view_init(elev=view[1], azim=view[0])
    ax.clear()
    ax.grid(False)
    
    if background:
        ax.set_xlim3d(-1.3, 1.3)
        ax.set_ylim3d(-1.3, 1.3)
        ax.set_zlim3d(-1.3, 1.3)
    else:
        ax.set_axis_off()
        ax.set_xlim3d(-0.7, 0.7)
        ax.set_ylim3d(-0.7, 0.7)
        ax.set_zlim3d(-0.7, 0.7)
    
    # Set aspect ratio for matplotlib >= 3.3
    if parse_version(matplotlib.__version__) >= parse_version('3.3'):
        ax.set_box_aspect((1, 1, 1))


def plot_bloch_sphere(ax, sphere_color='#FFDDDD', sphere_alpha=0.2, 
                     frame_color='gray', frame_alpha=0.2, frame_width=1):
    """
    Plot the basic Bloch sphere wireframe and surface.
    
    Parameters
    ----------
    ax : Axes3D
        3D matplotlib axes
    sphere_color : str, optional
        Color of sphere surface
    sphere_alpha : float, optional
        Transparency of sphere surface
    frame_color : str, optional
        Color of wireframe
    frame_alpha : float, optional
        Transparency of wireframe
    frame_width : int, optional
        Width of wireframe lines
    """
    # Back half of sphere
    u = np.linspace(0, np.pi, 25)
    v = np.linspace(0, np.pi, 25)
    w = np.linspace(-np.pi / 2, np.pi / 2, 25)
    x = outer(cos(u), sin(v))
    y = outer(sin(u), sin(v))
    z = outer(ones(np.size(u)), cos(v))
    
    ax.plot_surface(x, y, z, rstride=2, cstride=2,
                   color=sphere_color, linewidth=0, alpha=sphere_alpha)
    ax.plot_wireframe(x, y, z, rstride=5, cstride=5,
                     color=frame_color, alpha=frame_alpha)
    
    # Front half of sphere
    u = np.linspace(-np.pi, 0, 25)
    v = np.linspace(0, np.pi, 25)
    w = np.linspace(-3 * np.pi / 2, -np.pi / 2, 25)
    x = outer(cos(u), sin(v))
    y = outer(sin(u), sin(v))
    z = outer(ones(np.size(u)), cos(v))
    
    ax.plot_surface(x, y, z, rstride=2, cstride=2,
                   color=sphere_color, linewidth=0, alpha=sphere_alpha)
    ax.plot_wireframe(x, y, z, rstride=5, cstride=5,
                     color=frame_color, alpha=frame_alpha)
    
    # Equator circles
    # u_eq = np.linspace(0, 2*np.pi, 100)
    # ax.plot(1.0 * cos(u_eq), 1.0 * sin(u_eq), zs=0, zdir='z',
    #        lw=frame_width, color=frame_color)
    # w_eq = np.linspace(-np.pi / 2, np.pi / 2, 100)
    # ax.plot(1.0 * cos(w_eq), 1.0 * sin(w_eq), zs=0, zdir='x',
    #        lw=frame_width, color=frame_color)


def plot_bloch_axes(ax, frame_color='gray', frame_width=1):
    """
    Plot coordinate axes on Bloch sphere.
    
    Parameters
    ----------
    ax : Axes3D
        3D matplotlib axes
    frame_color : str, optional
        Color of axes lines
    frame_width : int, optional
        Width of axes lines
    """
    span = np.linspace(-1.0, 1.0, 2)
    ax.plot(span, 0 * span, zs=0, zdir='z', 
           lw=frame_width, color=frame_color, label='X')
    ax.plot(0 * span, span, zs=0, zdir='z',
           lw=frame_width, color=frame_color, label='Y')
    ax.plot(0 * span, span, zs=0, zdir='y',
           lw=frame_width, color=frame_color, label='Z')


def plot_bloch_labels(ax, xlabel=['$x$', ''], ylabel=['$y$', ''], 
                     zlabel=[r'$|0\rangle$', r'$|1\rangle$'],
                     xlpos=[1.2, -1.2], ylpos=[1.2, -1.2], zlpos=[1.2, -1.2],
                     font_size=20, font_color='black'):
    """
    Plot axis labels on Bloch sphere.
    
    Parameters
    ----------
    ax : Axes3D
        3D matplotlib axes
    xlabel, ylabel, zlabel : list
        Labels for positive and negative axes
    xlpos, ylpos, zlpos : list
        Positions for labels
    font_size : int, optional
        Font size for labels
    font_color : str, optional
        Font color for labels
    """
    opts = {'fontsize': font_size,
            'color': font_color,
            'horizontalalignment': 'center',
            'verticalalignment': 'center'}
    
    ax.text(0, -xlpos[0], 0, xlabel[0], **opts)
    ax.text(0, -xlpos[1], 0, xlabel[1], **opts)
    ax.text(ylpos[0], 0, 0, ylabel[0], **opts)
    ax.text(ylpos[1], 0, 0, ylabel[1], **opts)
    ax.text(0, 0, zlpos[0], zlabel[0], **opts)
    ax.text(0, 0, zlpos[1], zlabel[1], **opts)
    
    # Hide tick marks
    for axis in [ax.xaxis, ax.yaxis, ax.zaxis]:
        for a in (axis.get_ticklines() + axis.get_ticklabels()):
            a.set_visible(False)


def plot_bloch_vector(ax, vector, color=None, alpha=1.0, width=3, 
                     style='-|>', mutation=20):
    """
    Plot a vector on the Bloch sphere.
    
    Parameters
    ----------
    ax : Axes3D
        3D matplotlib axes
    vector : array-like
        3D vector [x, y, z]
    color : str, optional
        Vector color
    alpha : float, optional
        Vector transparency
    width : int, optional
        Vector line width
    style : str, optional
        Arrow style
    mutation : int, optional
        Arrow head size
    """
    if color is None:
        color = 'red'
    
    vec = np.asarray(vector)
    # Coordinate transformation: -X and Y data are switched for plotting
    xs3d = vec[1] * np.array([0, 1])
    ys3d = -vec[0] * np.array([0, 1])  
    zs3d = vec[2] * np.array([0, 1])
    
    arrow = Arrow3D(xs3d, ys3d, zs3d,
                   mutation_scale=mutation,
                   lw=width,
                   arrowstyle=style,
                   color=color, 
                   alpha=alpha)
    ax.add_artist(arrow)


def plot_bloch_points(ax, points, color=None, size=25, marker='o', alpha=1.0):
    """
    Plot points on the Bloch sphere.
    
    Parameters
    ----------
    ax : Axes3D
        3D matplotlib axes
    points : array-like
        Points as array of shape (3, N) or (N, 3)
    color : str or list, optional
        Point colors
    size : int or list, optional
        Point sizes
    marker : str, optional
        Point marker style
    alpha : float, optional
        Point transparency
    """
    points = np.asarray(points)
    if points.ndim == 1:
        points = points.reshape(3, 1)
    elif points.shape[0] != 3:
        points = points.T
        
    if color is None:
        color = 'blue'
    
    # Coordinate transformation: -X and Y data are switched for plotting
    ax.scatter(np.real(points[1]), -np.real(points[0]), np.real(points[2]),
              s=size, marker=marker, c=color, alpha=alpha, edgecolor=None)


def plot_bloch_trajectory(ax, points, color=None, alpha=1.0, linewidth=2):
    """
    Plot a trajectory (connected points) on the Bloch sphere.
    
    Parameters
    ----------
    ax : Axes3D
        3D matplotlib axes
    points : array-like
        Points as array of shape (3, N) or (N, 3)
    color : str, optional
        Line color
    alpha : float, optional
        Line transparency
    linewidth : float, optional
        Line width
    """
    points = np.asarray(points)
    if points.ndim == 1:
        points = points.reshape(3, 1)
    elif points.shape[0] != 3:
        points = points.T
        
    if color is None:
        color = 'blue'
    
    # Coordinate transformation: -X and Y data are switched for plotting
    ax.plot(np.real(points[1]), -np.real(points[0]), np.real(points[2]),
           color=color, alpha=alpha, linewidth=linewidth)


def plot_bloch_annotation(ax, position, text, font_size=20, font_color='black', **kwargs):
    """
    Add text annotation to Bloch sphere.
    
    Parameters
    ----------
    ax : Axes3D
        3D matplotlib axes
    position : array-like
        3D position [x, y, z] for annotation
    text : str
        Annotation text
    font_size : int, optional
        Font size
    font_color : str, optional
        Font color
    **kwargs
        Additional text formatting options
    """
    vec = np.asarray(position)
    opts = {'fontsize': font_size,
            'color': font_color,
            'horizontalalignment': 'center',
            'verticalalignment': 'center'}
    opts.update(kwargs)
    
    # Coordinate transformation: -X and Y data are switched for plotting
    ax.text(vec[1], -vec[0], vec[2], text, **opts)


def create_bloch_sphere(ax=None, view=None, background=False, **kwargs):
    """
    Create a complete Bloch sphere with standard appearance.
    
    Parameters
    ----------
    ax : Axes3D, optional
        3D matplotlib axes. If None, creates new figure
    view : list, optional
        [azimuth, elevation] viewing angles
    background : bool, optional
        Whether to show background grid
    **kwargs
        Additional styling options for sphere appearance
        
    Returns
    -------
    ax : Axes3D
        The 3D axes with Bloch sphere
    """
    if ax is None:
        fig = plt.figure(figsize=(8, 8))
        ax = fig.add_subplot(111, projection='3d')
    
    setup_bloch_axes(ax, view=view, background=background)
    plot_bloch_sphere(ax, **kwargs)
    if not background:
        plot_bloch_axes(ax)
        plot_bloch_labels(ax)
    
    return ax


# Convenience function for quick plotting
def quick_bloch_plot(vectors=None, points=None, states=None, figsize=(8, 8), 
                    view=None, save_path=None, **kwargs):
    """
    Quick function to create and display a Bloch sphere with data.
    
    Parameters
    ----------
    vectors : array-like, optional
        Vectors to plot as arrows
    points : array-like, optional  
        Points to scatter plot
    states : array-like, optional
        Quantum states (requires QuTiP)
    figsize : tuple, optional
        Figure size
    view : list, optional
        Viewing angles [azimuth, elevation]
    save_path : str, optional
        Path to save figure
    **kwargs
        Additional styling options
        
    Returns
    -------
    fig, ax : matplotlib figure and axes
    """
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    
    create_bloch_sphere(ax, view=view, **kwargs)
    
    if states is not None:
        states = np.asarray(states)
        if states.ndim == 1:
            states = [states]
        for state in states:
            coords = state_to_cartesian(state)
            plot_bloch_vector(ax, coords)
    
    if vectors is not None:
        vectors = np.asarray(vectors)
        if vectors.ndim == 1:
            vectors = [vectors]
        for vec in vectors:
            plot_bloch_vector(ax, vec)
    
    if points is not None:
        points = np.asarray(points)
        plot_bloch_points(ax, points)
    
    if save_path:
        fig.savefig(save_path, dpi=150, bbox_inches='tight')
    
    return fig, ax

In [None]:
import torch
import qutip

trajectory = torch.load("../simulation/asset/test_trajectory.pt").numpy() * 2

fig, ax = plt.subplots(1,1, figsize=(5,5),subplot_kw={"projection": "3d"})
plot_bloch_sphere(ax)
plot_bloch_trajectory(ax, trajectory)

ax.axis("off")
fig.tight_layout()



In [3]:
import numpy as np
import qutip

def qubit_integrate(w, theta, gamma1, gamma2, psi0, tlist):
    # operators and the hamiltonian
    sx = qutip.sigmax()
    sy = qutip.sigmay()
    sz = qutip.sigmaz()
    sm = qutip.sigmam()
    H = w * (np.cos(theta) * sz + np.sin(theta) * sx)
    # collapse operators
    c_op_list = []
    n_th = 0.5 # temperature
    rate = gamma1 * (n_th + 1)
    if rate > 0.0: c_op_list.append(np.sqrt(rate) * sm)
    rate = gamma1 * n_th
    if rate > 0.0: c_op_list.append(np.sqrt(rate) * sm.dag())
    rate = gamma2
    if rate > 0.0: c_op_list.append(np.sqrt(rate) * sz)
    # evolve and calculate expectation values
    output = qutip.mesolve(H, psi0, tlist, c_op_list, [sx, sy, sz])
    return output.expect[0], output.expect[1], output.expect[2]

## calculate the dynamics
w     = 1.0 * 2 * np.pi  # qubit angular frequency
theta = 0.2 * np.pi      # qubit angle from sigma_z axis (toward sigma_x axis)
gamma1 = 0.5             # qubit relaxation rate
gamma2 = 0.2             # qubit dephasing rate
# initial state
a = 1.0
psi0 = (a*qutip.basis(2, 0) + (1-a)*qutip.basis(2, 1))/np.sqrt(a**2 + (1-a)**2)
tlist = np.linspace(0, 4, 250)
#expectation values for ploting
sx, sy, sz = qubit_integrate(w, theta, gamma1, gamma2, psi0, tlist)


In [4]:
import numpy as np
b = qutip.Bloch()
b.vector_color = ['r']
b.view = [-40, 30]
for i in range(len(sx)):
    b.clear()
    b.add_vectors([np.sin(theta), 0, np.cos(theta)])
    b.add_points([sx[:i+1], sy[:i+1], sz[:i+1]])
    b.save(dirc='temp')  # saving images to temp directory in current working directory


In [7]:
!ffmpeg -i temp/bloch_%01d.png bloch.mp4

ffmpeg version 8.0 Copyright (c) 2000-2025 the FFmpeg developers
  built with Apple clang version 17.0.0 (clang-1700.0.13.3)
  configuration: --prefix=/opt/homebrew/Cellar/ffmpeg/8.0_1 --enable-shared --enable-pthreads --enable-version3 --cc=clang --host-cflags= --host-ldflags='-Wl,-ld_classic' --enable-ffplay --enable-gnutls --enable-gpl --enable-libaom --enable-libaribb24 --enable-libbluray --enable-libdav1d --enable-libharfbuzz --enable-libjxl --enable-libmp3lame --enable-libopus --enable-librav1e --enable-librist --enable-librubberband --enable-libsnappy --enable-libsrt --enable-libssh --enable-libsvtav1 --enable-libtesseract --enable-libtheora --enable-libvidstab --enable-libvmaf --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxml2 --enable-libxvid --enable-lzma --enable-libfontconfig --enable-libfreetype --enable-frei0r --enable-libass --enable-libopencore-amrnb --enable-libopencore-amrwb --enable-libopenjpeg --enable-libspeex --e