In [1]:
import numpy as np
import time
from mrmustard import math, settings
from mrmustard.lab.gates import Sgate
from mrmustard.lab.states import SqueezedVacuum, Vacuum
from mrmustard.physics import fidelity
import os
import jax
import optax
from functools import partial
jax.config.update('jax_default_device', jax.devices('cpu')[0])
jax.config.update("jax_enable_x64", True)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ["JAX_PLATFORM_NAME"] = "cpu"

math.change_backend("jax")

##### JAX updates

In [2]:
# JAX
S = Sgate(r=0.1, phi=0, r_trainable=True, phi_trainable=True)
target_state = SqueezedVacuum(r=1.0, phi=np.pi/2)

optimizer = optax.adam(learning_rate=0.1)
params = [p.value for p in S.parameter_set.all_parameters.values()]
opt_state = optimizer.init(params)

@jax.jit
def compute_loss(params):
    S.parameter_set.all_parameters['r'].assign(params[0])
    S.parameter_set.all_parameters['phi'].assign(params[1])
    return 1-fidelity(Vacuum(1) >> S, target_state)

@jax.jit
def update_step(params, opt_state):
    grads = jax.grad(compute_loss)(params)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state


start = time.time()
for i in range(100):
    params, opt_state = update_step(params, opt_state)

    
end = time.time()
print(compute_loss(params))
print(end-start)

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


6.334027528032138e-07
1.3647212982177734


##### Tensorflow benchmark
* Restart the kernel before running the cell below to change the backend

In [1]:
import numpy as np
import time
from mrmustard import math, settings
from mrmustard.lab.gates import Sgate
from mrmustard.lab.states import SqueezedVacuum, Vacuum
from mrmustard.physics import fidelity
from mrmustard.training import Optimizer
import os
import jax
import optax
jax.config.update('jax_default_device', jax.devices('cpu')[0])
jax.config.update("jax_enable_x64", True)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ["JAX_PLATFORM_NAME"] = "cpu"
math.change_backend("tensorflow")

E0000 00:00:1736452355.525083   13034 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1736452355.531181   13034 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-09 19:52:37.576006: I external/xla/xla/pjrt/pjrt_c_api_client.cc:127] PjRtCApiClient created.


In [2]:
# tensorflow
S = Sgate(r=0.1, phi=0, r_trainable=True, phi_trainable=True)
target_state = SqueezedVacuum(r=1.0, phi=np.pi/2)

def cost_fn_sympl():
    state_out = Vacuum(1) >> S
    return 1 - fidelity(state_out, target_state)

start = time.time()
opt = Optimizer(symplectic_lr=0.1, euclidean_lr=0.05)
opt.minimize(cost_fn_sympl, by_optimizing=[S], max_steps=100)
end = time.time()
print(cost_fn_sympl())
print("Time taken: ", end-start)

I0000 00:00:1736452359.576105   13034 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 663 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:1e.0, compute capability: 7.5


Output()

I0000 00:00:1736452359.838742   13034 cuda_solvers.cc:178] Creating GpuSolver handles for stream 0x63af94563e20


tf.Tensor(3.248519471532596e-06, shape=(), dtype=float64)
Time taken:  10.761883974075317
