#### JAX

In [1]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
    assert device.device_type != 'GPU'

import os
import jax
jax.config.update('jax_default_device', jax.devices('cpu')[0])
jax.config.update("jax_enable_x64", True)
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
os.environ["JAX_PLATFORM_NAME"] = "cpu"

import numpy as np
import time
from mrmustard import math, settings
from mrmustard.lab_dev import SqueezedVacuum, Vacuum, Sgate, Interferometer, Attenuator, BSgate, DM, TwoModeSqueezedVacuum, Circuit
from mrmustard.physics import fidelity
import os
import optax
from functools import partial
from typing import Sequence
import jax.numpy as jnp
from mrmustard.training import Optimizer
from tqdm import tqdm
import equinox as eqx

math._is_immutable = False
math.change_backend("jax")  #works on jax-development branch

#### Evaluate optimization with JAX - fully jitted

In [26]:
state_out = sq0 >> sq1 >> sq2 >> BS_01 >> BS_12

In [39]:
BS_01 = BSgate(modes=(0,1), theta=0.5, phi=0.5, theta_trainable=True, phi_trainable=True)
BS_12 = BSgate(modes=(1,2), theta=0.1, phi=0.5, theta_trainable=True, phi_trainable=True)
sq0 = SqueezedVacuum(mode=0, r=1.0, phi=0.1, r_trainable=True, phi_trainable=True)
sq1 = SqueezedVacuum(mode=1, r=1.0, phi=0.1, r_trainable=True, phi_trainable=True)
sq2 = SqueezedVacuum(mode=2, r=1.0, phi=0.1, r_trainable=True, phi_trainable=True)


# dummy cost function
def cost_fn(BS_01, BS_12, sq0, sq1, sq2):
    state_out = sq0 >> sq1 >> sq2 >> BS_01 >> BS_12
    return state_out.bargmann_triple()[0].sum().real


# convert the circuit to equinox module
class Objective(eqx.Module):
    sq0_r: jax.Array
    sq0_phi: jax.Array
    sq1_r: jax.Array
    sq1_phi: jax.Array
    sq2_r: jax.Array
    sq2_phi: jax.Array
    BS_01_theta: jax.Array
    BS_01_phi: jax.Array
    BS_12_theta: jax.Array
    BS_12_phi: jax.Array

    def __init__(self):
        self.sq0_r = self.sq1_r = self.sq2_r = jnp.array(1.0)
        self.sq0_phi = self.sq1_phi = self.sq2_phi = jnp.array(0.1)
        self.BS_01_theta = self.BS_01_phi = self.BS_12_theta = self.BS_12_phi = jnp.array(0.5)

    def __call__(self):
        sq0 = SqueezedVacuum(mode=0, r=self.sq0_r, phi=self.sq0_phi)
        sq1 = SqueezedVacuum(mode=1, r=self.sq1_r, phi=self.sq1_phi)
        sq2 = SqueezedVacuum(mode=2, r=self.sq2_r, phi=self.sq2_phi)
        BS_01 = BSgate(modes=(0,1), theta=self.BS_01_theta, phi=self.BS_01_phi)
        BS_12 = BSgate(modes=(1,2), theta=self.BS_12_theta, phi=self.BS_12_phi)
        return cost_fn(BS_01, BS_12, sq0, sq1, sq2)

# instantiate the model
model = Objective()

# partition the model into parameters and static variables
params, static = eqx.partition(model, eqx.is_array)

# loss function that accepts parameters and updates the circuit and returns the cost_fn
def loss(params, static):
    model = eqx.combine(params, static)
    return model()

# evaluate and compute gradients
loss_val, grads = jax.value_and_grad(loss)(params, static)
print(loss_val)

-1.5289120581876054


In [42]:
optim = optax.adamw(1e-1)
opt_state = optim.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def make_step(
    model,
    opt_state,
):
    params, static = eqx.partition(model, eqx.is_array)
    loss_value, grads = jax.value_and_grad(loss)(params, static)
    updates, opt_state = optim.update(
        grads, opt_state, eqx.filter(model, eqx.is_array)
    )
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss_value

for epoch in range(10):
    model, opt_state, loss_value = make_step(model, opt_state)
    print(f"epoch: {epoch}, loss: {loss_value}")

epoch: 0, loss: -2.821398038349355
epoch: 1, loss: -2.88894909229339
epoch: 2, loss: -2.910706038872732
epoch: 3, loss: -2.9209236451888803
epoch: 4, loss: -2.933078795831404
epoch: 5, loss: -2.9447167869090434
epoch: 6, loss: -2.9554825285507116
epoch: 7, loss: -2.9641055767101827
epoch: 8, loss: -2.9694987345966597
epoch: 9, loss: -2.972231889905241
