In [1]:
import sys
import numpy as np

import jax
from jax import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp
from jax import jit, value_and_grad
from functools import partial
from itertools import product
import matplotlib.pyplot as plt
from scipy.stats import unitary_group

!pip install optax
import optax

!git clone https://github.com/JasvithBasani/CasOptAx.git
import CasOptAx as conn
from CasOptAx.linear_optics import Linear_Optics
from CasOptAx.circuit_builder import Circuit_singlemode

Cloning into 'CasOptAx'...
remote: Enumerating objects: 143, done.[K
remote: Counting objects: 100% (1/1), done.[K
remote: Total 143 (delta 0), reused 0 (delta 0), pack-reused 142[K
Receiving objects: 100% (143/143), 46.07 KiB | 873.00 KiB/s, done.
Resolving deltas: 100% (82/82), done.


In [2]:
#Define the number of photons in the network
N_photons = 3
#Define the number of modes/waveguides in the network
N_modes = 4
#Define the number of layers in the network
N_layers = 2

#Let us initialize the singlemode circuit class with a randomly select N_photon, N_mode state
#Here, 'singlemode' refers to the retention of the gaussian spectral profile after all the scattering processes
input_photons = tuple([N_photons] + [0] * (N_modes - 1))
circ = Circuit_singlemode(N_modes, N_photons, input_photons)

#The N_photon, N_mode states in the circ class have a very particular structure. This structure lies in the ordering of the states
#Let us look at this state
init_state = circ.state_amps.copy()
init_state

Initialzing Circuit, Please Wait .  .  .
***Circuit Ready For Compilation***




{(0, 0, 0, 3): 0j,
 (0, 0, 1, 2): 0j,
 (0, 0, 2, 1): 0j,
 (0, 0, 3, 0): 0j,
 (0, 1, 0, 2): 0j,
 (0, 1, 1, 1): 0j,
 (0, 1, 2, 0): 0j,
 (0, 2, 0, 1): 0j,
 (0, 2, 1, 0): 0j,
 (0, 3, 0, 0): 0j,
 (1, 0, 0, 2): 0j,
 (1, 0, 1, 1): 0j,
 (1, 0, 2, 0): 0j,
 (1, 1, 0, 1): 0j,
 (1, 1, 1, 0): 0j,
 (1, 2, 0, 0): 0j,
 (2, 0, 0, 1): 0j,
 (2, 0, 1, 0): 0j,
 (2, 1, 0, 0): 0j,
 (3, 0, 0, 0): (1+0j)}

In [3]:
#In this tutorial, we want to train the QPNN to map an input state to a Haar-random state
#Let us generate this Haar-random state. First, we need to know the dimension in the qudit basis
N_dim = circ.num_possible_states
U_haar = jnp.array(unitary_group.rvs(circ.num_possible_states), dtype = jnp.complex128)
target_state = init_state.copy() #To maintain the structure of the pytree

input_amps = []
for idx, s in enumerate(list(init_state.keys())):
  input_amps.append(init_state[s])
output_amps = jnp.matmul(U_haar, jnp.array(input_amps))

for idx, s in enumerate(list(init_state.keys())):
  target_state[s] = output_amps[idx]

target_state

{(0, 0, 0, 3): Array(-0.06048939-0.07918852j, dtype=complex128),
 (0, 0, 1, 2): Array(-0.20830109+0.40231447j, dtype=complex128),
 (0, 0, 2, 1): Array(-0.11191342-0.22556897j, dtype=complex128),
 (0, 0, 3, 0): Array(0.07679562+0.08252124j, dtype=complex128),
 (0, 1, 0, 2): Array(0.00054541-0.02020634j, dtype=complex128),
 (0, 1, 1, 1): Array(-0.26533795+0.01508944j, dtype=complex128),
 (0, 1, 2, 0): Array(-0.069569+0.11455554j, dtype=complex128),
 (0, 2, 0, 1): Array(0.06269142+0.05989441j, dtype=complex128),
 (0, 2, 1, 0): Array(-0.08216262-0.09862449j, dtype=complex128),
 (0, 3, 0, 0): Array(0.07306662-0.41923714j, dtype=complex128),
 (1, 0, 0, 2): Array(0.1814922+0.00389952j, dtype=complex128),
 (1, 0, 1, 1): Array(0.02801461-0.26674467j, dtype=complex128),
 (1, 0, 2, 0): Array(0.06081417+0.20706778j, dtype=complex128),
 (1, 1, 0, 1): Array(-0.22801683-0.16676589j, dtype=complex128),
 (1, 1, 1, 0): Array(-0.18186234-0.13296931j, dtype=complex128),
 (1, 2, 0, 0): Array(-0.08905806-0.

In [4]:
#Now we define the parameters of the network
lo = Linear_Optics(N_modes)

theta, phi, D = [], [], []
chi_1, chi_2 = [], []

#Initialize the network randomly. Each layer is initialized to a Haar-random matrix, approximated by the clements decomposition
for layer in range(N_layers):
  U = lo.haar_mat(N_modes) + 0j
  theta_, phi_, D_ = lo.get_clements_phases(U)
  theta.append(theta_); phi.append(phi_); D.append(D_)
  chi_1_val = jnp.array(np.random.randn(N_modes))
  chi_2_val = jnp.array(np.random.randn(N_modes))
  chi_1.append(chi_1_val); chi_2.append(chi_2_val)

#Typecast everything into jax arrays
theta, phi, D = jnp.array(theta), jnp.array(phi), jnp.array(D)
chi_1, chi_2 = jnp.array(chi_1), jnp.array(chi_2)
#alpha and beta are the beam-splitter errors, i.e., deviation from 50:50 splitting. Assumed to be zero for now
alpha = jnp.zeros(theta[0].shape)
beta = alpha

In [5]:
#QPNN function defined here - cascaded linear and 3-level system nonlinearities
def QPNN(theta, phi, D, chi_1, chi_2, amps):
  for l_num in range(N_layers):
    amps = circ.add_linear_layer(amps, theta[l_num], phi[l_num], D[l_num], alpha, beta)
    amps = circ.add_3ls_nonlinear_layer(amps, chi_1[l_num], chi_2[l_num])
  return amps

out_state = QPNN(theta, phi, D, chi_1, chi_2, init_state)

In [6]:
#Define the loss function
def loss_func(theta, phi, D, chi_1, chi_2, init_state, target_state):
  out_state = QPNN(theta, phi, D, chi_1, chi_2, init_state)

  @jit
  def MSE(out_state, target_state):
    #mse_val = jax.tree_map(lambda amp_1, amp_2: jnp.abs(amp_1 - amp_2)**2, out_state, target_state)
    mse_val, _ = jax.flatten_util.ravel_pytree(jax.tree_map(lambda amp_1, amp_2: jnp.abs(amp_1 - amp_2)**2, out_state, target_state))
    return jnp.mean(mse_val)

  @jit
  def dot(out_state, target_state):
    #dot_prod = jax.tree_map(lambda amp_1, amp_2: amp_1 * amp_2, out_state, target_state)
    dot_val, _ = jax.flatten_util.ravel_pytree(jax.tree_map(lambda amp_1, amp_2: amp_1 * amp_2, out_state, target_state))
    return jnp.abs(jnp.sum(dot_val))

  #Fidelity defined as inner product of the output state and target state
  fid_val = dot(out_state, target_state)
  #Loss minimized to maximize fidelity to 1
  loss_val = jnp.abs(1 - fid_val)**2
  return loss_val, (fid_val)

loss_out = loss_func(theta, phi, D, chi_1, chi_2, init_state, target_state)
loss_out

(Array(0.69835682, dtype=float64), Array(0.16432254, dtype=float64))

In [8]:
#Define training hyperparameters - number of epochs, learning rate scheduling
num_epochs = 2001
scheduler = optax.exponential_decay(init_value=0.025, transition_steps = num_epochs, decay_rate=0.1)
optimizer = optax.adam(scheduler)

def update(step, theta, phi, D, chi_1, chi_2, opt_state, in_state, target_state):
  #Calculate gradients of parameters 0-4
  (loss_val, (fid_val)), grads = value_and_grad(loss_func, (0, 1, 2, 3, 4), has_aux = True)(theta, phi, D, chi_1, chi_2, in_state, target_state)
  #Perform updates on the parameters
  updates, opt_state = optimizer.update(grads, opt_state)
  theta, phi, D, chi_1, chi_2 = optax.apply_updates((theta, phi, D, chi_1, chi_2), updates)
  return loss_val, fid_val, opt_state, (theta, phi, D, chi_1, chi_2)

#Initialize optimizer only for the parameters to be optimized
opt_state = optimizer.init((theta, phi, D, chi_1, chi_2))
loss_val, fid_val, opt_state, (theta, phi, D, chi_1, chi_2) = update(0, theta, phi, D, chi_1, chi_2, opt_state, init_state, target_state)
print (loss_val, fid_val)

0.6983568212157405 0.16432253756862605


In [11]:
loss_array = []
fid_array = []

#Run the optimization
for epoch in range(num_epochs):
  loss_val, fid_val, opt_state, (theta, phi, D, chi_1, chi_2) = update(0, theta, phi, D, chi_1, chi_2, opt_state, init_state, target_state)
  loss_array.append(loss_val); fid_array.append(fid_val)
  if (epoch%20) == 0:
    print (f"Epoch Number: {epoch}, Loss: {loss_val}, Fidelity: {fid_val}")


Epoch Number: 0, Loss: 0.22792792725037958, Fidelity: 0.5225820203947283
Epoch Number: 20, Loss: 0.06065023133719564, Fidelity: 0.7537273232021148
Epoch Number: 40, Loss: 0.02562745866099215, Fidelity: 0.8399142146816522
Epoch Number: 60, Loss: 0.017270383629769202, Fidelity: 0.8685831683924422
Epoch Number: 80, Loss: 0.014167350352914585, Fidelity: 0.8809733208355598
Epoch Number: 100, Loss: 0.011603122777556825, Fidelity: 0.8922822077019918
Epoch Number: 120, Loss: 0.009333294235172328, Fidelity: 0.9033910240444899
Epoch Number: 140, Loss: 0.007717416622725001, Fidelity: 0.9121511717623678
Epoch Number: 160, Loss: 0.006664820225059984, Fidelity: 0.9183616497896976
Epoch Number: 180, Loss: 0.005913012294725586, Fidelity: 0.9231038863483623
Epoch Number: 200, Loss: 0.005309700649862288, Fidelity: 0.9271323072283588
Epoch Number: 220, Loss: 0.004783115366323919, Fidelity: 0.9308399293932986
Epoch Number: 240, Loss: 0.004302197140667781, Fidelity: 0.9344088638559463
Epoch Number: 260, Lo