In [None]:
import jax
import jax.numpy as jnp
import optax
import pennylane as qml
import numpy as np
from functools import partial


# Loss Functions

def mse_loss(px, py):
    """Mean Squared Error Loss"""
    return jnp.mean((px - py) ** 2)


def mae_loss(px, py):
    """Mean Absolute Error Loss"""
    return jnp.mean(jnp.abs(px - py))


def kl_divergence_loss(px, py):
    """Kullback-Leibler Divergence Loss"""
    return jnp.sum(py * jnp.nan_to_num(jnp.log(py / px)))


def cross_entropy_loss(px, py):
    """Cross-Entropy Loss"""
    return -jnp.sum(py * jnp.nan_to_num(jnp.log(px)))


def rmse_loss(px, py):
    """Root Mean Squared Error Loss"""
    return jnp.sqrt(jnp.mean((px - py) ** 2))


# Kernel Functions
def gaussian_kernel(x1, x2, bandwidth=1.0):
    """Gaussian Kernel"""
    dist = jnp.sum((x1 - x2) ** 2)
    return jnp.exp(-dist / (2 * bandwidth ** 2))


def linear_kernel(x1, x2):
    """Linear Kernel"""
    return jnp.dot(x1, x2)


def polynomial_kernel(x1, x2, degree=3):
    """Polynomial Kernel"""
    return (1 + jnp.dot(x1, x2)) ** degree


# MMD Loss Function with Kernel Choice
def mmd_loss(px, py, kernel, **kernel_params):
    """Maximum Mean Discrepancy (MMD) Loss with a chosen kernel"""
    kpxpx = jnp.mean([kernel(x, y, **kernel_params) for x in px for y in px])
    kpypy = jnp.mean([kernel(x, y, **kernel_params) for x in py for y in py])
    kpxpy = jnp.mean([kernel(x, y, **kernel_params) for x in px for y in py])
    return kpxpx + kpypy - 2 * kpxpy


def learn_qcbm(target_data, n_layers=3, n_iterations=100, learning_rate=0.1, loss_function="mse", kernel=None, kernel_params=None):
    """
    Train a QCBM to learn a target distribution.

    Parameters:
    - target_data: The target probability distribution or function data (will be normalized).
    - n_layers: The number of strongly entangling layers in the quantum circuit.
    - n_iterations: The number of training iterations.
    - learning_rate: The learning rate for the Adam optimizer.
    - loss_function: The loss function to use (options: "mse", "mae", "kl", "ce", "rmse", "mmd"). Default is "mse".
    - kernel: Optional kernel function to use with MMD (options: "gaussian", "linear", "polynomial"). Default is None.
    - kernel_params: Dictionary of parameters for the kernel function (e.g., bandwidth for Gaussian kernel, degree for Polynomial kernel).

    Returns:
    - qcbm_probs: The learned probability distribution from the QCBM.
    """
    # Step 1: Normalize the target data so it sums to 1
    target_data = np.array(target_data)
    Sum,Min = np.sum(target_data),np.min(target_data)  #For Rescaling after learning
    target_data = target_data - np.min(target_data)  # Ensure non-negative values
    target_data = target_data / np.sum(target_data)  # Normalize to make it a probability distribution

    # Step 2: Determine the number of qubits needed
    data_size = len(target_data)
    n_qubits = int(np.ceil(np.log2(data_size)))
    x_max = 2 ** n_qubits  # This is the total number of possible basis states

    # Pad the target data if necessary to match 2^n_qubits
    if data_size < x_max:
        target_data = np.pad(target_data, (0, x_max - data_size), 'constant')

    # Define the device and the QCBM circuit using PennyLane
    dev = qml.device("default.qubit", wires=n_qubits)

    @qml.qnode(dev, interface="jax")
    def circuit(weights):
        qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))
        return qml.probs(wires=range(n_qubits))

    # JIT compile the circuit for speed
    jit_circuit = jax.jit(circuit)

    # Choose the loss function
    if loss_function == "mse":
        loss_fn = mse_loss
    elif loss_function == "mae":
        loss_fn = mae_loss
    elif loss_function == "kl":
        loss_fn = kl_divergence_loss
    elif loss_function == "ce":
        loss_fn = cross_entropy_loss
    elif loss_function == "rmse":
        loss_fn = rmse_loss
    elif loss_function == "mmd":
        if kernel is None:
            raise ValueError("MMD loss requires a kernel to be specified.")
        # Wrap the loss function with the chosen kernel
        loss_fn = lambda px, py: mmd_loss(px, py, kernel, **kernel_params)
    else:
        raise ValueError(f"Unsupported loss function '{loss_function}'. Supported options are: 'mse', 'mae', 'kl', 'ce', 'rmse', 'mmd'.")

    # Define the QCBM class to compute the loss
    class QCBM:
        def __init__(self, circ, py, loss_fn):
            self.circ = circ
            self.py = py  # Target distribution Ï€(x)
            self.loss_fn = loss_fn

        # General loss function (e.g., MSE, MAE, KL Divergence, MMD with kernel)
        @partial(jax.jit, static_argnums=0)
        def loss(self, params):
            px = self.circ(params)  # Get probabilities from QCBM circuit
            return self.loss_fn(px, self.py), px  # Return the loss and probabilities

    # Initialize QCBM with the compiled circuit and target distribution
    qcbm = QCBM(jit_circuit, target_data, loss_fn)

    # Initialize the optimizer
    opt = optax.adam(learning_rate=learning_rate)

    # Randomly initialize the weights for the Strongly Entangling Layers
    wshape = qml.StronglyEntanglingLayers.shape(n_layers=n_layers, n_wires=n_qubits)
    weights = np.random.random(size=wshape)

    # Initialize the optimizer state
    opt_state = opt.init(weights)

    # Define the update step for gradient descent
    @jax.jit
    def update_step(params, opt_state):
        # Compute the loss and gradients
        (loss_val, qcbm_probs), grads = jax.value_and_grad(qcbm.loss, has_aux=True)(params)
        # Update the optimizer state and parameters
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss_val

    # Training loop
    for i in range(n_iterations):
        weights, opt_state, loss_val = update_step(weights, opt_state)
        if i % 10 == 0:
            print(f"Iteration {i}, Loss: {loss_val:.4f}")

    # Get the final QCBM probabilities after training
    qcbm_probs = qcbm.circ(weights)

    return qcbm_probs,Sum*qcbm_probs + Min

norm_res,res=learn_qcbm(target_data=silu(np.linspace(-2.5,2.5,num=64)), n_layers=10, n_iterations=150, learning_rate=0.1, loss_function="rmse", kernel=gaussian_kernel,kernel_params={"bandwidth": 1.0})
