In [None]:
import matplotlib.pyplot as plt
import optax
import numpy as np
# from pennylane import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import pickle
from collections import defaultdict
from models.pennylane_models import serial, strongly_parallel, all_to_all_crz, all_to_all_rzz, strongly_crz, strongly_rzz, basic_mixed, one_to_all_mixed, all_to_one_mixed

seed=12
def keygenerator(seed=seed):
    key = jax.random.PRNGKey(seed)
    while True:
        key, subkey = jax.random.split(key)
        yield subkey

In [None]:
coeffs = [
    (1.29 + 1.13j),  # c_1
    (0.43 + 0.89j),  # c_2
    (1.97 + 1.03j),  # c_3
    (0.17 + 0.59j),  # c_4
    (1.71 + 1.41j),  # c_5
    (0.61 + 0.37j),  # c_6
    (1.19 + 1.67j),  # c_7
    (0.73 + 1.61j),  # c_8
    (0.23 + 0.47j),  # c_9
    (1.83 + 0.83j),  # c_10
]

c0 = 0.0
scaling = 1

def target_function(x):
    res = c0
    for idx, c in enumerate(coeffs):
        exponent = scaling * (idx + 1) * x * 1j
        conj_c = jnp.conjugate(c)
        res += c * jnp.exp(exponent) + conj_c * jnp.exp(-exponent)
    return jnp.real(res)

def minmax_scaler(y):
    # Scale y to [0, 1]
    y_min = jnp.min(y)
    y_max = jnp.max(y)
    y_scaled = (y - y_min) / (y_max - y_min)
    return y_scaled

In [None]:
x_raw = jnp.linspace(-12, 12, 200)
x = minmax_scaler(x_raw) * 2 * jnp.pi
target_y = jax.vmap(target_function)(x)

target_y_scaled = minmax_scaler(target_y) * 2 - 1

# plt.plot(x, target_y, c="black")
plt.scatter(x, target_y_scaled, facecolor="white", edgecolor="black")
plt.ylim(-1.5, 1.5)
plt.show()

In [None]:
keys = keygenerator()

serial_models = [serial(x, 10, l, 1, next(keys)) for l in range(1, 10)]
strongly_entangling_parallel_models = [strongly_parallel(x, n_qubits=10, trainable_layers=l, scaling=1, random_key=next(keys)) for l in range(1, 8)]
basic_entangling_mixed_models = [basic_mixed(x, n_qubits=10, trainable_layers=l, scaling=1, random_key1=next(keys), random_key2=next(keys)) for l in range(1, 8)]
one_to_all_mixed_models = [one_to_all_mixed(x, n_qubits=10, trainable_layers=l, scaling=1, random_key1=next(keys), random_key2=next(keys)) for l in range(1, 8)]
all_to_one_mixed_models = [all_to_one_mixed(x, n_qubits=10, trainable_layers=l, scaling=1, random_key1=next(keys), random_key2=next(keys)) for l in range(1, 8)]
strongly_entangling_crz_models = [strongly_crz(x, n_qubits=10, trainable_layers=l, scaling=1, random_key1=next(keys), random_key2=next(keys)) for l in range(1, 10)]
strongly_entangling_rzz_models = [strongly_rzz(x, n_qubits=10, trainable_layers=l, scaling=1, random_key1=next(keys), random_key2=next(keys)) for l in range(1, 10)]
all_to_all_crz_models = [all_to_all_crz(x, n_qubits=10, trainable_layers=l, scaling=1, random_key1=next(keys), random_key2=next(keys)) for l in range(1, 5)]
all_to_all_rzz_models = [all_to_all_rzz(x, n_qubits=10, trainable_layers=l, scaling=1, random_key1=next(keys), random_key2=next(keys)) for l in range(1, 8)]

varying_models = [strongly_entangling_crz_models]

# qm, weights, name = all_to_all_crz_models[0]
# print(qml.draw(qm, level="device")(weights, x))

In [None]:
def square_loss(targets, predictions):
    return 0.5 * jnp.mean((targets - predictions) ** 2)

