In [None]:
import os
import pennylane as qml
import jax.numpy as jnp
import numpy as np
import jax
import matplotlib.pyplot as plt
from tqdm import tqdm,trange
from tqdm.notebook import tqdm


In [2]:
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)


In [3]:
def generate_hamiltonian(n,kappa,h):
    """
    Generating Hamiltonian of n-qubit Heisenberg model.

    Args:
        - n (int): The number of qubits 

    Returns:
        - cost_h (qml.Hamiltonian): the hamiltonian of n-qubit Heisenberg model
    """
    obs = []
    coeffs = []
    
    for i in range(n-1):
        obs.append(qml.PauliX(i) @ qml.PauliX(i+1))
        coeffs.append(-1)
    for i in range(n-2):
        obs.append(qml.PauliX(i) @ qml.PauliX(i+2))
        coeffs.append(kappa)
    for i in range(n):
        obs.append(qml.PauliZ(i))
        coeffs.append(-h)
    cost_h = qml.Hamiltonian(coeffs, obs)
    return cost_h

In [4]:
n = 6

In [5]:
def get_state(cost_h):
    Hmat = qml.Hamiltonian.sparse_matrix(cost_h).toarray()
    spectrum = np.linalg.eig(Hmat)
    # ground_energy = np.min(spectrum[0])
    ground_state = spectrum[1][:,np.argmin(spectrum[0])]

    return ground_state

In [6]:
def phase_bound_left(kappa):
    if kappa < 0.5 and kappa > 0:
        h = (1-kappa)/kappa * (1 - jnp.sqrt((1-3*kappa+4*kappa**2)/(1-kappa)))
    elif kappa == 0:
        h = 1.0
    else:
        h=0
    return h

In [7]:
def phase_bound_right(kappa):
    if kappa > 0.5:
        h = 1.05 * jnp.sqrt((kappa-0.5)*(kappa-0.1))
    else:
        h = 0
    return h

In [8]:
def get_label(kappa,h):
    if h < phase_bound_left(kappa) or h < phase_bound_right(kappa):
        label = 1
    else:
        label = -1 
    return label

In [9]:
kappa_list = jnp.linspace(0, 1, 300)
h_list = jnp.linspace(0, 1.5, 300)
K, H = jnp.meshgrid(kappa_list, h_list)
rows, cols = K.shape

In [10]:
# Generate N random samples
N = 20000
key = jax.random.PRNGKey(42)  # For reproducibility

# Generate random kappa and h values within bounds
kappas = jax.random.uniform(key, shape=(N,), minval=0.0, maxval=1.0)
key, subkey = jax.random.split(key)
hs = jax.random.uniform(subkey, shape=(N,), minval=0.0, maxval=1.5)


In [None]:

# Generate labels
X_train = []
y_train = []



for kappa, h in tqdm(zip(kappas, hs), total=N, desc='Generating training data'):
    cost_h = generate_hamiltonian(n, kappa, h)
    state = get_state(cost_h)
    X_train.append(state)
    y_train.append(get_label(kappa, h))


In [12]:
X_train = np.array(X_train)
y_train = np.array(y_train)


In [13]:
X_total = X_train
Y_total = y_train


In [14]:
X_train.shape

(20000, 64)

In [15]:
def generate_dataset(features,labels,M_total,M_train,M_test,key=jax.random.PRNGKey(0)):
    
    # subsample train and test split
    test_indices = jax.random.choice(key,M_total, shape=(M_test,), replace=False)
    
    

    train_indices = jax.random.choice(key,
        jnp.setdiff1d(jnp.arange(M_total), test_indices), shape=(M_train,), replace=False)
        
    
    x_train, y_train = features[train_indices], labels[train_indices]
    x_test, y_test = features[test_indices], labels[test_indices]
    

    return jnp.array(x_train),jnp.array(y_train),jnp.array(x_test),jnp.array(y_test)

In [None]:
def save_dataset(n_qubits,M_train,M_test,datasets_path,data_type,seed,M_total = 20000):
    # Generate synthetic data
    os.makedirs(f"{datasets_path}/{data_type}_{seed}", exist_ok=True)
    features,labels = X_total,Y_total
    print(features.shape)
    x_train,y_train,x_test,y_test = generate_dataset(features,labels,M_total,M_train,M_test)
    jnp.save(f"x_train_qubit_{n_qubits}_sample_{M_train}_{data_type}.npy",x_train)
    jnp.save(f"y_train_qubit_{n_qubits}_sample_{M_train}_{data_type}.npy",y_train)
    jnp.save(f"x_test_qubit_{n_qubits}_sample_{M_test}_{data_type}.npy",x_test)
    jnp.save(f"y_test_qubit_{n_qubits}_sample_{M_test}_{data_type}.npy",y_test)
    

In [None]:
datasets_path = '../../datasets'
data_type = "classification_phase"
n_qubits = 6

In [None]:

for seed in range(1,10):
    for M_train in [10,500,1000,1500,2000]:
        save_dataset(n_qubits=n_qubits,M_train=M_train,M_test=10000,datasets_path=datasets_path,data_type=data_type,seed=seed)