# Optical Neural Network with numpy/jax

Simple example Y=WX
![alt text](mzi_mesh.jpg "Mesh type used")

## Trainable photonic circuit


###  Noise function according to AnalogVNN formula
source: https://arxiv.org/pdf/2210.10048.pdf

![alt text](analogVNN.png "AnalogVNN big picture figure")

In [89]:
import jax
from jax import numpy as np

random_seed=1
noisy=True

def get_key():
    global random_seed
    random_seed+=1
    return jax.random.PRNGKey(random_seed)

def no_back(f): 
    """ Decorator to avoid backpropagation of the decorated function.
    For example it is useful for "round(x)" """
    def decorated_f(x, *args):
        # Create an exactly-zero expression with Sterbenz lemma that has
        # an exactly-one gradient.
        # URL : https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html
        zero = x - jax.lax.stop_gradient(x)
        return zero + jax.lax.stop_gradient(f(x, *args))
    return decorated_f

# Previous function implement noise AnalogVNN: https://arxiv.org/pdf/2210.10048.pdf
def _rounding_with_thresh(g, r):
    g_abs = np.abs(g)
    g_floor = np.floor(g_abs)
    g_ceil = np.ceil(g_abs)
    prob_floor = 1. - np.abs(g_floor - g)
    do_floor = np.array( r <= prob_floor, dtype=np.float32)
    do_ceil = np.array( r > prob_floor, dtype=np.float32)
    return do_floor * g_floor + do_ceil * g_ceil

@no_back
def precion_reduction(x, p):
    """warning precision=4 means 5 potential value:  {0,0.25,0.5,0.75,1}
    substracting by 1 before calling it is maybe always required"""
    r=0.5
    g = x * p
    f = np.sign(g) * _rounding_with_thresh(g, r) * (1. / p)
    return f

@no_back
def stochastic_reduce_precision(x, p):
    g = x * p
    r=jax.random.uniform(shape=x.shape, key=get_key(), dtype=np.float32)
    f = np.sign(g) * _rounding_with_thresh(g, r) * (1. / p)
    return f

def normalization(x):
    return np.clip(x,-1.,+1.)

@no_back
def additive_noise(x, std):
    noise=jax.random.normal(shape=x.shape, key=get_key(),dtype=np.float32) * std
    return x+noise


### Circuits: MZI, column of MZI, mesh, ...

In [90]:
def MZI(X, teta):
    R = np.array([
      [np.cos(teta), -np.sin(teta)],
      [np.sin(teta), np.cos(teta)]
    ])
    out_vector=np.dot(R, X)
    return out_vector

def noisy_MZI(X, teta):
    p_signal=2.**4
    p_weights=2.**4
    noise_signal=1e-3
    noise_weights=1e-3
    
    X=additive_noise( precion_reduction( normalization(X) , p_signal ) , noise_signal)
    teta=additive_noise( stochastic_reduce_precision( normalization(teta) , p_weights), noise_weights)
    
    y=MZI(X, teta)

    y=precion_reduction( normalization( additive_noise(y, noise_signal ) ) , p_signal)

    return y

def MZI_col(X, nb_mzi, W):
    
    if noisy:
        MZI_strat=noisy_MZI
    else:
        MZI_strat=MZI
    
    # Column type: odd or even ?
    nb_pins=nb_mzi*2
    if nb_pins==len(X):
        start_pin_id=0
    elif nb_pins+2==len(X):
        start_pin_id=1
    else:
        raise ValueError("This mesh patern is not compatible with this input size and #MZIs")

    # pin them
    layer_outputs=[]
    if start_pin_id==1:
        layer_outputs.append(np.array([X[0]]))
    
    for ID in range(0, nb_mzi):
        # take input vector
        first_pin_pos=2*ID+start_pin_id
        second_pin_pos=first_pin_pos+1
        local_inp = X[first_pin_pos:second_pin_pos+1]
        
        # compute the output vector
        local_out=MZI_strat(local_inp, W[ID])
        layer_outputs.append(local_out)
    
    if start_pin_id==1:
        layer_outputs.append(np.array([X[-1]]))
    
    Y=np.concatenate(layer_outputs)
    return Y

def onn(X, nb_mzis, weights):
    nb_layers=len(weights)

    def recusive_layer_builder(id_layer=0):
        if id_layer==nb_layers-1: # last layer. No dependency
            input_shape=X.shape
            y=MZI_col(X, nb_mzis[id_layer], weights[id_layer])
        else:
            y = recusive_layer_builder(id_layer + 1)
            input_shape=y.shape
            y=MZI_col(y, nb_mzis[id_layer], weights[id_layer])
        return y

    Y=recusive_layer_builder()
    return Y

