# Optical Neural Network with numpy/jax

![alt text](mzi_mesh.jpg "Mesh type used")

## Trainable photonic circuit
MZI, column of MZI, mesh, ...

In [1]:
import jax
from jax import numpy as np

random_seed=1

def MZI(X, teta):
    R = np.array([
      [np.cos(teta), -np.sin(teta)],
      [np.sin(teta), np.cos(teta)]
    ])
    out_vector=np.dot(R, X)
    return out_vector

def MZI_col(X, nb_mzi, W):
    # Column type: odd or even ?
    nb_pins=nb_mzi*2
    if nb_pins==len(X):
        start_pin_id=0
    elif nb_pins+2==len(X):
        start_pin_id=1
    else:
        raise ValueError("This mesh patern is not compatible with this input size and #MZIs")

    # pin them
    layer_outputs=[]
    if start_pin_id==1:
        layer_outputs.append(np.array([X[0]]))
    
    for ID in range(0, nb_mzi):
        # take input vector
        first_pin_pos=2*ID+start_pin_id
        second_pin_pos=first_pin_pos+1
        local_inp = X[first_pin_pos:second_pin_pos+1]
        # compute the output vector
        local_out=MZI(local_inp, W[ID])
        layer_outputs.append(local_out)
    
    if start_pin_id==1:
        layer_outputs.append(np.array([X[-1]]))
    
    Y=np.concatenate(layer_outputs)
    return Y

def onn(X, nb_mzis, weights):
    nb_layers=len(weights)

    def recusive_layer_builder(id_layer=0):
        if id_layer==nb_layers-1: # last layer. No dependency
            input_shape=X.shape
            y=MZI_col(X, nb_mzis[id_layer], weights[id_layer])
        else:
            y = recusive_layer_builder(id_layer + 1)
            input_shape=y.shape
            y=MZI_col(y, nb_mzis[id_layer], weights[id_layer])
        return y

    Y=recusive_layer_builder()
    return Y

def spec_mesh(cols, mzi_per_col): #e.g. 6,6->3,2,3,2,3,2
    cols=n_comp
    mzi_per_col=n_comp//2
    nb_mzis=[]
    for i in range(cols):
        nb_mzis.append( mzi_per_col-i%2 )
    return nb_mzis

def glorot_init(nb_mzi):
    key = jax.random.PRNGKey(random_seed)  # random seed is explicit
    weights=jax.random.normal(shape=(nb_mzi,), key=key,dtype=np.float32) * np.sqrt(0.5)
    return weights

## Simple function learning: [0,1,0,1] -> [1,0,1,0]
Simple problem before solving harder problems

Configuration of the training dataset and ONN

In [41]:
X=np.array([1, 0, 1, 0])
Y=np.array([0, 1, 0, 1])
nb_MZIs=(2,1,2,1)
W=[]
for n in nb_MZIs:
    W.append(glorot_init(n))

Backward

In [42]:
def circuit(X, W):
    return onn(X, nb_MZIs, W)
    
# Create the circuit with metric
def circuit_to_opt(*args):
    y_=circuit(*args)
    loss=np.mean((Y-y_)**2)
    return loss

deriv_circuit_to_opt=jax.grad(circuit_to_opt, argnums=(-1,))

Training loop

In [40]:
lr=0.5
print("First pred.:", circuit(X,W))
for i in range(10):
    
    # forward phase
    print("current loss:", circuit_to_opt(X,W))

    # backward phase
    dW=deriv_circuit_to_opt(X,W)[0]

    # Update using the gradient information
    for i, dWi in enumerate(dW):
        W[i] = W[i] - lr * dWi
print("Final pred.:", circuit(X,W))

First pred.: [0.03837496 1.0827922  0.01405865 0.90878546]
current loss: 0.0042112293
current loss: 0.0031466344
current loss: 0.002370951
current loss: 0.0017973367
current loss: 0.0013688633
current loss: 0.0010464268
current loss: 0.000802356
current loss: 0.0006167479
current loss: 0.00047505228
current loss: 0.00036653606
Final pred.: [0.00749612 1.0225742  0.00562081 0.9768595 ]


Generate Python code of the backward Jax representation

In [44]:
import sys
import os
sys.path.append("/home/pierrick/PycharmProjects/JaxDecompiler/")
#for module_name in ["decompiler", "primitive_mapping"]:
#    os.remove(sys.modules[module_name].__cached__)  # remove cached bytecode
#    del sys.modules[module_name]
import decompiler # JaxDecompiler
df, c= decompiler.python_jaxpr_python(deriv_circuit_to_opt, (X, W), is_python_returned=True)
print(c)

