# Optical Neural Network with numpy/jax

![alt text](mzi_mesh.jpg "Mesh type used")

## Trainable photonic circuit
MZI, column of MZI, mesh, ...

In [1]:
import jax
from jax import numpy as np

random_seed=1

def MZI(X, teta):
    R = np.array([
      [np.cos(teta), -np.sin(teta)],
      [np.sin(teta), np.cos(teta)]
    ])
    out_vector=np.dot(R, X)
    return out_vector

def MZI_col(X, nb_mzi, W):
    # Column type: odd or even ?
    nb_pins=nb_mzi*2
    if nb_pins==len(X):
        start_pin_id=0
    elif nb_pins+2==len(X):
        start_pin_id=1
    else:
        raise ValueError("This mesh patern is not compatible with this input size and #MZIs")

    # pin them
    layer_outputs=[]
    if start_pin_id==1:
        layer_outputs.append(np.array([X[0]]))
    
    for ID in range(0, nb_mzi):
        # take input vector
        first_pin_pos=2*ID+start_pin_id
        second_pin_pos=first_pin_pos+1
        local_inp = X[first_pin_pos:second_pin_pos+1]
        # compute the output vector
        local_out=MZI(local_inp, W[ID])
        layer_outputs.append(local_out)
    
    if start_pin_id==1:
        layer_outputs.append(np.array([X[-1]]))
    
    Y=np.concatenate(layer_outputs)
    return Y

def onn(X, nb_mzis, weights):
    nb_layers=len(weights)

    def recusive_layer_builder(id_layer=0):
        if id_layer==nb_layers-1: # last layer. No dependency
            input_shape=X.shape
            y=MZI_col(X, nb_mzis[id_layer], weights[id_layer])
        else:
            y = recusive_layer_builder(id_layer + 1)
            input_shape=y.shape
            y=MZI_col(y, nb_mzis[id_layer], weights[id_layer])
        return y

    Y=recusive_layer_builder()
    return Y

def spec_mesh(cols, mzi_per_col): #e.g. 6,6->3,2,3,2,3,2
    cols=n_comp
    mzi_per_col=n_comp//2
    nb_mzis=[]
    for i in range(cols):
        nb_mzis.append( mzi_per_col-i%2 )
    return nb_mzis

def glorot_init(nb_mzi):
    key = jax.random.PRNGKey(random_seed)  # random seed is explicit
    weights=jax.random.normal(shape=(nb_mzi,), key=key,dtype=np.float32) * np.sqrt(0.5)
    return weights

## Simple function learning: [0,1,0,1] -> [1,0,1,0]
Simple problem before solving harder problems

Configuration of the training dataset and ONN

In [45]:
X=np.array([1, 0, 1, 0])
Y=np.array([0, 1, 0, 1])
nb_MZIs=(2,1,)
W=[]
for n in nb_MZIs:
    W.append(glorot_init(n))

Backward

In [46]:
def circuit(X, W):
    return onn(X, nb_MZIs, W)
    
# Create the circuit with metric
def circuit_to_opt(*args):
    y_=circuit(*args)
    loss=np.mean((Y-y_)**2)
    return loss

deriv_circuit_to_opt=jax.grad(circuit_to_opt, argnums=(-1,))

Training loop

In [47]:
lr=0.5
print("First pred.:", circuit(X,W))
for i in range(10):
    
    # forward phase
    print("current loss:", circuit_to_opt(X,W))

    # backward phase
    dW=deriv_circuit_to_opt(X,W)[0]

    # Update using the gradient information
    for i, dWi in enumerate(dW):
        W[i] = W[i] - lr * dWi
print("Final pred.:", circuit(X,W))

First pred.: [1.0575861  0.6583579  0.00422926 0.6693718 ]
current loss: 0.33613512
current loss: 0.20870842
current loss: 0.12832996
current loss: 0.08206382
current loss: 0.055986747
current loss: 0.040702645
current loss: 0.031109665
current loss: 0.024656296
current loss: 0.020063689
current loss: 0.016656145
Final pred.: [9.6791387e-02 1.1383200e+00 6.7896367e-04 8.3358181e-01]