def spec_mesh(cols, mzi_per_col): #e.g. 6,3->3,2,3,2,3,2
    cols=n_comp
    mzi_per_col=n_comp//2
    nb_mzis=[]
    for i in range(cols):
        nb_mzis.append( mzi_per_col-i%2 )
    return nb_mzis

def glorot_init(nb_mzi):
    weights=jax.random.normal(shape=(nb_mzi,), key=get_key(),dtype=np.float32) * np.sqrt(0.5)
    return weights

## Simple function learning: [0,1,] -> [1,0]
Simple problem before solving harder problems

#### Configuration of the training dataset and ONN

In [91]:
X=np.array([1, 0])
Y=np.array([0, 1])
nb_MZIs=(1,)
W=[]
for n in nb_MZIs:
    W.append(glorot_init(n))

#### Backward

In [92]:
def circuit(X, W):
    return onn(X, nb_MZIs, W)
    
# Create the circuit with metric
def circuit_to_opt(*args):
    y_=circuit(*args)
    loss=np.mean((Y-y_)**2)
    return loss

deriv_circuit_to_opt=jax.grad(circuit_to_opt, argnums=(-1,))

#### Training loop

In [93]:
lr=0.5
print("First pred.:", circuit(X,W))
for i in range(5):
    
    # forward phase
    print("current loss:", circuit_to_opt(X,W))

    # backward phase
    dW=deriv_circuit_to_opt(X,W)[0]

    # Update using the gradient information
    for i, dWi in enumerate(dW):
        W[i] = W[i] - lr * dWi
print("Final pred.:", circuit(X,W))

First pred.: [ 1.     -0.1875]
current loss: 1.2050781
current loss: 0.6347656
current loss: 0.26757812
current loss: 0.17578125
current loss: 0.17578125
Final pred.: [0.5625 0.8125]


## On-chip learning

#### Information about on-chip learning

