In [1]:
import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import NesterovMomentumOptimizer
import LHC_QML_module as lqm
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import time
from matplotlib import pyplot as plt

from warnings import simplefilter
import random
import os
import jax
import jax.numpy as jnp
import optax

In [2]:
seed = 123 # note, the seed does not work for the quantum circuit in this notebook
# for the quantum circuit in this notebook, I don't yet know a way to provide a seed

# Features to train on
training_feature_keys = [
    "f_mass4l",
    "f_eta4l",
    "f_Z2mass",
    "f_Z1mass",
]

# save_folder = os.path.join("saved", "model1-testspeed-ionq")

batch_size = 1
n_epochs = 1

n_qubits = len(training_feature_keys)

In [3]:
jax.config.update("jax_enable_x64", True)
np.random.seed(seed)

In [4]:
dev = qml.device("ionq.simulator", wires=n_qubits, shots=1024)

In [6]:
@qml.qnode(dev)
def qnode(weights, inputs):
    qml.AngleEmbedding(inputs, wires=range(n_qubits))
    qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
    return qml.expval(qml.PauliX(0))

# @jax.jit
def model(weights, inputs):
    return (qnode(weights, inputs)+1)/2

# @jax.jit
def loss(prob, label):
    # print(prob)
    return -jnp.mean(label*jnp.log(prob+1e-5)+(1-label)*jnp.log(1-prob+1e-5))

# @jax.jit
def accuracy(pred, label):
    return jnp.mean(jnp.isclose(pred,label))

# @jax.jit
def cost(weights, features, labels):
    probs = jnp.array([model(weights, f) for f in features])
    # probs = model(weights, features)
    return loss(probs, labels)

In [7]:
num_layers = 2

weights_init = 0.5 * np.random.randn(num_layers, n_qubits)
weights_init = jnp.array(weights_init)

print(weights_init.shape)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


(2, 4)


In [8]:
signals_folder = "./data/signal/4e"
backgrounds_folder = "./data/background/4e"

choice_feature_keys = [
    "f_lept1_pt",
    "f_lept1_eta",
    "f_lept1_phi",
    "f_lept1_pfx",
    "f_lept2_pt",
    "f_lept2_eta",
    "f_lept2_phi",
    "f_lept2_pfx",
    "f_lept3_pt",
    "f_lept3_eta",
    "f_lept3_phi",
    "f_lept4_pt",
    "f_lept4_eta",
    "f_lept4_phi",
    "f_Z1mass",
    "f_Z2mass",
    "f_angle_costhetastar",
    "f_angle_costheta1",
    "f_angle_costheta2",
    "f_angle_phi",
    "f_angle_phistar1",
    "f_pt4l",
    "f_eta4l",
    "f_mass4l",
    "f_deltajj",
    "f_massjj",
    "f_jet1_pt",
    "f_jet1_eta",
    "f_jet1_phi",
    "f_jet1_e",
    "f_jet2_pt",
    "f_jet2_eta",
    "f_jet2_phi",
    "f_jet2_e",
]

use_pca = False



num_features = len(training_feature_keys)


# load data from files
signal_dict, background_dict, files_used = lqm.load_data(
    signals_folder, backgrounds_folder, training_feature_keys
)

# formats data for input into vqc
features, labels = lqm.format_data(signal_dict, background_dict)

# for some reason, if you want to use jax.jit and jax.vmap with default.qubit, you need to use float64
# if you use float32, it will give you an error message
features = features.astype(np.float64)

# this is clunky, might want to make this its own function or something
# makes sure we use an equal amount of signal and background even if we have more signal than background
n_signal_events = (labels == 1).sum()
n_background_events = (labels == 0).sum()
if n_signal_events <= n_background_events:
    start = 0
    stop = 2 * n_signal_events
else:
    start = -2 * n_background_events
    stop = None

# splits data into testing and training sets
# data is first cut to inlcude equal number of signal and background events
# TODO: maybe split signal and backgrounds seperately to ensure equal number of signal/background in each test/training set and then combine and randomize order
train_features, rest_features, train_labels, rest_labels = train_test_split(
    features[start:stop, :],
    labels[start:stop],
    train_size=0.75,
    random_state=seed,
    stratify=labels[start:stop]
)

# preprocess data
train_features, rest_features = lqm.preprocess_data(
    train_features, rest_features, use_pca, num_features, seed
)

valid_features, test_features, valid_labels, test_labels = train_test_split(
    rest_features,
    rest_labels,
    train_size=0.2,  # meaning testing set will be 20% of the whole, while validation set 5% of the whole
    random_state=seed,
    stratify=rest_labels
)


signal data from:
data/signal/4e/4e_1-output_GluGluToHToZZTo4L_M-125_8TeV-powheg15-pythia6.root

background data from:
data/background/4e/4e_1-output_GluGluToZZTo4L_8TeV-gg2zz-pythia6.root

data loaded

# of signal events: 7057
# of background events: 21500

data formatted
data preprocessed



In [9]:
train_features = jnp.array(train_features)
train_labels = jnp.array(train_labels)
valid_features = jnp.array(valid_features)
valid_labels = jnp.array(valid_labels)
# test_features = jnp.array(test_features)
# test_labels = jnp.array(test_labels)

# maybe not necessary, but just in case

In [10]:
num_train=train_features.shape[0]

# opt = NesterovMomentumOptimizer(0.01)
optimizer = optax.adam(0.01)

# train the variational classifier
weights = weights_init
weights_best_loss = weights
best_loss = np.inf

In [11]:
opt_state = optimizer.init(weights)

In [12]:
valid_features[0]

Array([0.03042169, 0.42738457, 0.16541518, 0.55854005], dtype=float64)

In [13]:
valid_features[:1]

Array([[0.03042169, 0.42738457, 0.16541518, 0.55854005]], dtype=float64)

In [14]:
now = time.time()

model(weights, valid_features[0])

now, last = time.time(), now
print(now - last)


2.801501989364624


In [15]:
now = time.time()

model(weights, valid_features[:10])

now, last = time.time(), now
print(now - last)

92.07054090499878


In [16]:
now = time.time()

cost_grads = jax.grad(cost)     
grads = cost_grads(weights, train_features[:1], train_labels[:1])

now, last = time.time(), now
print(now - last)

36.98423647880554


In [17]:
now = time.time()

cost_grads = jax.grad(cost)     
grads = cost_grads(weights, train_features[:5], train_labels[:5])

now, last = time.time(), now
print(now - last)

304.20461320877075


In [None]:
now = time.time()

cost_grads = jax.grad(cost)     
grads = cost_grads(weights, train_features[:5], train_labels[:5])

now, last = time.time(), now
print(now - last)