# From NN to ONN, different strategies comparison








Given Y=WX a standard trained NN layer, with Y the prediction, X input data, and W ALREADY TRAINED weights.

Convert the weights W into a sequence of 2x2 rotation matrices (describing MZI) and scaling vector (photonic attenuator).


**Strategy 1 simple 2x2 decomposition**

**Strategy 2 Teacher-Student**

[3] https://arxiv.org/pdf/1503.02531.pdf Knowledge Distilling

**Strategy 3 Teacher-Student with prior decompotion**

**Strategy 4 TT-decompotion**

[1] https://aip.scitation.org/doi/10.1063/5.0070913# Inspiration for designing the strategy 1

[2] https://github.com/Bihaqo/t3f/ T3F framework for TensorTrain arithmetics based on Tensorflow. I implement with numpy.

In [24]:
import numpy as np
np.random.seed(0)
N=4
weights = np.random.uniform(-1., +1., (N, N))
X=np.random.uniform(-1., +1., (N, 1))

Y=np.dot(weights, X)
print("X")
print(X)
print("weights")
print(weights)
print("Y")
print(Y)

X
[[ 0.10438494]
 [ 0.16895214]
 [ 0.92387276]
 [-0.41570495]
 [-0.51834244]
 [-0.79941212]
 [-0.96714074]
 [ 0.85905863]
 [ 0.33983309]
 [ 0.57030582]
 [-0.43653979]
 [ 0.17282033]
 [-0.87208947]
 [-0.02874481]
 [ 0.95499028]
 [ 0.75301049]]
