In [1]:
import numpy as np
from tqdm import tqdm
import xarray as xr
import matplotlib.pyplot as plt
import warnings
from itertools import product
from polarization_controller import optical_fiber, polarization_splitter_rotator, mzi, phase_aligner, polarization_bidi, polarization_bidi_double_fiber, polarization_bidi_single_fiber
from scipy.optimize import minimize
warnings.filterwarnings("ignore")
plt.style.use("plot_style.mplstyle")

In [2]:
p_cont = polarization_bidi(psr_bool=False, tx_num_mzi_stages=1)

In [3]:
sample_points = 201

voltages = np.linspace(0, 3, sample_points)
phase_shifts = np.linspace(-np.pi, np.pi, sample_points)

rotation_rx = phase_shifts[np.random.permutation(sample_points)]
rotation_ry = phase_shifts[np.random.permutation(sample_points)]
rotation_rz = phase_shifts[np.random.permutation(sample_points)]
rotation_delta = phase_shifts[np.random.permutation(sample_points)]

p_cont.fiber_1._rotation = (rotation_rx[0], rotation_ry[0], rotation_rz[0], rotation_delta[0])
p_cont.fiber_2._rotation = (rotation_rx[1], rotation_ry[1], rotation_rz[1], rotation_delta[1])
p_cont.fiber_3._rotation = (rotation_rx[2], rotation_ry[2], rotation_rz[2], rotation_delta[2])
p_cont.fiber_4._rotation = (rotation_rx[3], rotation_ry[3], rotation_rz[3], rotation_delta[3])
p_cont.fiber_5._rotation = (rotation_rx[4], rotation_ry[4], rotation_rz[4], rotation_delta[4])
p_cont.fiber_6._rotation = (rotation_rx[5], rotation_ry[5], rotation_rz[5], rotation_delta[5])
p_cont.fiber_7._rotation = (rotation_rx[6], rotation_ry[6], rotation_rz[6], rotation_delta[6])

p_cont.recursive_update()

In [4]:
no_sweep = 11
voltages = np.linspace(0, 3.6, no_sweep)

tap_1 = []
tap_2 = []
tap_3 = []

input_state = np.array([0, 0, 1, 0])

for xps0_voltages, xps1_voltages, xps2_voltages, xps3_voltages in tqdm(product(voltages, voltages, voltages, voltages), total=no_sweep**4):
  p_cont.bidi_tx.pa.XPS1.heater_voltage = xps0_voltages
  p_cont.bidi_tx.pa.XPS2.heater_voltage = xps1_voltages
  p_cont.bidi_tx.mzi_1.XPS1.heater_voltage = xps2_voltages
  p_cont.bidi_tx.mzi_1.XPS2.heater_voltage = xps3_voltages
  p_cont.recursive_update()

  output_state = p_cont.smatrix @ input_state

  tap_1.append(np.abs(output_state[0])**2)

for i, j, k, l in tqdm(product(range(no_sweep), range(no_sweep), range(no_sweep), range(no_sweep))):
  if tap_1[i*no_sweep**3 + j*no_sweep**2 + k*no_sweep + l] == np.min(tap_1):
    break

xps0_voltage_step1 = voltages[i]
xps1_voltage_step1 = voltages[j]
xps2_voltage_step1 = voltages[k]
xps3_voltage_step1 = voltages[l]

if xps0_voltage_step1 == np.min(voltages):
  index_min = 0 
else:
  index_min = 1
if xps0_voltage_step1 == np.max(voltages):
  index_max = 0
else:
  index_max = 1

xps0_voltage_range = np.linspace(xps0_voltage_step1 - index_min*np.diff(voltages)[0], xps0_voltage_step1 + index_max*np.diff(voltages)[0], no_sweep)

if xps1_voltage_step1 == np.min(voltages):
  index_min = 0
else:
  index_min = 1 

if xps1_voltage_step1 == np.max(voltages):
  index_max = 0
