---
# Circuit Design based on Feature Similarity for Quantum Generative Modeling
---
This notebook includes code for plotting all figures from the paper.

Author: Mathis Makarski \
GitHub: [@4d6174686973](https://github.com/4d6174686973) \
Email: makarski@keio.jp \
DOI: [https://doi.org/10.48550/arXiv.2503.11983](https://doi.org/10.48550/arXiv.2503.11983)

Download JGB dataset from: [https://www.mof.go.jp/english/policy/jgbs/reference/interest_rate/index.htm](https://www.mof.go.jp/english/policy/jgbs/reference/interest_rate/index.htm)

In [None]:
%cd /workspaces/circuit-design/

# SU4 Gate

In [None]:
import os
from src.extension import add_su4_gate
from qiskit import QuantumCircuit
from qiskit.circuit import ParameterVector

# Create a QuantumCircuit
qc_test = QuantumCircuit(2)

# Define the parameters (15 parameters)
params = ParameterVector('θ', 15)

# Add the SU(4) decomposition to the circuit
add_su4_gate(qc_test, 0, 1, params)

# Print the circuit
os.makedirs('plots', exist_ok=True)
qc_test.draw(output='mpl', filename='plots/SU4_gate.pdf')
qc_test.draw(output='mpl', filename='plots/SU4_gate.png')


# Use Scienceplot

In [None]:
import matplotlib.pyplot as plt
import scienceplots
plt.style.use(['science','ieee','no-latex'])

___
# Bars and Stripes
___

## BAS Dataset

In [None]:
def plot_bas(bas_images, width, height, save=False, figsize=(10, 3)):

    num_images = len(bas_images)
    fig, axes = plt.subplots(1, num_images, figsize=figsize)
    axes = axes.flatten()

    for i, ax in enumerate(axes):
        ax.imshow(bas_images[i].reshape(width,height), norm=plt.Normalize(0, 1), cmap='gray')

        # add border around images
        for _, spine in ax.spines.items():
            spine.set_visible(True)
            spine.set_color('black')
            spine.set_linewidth(1)

        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.tight_layout()
    if save:
        plt.savefig('plots/BAS_images.pdf', bbox_inches='tight', transparent=True)
        plt.savefig('plots/BAS_images.png', bbox_inches='tight', transparent=False, dpi=300)
    plt.show()
    plt.close()

In [None]:
from src.data import BAS, DataLoader

w = 3
h = 3
dim = w*h
bas_data = BAS(w,h)
dl = DataLoader(bas_data)
X = dl.binary

plot_bas(X, w, h, save=True)

## BAS Extensions

In [None]:
import numpy as np
import seaborn as sns
from scipy.spatial import distance  # for hamming distance

X = dl.binary
hamming = distance.cdist(X.T, X.T, 'hamming')

threshold = 0.5

# filter connections based on distance threshold    
dim = hamming.shape[0]
dist_filter = np.zeros_like(hamming)
dist_filter[hamming < threshold] = 1.0
dist_filter = dist_filter - np.eye(dim)  # remove self-connections

fig, axs = plt.subplots(1, 2, figsize=(6, 2.6))
sns.heatmap(hamming, annot=False, fmt=".2f", cmap="Blues", ax=axs[0], vmin=0.0, vmax=1.0)
axs[0].set_title("a) Hamming distance")
sns.heatmap(dist_filter, annot=False, fmt=".2f", cmap="Blues", ax=axs[1])
axs[1].set_title("b) Circuit Extension")
plt.tight_layout()
plt.savefig("plots/BAS_extension.pdf", bbox_inches='tight', transparent=True)
plt.savefig("plots/BAS_extension.png", bbox_inches='tight', transparent=False, dpi=300)
plt.show()

In [None]:
import networkx as nx
from src.extension import metric_based_topology, nearest_neighbor_topology, all_to_all_topology, linear_topology
from src.data import init_qubit_order_bas


# edges
n_qubits = 9
edges_lin = linear_topology(init_qubit_order_bas["3x3"])
edges_nearest_neighbor = nearest_neighbor_topology(w, h)
edges_metric_based = metric_based_topology(hamming, threshold)
edges_all_to_all = all_to_all_topology(n_qubits)

extensions = {
    "a) Linear" : edges_lin,
    "b) Nearest-Neighbor" : edges_nearest_neighbor,
    "c) Metric-Based" : edges_metric_based,
    "d) All-to-All" : edges_all_to_all
    }

# plot
fig, axs = plt.subplots(2, 2, figsize=(6, 6))
i = 0
axs_flat = axs.flatten()
for ext_name, edges_ext in extensions.items():
    ax = axs_flat[i]
    edges_new = set(edges_ext) - set(edges_lin)
    edges_new = sorted(list(edges_new))

    G = nx.Graph()
    G.add_nodes_from(range(n_qubits))
    G.add_edges_from(edges_lin)
    G.add_edges_from(edges_ext)

    pos = nx.circular_layout(G)

    nodes = nx.draw_networkx_nodes(G, pos, node_color='white', edgecolors='black', node_size=200, ax=ax)
    nx.draw_networkx_edges(G, pos, edgelist=edges_lin, edge_color='black', ax=ax)
    nx.draw_networkx_edges(G, pos, edgelist=edges_new, edge_color='cornflowerblue', ax=ax)
    nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold', font_color='black', ax=ax)
    ax.set_title(ext_name)
    i += 1

for ax in axs.flatten():
    ax.set_aspect('equal', adjustable='box')  # make figure symmetric

    # no border
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)
plt.tight_layout()
plt.savefig('plots/BAS_topology.pdf', bbox_inches='tight', transparent=True)
plt.savefig('plots/BAS_topology.png', bbox_inches='tight', transparent=False, dpi=300)
plt.show()


## BAS Training

In [None]:
from src.utils import plot_mmd_two_sets

bas_data = {
    'linear': [
        'results/runs/2025-02-14/15-44-28/0',
        'results/runs/2025-02-15/06-02-13/0',
        'results/runs/2025-03-04/02-29-41/0',
        'results/runs/2025-03-04/02-29-41/1',
        'results/runs/2025-03-04/02-29-41/2',
    ],
    'nearest-neighbor': [
        'results/runs/2025-02-14/15-44-28/1',
        'results/runs/2025-02-15/06-02-13/1',
        'results/runs/2025-03-03/17-57-03/3',
        'results/runs/2025-03-03/17-57-03/4',
        'results/runs/2025-03-03/17-57-03/5',
    ],
    'random':  [
        'results/runs/2025-03-01/10-41-54/0',  # seed=42 [(0, 2), (0, 5), (1, 3), (1, 4), (1, 5), (1, 7), (2, 5), (2, 7), (2, 8), (5, 8)]
        'results/runs/2025-03-01/10-41-54/1',  # seed=43 [(0, 2), (0, 4), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 5), (3, 5), (3, 8)]
        'results/runs/2025-03-01/10-41-54/2',  # seed=44 [(0, 5), (0, 6), (1, 3), (1, 4), (1, 5), (1, 6), (1, 8), (2, 6), (2, 8), (3, 7)]
        'results/runs/2025-03-02/12-25-45/0',  # seed=45 [(0, 2), (0, 6), (1, 3), (1, 4), (2, 5), (2, 7), (2, 8), (3, 6), (4, 7), (5, 8)]
        'results/runs/2025-03-02/12-25-45/1',  # seed=46 [(0, 2), (0, 8), (1, 3), (1, 4), (1, 5), (1, 7), (1, 8), (2, 6), (3, 5), (3, 6)]
    ],
    'metric-based': [
        'results/runs/2025-02-14/15-44-28/2',
        'results/runs/2025-02-15/06-02-13/2',
        'results/runs/2025-02-26/06-41-31/0',
        'results/runs/2025-02-26/06-41-31/1',
        'results/runs/2025-02-26/06-41-31/2',
    ],
    'all-to-all': [
        'results/runs/2024-12-20/00-00-00/3',
        'results/runs/2025-02-14/15-44-28/3',
        'results/runs/2025-02-15/06-02-13/3',
        'results/runs/2025-03-03/01-16-32/0',
        'results/runs/2025-03-03/01-16-32/1',
    ],
}

blues = plt.cm.Blues(np.linspace(0.3, 1, 5))[::-1].tolist()

colors = {
    'linear': blues[4],
    'nearest-neighbor': 'slategray',
    'random': blues[2],
    'metric-based': 'darkorange',
    'all-to-all': blues[0],
}

fig = plot_mmd_two_sets(bas_data, colors=colors, mode="meanstd", window=50, filename='BAS_MMD', save=True)


___
# Japanese Government Bond Interest Rates
___

## JGB Dataset

In [None]:
from src.data import JGB

n_qubits = 12
n_features = 3
jgb = JGB(n_qubits, n_features)
dl = DataLoader(jgb)

df0 = jgb.raw.copy()
df = jgb.decimal.copy()

# plot all rates
colors = plt.cm.Blues(np.linspace(0.4, 1, len(df0.columns)))[::-1].tolist()

fig, ax = plt.subplots(1, 1, figsize=(5, 3))
for i, c in enumerate(df0.columns):
    ax.plot(df0[c], label=f'{c[:-1]}-year Rate', color=colors[i], ls='-')


ax.set_xlabel('Date')
ax.set_ylabel('Interest Rate [%]')
ax.legend()
plt.savefig('plots/JGB_raw_data.pdf', bbox_inches='tight', transparent=True)
plt.savefig('plots/JGB_raw_data.png', bbox_inches='tight', transparent=False, dpi=300)
plt.show()

In [None]:
from src.utils import array_to_str, get_features_for_quasi_dist
from collections import Counter
from qiskit.visualization import plot_histogram
from matplotlib.colors import to_hex


target_str = array_to_str(jgb.binary)
target_dict = Counter(target_str)
bits_per_feature = n_qubits // n_features
bin_feat_dicts = get_features_for_quasi_dist(target_dict, bits_per_feature, n_features)

if n_features == 3:
    labels = ['a) 5-year Rate', 'b) 10-year Rate', 'c) 20-year Rate']
elif n_features == 4:
    labels = ['a) 2-year Rate', 'b) 5-year Rate', 'c) 10-year Rate', 'd) 20-year Rate']
    
# colors = plt.cm.Blues(np.linspace(0.4, 1, n_features))[::-1].tolist()
colors = [to_hex(c) for c in plt.cm.Blues(np.linspace(0.4, 1, n_features))[::-1]]

fig, axs = plt.subplots(3, 1, figsize=(5, 5))
for i, ax in enumerate(axs.flatten()):
    plot_histogram(bin_feat_dicts[i], ax=ax, bar_labels=False, color=colors[i])

plt.grid(False)

for i, ax in enumerate(axs.flatten()):
    ax.set_title(labels[i])

if n_features == 4:
    axs[0][1].set_ylabel('')
    axs[1][1].set_ylabel('')
plt.tight_layout()

plt.savefig('plots/JGB_binary_histograms.pdf', bbox_inches='tight', transparent=True)
plt.savefig('plots/JGB_binary_histograms.png', bbox_inches='tight', transparent=False, dpi=300)
plt.show()

## JGB Extensions

In [None]:
X, X_test, _, _ = dl.train_test_split(0.8)

print(X.shape, X_test.shape)
print(X.shape[0] + X_test.shape[0])
print(jgb.raw.index[0], jgb.raw.index[-1])

In [None]:
from src.utils import varInfoMat
import pandas as pd

num_steps = 100
steps = np.linspace(0 + 1 / num_steps, 1, num_steps)
num_connections = np.zeros_like(steps + 1)  # first element is 0

for i, threshold in enumerate(steps):
    varinfo = varInfoMat(pd.DataFrame(X), norm=True)
    dim = varinfo.shape[0]
    dist_filter = np.zeros_like(varinfo)
    dist_filter[varinfo < threshold] = 1.0
    dist_filter = dist_filter - np.eye(dim)  # remove self-connections
    num_connections[i] = np.sum(dist_filter) / 2

In [None]:
marker_idx = 94
print(marker_idx, steps[marker_idx], num_connections[marker_idx])

blues = plt.cm.Blues(np.linspace(0.2, 1, 5))[::-1].tolist()

plt.figure(figsize=(4, 2))
plt.plot(steps, num_connections, color=blues[0])
plt.plot(steps[marker_idx], num_connections[marker_idx], 'o', markersize=5, color=blues[3])
# plt.axvline(steps[marker_idx], color=blues[3], ls='--', lw=0.5)
# plt.axhline(num_connections[marker_idx], color=blues[3], ls='--', lw=0.5)
plt.xlabel('Threshold')
plt.ylabel('Number of Connections')
plt.savefig('plots/JGB_threshold.pdf', bbox_inches='tight', transparent=True)
plt.savefig('plots/JGB_threshold.png', bbox_inches='tight', transparent=False, dpi=300)
plt.show()

In [None]:
varinfo = varInfoMat(pd.DataFrame(X), norm=True)
threshold = 0.95 

# filter connections based on distance threshold    
dim = varinfo.shape[0]
dist_filter = np.zeros_like(varinfo)
dist_filter[varinfo < threshold] = 1.0
dist_filter = dist_filter - np.eye(dim)  # remove self-connections

# plot both varinfo and the filtered connections, use blue color palette
fig, axs = plt.subplots(1, 2, figsize=(6, 2.6))
sns.heatmap(varinfo, ax=axs[0], cmap='Blues', vmin=0, vmax=1)
axs[0].set_title('a) Variation of Information')
sns.heatmap(dist_filter, ax=axs[1], cmap='Blues', vmin=0, vmax=1)
axs[1].set_title('b) Circuit Extension')
plt.tight_layout()
plt.savefig('plots/JGB_extension.pdf', bbox_inches='tight', transparent=True)
plt.savefig('plots/JGB_extension.png', bbox_inches='tight', transparent=False)
plt.show()

In [None]:
import networkx as nx
from src.extension import metric_based_topology

# edges
edges_lin = [(i, i+1) for i in range(n_qubits-1)]
edges_ext = metric_based_topology(varinfo, threshold)
edges_new = set(edges_ext) - set(edges_lin)
edges_new = sorted(list(edges_new))

# plot
fig, axs = plt.subplots(1, 2, figsize=(6, 2.9))

G = nx.Graph()
G.add_nodes_from(range(n_qubits))
G.add_edges_from(edges_lin)

pos = nx.circular_layout(G)
nodes = nx.draw_networkx_nodes(G, pos, node_color='white', edgecolors='black', node_size=200, ax=axs[0])
nx.draw_networkx_edges(G, pos, edge_color='black', ax=axs[0])
nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold', font_color='black', ax=axs[0])
axs[0].set_title('a) Linear')

G = nx.Graph()
G.add_nodes_from(range(n_qubits))
G.add_edges_from(edges_lin)
G.add_edges_from(edges_ext)

pos = nx.circular_layout(G)

nodes = nx.draw_networkx_nodes(G, pos, node_color='white', edgecolors='black', node_size=200, ax=axs[1])
nx.draw_networkx_edges(G, pos, edgelist=edges_lin, edge_color='black', ax=axs[1])
nx.draw_networkx_edges(G, pos, edgelist=edges_new, edge_color='cornflowerblue', ax=axs[1])
nx.draw_networkx_labels(G, pos, font_size=8, font_weight='bold', font_color='black', ax=axs[1])
axs[1].set_title('b) Extended')

for ax in axs:
    ax.set_aspect('equal', adjustable='box')  # make figure symmetric

    # no border
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_frame_on(False)
plt.tight_layout()
plt.savefig('plots/JGB_topology.pdf', bbox_inches='tight', transparent=True)
plt.savefig('plots/JGB_topology.png', bbox_inches='tight', transparent=False)
plt.show()


## JGB Training

In [None]:
from src.utils import plot_mmd_one_set

jgb_data = {
    'linear': [
        'results/runs/2025-03-06/10-42-56/0',
        'results/runs/2025-03-06/13-48-33/2',
        'results/runs/2025-03-09/11-17-41/0',
        'results/runs/2025-03-09/11-17-41/1',
        'results/runs/2025-03-09/11-17-41/2',
        ],
    'random': [
        'results/runs/2025-03-06/13-48-33/0',  # seed=42 [(0, 2), (0, 5), (0, 8), (1, 5), (1, 9), (2, 11), (3, 6), (3, 7), (4, 6), (4, 7), (4, 9), (5, 7), (5, 9)]
        'results/runs/2025-03-08/17-30-38/0',  # seed=43 [(0, 6), (1, 7), (1, 8), (1, 9), (1, 11), (2, 5), (2, 7), (3, 8), (4, 7), (4, 10), (4, 11), (6, 10), (6, 11)]
        'results/runs/2025-03-08/17-30-38/1',  # seed=44 [(0, 2), (0, 10), (1, 3), (1, 5), (1, 11), (2, 10), (3, 6), (3, 8), (4, 7), (5, 9), (6, 10), (7, 9), (7, 11)]
        'results/runs/2025-03-08/17-30-38/2',  # seed=45 [(0, 4), (0, 7), (0, 8), (0, 11), (1, 3), (2, 6), (2, 8), (2, 9), (3, 5), (4, 8), (4, 10), (5, 8), (8, 11)]
        'results/runs/2025-03-10/16-02-23/0',  # seed=46 [(0, 9), (0, 11), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (2, 4), (2, 5), (2, 6), (3, 9), (6, 10), (8, 11)]
        ],
    'metric-based': [
        'results/runs/2025-03-06/03-26-40/0',
        'results/runs/2025-03-06/13-48-33/1',
        'results/runs/2025-03-07/09-19-47/2',
        'results/runs/2025-03-09/11-17-41/3',
        'results/runs/2025-03-09/11-17-41/4',
        ],
}

blues = plt.cm.Blues(np.linspace(0.3, 0.9, 2))[::-1].tolist()

colors = {
    'linear': blues[1],
    'random': blues[0],
    'metric-based': 'darkorange',
}

# fig = plot_mmd_two_sets(jgb_data, colors=colors, mode="meanstd", window=50, filename="JGB_MMD", save=True)

fig = plot_mmd_one_set(jgb_data, colors=colors, evalset='train', mode="meanstd", window=50, filename="JGB_MMD_train", save=True)

