In [29]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

In [30]:
class QuantumModuleGraph:
    def __init__(self, qubits, snails, edges, driven_snail):
        self.G = nx.Graph()
        self.qubits = qubits  # List of qubits
        self.snails = snails  # List of SNAILs
        self.edges = edges  # Edge list
        self.driven_snail = f"SNAIL{driven_snail}"  # Define the driven SNAIL
        self._add_edges()

    def _add_edges(self):
        for u, v in self.edges:
            node_u = f"Q{u}" if u in self.qubits else f"SNAIL{u}"
            node_v = f"Q{v}" if v in self.qubits else f"SNAIL{v}"
            self.G.add_edge(node_u, node_v, interaction="default", color="black")

    def define_special_edges(self):
        # Highlight intra and inter edges based on the driven SNAIL
        for u, v in self.edges:
            node_u = f"Q{u}" if u in self.qubits else f"SNAIL{u}"
            node_v = f"Q{v}" if v in self.qubits else f"SNAIL{v}"

            if node_u == self.driven_snail or node_v == self.driven_snail:
                # Intra-module qubit-SNAIL edges
                if node_u.startswith("Q") or node_v.startswith("Q"):
                    self.G.edges[node_u, node_v]["interaction"] = "snail-qubit (intra)"
                    self.G.edges[node_u, node_v]["color"] = "orange"

                # Intra-module qubit-qubit edges
                for neighbor in self.G.neighbors(self.driven_snail):
                    if neighbor.startswith("Q"):
                        for other_neighbor in self.G.neighbors(self.driven_snail):
                            if (
                                other_neighbor.startswith("Q")
                                and neighbor != other_neighbor
                            ):
                                self.G.add_edge(
                                    neighbor,
                                    other_neighbor,
                                    interaction="qubit-qubit",
                                    color="blue",
                                )

            elif node_u.startswith("SNAIL") and node_v.startswith("Q"):
                # Inter-module qubit-SNAIL edges
                if node_u != self.driven_snail and node_v not in self.G.neighbors(
                    self.driven_snail
                ):
                    self.G.edges[node_u, node_v]["interaction"] = "snail-qubit (inter)"
                    self.G.edges[node_u, node_v]["color"] = "purple"

    def plot_graph(self, qubit_frequencies, snail_frequencies):
        pos = nx.spring_layout(self.G, seed=42)
        labels = {
            node: (
                f"{node}\n{qubit_frequencies[self.qubits.index(int(node[1:]))]:.2f} GHz"
                if node.startswith("Q")
                else f"SNAIL{node[5:]} {snail_frequencies[self.snails.index(int(node[5:]))]:.2f} GHz"
            )
            for node in self.G.nodes
        }
        node_colors = [
            "gray" if node.startswith("Q") else "black" for node in self.G.nodes
        ]
        plt.figure()
        nx.draw(
            self.G,
            pos,
            with_labels=True,
            node_color=node_colors,
            edgecolors="black",
            width=2,
        )
        nx.draw_networkx_edges(
            self.G,
            pos,
            edge_color=[self.G.edges[e]["color"] for e in self.G.edges],
            width=2,
        )
        plt.show()

    def get_graph(self):
        """Returns the NetworkX graph object."""
        return self.G

In [31]:
qubits_square = [0, 2, 4, 6, 11, 13, 15, 17, 22, 24, 26, 28, 33, 35, 37, 39]
snails_square = [
    1,
    3,
    5,
    7,
    8,
    9,
    10,
    12,
    14,
    16,
    18,
    19,
    20,
    21,
    23,
    25,
    27,
    29,
    30,
    31,
    32,
    34,
    36,
    38,
]
edges_square = [
    (0, 1),
    (1, 2),
    (2, 3),
    (3, 4),
    (4, 5),
    (5, 6),
    (0, 7),
    (2, 8),
    (4, 9),
    (6, 10),
    (7, 11),
    (8, 13),
    (9, 15),
    (10, 17),
    (11, 12),
    (12, 13),
    (13, 14),
    (14, 15),
    (15, 16),
    (16, 17),
    (11, 18),
    (13, 19),
    (15, 20),
    (17, 21),
    (18, 22),
    (19, 24),
    (20, 26),
    (21, 28),
    (22, 23),
    (23, 24),
    (24, 25),
    (25, 26),
    (26, 27),
    (27, 28),
    (22, 29),
    (24, 30),
    (26, 31),
    (28, 32),
    (29, 33),
    (30, 35),
    (31, 37),
    (32, 39),
    (33, 34),
    (34, 35),
    (35, 36),
    (36, 37),
    (37, 38),
    (38, 39),
]

In [33]:
m = QuantumModuleGraph(qubits_square, snails_square, edges_square, 0)
qubit_frequencies = np.random.uniform(4.5, 5.5, len(qubits_square))
snail_frequencies = np.random.uniform(4.5, 5.5, len(snails_square))
m.define_special_edges()
m.plot_graph(qubit_frequencies, snail_frequencies)

NetworkXError: The node SNAIL0 is not in the graph.