def cost(weights, model, x, y):
    predictions = jax.vmap(lambda x_: model(weights, x_))(x)
    # predictions = model(weights, x)
    return square_loss(y, predictions)

def r2_score(y_true, y_pred):
    ss_resid = jnp.sum((y_true - y_pred) ** 2)
    ss_total = jnp.sum((y_true - jnp.mean(y_true)) ** 2)
    return 1 - ss_resid / ss_total


class GradientLogger:
    def __init__(self):
        self.log = {
            "step": [],
            "loss": [],
            "grad_mean": [],
            "grad_std": [],
            "grad_min": [],
            "grad_max": [],
        }

    def get_gradients(self, weights, model, x, target_y):
        cost_fn = lambda w: cost(w, model, x, target_y)
        grads = jax.grad(cost_fn)(weights)

        # Flatten all gradient arrays into a single vector
        flat_grads = jnp.concatenate([jnp.ravel(g) for g in jax.tree_util.tree_leaves(grads)])

        return flat_grads

    def update(self, step, loss_val, flat_grads):
        self.log["step"].append(step)
        self.log["loss"].append(loss_val)
        self.log["grad_mean"].append(jnp.mean(flat_grads))
        self.log["grad_std"].append(jnp.std(flat_grads))
        self.log["grad_min"].append(jnp.min(flat_grads))
        self.log["grad_max"].append(jnp.max(flat_grads))

    def get_logs(self):
        return self.log

In [None]:
predictions_r2scores_models = []
gradient_logger = []
costs = []

for models in varying_models:
    layer = 0
    plt.plot(x, target_y_scaled, c="black")
    plt.scatter(x, target_y_scaled, facecolor="white", edgecolor="black")

    for model in models:
        qm, weights, name = model
        opt = optax.adam(0.01)
        opt_state = opt.init(weights)
        layer += 1
        if models in (serial_models, strongly_entangling_parallel_models):
            d = weights.size
        else:
            d = weights['W'].size + weights['final'].size

        @jax.jit
        def update_step(weights, opt_state, x_batch, y_batch):
            loss_fn = lambda w: cost(w, qm, x_batch, y_batch)
            loss, grads = jax.value_and_grad(loss_fn)(weights)
            # loss, grads = jax.value_and_grad(cost)(qm, weights, x_batch, y_batch)
            updates, opt_state = opt.update(grads, opt_state)
            weights = optax.apply_updates(weights, updates)
            return weights, opt_state, loss

        max_steps = 5000
        batch_size = 25
        cst = [cost(weights, qm, x, target_y)]  # initial cost
        logger = GradientLogger()

        for step in range(max_steps):

            batch_index = jax.random.choice(next(keys), len(x), (batch_size,), replace=False)

            x_batch = x[batch_index]
            y_batch = target_y_scaled[batch_index]

            weights, opt_state, _ = update_step(weights, opt_state, x_batch, y_batch)

            c = cost(weights, qm, x, target_y_scaled)
            cst.append(c)

            grads = logger.get_gradients(weights, qm, x, target_y_scaled)
            logger.update(step, c, grads)

            if (step + 1) % 1000 == 0:
                print("Cost at step {0:3} for {1} params: {2}".format(step + 1, d, c))

        costs.append((cst, name, d))
        gradient_logger.append((logger.get_logs(), name, d))

        predictions = jax.vmap(lambda x_: qm(weights, x_))(x)
        r2 = r2_score(target_y_scaled, predictions)
        predictions_r2scores_models.append((predictions, r2, name, d))

        plt.plot(x, predictions, label=f"{name}_{d}: R² = {r2:.4f}")

    plt.ylim(-1, 1)
    plt.legend()
    plt.show()
    plt.savefig(f"plots/model_performances_{name}.png")

with open(f"preds_and_r2/predictions_r2scores_models_10qubits_str_crz{seed}", "wb") as f:
    pickle.dump(predictions_r2scores_models, f)
with open(f"preds_and_r2/gradients_10qubits_str_crz{seed}", "wb") as f:
    pickle.dump(gradient_logger, f)
