In [None]:
def plot_dm(ax, density_matrix,
            xslice, yslice, 
            ticks_font_size = [8.5,8.5,14], 
            view=[45, -20], 
            zlimits = None, xlabel = None, ylabel = None, zlabel = None, plot_title = None, real_part = True):

    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
    from matplotlib import cm, colors
    from qutip import Qobj

    if isinstance(density_matrix, Qobj):
        if real_part:
            data_matrix = np.array(density_matrix.full().real[xslice[0]:xslice[1], yslice[0]:yslice[1]])
        else:
            data_matrix = np.array(density_matrix.full().imag[xslice[0]:xslice[1], yslice[0]:yslice[1]])
    else:
        if real_part:
            data_matrix = np.real(density_matrix[xslice[0]:xslice[1], yslice[0]:yslice[1]])
        else:
            data_matrix = np.imag(density_matrix[xslice[0]:xslice[1], yslice[0]:yslice[1]])
    
    rows, cols = data_matrix.shape
    x, y = np.meshgrid(np.arange(cols), np.arange(rows))
    X, Y = x.ravel(), y.ravel()
    z = data_matrix.ravel()
    width = depth = 1

    # Use diverging colormap for positive and negative
    cmap = plt.get_cmap("Spectral")

    # Normalize between global min and max
    min_height = np.round(np.min(z),2)
    max_height = np.round(np.max(z),2)
    if max_height == min_height:
        max_height += 1  # Prevent division by zero

    ###############################################
    #          Part that colors the bars
    ###############################################

    def get_bar_faces_with_smooth_gradient(x, y, z, width, depth, cmap, min_height, max_height):

        """Generates faces for a 3D bar with gradient, accounting for both positive and negative values."""

        num_shades = 20
        dz = z / num_shades
        faces = []
        face_colors = []

        for i in range(num_shades):
            z_start = i * dz
            z_end = (i + 1) * dz

            # Correct direction for negative bars (going downward)
            if z < 0:
                z_start, z_end = z_end, z_start

            # Normalize color using full range
            normalized_height = (z_start - min_height) / (max_height - min_height)
            color = cmap(normalized_height)

            # Add side faces
            faces.extend([
                [[x, y, z_start], [x + width, y, z_start], [x + width, y, z_end], [x, y, z_end]],  # Front
                [[x, y + depth, z_start], [x + width, y + depth, z_start], [x + width, y + depth, z_end], [x, y + depth, z_end]],  # Back
                [[x, y, z_start], [x, y + depth, z_start], [x, y + depth, z_end], [x, y, z_end]],  # Left
                [[x + width, y, z_start], [x + width, y + depth, z_start], [x + width, y + depth, z_end], [x + width, y, z_end]],  # Right
            ])
            face_colors.extend([color] * 4)

        # Base and top face
        bottom_z = 0
        top_z = z
        if z < 0:
            bottom_z, top_z = z, 0

        bottom_face = [[x, y, bottom_z], [x + width, y, bottom_z], [x + width, y + depth, bottom_z], [x, y + depth, bottom_z]]
        top_face = [[x, y, top_z], [x + width, y, top_z], [x + width, y + depth, top_z], [x, y + depth, top_z]]

        return faces, face_colors, bottom_face, top_face, top_z

    ########################################################
    #           Part that outlines each face of the bars
    ########################################################
    def get_bar_faces(x, y, z, width, depth):
        
        """Generates faces for a 3D bar"""
        
        # Vertices for each face of the bar
        bottom_face = [[x, y, 0], [x + width, y, 0], [x + width, y + depth, 0], [x, y + depth, 0]]
        top_face = [[x, y, z], [x + width, y, z], [x + width, y + depth, z], [x, y + depth, z]]
        front_face = [[x, y, 0], [x + width, y, 0], [x + width, y, z], [x, y, z]]
        back_face = [[x, y + depth, 0], [x + width, y + depth, 0], [x + width, y + depth, z], [x, y + depth, z]]
        left_face = [[x, y, 0], [x, y + depth, 0], [x, y + depth, z], [x, y, z]]
        right_face = [[x + width, y, 0], [x + width, y + depth, 0], [x + width, y + depth, z], [x + width, y, z]]

        # Return faces for the 3D bar
        return [bottom_face, top_face, front_face, back_face, left_face, right_face]

    ########################################################
    #           Plot the colored bars
    ########################################################

    for i in range(len(X)):
        faces, face_colors, bottom_face, top_face, top_z = get_bar_faces_with_smooth_gradient(X[i], Y[i], z[i], width, depth, cmap, min_height, max_height)

        poly_side = Poly3DCollection(faces, facecolors=face_colors, edgecolors=None, alpha=0.6)
        ax.add_collection3d(poly_side)

        # Color the top and bottom faces
        top_color = cmap((top_z - min_height) / (max_height - min_height))
        bottom_color = cmap((0 - min_height) / (max_height - min_height))
        ax.add_collection3d(Poly3DCollection([top_face], facecolors=top_color, edgecolors=None, alpha=0.5))
        ax.add_collection3d(Poly3DCollection([bottom_face], facecolors=bottom_color, edgecolors=None, alpha=0.5))

    #########################################################
    #        Plot the outlines
    #########################################################

    for i in range(len(X)):
        faces_2 = get_bar_faces(X[i], Y[i], z[i], width, depth)
        poly3d = Poly3DCollection(faces_2, facecolors = None, edgecolors="k", alpha = 0)
        # Change the linewidth here
        poly3d.set_linewidth(0.8)
        ax.add_collection3d(poly3d)
    
    # Set custom tick positions
    xticks = [i + 0.5  for i in range(0, xslice[1] - xslice[0])]
    yticks = [i + 0.5  for i in range(0, yslice[1] - yslice[0])]
 
    ax.set_xticks(xticks)  # X-axis ticks
    ax.set_yticks(yticks)  # Y-axis ticks
    ax.set_zticks([min_height, max_height])   # Z-axis ticks

    # Change tick font sizes
    ax.tick_params(axis='z', labelsize=ticks_font_size[2])   # Change to desired size

    # shift the ticks along the axes
    ax.tick_params(axis='x', pad=-5)          # Smaller pad = closer to axis
    ax.tick_params(axis='y', pad=-5)
    ax.tick_params(axis='z', pad=0)

    # set zlimit
    if zlimits:
        ax.set_zlim(zlimits)
    else:
        ax.set_zlim(min_height, max_height)

    # Set custom tick labels
    xticklabels = [rf'$|{i}\rangle$' for i in range(xslice[0], xslice[1])]
    yticklabels = [rf'$|{i}\rangle$' for i in range(yslice[0], yslice[1])]

    
    ax.set_xticklabels(xticklabels, fontsize = ticks_font_size[0])  # X-axis labels
    ax.set_yticklabels(yticklabels, fontsize = ticks_font_size[1])  # Y-axis labels

    # Set labels and title
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_zlabel(zlabel)
    ax.set_title(plot_title, fontsize = 20)

    # Set background color to grey
    ax.set_facecolor('white')  # Background color for the plot

    # Set grid lines color to white
    #ax.grid(color='white', linewidth = 5)  # Grid lines color
    ax.grid(False)  # removes grid lines all together

    # Set the XY, XZ, and YZ planes to desired color by using the Axes properties
    ax.xaxis.pane.fill = True
    ax.xaxis.pane.set_facecolor('#E5ECF6')      # plotly color: #E5ECF6  other color: #f5f9feff

    ax.yaxis.pane.fill = True
    ax.yaxis.pane.set_facecolor('#E5ECF6')

    ax.zaxis.pane.fill = True
    ax.zaxis.pane.set_facecolor('#E5ECF6')

    # Set the color of the axis lines
    ax.xaxis.line.set_color('white')  # X-axis boundary color
    ax.yaxis.line.set_color('white')  # Y-axis boundary color
    ax.zaxis.line.set_color('white')  # Z-axis boundary color

    # change the camera angle
    ax.view_init(elev=view[0], azim=view[1], roll=0, vertical_axis='z') #-25

    # Rotate the x and y axis ticks to make it look like they are projected on to the plane
    for label in ax.get_yticklabels():
        label.set_rotation(20)  # Rotate y-axis labels by -45 degrees

    for label in ax.get_xticklabels():
        label.set_rotation(20)  # Rotate y-axis labels by -45 degrees

    # Make figure background transparent
    fig.patch.set_alpha(0.0)

    # Make 3D axes panes transparent
    # ax.xaxis.pane.set_alpha(0.0)
    # ax.yaxis.pane.set_alpha(0.0)
    # ax.zaxis.pane.set_alpha(0.0)