import jax
from jax.numpy import *
def f(b, c, d, e, f):
    a = Array([0, 1, 0, 1], dtype=int32)
    g = b[0:0+(1,)[0]] # dynamic slice
    h = squeeze(g)
    i = array(broadcast_to(h, (1,)))
    j = b[1:1+(2,)[0]] # dynamic slice
    k = f[0:0+(1,)[0]] # dynamic slice
    l = squeeze(k)
    m = cos(l)
    n = sin(l)
    o = sin(l)
    p = cos(l)
    q = -o
    r = sin(l)
    s = cos(l)
    t = cos(l)
    u = sin(l)
    v = array(broadcast_to(m, (1,)))
    w = array(broadcast_to(q, (1,)))
    x = concatenate((v, w), axis=0)
    y = array(broadcast_to(r, (1,)))
    z = array(broadcast_to(t, (1,)))
    ba = concatenate((y, z), axis=0)
    bb = array(broadcast_to(x, (1, 2)))
    bc = array(broadcast_to(ba, (1, 2)))
    bd = concatenate((bb, bc), axis=0)
    be = array(j).astype(float32)
    bf = dot(bd, be)
    bg = -1 + 4
    bh = array(bg).astype(int32)
    bi = array(broadcast_to(bh, (1,)))
    bj = squeeze( b[bi[0] if len(bi)>0 else 0:bi[0]+1] , axis=(0,))
    bk = array(broadcast_to

# MNIST classification

Read the raw dataset

In [19]:
from keras.datasets import mnist
import numpy as npo
(train_X, train_y), (test_X, test_y) = mnist.load_data()

Preprocessing dataset (scaling, reshaping, projection, shuffling...)

In [20]:
# data processing
train_X=train_X.reshape((60000,28*28))/255.
test_X=test_X.reshape((10000,28*28))/255.

# projection
n_comp=10
from sklearn.decomposition import PCA
proj = PCA(n_components = n_comp)
train_X = proj.fit_transform(train_X)
test_X=proj.transform(test_X)

# label processing into one-hot vector
train_y2=npo.zeros((60000,n_comp),dtype=float)
test_y2=npo.zeros((10000,n_comp), dtype=float)
for i,v in enumerate(train_y):
    train_y2[i][v]=1.

for i,v in enumerate(test_y):
    test_y2[i][v]=1.

# shuffling
ids=npo.array(range(len(train_X)))
npo.random.shuffle(ids)
train_X=train_X[ids]
train_y2=train_y2[ids]

ids=npo.array(range(len(test_X)))
npo.random.shuffle(ids)
test_X=test_X[ids]
test_y2=test_y2[ids]

# Dimension check
print(train_X.shape)
print(train_y2.shape)
print(test_X.shape)
print(test_y2.shape)

(60000, 10)
(60000, 10)
(10000, 10)
(10000, 10)


Definition of the ONN (forward)

In [176]:
# Place holders
dumb_X=np.zeros((n_comp,))
dumb_Y=np.zeros((n_comp,))

# Init weights
nb_mzis=spec_mesh(10, 5)
W=[]
for n in nb_mzis:
    W.append(glorot_init(n))

# compilation of the onn
@jax.jit
def onn10(X, W):
    # architecture spec. cannot be given as input of onn10 otherwise it cannot be compilable
    nb_mzis=spec_mesh(10, 5) #10 columns of 5 or 4 MZIs
    return onn(X, nb_mzis, W)
circuit=onn10

Backward definition

In [21]:
# Create the circuit with metric
def circuit_to_opt(*args):
    y_pred=circuit(*(args[0], args[2])) #0:X, 2:W
    y_expected=args[1] #1:Y
    loss=np.mean((y_expected-y_pred)**2)
    return loss

deriv_circuit_to_opt=jax.grad(circuit_to_opt, argnums=(-1,))

Training loop

In [None]:
n_eval=1000
n_train=60000
eval_every=10000
for ds_id in range(0, n_train): # for each data sample
    # input the photonic circuit
    circuit_to_opt_args=(train_X[ds_id], train_y2[ds_id], W)

    # backward phase
    dW=deriv_circuit_to_opt(*circuit_to_opt_args)[0]

    # Update using the gradient information
    if isinstance(dW, list): # dW is always a sequence of arrays
        for i, dWi in enumerate(dW):
            circuit_to_opt_args[trainable_arg_id][i] = circuit_to_opt_args[trainable_arg_id][i] - lr * dWi
    else:
        circuit_to_opt_args[trainable_arg_id] = circuit_to_opt_args[trainable_arg_id] - lr * dW

    # EVALUATION
    if ds_id%eval_every==0 and ds_id>0:
        is_ok=0
        for ds_id in range(0, n_eval):
            y_pred=circuit(*(test_X[ds_id], W))
            is_ok+=np.argmax(y_pred)==np.argmax(test_y2[ds_id])
        print(f"accuracy:{float(is_ok)/n_eval}")
        lr/=2.

accuracy:0.565
accuracy:0.604
accuracy:0.621
accuracy:0.624
accuracy:0.626
