#### JAX

In [1]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
    assert device.device_type != 'GPU'
import os
import jax
jax.config.update('jax_default_device', jax.devices('cpu')[0])
jax.config.update("jax_enable_x64", True)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import time
from mrmustard import math, settings
from mrmustard.lab_dev import SqueezedVacuum, Vacuum, Sgate, Interferometer, Attenuator, BSgate, DM
from mrmustard.physics import fidelity
import os
import optax
from functools import partial
from typing import Sequence
import jax.numpy as jnp
from mrmustard.training import Optimizer
from tqdm import tqdm

math.change_backend("jax")  #works on jax-development branch

2025-01-20 18:42:26.269286: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-20 18:42:26.285814: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1737398546.306012  315696 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1737398546.312085  315696 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-20 18:42:26.332869: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

#### Evaluate optimization with JAX - fully jitted

In [3]:
BS_01 = BSgate(modes=(0,1), theta=0.5, phi=0.5, theta_trainable=True, phi_trainable=True)
BS_12 = BSgate(modes=(1,2), theta=0.1, phi=0.5, theta_trainable=True, phi_trainable=True)
att = Attenuator(modes=(0,1,2), transmissivity=0.5, transmissivity_trainable=False)
initial_state = SqueezedVacuum(modes=(0,1,2), r=1.0, phi=0.1, r_trainable=True, phi_trainable=True)
state_out = initial_state >> BS_01 >> BS_12 >> att


optimizer = optax.adam(learning_rate=0.001)
params = [p.value for p in initial_state.parameters.all_parameters.values()] + [p.value for p in BS_01.parameters.all_parameters.values()] + [p.value for p in BS_12.parameters.all_parameters.values()]
opt_state = optimizer.init(params)

#@jax.jit
def compute_state_out(params):
    initial_state.parameters.all_parameters['r'].assign(params[0])
    initial_state.parameters.all_parameters['phi'].assign(params[1])
    BS_01.parameters.all_parameters['theta'].assign(params[2])
    BS_01.parameters.all_parameters['phi'].assign(params[3])
    BS_12.parameters.all_parameters['theta'].assign(params[4])
    BS_12.parameters.all_parameters['phi'].assign(params[5])
    state_out = initial_state >> BS_01 >> BS_12 >> att
    return state_out.bargmann_triple()


def compute_loss(params):
    triples = compute_state_out(params)
    state_out = DM.from_bargmann(modes=(0,1,2), triple=triples)
    output_fock_state = state_out.fock_array(shape=(20,5,5,20,5,5))
    marginal = output_fock_state[:,4,4,:,4,4]
    val = -1*math.real(math.trace(marginal))
    return val

@jax.jit
def update_step(params, opt_state):
    loss, grads = jax.value_and_grad(compute_loss)(params)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss


params, opt_state, loss_val = update_step(params, opt_state)
start = time.time()
for i in tqdm(range(1000)):
    params, opt_state, loss_val = update_step(params, opt_state)
end = time.time()
print(end-start)
print(loss_val)

100%|██████████| 1000/1000 [01:27<00:00, 11.47it/s]


87.22015571594238
-0.003090005071263963


#### Evaluate forward passes with JAX - fully jitted

In [5]:
BS_01 = BSgate(modes=(0,1), theta=0.5, phi=0.5, theta_trainable=True, phi_trainable=True)
BS_12 = BSgate(modes=(1,2), theta=0.1, phi=0.5, theta_trainable=True, phi_trainable=True)
att = Attenuator(modes=(0,1,2), transmissivity=0.5, transmissivity_trainable=False)
initial_state = SqueezedVacuum(modes=(0,1,2), r=1.0, phi=0.1, r_trainable=True, phi_trainable=True)
state_out = initial_state >> BS_01 >> BS_12 >> att


optimizer = optax.adam(learning_rate=0.001)
params = [p.value for p in initial_state.parameters.all_parameters.values()] + [p.value for p in BS_01.parameters.all_parameters.values()] + [p.value for p in BS_12.parameters.all_parameters.values()]
opt_state = optimizer.init(params)

@jax.jit
def compute_state_out(params):
    initial_state.parameters.all_parameters['r'].assign(params[0])
    initial_state.parameters.all_parameters['phi'].assign(params[1])
    BS_01.parameters.all_parameters['theta'].assign(params[2])
    BS_01.parameters.all_parameters['phi'].assign(params[3])
    BS_12.parameters.all_parameters['theta'].assign(params[4])
    BS_12.parameters.all_parameters['phi'].assign(params[5])
    state_out = initial_state >> BS_01 >> BS_12 >> att
    return state_out.bargmann_triple()


