In [3]:
import os

import jax
import jax.numpy as jnp



In [4]:
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)


In [5]:
key = jax.random.PRNGKey(0)

In [8]:
def generate_gaussian(n_samples,n_features,sigma_d = 0.8):
    # n_samples for each class
    D = n_features
    # Generate synthetic data
    # First class: mean = 2π/32 * (d mod 16)
    mu_1 = (2 * jnp.pi / 16) * ( (jnp.arange(D)) % 8 ) % (2*jnp.pi)

    # Second class: mean = 2π/32 * (16 + d mod 16)
    mu_2 = (2 * jnp.pi / 16) * ( 8 + jnp.arange(D) % 8 )% (2*jnp.pi)
    n_samples_per_class = n_samples // 2
    # Generate data from Gaussian distribution
    features_1 = mu_1 + sigma_d * jax.random.normal(key, shape=(n_samples_per_class,D))
    features_2 = mu_2 + sigma_d * jax.random.normal(key, shape=(n_samples_per_class,D))
    features = jnp.vstack([features_1,features_2])
    labels = jnp.hstack([jnp.zeros(n_samples_per_class,dtype=int),jnp.ones(n_samples_per_class,dtype=int)])

    return features,labels

In [20]:
def generate_dataset(features,labels,M_train,M_test,key=jax.random.PRNGKey(0)):
    
    # subsample train and test split
    train_indices = jax.random.choice(key,len(labels), shape=(M_train,), replace=False)
    test_indices = jax.random.choice(key,
        jnp.setdiff1d(jnp.arange(len(labels)), train_indices), shape=(M_test,), replace=False)

    x_train, y_train = features[train_indices], labels[train_indices]
    x_test, y_test = features[test_indices], labels[test_indices]

    return jnp.array(x_train),jnp.array(y_train),jnp.array(x_test),jnp.array(y_test)

In [21]:
def save_dataset(n_qubits,n_layers,M_train,M_test,data_type,M_total,datasets_path):

    # Create directory if it doesn't exist
    os.makedirs(f"{datasets_path}/{data_type}", exist_ok=True)
    n_features =  n_qubits * n_layers * 3  # Dimensionality
    n_samples = M_total
    features,labels = generate_gaussian(n_samples,n_features)
    print(features.shape)
    features = features.reshape(n_samples,n_layers,n_qubits,3)
    x_train,y_train,x_test,y_test = generate_dataset(features,labels,M_train,M_test)
    jnp.save(f"{datasets_path}/{data_type}/x_train_qubit_{n_qubits}_layer_{n_layers}_sample_{M_train}.npy",x_train)
    jnp.save(f"{datasets_path}/{data_type}/y_train_qubit_{n_qubits}_layer_{n_layers}_sample_{M_train}.npy",y_train)
    jnp.save(f"{datasets_path}/{data_type}/x_test_qubit_{n_qubits}_layer_{n_layers}_sample_{M_test}.npy",x_test)
    jnp.save(f"{datasets_path}/{data_type}/y_test_qubit_{n_qubits}_layer_{n_layers}_sample_{M_test}.npy",y_test)
    

In [22]:
M_train = 2000
M_test = 1000000
M_total = 1002000
n_qubits = 1
data_type = "gaussian"
datasets_path = '../../datasets'

In [24]:
for n in [2,4,6]:
    for l in range(1,9):
        save_dataset(n,l,M_train=M_train,M_test=M_test,M_total=M_total,data_type=data_type,datasets_path=datasets_path)

(1002000, 6)
(1002000, 12)
(1002000, 18)
(1002000, 24)
(1002000, 30)
(1002000, 36)
(1002000, 42)
(1002000, 48)
(1002000, 12)
(1002000, 24)
(1002000, 36)
(1002000, 48)
(1002000, 60)
(1002000, 72)
(1002000, 84)
(1002000, 96)
(1002000, 18)
(1002000, 36)
(1002000, 54)
(1002000, 72)
(1002000, 90)
(1002000, 108)
(1002000, 126)
(1002000, 144)
