# From ANN to ONN





Given a keras file (neural network) and calibration data, build and train the photonic equivalent.

## Read calibration data

In [1]:
import numpy as np
import keras

# synthetic data
#calib_data=np.random.uniform(-1., +1., (100, 256))

# realistic data
N = 8
nb_data = 100
from ANN import get_db
(train_X, train_Y), (test_X, test_Y) = get_db("MNIST", N, shuffle=True)
calib_data=train_X[:nb_data]


print(calib_data.shape)
print(np.argmax(train_Y[:nb_data],axis=1))

(60000, 64)
(60000, 10)
(10000, 64)
(10000, 10)
(100, 64)
[0 6 2 4 1 5 9 7 2 1 2 9 8 7 1 8 9 5 9 6 8 7 3 0 8 9 9 1 5 2 7 1 5 5 5 9 4
 9 1 0 4 6 0 4 8 8 9 6 7 4 7 7 0 9 5 5 7 9 5 3 5 5 1 2 7 1 3 0 9 7 1 7 9 2
 1 7 9 5 1 8 7 8 1 8 4 1 7 1 4 2 2 6 8 2 8 1 9 3 8 5]


## Read weights

In [3]:
keras_path = "tmp/MNIST/model_0/"
ann = keras.models.load_model(keras_path)

W1 = ann.layers[1].get_weights()[0]
W2 = ann.layers[3].get_weights()[0]

# free memory
keras.backend.clear_session()
del ann

## Teacher (ANN) Student (ONN) settings using SVD decomposition

In [18]:
from TeacherStudent import TeacherMatrix, TeacherStudent, TeacherStudent_SVD
import ONN

LR=0.1
LR_DECAY=10.
EPOCHS=1
DEBUG=True

M,N=W1.shape
col_M=min(M,4) if DEBUG else M # 4 MZI column for debugging
col_N=min(N,4) if DEBUG else N
hp_v = {"lr": LR, "lr_decay": LR_DECAY, "layers": [M], "pattern": ["rectangle"], "col_layer_limit": [col_M]}
hp_u = {"lr": LR, "lr_decay": LR_DECAY, "layers": [N], "pattern": ["rectangle"],
        "col_layer_limit": [col_N]}
student_v_W1 = ONN.ONN(hp_v, {}, {"epochs": EPOCHS, "loss": ONN.clipped_MSE, "metrics": ONN.MSE})
student_u_W1 = ONN.ONN(hp_u, {}, {"epochs": EPOCHS, "loss": ONN.clipped_MSE, "metrics": ONN.MSE})
teacher_student_W1 = TeacherStudent_SVD(W1, student_u_W1, student_v_W1 )


M,N=W2.shape
col_M=min(M,4) if DEBUG else M
col_N=min(N,4) if DEBUG else N
hp_v = {"lr": LR, "lr_decay": LR_DECAY, "layers": [M], "pattern": ["rectangle"], "col_layer_limit": [col_M]}
hp_u = {"lr": LR, "lr_decay": LR_DECAY, "layers": [N], "pattern": ["rectangle"], "col_layer_limit": [col_N]}

student_v_W2 = ONN.ONN(hp_v, {}, {"epochs": EPOCHS, "loss": ONN.clipped_MSE, "metrics": ONN.MSE})
student_u_W2 = ONN.ONN(hp_u, {}, {"epochs": EPOCHS, "loss": ONN.clipped_MSE, "metrics": ONN.MSE})
teacher_student_W2 = TeacherStudent_SVD(W2, student_u_W2, student_v_W2 )

## Teacher-Student training with LR scheduling

In [19]:
# Teacher student training for W1
teacher_student_W1.fit(calib_data)

# Training for W2
A=teacher_student_W1.predict(calib_data) # Notice: Student W2 is trained on W1 prediction
A=np.maximum(A,0)
score=teacher_student_W2.fit(A)