def compute_loss(params):
    params[0] -= 0.001 #jax.random.uniform(jax.random.PRNGKey(0), shape=(1,)) / 100
    triples = compute_state_out(params)
    state_out = DM.from_bargmann(modes=(0,1,2), triple=triples)
    output_fock_state = state_out.fock_array(shape=(20,5,5,20,5,5))
    marginal = output_fock_state[:,4,4,:,4,4]
    val = -1*math.real(math.trace(marginal))
    return val

_ = compute_loss(params)
%timeit _ = compute_loss(params)

10 ms ± 35.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Numba 

In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
    assert device.device_type != 'GPU'
import os
import jax
jax.config.update('jax_default_device', jax.devices('cpu')[0])
jax.config.update("jax_enable_x64", True)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import time
from mrmustard import math, settings
from mrmustard.lab_dev import SqueezedVacuum, Vacuum, Sgate, Interferometer, Attenuator, BSgate, DM
from mrmustard.physics import fidelity
import os
import optax
from functools import partial
from typing import Sequence
import jax.numpy as jnp
from mrmustard.training import Optimizer
from tqdm import tqdm

math.change_backend("numpy")

In [2]:
BS_01 = BSgate(modes=(0,1), theta=0.5, phi=0.5, theta_trainable=True, phi_trainable=True)
BS_12 = BSgate(modes=(1,2), theta=0.1, phi=0.5, theta_trainable=True, phi_trainable=True)
att = Attenuator(modes=(0,1,2), transmissivity=0.5, transmissivity_trainable=False)
initial_state = SqueezedVacuum(modes=(0,1,2), r=1.0, phi=0.1, r_trainable=True, phi_trainable=True)

def cost_fn():
    output_state = initial_state >> BS_01 >> BS_12 >> att
    output_fock_state = output_state.fock_array(shape=(20,5,5,20,5,5))
    marginal = output_fock_state[:,4,4,:,4,4]
    val = -1*math.real(math.trace(marginal))
    return val

_ = cost_fn()
%timeit _ = cost_fn()

27.3 ms ± 274 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### Tensorflow code

In [1]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
    assert device.device_type != 'GPU'
import os
import jax
jax.config.update('jax_default_device', jax.devices('cpu')[0])
jax.config.update("jax_enable_x64", True)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import time
from mrmustard import math, settings
from mrmustard.lab_dev import SqueezedVacuum, Vacuum, Sgate, Interferometer, Attenuator, BSgate, DM
from mrmustard.physics import fidelity
import os
import optax
from functools import partial
from typing import Sequence
import jax.numpy as jnp
from mrmustard.training import Optimizer
from tqdm import tqdm

math.change_backend("tensorflow")  #works on  branch tfLabDev

2025-01-20 18:54:19.925163: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-20 18:54:19.941444: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1737399259.961501  316216 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1737399259.967791  316216 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-20 18:54:19.988721: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [2]:
BS_01 = BSgate(modes=(0,1), theta=0.5, phi=0.5, theta_trainable=True, phi_trainable=True)
BS_12 = BSgate(modes=(1,2), theta=0.1, phi=0.5, theta_trainable=True, phi_trainable=True)
att = Attenuator(modes=(0,1,2), transmissivity=0.5, transmissivity_trainable=False)
initial_state = SqueezedVacuum(modes=(0,1,2), r=1.0, phi=0.1, r_trainable=True, phi_trainable=True)

def cost_fn():
    output_state = initial_state >> BS_01 >> BS_12 >> att
    output_fock_state = output_state.fock_array(shape=(20,5,5,20,5,5))
    marginal = output_fock_state[:,4,4,:,4,4]
    val = -1*math.real(math.trace(marginal))
    return val

_ = cost_fn()

opt = Optimizer(euclidean_lr=0.1)
start = time.time()
opt.minimize(cost_fn, by_optimizing=[initial_state, BS_01, BS_12], max_steps=10)
end = time.time()
print(end-start) # this time is not accurate. Refer to the time reported at the bottom left of this cell for an estimate.

In [4]:
BS_01 = BSgate(modes=(0,1), theta=0.5, phi=0.5, theta_trainable=True, phi_trainable=True)
BS_12 = BSgate(modes=(1,2), theta=0.1, phi=0.5, theta_trainable=True, phi_trainable=True)
att = Attenuator(modes=(0,1,2), transmissivity=0.5, transmissivity_trainable=False)
initial_state = SqueezedVacuum(modes=(0,1,2), r=1.0, phi=0.1, r_trainable=True, phi_trainable=True)

def cost_fn():
    output_state = initial_state >> BS_01 >> BS_12 >> att
    output_fock_state = output_state.fock_array(shape=(20,5,5,20,5,5))
    marginal = output_fock_state[:,4,4,:,4,4]
    val = -1*math.real(math.trace(marginal))
    return val

_ = cost_fn()
%timeit _ = cost_fn()

170 ms ± 893 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