Generate Python code of the backward Jax representation

In [48]:
import sys
import os
sys.path.append("/home/pierrick/PycharmProjects/JaxDecompiler/")
#for module_name in ["decompiler", "primitive_mapping"]:
#    os.remove(sys.modules[module_name].__cached__)  # remove cached bytecode
#    del sys.modules[module_name]
import decompiler # JaxDecompiler
df, c= decompiler.python_jaxpr_python(deriv_circuit_to_opt, (X, W), is_python_returned=True)
print(c)

import jax
from jax.numpy import *
def f(b, c, d):
    a = Array([0, 1, 0, 1], dtype=int32)
    e = b[0:0+(1,)[0]] # dynamic slice
    f = squeeze(e)
    g = array(broadcast_to(f, (1,)))
    h = b[1:1+(2,)[0]] # dynamic slice
    i = d[0:0+(1,)[0]] # dynamic slice
    j = squeeze(i)
    k = cos(j)
    l = sin(j)
    m = sin(j)
    n = cos(j)
    o = -m
    p = sin(j)
    q = cos(j)
    r = cos(j)
    s = sin(j)
    t = array(broadcast_to(k, (1,)))
    u = array(broadcast_to(o, (1,)))
    v = concatenate((t, u), axis=0)
    w = array(broadcast_to(p, (1,)))
    x = array(broadcast_to(r, (1,)))
    y = concatenate((w, x), axis=0)
    z = array(broadcast_to(v, (1, 2)))
    ba = array(broadcast_to(y, (1, 2)))
    bb = concatenate((z, ba), axis=0)
    bc = array(h).astype(float32)
    bd = dot(bb, bc)
    be = -1 + 4
    bf = array(be).astype(int32)
    bg = array(broadcast_to(bf, (1,)))
    bh = squeeze( b[bg[0] if len(bg)>0 else 0:bg[0]+1] , axis=(0,))
    bi = array(broadcast_to(bh, (1,))

# MNIST classification

Read the raw dataset

In [82]:
from keras.datasets import mnist
import numpy as npo
(train_X, train_y), (test_X, test_y) = mnist.load_data()

Preprocessing dataset (croping, interpolating, intensity scaling, reshaping, projection, shuffling...)

In [83]:
# Cropping
train_X=train_X[:,4:24,4:24]
test_X=test_X[:,4:24,4:24]

# Interpolating
from scipy.ndimage import zoom
train_X= zoom(train_X,(1.,.5,.5),order=3)  #order = 3 for cubic interpolation
test_X= zoom(test_X,(1.,.5,.5),order=3)

# intensity scaling and flatting
train_X=train_X.reshape((len(train_X),10*10))/255.
test_X=test_X.reshape((len(test_X),10*10))/255.

# projection
n_comp=10 # variance explained is only 52%
from sklearn.decomposition import PCA
proj = PCA(n_components = n_comp)
train_X = proj.fit_transform(train_X)
test_X=proj.transform(test_X)
print(f"PCA variance explained: {sum(proj.explained_variance_ratio_)}")

# label processing into one-hot vector
train_y2=npo.zeros((len(train_X),n_comp),dtype=float)
test_y2=npo.zeros((len(test_X),n_comp), dtype=float)
for i,v in enumerate(train_y):
    train_y2[i][v]=1.

for i,v in enumerate(test_y):
    test_y2[i][v]=1.

# shuffling
ids=npo.array(range(len(train_X)))
npo.random.shuffle(ids)
train_X=train_X[ids]
train_y2=train_y2[ids]

ids=npo.array(range(len(test_X)))
npo.random.shuffle(ids)
test_X=test_X[ids]
test_y2=test_y2[ids]

# Dimension check
print(train_X.shape)
print(train_y2.shape)
print(test_X.shape)
print(test_y2.shape)

PCA variance explained: 0.5216947862659245
(60000, 10)
(60000, 10)
(10000, 10)
(10000, 10)


Definition of the ONN (forward)

In [116]:
# Init weights
nb_mzis=spec_mesh(10, 5)
W=[]
for n in nb_mzis:
    W.append(glorot_init(n))

# compilation of the onn
def onn10(X, W):
    return onn(X, nb_mzis, W)
circuit=jax.jit(onn10) # JIT circuit is ~3300 times faster!

Backward definition

In [117]:
# Create the circuit with metric
def circuit_to_opt(*args):
    y_pred=circuit(*(args[0], args[2])) #0:X, 2:W
    y_expected=args[1] #1:Y
    loss=np.mean((y_expected-y_pred)**2)
    return loss

deriv_circuit_to_opt=jax.grad(circuit_to_opt, argnums=(-1,))
deriv_circuit_to_opt=jax.jit(deriv_circuit_to_opt) # JIT circuit is ~3300 times faster!

Training loop

In [118]:
random_seed=1
lr=0.1
for e in range(5): # for each epoch
    
    # Data shuffling
    ids=npo.array(range(len(train_X)))
    npo.random.shuffle(ids)
    train_X=train_X[ids]
    train_y2=train_y2[ids]
    
    # Training
    for X,Y in zip(train_X, train_y2): # for each data sample
        # backward phase
        dW=deriv_circuit_to_opt(X, Y, W)[0]

        # Update using the gradient information
        for i, dWi in enumerate(dW):
            W[i] = W[i] - lr * dWi

    # Evaluation
    nb_correct=0
    for X,Y in zip(test_X, test_y2):        
        y_pred=circuit(X, W)
        nb_correct+=np.argmax(y_pred)==np.argmax(Y)
    print(f"accuracy:{float(nb_correct)/len(test_y2)}")
    lr/=10.

accuracy:0.6496
accuracy:0.6677
accuracy:0.6686
accuracy:0.669
accuracy:0.6691


## Ensemble of ONN

In [113]:
# BUILDING A SECOND MNIST CLASSIFIER

nb_mzis=spec_mesh(10, 5)
lr=0.1

Ensemble_W=[]
for ens in range(5):
    random_seed=ens
    
    W2=[]
    for n in nb_mzis:
        W2.append(glorot_init(n))
    
    # Data shuffling
    ids=npo.array(range(len(train_X)))
    npo.random.shuffle(ids)
    train_X=train_X[ids]
    train_y2=train_y2[ids]
    
    #training
    for X,Y in zip(train_X, train_y2): # for each data sample
        # backward phase
        dW=deriv_circuit_to_opt(X, Y, W2)[0]

        # Update using the gradient information
        for i, dWi in enumerate(dW):
            W2[i] = W2[i] - lr * dWi

    # Evaluation
    nb_correct=0
    for X,Y in zip(test_X, test_y2):        
        y_pred=circuit(X, W2)
        nb_correct+=np.argmax(y_pred)==np.argmax(Y)
    print(f"Accuracy of the model {ens}: {float(nb_correct)/len(test_y2)}")

    Ensemble_W.append(W2)
    
# Ensemble evaluation
nb_correct=0
for X,Y in zip(test_X, test_y2):        
    y_pred=np.average(np.array([circuit(X, Wi) for Wi in Ensemble_W]),axis=0)
    nb_correct+=np.argmax(y_pred)==np.argmax(Y)
print(f"Ensemble accuracy: {float(nb_correct)/len(test_y2)}")

Accuracy of the model 0: 0.6672
Accuracy of the model 1: 0.6527
Accuracy of the model 2: 0.6741
Accuracy of the model 3: 0.6434
Accuracy of the model 4: 0.6703
Ensemble accuracy: 0.6756