else:
  index_max = 1

xps1_voltage_range = np.linspace(xps1_voltage_step1 - index_min*np.diff(voltages)[0], xps1_voltage_step1 + index_max*np.diff(voltages)[0], no_sweep)

if xps2_voltage_step1 == np.min(voltages):
  index_min = 0 
else:
  index_min = 1
if xps3_voltage_step1 == np.max(voltages):
  index_max = 0
else:
  index_max = 1

xps2_voltage_range = np.linspace(xps2_voltage_step1 - index_min*np.diff(voltages)[0], xps2_voltage_step1 + index_max*np.diff(voltages)[0], no_sweep)

if xps3_voltage_step1 == np.min(voltages):
  index_min = 0
else:
  index_min = 1 

if xps3_voltage_step1 == np.max(voltages):
  index_max = 0
else:
  index_max = 1

xps3_voltage_range = np.linspace(xps3_voltage_step1 - index_min*np.diff(voltages)[0], xps3_voltage_step1 + index_max*np.diff(voltages)[0], no_sweep)

for xps0_voltages, xps1_voltages, xps2_voltages, xps3_voltages in tqdm(product(xps0_voltage_range, xps1_voltage_range, xps2_voltage_range, xps3_voltage_range), total=no_sweep**4):
  p_cont.bidi_tx.pa.XPS1.heater_voltage = xps0_voltages
  p_cont.bidi_tx.pa.XPS2.heater_voltage = xps1_voltages
  p_cont.bidi_tx.mzi_1.XPS1.heater_voltage = xps2_voltages
  p_cont.bidi_tx.mzi_1.XPS2.heater_voltage = xps3_voltages
  p_cont.recursive_update()

  output_state = p_cont.smatrix @ input_state

  tap_1.append(np.abs(output_state[0])**2)

for i, j, k, l in tqdm(product(range(no_sweep), range(no_sweep), range(no_sweep), range(no_sweep))):
  if tap_1[i*no_sweep**3 + j*no_sweep**2 + k*no_sweep + l] == np.min(tap_1):
    break

xps0_voltage_opt = xps0_voltage_range[i]
xps1_voltage_opt = xps1_voltage_range[j]
xps2_voltage_opt = xps2_voltage_range[i]
xps3_voltage_opt = xps3_voltage_range[j]

print(f"XPS0 voltage: {xps0_voltage_opt}")
print(f"XPS1 voltage: {xps1_voltage_opt}")
print(f"XPS2 voltage: {xps2_voltage_opt}")
print(f"XPS3 voltage: {xps3_voltage_opt}")


p_cont.bidi_tx.pa.XPS1.heater_voltage = xps0_voltage_opt
p_cont.bidi_tx.pa.XPS2.heater_voltage = xps1_voltage_opt
p_cont.bidi_tx.mzi_1.XPS1.heater_voltage = xps2_voltage_opt
p_cont.bidi_tx.mzi_1.XPS2.heater_voltage = xps3_voltage_opt

p_cont.recursive_update()



100%|██████████| 14641/14641 [00:08<00:00, 1746.11it/s]
13343it [00:03, 3614.72it/s]
100%|██████████| 14641/14641 [00:08<00:00, 1745.41it/s]
14641it [00:07, 1839.85it/s]

XPS0 voltage: 3.6
XPS1 voltage: 0.36
XPS2 voltage: 1.44
XPS3 voltage: 0.36





In [5]:
input_state = np.array([0, 1, 0, 0])
output_state = p_cont.smatrix @ input_state
print("Output state direction: ", np.abs(output_state)**2)

input_state_reverse = np.array([0, 0, 1, 0])
output_state_reverse = p_cont.smatrix @ input_state_reverse
print("Output state reverse direction: ", np.abs(output_state_reverse)**2)

Output state direction:  [0.00178273 0.00070823 0.07132276 0.00126877]
Output state reverse direction:  [0.00195863 0.07132276 0.00064078 0.00141281]