64
1.335458
1.0525422
0.90156543
0.817489
0.76182693
0.7213937
0.6917838
0.67041993
0.65565836
0.6450027
0.6371933
0.63150054
0.6280534
0.62430245
0.6212123
0.6181451
0.6146537
0.6115196
0.60887533
0.6072415
0.60656375
0.6042567
0.6012192
0.59984183
0.5982474
0.59715146
0.5946223
0.59360754
0.5921987
0.5915034
0.58991086
0.5883945
0.5883978
New LR:  0.01
0.58648217
0.5840174
0.58243173
0.582582
New LR:  0.001
0.5818692
0.58201694
New LR:  0.0001
0.5809755
0.58082527
0.57930344
0.5793025
New LR:  1e-05
Stopping
Training score:  0.6034349
Validation score:  0.5793025
Testing score:  0.60408986
Time:  8.72546124458313
64
1.315587
1.0195557
0.86459625
0.7891711
0.74582356
0.7171211
0.6964451
0.6809429
0.6670236
0.6587871
0.6490838
0.64226496
0.635923
0.6289617
0.6223427
0.6183684
0.6114818
0.6062661
0.6018616
0.59801584
0.5930597
0.5894712
0.585773
0.5797347
0.5741379
0.5697549
0.5659505
0.560411
0.55621356
0.5510205
0.54732525
0.5437691
0.5406737
0.5376903
0.53511035
0.5321638
0.5313734
0

## ONN prediction

In [21]:
A=teacher_student_W1.predict(A)
A=np.maximum(A,0) #<- relu
Y2=teacher_student_W2.predict(A)

## Evaluate the prediction

In [22]:
score=np.mean(   np.argmax(Y2,axis=1) == np.argmax(train_Y[:nb_data],axis=1)   )
print("ACCURACY : ", score) # NOTICE: when DEBUG is enabled, this is normal the accuracy is not good!

ACCURACY :  0.05


## Models interpretation

In [27]:
np.histogram(W1)

(array([   2,    4,   12,   74,  387, 1792, 1495,  280,   45,    5]),
 array([-1.5707824 , -1.304939  , -1.0390958 , -0.7732524 , -0.5074091 ,
        -0.24156576,  0.02427757,  0.2901209 ,  0.55596423,  0.82180756,
         1.0876509 ], dtype=float32))

In [28]:
np.histogram(W2)

(array([  4,   2,  26,  51,  77, 164, 153, 126,  28,   9]),
 array([-1.7247293 , -1.4425162 , -1.1603031 , -0.87808996, -0.5958769 ,
        -0.31366378, -0.03145068,  0.25076243,  0.53297555,  0.81518865,
         1.0974017 ], dtype=float32))

In [47]:
print( np.histogram(np.concatenate(teacher_student_W1.student_v.W[0], axis=0)) )
print( np.histogram( teacher_student_W1.s) )
print( np.histogram(np.concatenate(teacher_student_W1.student_u.W[0], axis=0)) )

(array([ 5,  8, 18, 26, 28, 16, 19,  3,  2,  1]), array([-2.2346728 , -1.7132144 , -1.191756  , -0.6702977 , -0.14883932,
        0.37261903,  0.8940774 ,  1.4155358 ,  1.9369941 ,  2.4584525 ,
        2.9799109 ], dtype=float32))
(array([15, 13,  9,  7,  6,  3,  4,  1,  3,  3]), array([0.0079032 , 0.45109016, 0.8942771 , 1.337464  , 1.780651  ,
       2.2238379 , 2.6670249 , 3.1102118 , 3.5533986 , 3.9965856 ,
       4.4397726 ], dtype=float32))
(array([ 9, 10, 11, 17, 22, 24, 23,  5,  4,  1]), array([-2.3940642 , -1.8652995 , -1.3365347 , -0.8077701 , -0.27900538,
        0.24975932,  0.77852404,  1.3072888 ,  1.8360534 ,  2.364818  ,
        2.8935828 ], dtype=float32))


In [48]:
print( np.histogram(np.concatenate(teacher_student_W2.student_v.W[0], axis=0)) )
print( np.histogram( teacher_student_W2.s) )
print( np.histogram(np.concatenate(teacher_student_W2.student_u.W[0], axis=0)) )

(array([10, 21, 31, 21, 22, 13,  5,  2,  0,  1]), array([-1.2621078 , -0.85013866, -0.43816942, -0.02620022,  0.38576898,
        0.7977382 ,  1.2097074 ,  1.6216766 ,  2.0336459 ,  2.445615  ,
        2.8575842 ], dtype=float32))
(array([3, 1, 1, 1, 0, 1, 1, 0, 1, 1]), array([2.235886 , 2.540027 , 2.844168 , 3.1483088, 3.4524498, 3.7565906,
       4.0607314, 4.3648725, 4.6690135, 4.973154 , 5.277295 ],
      dtype=float32))
(array([3, 1, 1, 3, 1, 1, 2, 3, 1, 2]), array([-1.7171109 , -1.3682885 , -1.0194662 , -0.67064387, -0.3218215 ,
        0.02700084,  0.3758232 ,  0.72464556,  1.0734679 ,  1.4222902 ,
        1.7711126 ], dtype=float32))