with open(f"preds_and_r2/costs_10qubits_str_crz{seed}", "wb") as f:
    pickle.dump(costs, f)

In [None]:
predictions_by_model = defaultdict(list)
r2s_by_model = defaultdict(list)

for preds, r2, name, d in predictions_r2scores_models:
    predictions_by_model[(name, d)].append(preds)
    r2s_by_model[(name, d)].append(r2)

for (name, d), preds_list in predictions_by_model.items():
    plt.figure(figsize=(10, 6))
    plt.plot(x, target_y, color='black', label='Target function')

    preds_array = np.array(preds_list)  # shape: (n_seeds, n_points)
    mean_preds = preds_array.mean(axis=0)
    std_preds = preds_array.std(axis=0)

    # Compute R² stats
    r2s = r2s_by_model[(name, d)]
    mean_r2 = np.mean(r2s)
    std_r2 = np.std(r2s)

    # Plot mean prediction
    plt.plot(x, mean_preds, label=f"{name}_{d}: R² = {mean_r2:.4f} ± {std_r2:.4f}")
    plt.fill_between(x, mean_preds - std_preds, mean_preds + std_preds, alpha=0.2)

    plt.xlabel("x")
    plt.ylabel("Prediction")
    plt.title(f"Model Predictions Averaged Over Seeds — {name} (d={d})")
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
r2_by_d = defaultdict(list)

for _, r2, name, d in predictions_r2scores_models:
    r2_by_d[d].append(r2)

sorted_ds = sorted(r2_by_d.keys())

# Transpose the data so rows = seeds
df = pd.DataFrame(list(zip(*[r2_by_d[d] for d in sorted_ds])), columns=sorted_ds)
df.columns.name = "Number of parameters (d)"
df.index.name = "Seed"

display(df.style.format("{:.4f}", na_rep='–'))

In [None]:
# Group costs by model
costs_by_model = defaultdict(list)

for cost_list, name, d in costs:
    costs_by_model[(name, d)].append(cost_list)

# Plot
for (name, d), all_costs in costs_by_model.items():
    # Pad all cost lists to the same length if needed
    max_len = max(len(c) for c in all_costs)
    padded = np.array([c + [np.nan] * (max_len - len(c)) for c in all_costs])

    mean_cost = np.nanmean(padded, axis=0)
    std_cost = np.nanstd(padded, axis=0)

    steps = np.arange(len(mean_cost))

    plt.plot(steps, mean_cost, label=f"{name} ({d} params)")
    plt.fill_between(steps, mean_cost - std_cost, mean_cost + std_cost, alpha=0.2)

    plt.xlabel("Training steps")
    plt.ylabel("Cost")
    plt.title("Average Training Cost over Seeds")
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
table_data = {}

for _, r2, name, d in predictions_r2scores_models:
    if name not in table_data:
        table_data[name] = {}
    table_data[name][d] = r2

df = pd.DataFrame.from_dict(table_data, orient='index')
df.columns.name = 'Number of parameters (d)'
df.index.name = 'Model'

df = df[sorted(df.columns, reverse=False)]

display(df.style.format("{:.4f}", na_rep='–'))

In [None]:
plt.plot(x, target_y_scaled, c="black")
# plt.scatter(x, target_y_scaled, facecolor="white", edgecolor="black")

for i in range(len(predictions_r2scores_models)):
    predictions, r2, name, d = predictions_r2scores_models[i]
    if d == 790:
        plt.plot(x, predictions, label=f"{name}_{d}: R² = {r2:.4f}")

plt.ylim(-1, 1)
plt.legend()
plt.show()
print(predictions, r2, name, d)

In [None]:
with open("preds_and_r2/predictions_r2scores_models_10qubits_dummy_ata_rzz", "rb") as f:
    predictions_r2scores_models = pickle.load(f)
with open("preds_and_r2/gradients_10qubits_dummy_ata_rzz", "rb") as f:
    gradient_loggers = pickle.load(f)
with open("preds_and_r2/costs_10qubits_dummy_ata_rzz", "rb") as f:
    costs = pickle.load(f)