In [9]:
"""
Utility functions for plotting Bloch spheres on matplotlib ax.
Extracted and refactored from QuTiP's Bloch sphere implementation.
"""

import numpy as np
from numpy import linspace, cos, sin, ones, outer, pi, size
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

import torch
import qutip

sphere_color="#FFDDDD"
sphere_alpha = 0.2
frame_color = 'gray'
frame_alpha = 0.2
frame_width = 1

class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0,0), (0,0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))

        return np.min(zs)

def show_label(ax, **opts):
    ax.scatter(-1,0,0, "$-1$", color="black")
    ax.scatter(1,0,0, "$1$", color="black")
    ax.scatter(0,-1,0, "$-1$", color="black")
    ax.scatter(0,1,0, "$1$", color="black")
    ax.scatter(0,0,-1, "$-1$", color="black")
    ax.scatter(0,0,1, "$1$", color="black")

    ax.text(-1.2, 0, 0, "$-x$", **opts)
    ax.text( 1.2, 0, 0,  "$x$", **opts)

    ax.text(0, -1.2, 0, "$-y$", **opts)
    ax.text(0,  1.2, 0,  "$y$", **opts)

    ax.text(0, 0, -1.2, "$-z$", **opts)
    ax.text(0, 0,  1.2,  "$z$", **opts)

def plot_back(ax):
    # back half of sphere
    u = linspace(0, pi, 25)
    v = linspace(0, pi, 25)
    x = outer(cos(u), sin(v))
    y = outer(sin(u), sin(v))
    z = outer(ones(size(u)), cos(v))
    ax.plot_surface(x, y, z, rstride=2, cstride=2,
                            color=sphere_color, linewidth=0,
                            alpha=sphere_alpha)
    # wireframe
    ax.plot_wireframe(x, y, z, rstride=5, cstride=5,
                                color=frame_color,
                                alpha=frame_alpha)
    # equator
    ax.plot(1.0 * cos(u), 1.0 * sin(u), zs=0, zdir='z',
                    lw=frame_width, color=frame_color)
    ax.plot(1.0 * cos(u), 1.0 * sin(u), zs=0, zdir='x',
                    lw=frame_width, color=frame_color)

def plot_front(ax):
    # front half of sphere
    u = linspace(-pi, 0, 25)
    v = linspace(0, pi, 25)
    x = outer(cos(u), sin(v))
    y = outer(sin(u), sin(v))
    z = outer(ones(size(u)), cos(v))
    ax.plot_surface(x, y, z, rstride=2, cstride=2,
                            color=sphere_color, linewidth=0,
                            alpha=sphere_alpha)
    # wireframe
    ax.plot_wireframe(x, y, z, rstride=5, cstride=5,
                                color=frame_color,
                                alpha=frame_alpha)
    # equator
    ax.plot(1.0 * cos(u), 1.0 * sin(u),
                    zs=0, zdir='z', lw=frame_width,
                    color=frame_color)
    ax.plot(1.0 * cos(u), 1.0 * sin(u),
                    zs=0, zdir='x', lw=frame_width,
                    color=frame_color)

In [5]:
!rm ./temp/* ./bloch.mp4

rm: cannot remove './bloch.mp4': No such file or directory


In [18]:
fig, ax = plt.subplots(1,1, figsize=(5,5), subplot_kw={"projection" : "3d"})

full_trajectory = torch.load("./asset/test_trajectory.pt").numpy() * 2

for i in range(int(full_trajectory.shape[0]/10)):
    plot_back(ax)

    a = Arrow3D(xs=[0, 0], ys=[0, 0], zs=[0, 1.5], mutation_scale=20, 
                lw=3, arrowstyle="-|>", color="r")
    ax.add_artist(a)

    trajectory = full_trajectory[max(0, i*10 - 100):min(i*10, len(full_trajectory))]
    ax.scatter(trajectory[:,0], trajectory[:,1], trajectory[:,2], c=np.arange(len(trajectory)), cmap="rainbow", s=1)
      
    
    plot_front(ax)
    show_label(ax)
    ax.axis("off")
    ax.legend()
    fig.tight_layout(pad=0)
    fig.savefig(fname=f'./temp/bloch_{i}.png')  # saving images to temp directory in current working directory
    ax.clear()

plt.close(fig)

  ax.legend()


In [19]:
!ffmpeg -i temp/bloch_%01d.png bloch.mp4

ffmpeg version 6.1.1-3ubuntu5 Copyright (c) 2000-2023 the FFmpeg developers
  built with gcc 13 (Ubuntu 13.2.0-23ubuntu3)
  configuration: --prefix=/usr --extra-version=3ubuntu5 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --disable-omx --enable-gnutls --enable-libaom --enable-libass --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libglslang --enable-libgme --enable-libgsm --enable-libharfbuzz --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --ena