In [1]:
pip install cirq qiskit gudhi networkx matplotlib -q

[0mNote: you may need to restart the kernel to use updated packages.


In [6]:
import itertools
import numpy as np
import networkx as nx
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from qiskit import QuantumCircuit, QuantumRegister, Aer, transpile
from qiskit.quantum_info import DensityMatrix, partial_trace, entropy
from sklearn.manifold import MDS
from qiskit.extensions import UnitaryGate
from qiskit.quantum_info.random import random_unitary
import gudhi as gd
from ipywidgets import interact, FloatSlider, IntSlider, FloatLogSlider

def qiskit_graph_state_circuit(num_qubits):
    """Returns a graph state circuit on num_qubits qubits"""
    # Generate a random graph with num_qubits nodes
    graph = nx.complete_graph(num_qubits)

    # Initialize the circuit with the given qubits
    qubits = QuantumRegister(num_qubits)
    circuit = QuantumCircuit(qubits)

    # Apply Hadamard gates to all qubits
    circuit.h(qubits)

    # Apply random controlled unitary gates to each edge
    for edge in graph.edges():
        control_qubit, target_qubit = qubits[edge[0]], qubits[edge[1]]

        # Generate a random unitary matrix
        random_unitary_matrix = random_unitary(2).data

        # Apply the controlled random unitary gate to the circuit
        controlled_unitary_gate = UnitaryGate(random_unitary_matrix).control()
        circuit.append(controlled_unitary_gate, [control_qubit, target_qubit])

    return circuit

def compute_inverse_mutual_information(rho):
    """
    Computes the inverse quantum mutual information for every pair of qubits in the input density matrix.

    Args:
        rho (DensityMatrix): a DensityMatrix object representing the state of n qubits.

    Returns:
        numpy.ndarray: an n x n matrix of inverse quantum mutual information values, where element i,j gives the inverse
        mutual information between qubits i and j.
    """
    # Convert the DensityMatrix to a numpy.ndarray
    rho = rho.data

    # Determine the number of qubits n from the shape of rho.
    n = int(np.log2(rho.shape[0]))

    # Create an empty n x n matrix to store the inverse mutual information.
    inverse_mutual_information_matrix = np.zeros((n, n))

    # Loop over all pairs of qubits (i, j).
    for i in range(n):
        for j in range(i+1, n):

            # Define the indices of the qubits to trace out.
            trace_indices = [k for k in range(n) if k != i and k != j]

            # Trace out the other qubits to obtain the reduced density matrix of qubits i and j.
            rho_ij = partial_trace(rho, trace_indices)

            # Compute the von Neumann entropy of the reduced density matrix of qubit i.
            entropy_i = entropy(partial_trace(rho_ij, [1]))

            # Compute the von Neumann entropy of the reduced density matrix of qubit j.
            entropy_j = entropy(partial_trace(rho_ij, [0]))

            # Compute the mutual information between qubits i and j.
            qmi = entropy_i + entropy_j - entropy(rho_ij)

            # Compute the inverse mutual information between qubits i and j.
            inverse_mutual_information_matrix[i, j] = 2 * np.log(2) - qmi

    # Copy the upper triangular part of the matrix to the lower triangular part.

    inverse_mutual_information_matrix += inverse_mutual_information_matrix.T - np.diag(inverse_mutual_information_matrix.diagonal())
    return inverse_mutual_information_matrix

# Compute persistent homology
def persistent_homology(qmi_distance_matrix, dimension):
    rips_complex = gd.RipsComplex(distance_matrix=qmi_distance_matrix, max_edge_length= 2*np.log(2))
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=dimension)
    persistence = simplex_tree.persistence()
    return persistence


# Partition circuit by moments for Qiskit
def partition_circuit_by_moments(circuit):
    moments = circuit.decompose().data
    backend = Aer.get_backend('statevector_simulator')
    state_vectors = []

    for i in range(1, len(moments) + 1):
        partial_circuit = QuantumCircuit(*circuit.qregs)
        for operation in moments[:i]:
            partial_circuit.append(operation[0], operation[1])

        transpiled_circuit = transpile(partial_circuit, backend)
        result = backend.run(transpiled_circuit).result()
        state_vectors.append(result.get_statevector())

    return state_vectors