* Stochastic Gradient Descent optimization (code below):
    * Pros:
        * Speed and Scalability when the dimensionality (#params) increase
    * Cons: 
        * Above code need to be embedded on-chip
        * Noisy gradient (E.g. MZI noise) -> catastrophic performance (E.g. > 0.001)
* Other optimizer exists:
    * Example:
        * Forward gradient descent
        * Simulated annealing
        * ...
    * Pros:
        * Simpler to implement (no backpropagation)
    * Cons: 
        * They do not scale well when the dimensionality increase

In [94]:
import sys
import os
sys.path.append("/home/pierrick/PycharmProjects/JaxDecompiler/") # <-- Your path here
for module_name in ["decompiler", "primitive_mapping"]:
    if module_name in sys.modules:
        os.remove(sys.modules[module_name].__cached__)  # remove cached bytecode
        del sys.modules[module_name]
import decompiler # JaxDecompiler
df, c= decompiler.python_jaxpr_python(deriv_circuit_to_opt, (X, W), is_python_returned=True)
print("\n".join(c.split("\n")[:20])) # print the 20 first lines

import jax
from jax.numpy import *
from jax._src import prng
def f(b, c):
    a = array([0, 1], dtype=int32)
    d = jax.lax.dynamic_slice_in_dim(c, 0, (1,)[0], axis=0)
    e = squeeze(array(d))
    def local_f0(a, b, c):
        d = array(a).astype(float32)
        e = array([max(b)])
        f = array([min(c)])
        return f
    f = local_f0(b, -1.0, 1.0)
    g = f # stop grad
    h = f - g
    i = f * 16.0
    def local_f1(a):
        b = sign(a)
        return b
    j = local_f1(i)


## MNIST classification

Commonly used dataset

#### Read the raw dataset

In [95]:
from keras.datasets import mnist
import numpy as npo
(train_X, train_y), (test_X, test_y) = mnist.load_data()

#### Preprocessing dataset 
Croping, interpolating, intensity scaling, reshaping, projection, shuffling...

In [96]:
# Cropping
train_X=train_X[:,4:24,4:24]
test_X=test_X[:,4:24,4:24]

# Interpolating
from scipy.ndimage import zoom
train_X= zoom(train_X,(1.,.5,.5),order=3)  #order = 3 for cubic interpolation
test_X= zoom(test_X,(1.,.5,.5),order=3)

# intensity scaling and flatting
train_X=train_X.reshape((len(train_X),10*10))/255.
test_X=test_X.reshape((len(test_X),10*10))/255.

# projection
n_comp=10 # variance explained is only 52%
from sklearn.decomposition import PCA
proj = PCA(n_components = n_comp)
train_X = proj.fit_transform(train_X)
test_X=proj.transform(test_X)
print(f"PCA variance explained: {sum(proj.explained_variance_ratio_)}")

# label processing into one-hot vector
train_y2=npo.zeros((len(train_X),n_comp),dtype=float)
test_y2=npo.zeros((len(test_X),n_comp), dtype=float)
for i,v in enumerate(train_y):
    train_y2[i][v]=1.

for i,v in enumerate(test_y):
    test_y2[i][v]=1.

# shuffling
ids=npo.array(range(len(train_X)))
npo.random.shuffle(ids)
train_X=train_X[ids]
train_y2=train_y2[ids]

ids=npo.array(range(len(test_X)))
npo.random.shuffle(ids)
test_X=test_X[ids]
test_y2=test_y2[ids]

# Dimension check
print(train_X.shape)
print(train_y2.shape)
print(test_X.shape)
print(test_y2.shape)

PCA variance explained: 0.521593761795937
(60000, 10)
(60000, 10)
(10000, 10)
(10000, 10)


#### Definition of the ONN (forward)

In [97]:
# Init weights
nb_mzis=spec_mesh(10, 5)
W=[]
for n in nb_mzis:
    W.append(glorot_init(n))

# compilation of the onn
def onn10(X, W):
    return onn(X, nb_mzis, W)
circuit=jax.jit(onn10) # JIT circuit is ~3300 times faster!

#### Backward definition

In [98]:
# Create the circuit with metric
def circuit_to_opt(*args):
    y_pred=circuit(*(args[0], args[2])) #0:X, 2:W
    y_expected=args[1] #1:Y
    loss=np.mean((y_expected-y_pred)**2)
    return loss

deriv_circuit_to_opt=jax.grad(circuit_to_opt, argnums=(-1,))
deriv_circuit_to_opt=jax.jit(deriv_circuit_to_opt) # JIT circuit is ~3300 times faster!

#### Training loop

In [99]:
random_seed=1
lr=0.1

for e in range(5): # for each epoch
    
    # Data shuffling
    ids=npo.array(range(len(train_X)))
    npo.random.shuffle(ids)
    train_X=train_X[ids]
    train_y2=train_y2[ids]
    
    # Training
    for X,Y in zip(train_X, train_y2): # for each data sample
        # backward phase
        dW=deriv_circuit_to_opt(X, Y, W)[0]

        # Update using the gradient information
        for i, dWi in enumerate(dW):
            W[i] = W[i] - lr * dWi

    # Evaluation
    nb_correct=0
    for X,Y in zip(test_X, test_y2):        
        y_pred=circuit(X, W)
        nb_correct+=np.argmax(y_pred)==np.argmax(Y)
    print(f"accuracy:{float(nb_correct)/len(test_y2)}")
    lr/=10.

accuracy:0.3709
accuracy:0.3736
accuracy:0.3752
accuracy:0.371
accuracy:0.3676


#### Ensemble of ONN

In [None]:
nb_mzis=spec_mesh(10, 5)
lr=0.1

Ensemble_W=[]
for ens in range(5):
    
    W2=[]
    for n in nb_mzis:
        W2.append(glorot_init(n))
    
    # Data shuffling
    ids=npo.array(range(len(train_X)))
    npo.random.shuffle(ids)
    train_X=train_X[ids]
    train_y2=train_y2[ids]
    
    #training
    for X,Y in zip(train_X, train_y2): # for each data sample
        # backward phase
        dW=deriv_circuit_to_opt(X, Y, W2)[0]

        # Update using the gradient information
        for i, dWi in enumerate(dW):
            W2[i] = W2[i] - lr * dWi

    # Evaluation
    nb_correct=0
    for X,Y in zip(test_X, test_y2):        
        y_pred=circuit(X, W2)
        nb_correct+=np.argmax(y_pred)==np.argmax(Y)
    print(f"Accuracy of the model {ens}: {float(nb_correct)/len(test_y2)}")

    Ensemble_W.append(W2)
    
# Ensemble evaluation
nb_correct=0
for X,Y in zip(test_X, test_y2):        
    y_pred=np.average(np.array([circuit(X, Wi) for Wi in Ensemble_W]),axis=0)
    nb_correct+=np.argmax(y_pred)==np.argmax(Y)
print(f"Ensemble accuracy: {float(nb_correct)/len(test_y2)}")

Accuracy of the model 0: 0.4347
Accuracy of the model 1: 0.3755
Accuracy of the model 2: 0.4251
Accuracy of the model 3: 0.4344
Accuracy of the model 4: 0.4163


Conclusion about the ensemble (with no noise):

    * The ensemble is better than base ONN in it
    * An ensemble of 5 base ONNs trained 1 epoch (67.56%) > 1 ONN trained 5 epochs (66.91%)  -> Better usage of computing ressources.