weights
[[ 0.09762701  0.43037873  0.20552675  0.08976637 -0.1526904   0.29178823
  -0.12482558  0.783546    0.92732552 -0.23311696  0.58345008  0.05778984
   0.13608912  0.85119328 -0.85792788 -0.8257414 ]
 [-0.95956321  0.66523969  0.5563135   0.7400243   0.95723668  0.59831713
  -0.07704128  0.56105835 -0.76345115  0.27984204 -0.71329343  0.88933783
   0.04369664 -0.17067612 -0.47088878  0.54846738]
 [-0.08769934  0.1368679  -0.9624204   0.23527099  0.22419145  0.23386799
   0.88749616  0.3636406  -0.2809842  -0.12593609  0.39526239 -0.87954906
   0.33353343  0.34127574 -0.57923488 -0.7421474 ]
 [-0.3691433  -0.27257846  0.14039354 -0.12279697  0.97674768 -0.79591038
  -0.58224649 -0.67738096  0.30621665 -0.49341679 -0.06737845 -0.51114882
  -0

# Strategy 1: rank-2 TensorTrain decomposition

Illustration of the strategy:

![alt text](SVD.png "svd decomposition")



### Util functions:
* from_arr_to_tt : Converts np.ndarray into TensorTrain
* from_tt_to_arr : Converts TensorTrain format into np.ndarray
* tt_dot : Dot product between TensorTrain tensors

## Strategy 1: Tiled 2x2 matrices decomposition

Steps:
* Split NxN matrix into 2x2 tiles
* Using eiven values/vectors decomposition to compute 2x2 rotation matrices 2d scaling vectors
* Using those matrices/vectors, to compute phase shift in MZIs/attenuators based on arccos/arcsin

## Utils
* from_arr_to_tiles
* from_tiles_to_arr
* dot_tile_vec

In [79]:
def from_arr_to_tiles(matrix):
    n = matrix.shape[0]
    if n % 2 != 0:
        raise ValueError("Matrix must be of even size")

    block_matrices = []
    for i in range(0, n, 2):
        for j in range(0, n, 2):
            block = matrix[i:i+2, j:j+2]
            block_matrices.append(block)
    return block_matrices

def from_tiles_to_arr(block_matrices):
    n = len(block_matrices)
    size = int(np.sqrt(n))
    if size ** 2 != n:
        raise ValueError("Number of block matrices must be a perfect square")

    C = np.zeros((size*2, size*2))
    for i in range(size):
        for j in range(size):
            C[i*2:i*2+2, j*2:j*2+2] = block_matrices[i*size+j]
    return C

def dot_tile_vec(block_matrices, x):
    """tiled matrix and normal vector multiplication """
    n = len(block_matrices)
    size = int(np.sqrt(n))
    if size ** 2 != n:
        raise ValueError("Number of block matrices must be a perfect square")
    if x.shape[0] != size*2:
        raise ValueError("Input vector must have the same size as the matrix")
    y = x.copy()
    result = np.zeros(y.shape)
    for i in range(size):
        for j in range(size):
            b1=block_matrices[i*size+j]
            b2=y[j*2:j*2+2]
            result[i*2:i*2+2] += np.dot(b1, b2) #<----- 2x2 dot between matrix and vector
    return result

def tiled_prediction(tiled_mat_a, tiled_vec_b, vec_x):
    """tiled computation between: (a.x)+b. 
    With a tiled matrix format and b a tiled vector format. x is a standard vector. """
    n = len(tiled_mat_a)
    size = int(np.sqrt(n))
    y = vec_x.copy()
    result = np.zeros(y.shape)
    for i in range(size):
        for j in range(size):
            b1=tiled_mat_a[i * size + j] # <---- 2x2 matrix
            b2=y[j*2:j*2+2] # <------ 1x2 numpy array
            b3=np.array([tiled_vec_b[i * size + j]]) # <------- 1x2 numpy array
            result[i*2:i*2+2] += np.dot(b1, b2) + np.diag(b3) #<----- The core computing is here
    return result


def from_random_to_rotation(x):
    eigenvalues, eigenvectors = np.linalg.eig(x)
    
    # rotation matrix are not unique. Different runs may produces different rotation. However, the dot product of them is unique.
    rot0=[[eigenvectors[0][0], -eigenvectors[1][0]],
          [eigenvectors[1][0], eigenvectors[0][0]]]
    rot1=[[eigenvectors[0][1], -eigenvectors[1][1]],
          [eigenvectors[1][1], eigenvectors[0][1]]]
    rot=np.dot(rot0,rot1)
    
    make_real=lambda x: x.real
    rot[0][0]=make_real(rot[0][0])
    rot[1][0]=make_real(rot[1][0])
    rot[0][1]=make_real(rot[0][1])
    rot[1][1]=make_real(rot[1][1])
    
    eigenvalues[0]=make_real(eigenvalues[0])
    eigenvalues[1]=make_real(eigenvalues[1])
    
    return rot.astype(np.float32), eigenvalues.astype(np.float32)

def check_and_validate_2x2_rot_mat(x):
    epsilon=1e-2
    if x.shape != (2, 2):
        raise ValueError(f"Matrix {x} is not 2x2")
    det = np.linalg.det(x)
    if abs(det - 1) > epsilon: #The determinant of a rotation matrix is always 1, because the matrix preserves the volume of the space it acts on. 
        raise ValueError(f"Matrix {x} is not a rotation matrix. Determinant: {det}")
    if np.allclose(np.dot(x, x.T),np.eye(2),atol=epsilon): #
        raise ValueError(f"Matrix {x} is not orthogonal. x.xT = {I}")

def rotation_mzi_angle(x):
    #check_and_validate_2x2_rot_mat(x)

    clipped_x=np.clip(x, -1, 1)
    angle = np.arccos(clipped_x[0][0])
    # revert the angle if needed
    if clipped_x[1,0] < 0:
        angle = 2*np.pi - angle 
    return angle

def rotation_2attenuators_angle(x):
    clipped_x=np.clip(x, -1, 1)
    return np.arccos(clipped_x)

### Converts NxN random weights -> 2x2 rotation matrices and 2d scaling vectors

In [80]:
# Tiling weights 4x4 into 4 2x2 matrices
w_t=from_arr_to_tiles(weights)

# From 2x2 matrix to 2x2 rotation matrix
w_rot_t=[]
w_val_t=[]
for w_ti in w_t:
    rot_matrix, eigen_values=from_random_to_rotation(w_ti)
    w_rot_t.append(rot_matrix)
    w_val_t.append(eigen_values)

  return rot.astype(np.float32), eigenvalues.astype(np.float32)


### Prediction with the 2x2 rotation matrices and 2D scaling vec.

In [81]:
# Predicting with 2x2 rotations and 2d scaling vectors 
reconstructed_y=tiled_prediction(w_rot_t, w_val_t, X)
print(f"Expected Y:", Y)
print(f"Reconstructed Y:", reconstructed_y)
print(f"Prediction MSE: {np.mean((Y-Y_reconstructed)**2)}")

Expected Y: [[ 0.37339233]
 [ 0.85102612]
 [-0.67755906]
 [-0.65268413]]
Reconstructed Y: [[-1.03054864]
 [ 0.92809649]
 [ 0.67841055]
 [ 2.26272296]]
Prediction MSE: 0.0539663482169657


### MZI and attenuators angles of the phase shift (radian)

In [82]:
# MZI angles
for i in range(len(w_rot_t)):
    mzi_theta=rotation_mzi_angle(w_rot_t[i])
    atts_theta=rotation_2attenuators_angle(w_val_t[i])
    print(f"MZI#{i} rotation angle = {str(round(mzi_theta,7))} , associated attenuators: {np.round(atts_theta,7)}")


MZI#0 rotation angle = 1.074408 , associated attenuators: [1.3748369 1.3748369]
MZI#1 rotation angle = 1.6313761 , associated attenuators: [1.3432273 0.7021385]
MZI#2 rotation angle = 1.3048719 , associated attenuators: [0.4750728 0.4750728]
MZI#3 rotation angle = 1.0543903 , associated attenuators: [0.9916114 2.4810073]


## Strategy 2: Teacher-Student approach 

The teacher is a standard ANN (NxN weights) and the student ONN (set of 2x2 rotation matrices). This method is a data driven procedure.

### Generate random data to get the teacher predictions

In [107]:
def teacher(W, nb_data):
    synt_data=np.random.uniform(-1., +1., (nb_data , 4, 1))
    teacher_pred=[]
    for x in synt_data:
        pred_Y=np.dot(W,x)# <-- teacher model
        teacher_pred.append( pred_Y )
    
    #reformating by squeezing
    synt_data=synt_data.squeeze()
    teacher_pred=np.array(teacher_pred).squeeze()
    return synt_data, teacher_pred
teacher_X, teacher_pred = teacher(weights, 1000)

### The student is trained to mimic the teacher

In [109]:
from ONN import ONN # contains photonic neural network
hp={"lr":1., "lr_decay":10., "layers":[4], "epochs": 5}
student_onn=ONN(hp)
student_onn.initialize()   
student_onn.fit(teacher_X, teacher_pred, teacher_X, teacher_pred)

0.10698833
0.07487212
0.07161769
0.07159352
0.07159182


### Prediction with the student model (ONN)

In [110]:
Y_reconstructed=student_onn.predict(np.array([X.squeeze()])).squeeze()
Y_expected=Y.squeeze()

print(f"Y expected: {Y_expected}")
print(f"Y reconstructed: {Y_reconstructed}")
print(f"Prediction MSE: {np.mean((Y_expected-Y_reconstructed)**2)}")

Y expected: [ 0.37339233  0.85102612 -0.67755906 -0.65268413]
Y reconstructed: [ 0.93145835  0.62140274 -0.75870734 -0.62532663]
Prediction MSE: 0.09287450462579727


### Theta of phase shifters

In [111]:
print(model.W)

[Array([-3.1486027,  0.282853 ], dtype=float32), Array([1.6477056], dtype=float32), Array([ 1.4729869, -2.0515711], dtype=float32), Array([-0.7082593], dtype=float32)]


NB: The teacher is 4x4 matrix, and the student contains 4 columns of respectively 2,1,2,1 MZIs. The teacher contains 16 parameters and the student only 6 params.

## Strategy 4: Teacher-Student with prior SVD decomposition

After SVD decomposition, we have two teachers and two associated students. The procedure is similar than strategy 3 but applied indpedantly on S and V.

teacher_u, teacher_v produce preds

In [112]:
u, s, vT = np.linalg.svd(weights, full_matrices=False)
s_diag=np.diag(s)
teacher_u_X, teacher_u_pred = teacher(u, 1000)
teacher_v_X, teacher_v_pred = teacher(vT, 1000)

student_u, student_v are trained to mimic their teacher

In [113]:
from ONN import ONN # contains photonic neural network
hp={"lr":1., "lr_decay":10., "layers":[4], "epochs": 5}
student_u=ONN(hp)
student_u.initialize() 
print("Teacher_u -> Student_u training")
student_u.fit(teacher_u_X, teacher_u_pred, teacher_u_X, teacher_u_pred)

hp={"lr":1., "lr_decay":10., "layers":[4], "epochs": 5}
student_v=ONN(hp)
student_v.initialize() 
print("Teacher_v -> Student_v training")
student_v.fit(teacher_v_X, teacher_v_pred, teacher_v_X, teacher_v_pred)

Teacher_u -> Student_u training
0.33702737
0.33489755
0.3336078
0.33354986
0.33354458
Teacher_v -> Student_v training
0.34730762
0.34290767
0.34273097
0.34270787
0.34270602


Prediction on new data with student_u, student_v and diag_s

In [114]:
# compute v.T*x
vTx=student_u.predict(np.array([X.squeeze()])).squeeze()
# compute u*v.T*x
uvTx=student_v.predict(np.array([vTx.squeeze()])).squeeze()
# compute diag_s*u*v.T*x
Y_reconstructed=np.dot(s_diag , uvTx)

Y_expected=Y.squeeze()
print(f"Y expected: {Y_expected}")
print(f"Y reconstructed: {Y_reconstructed}")
print(f"Prediction MSE: {np.mean((Y_expected-Y_reconstructed)**2)}")

Y expected: [ 0.37339233  0.85102612 -0.67755906 -0.65268413]
Y reconstructed: [ 1.60870897  0.63157953 -0.01950631  0.29593791]
Prediction MSE: 0.7267703008886807


### Theta of phase shifters

In [115]:
print("U phase shifters:")
print(student_u.W)
print("V phase shifters:")
print(student_v.W)
print("S attenuators:")
print(s)

U phase shifters:
[Array([-2.5794332, -5.7190742], dtype=float32), Array([5.4699335], dtype=float32), Array([-0.9637948,  0.1765687], dtype=float32), Array([4.623848], dtype=float32)]
V phase shifters:
[Array([-1.0001415, -3.054548 ], dtype=float32), Array([1.2980582], dtype=float32), Array([-4.163986, -8.60073 ], dtype=float32), Array([9.581584], dtype=float32)]
S attenuators:
[1.57967707 1.09579685 0.77645471 0.32089482]


Conclusion:  Different random seeds show that strat3 is more accurate than strat4. "y_pred=approx(w).x" seems more accurate than successive errors "y_pred=s.approx(u).approx(v).x" .

## Strategy 5: Auto-Encoder from NN weights to ONN weights