def get_simplices(threshold, qmi_distance_matrix, dimension):
    # Create a Rips complex with the filtered edges
    rips_complex = gd.RipsComplex(distance_matrix=qmi_distance_matrix, max_edge_length=2*np.log(2))
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=dimension)

    # Get the simplices in the simplicial complex
    simplices = list(simplex_tree.get_filtration())

    # Filter out simplices with distance greater than the threshold
    simplices = [(simplex, distance) for simplex, distance in simplices if distance <= threshold]

    return simplices

def get_coordinates(qmi_distance_matrix):
    # Use MDS to get 3D coordinates for each qubit
    mds = MDS(n_components=3, dissimilarity='precomputed', random_state=42)
    coordinates = mds.fit_transform(qmi_distance_matrix)

    return coordinates


def plot_simplicial_complex(threshold, qmi_distance_matrix, dimension, coordinates, fig, row, col):
    simplices = get_simplices(threshold, qmi_distance_matrix, dimension)

    # Add qubit nodes to the plot
    fig.add_trace(go.Scatter3d(
        x=coordinates[:, 0], y=coordinates[:, 1], z=coordinates[:, 2],
        mode='markers+text',
        text=[str(i) for i in range(coordinates.shape[0])],
        textposition="top center",
        textfont=dict(size=12),
        marker=dict(size=6),
        showlegend=False
    ), row=row, col=col)

    # Add edges to the plot
    for simplex, _ in simplices:
        if len(simplex) == 2:
            i, j = simplex
            fig.add_trace(go.Scatter3d(
                x=[coordinates[i, 0], coordinates[j, 0]],
                y=[coordinates[i, 1], coordinates[j, 1]],
                z=[coordinates[i, 2], coordinates[j, 2]],
                mode='lines',
                line=dict(width=1.5, color='gray'),
                showlegend=False
            ), row=row, col=col)

    # Set axis properties
    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='cube'
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )


def grid_plot(num_qubits, threshold):
    global global_qubit_coordinates, global_circuit

    if global_circuit is None or global_circuit.num_qubits != num_qubits:
        global_circuit = qiskit_graph_state_circuit(num_qubits)
        global_qubit_coordinates = None

    state_vectors = partition_circuit_by_moments(global_circuit)

    m = int(np.ceil(np.sqrt(len(state_vectors))))
    n = int(np.ceil(len(state_vectors) / m))

    subplot_titles = [f'Time Step {i}' for i in range(len(state_vectors))]
    fig = make_subplots(rows=m, cols=n, specs=[[{'type': 'scatter3d'}]*n]*m, subplot_titles=subplot_titles)

    for idx, state_vector in enumerate(state_vectors):
        rho = DensityMatrix(state_vector)
        qmi_distance_matrix = compute_inverse_mutual_information(rho)

        if global_qubit_coordinates is None:
            global_qubit_coordinates = get_coordinates(qmi_distance_matrix)

        row, col = divmod(idx, n)
        plot_simplicial_complex(threshold, qmi_distance_matrix, 2, global_qubit_coordinates, fig, row+1, col+1)

    fig.update_layout(height=250 * m, width=250 * n)
    fig.show()

global_qubit_coordinates = None
global_circuit = None

# Create sliders for number of qubits and threshold
num_qubits_slider = IntSlider(min=2, max=10, step=1, value=5, description='Qubit Count', continuous_update=False)
threshold_slider = FloatSlider(min=0, max=2*np.log(2), step=0.01, value=1.18, description='Threshold')

# Display the interactive plot
interact(grid_plot, num_qubits=num_qubits_slider, threshold=threshold_slider)



interactive(children=(IntSlider(value=5, continuous_update=False, description='Qubit Count', max=10, min=2), F…

<function __main__.grid_plot(num_qubits, threshold)>