# 🦾🤖 Noise-aware optimization scheme with ✨ XLuminA ✨: 

This notebook is a step-by-step guide for building a robust (noise-aware) optimization scheme in XLuminA.

We will set-up an optimization scheme for *the sharp focus for a radially polarized light beam* - (we use **robust_discovery** from **optical_elements.py**)

In [1]:
import os
import sys

# Setting the path for XLuminA modules:
current_path = os.path.abspath(os.path.join('..'))
module_path = os.path.join(current_path)

if module_path not in sys.path:
    sys.path.append(module_path)

In [17]:
from xlumina.__init__ import um, nm, mm
from xlumina.vectorized_optics import *
from xlumina.optical_elements import robust_discovery
from xlumina.toolbox import space, softmin
from xlumina.loss_functions import vectorized_loss_hybrid
import jax.numpy as jnp
import jax
from jax import random
import optax

## System specs, define light sources, output dimensions and static parameters during optimization:

In [5]:
# 1. System specs:
sensor_lateral_size = 512  # Resolution
wavelength1 = 650*nm
x_total = 2500*um
x, y = space(x_total, sensor_lateral_size)
shape = jnp.shape(x)[0]

# 2. Define the optical functions: two orthogonally polarized beams:
w0 = (1200*um, 1200*um)  
ls1 = PolarizedLightSource(x, y, wavelength1)
ls1.gaussian_beam(w0=w0, jones_vector=(1, 1))

# 3. Define the output (High Resolution) detection:
x_out, y_out = jnp.array(space(10*um, 400)) # Pixel size detector: 20 um / 400 pix 

# 4. High NA objective lens specs:
NA = 0.9 
radius_lens = 3.6*mm/2 
f_lens = radius_lens / NA

# 5. Static parameters - don't change during optimization:
fixed_params = [radius_lens, f_lens, x_out, y_out]

## Define the optical setup:

1. Vectorized version of the optical setup [`optical_elements.py` > `robust_discovery`] over a new axis (defined by the noise).

    Here the args of the function are: `light source` (ls1 - ls6, common to all tables), `parameters` (common to all tables), `fixed_params` (common to all tables), `noise` (DIFFERENT for each table). 

In [6]:
def batch_robust_discovery(ls1, ls2, ls3, ls4, ls5, ls6, parameters, fixed_params, noise_distances, noise_slms, noise_wps, noise_amps, distance_offset):
    """
    Vectorized (efficient) version of robust_discovery() for batch optimization. 
    
    Parameters: 
        ls1, ls2, ls3, ls4, ls5, ls6 (PolarizedLightSource)
        parameters (jnp.array): parameters to pass to the optimizer
            BB 1: [phase1_1, phase1_2, eta1, theta1, z1_1, z1_2]
            BB 2: [phase2_1, phase2_2, eta2, theta2, z2_1, z2_2] 
            BB 3: [phase3_1, phase3_2, eta3, theta3, z3_1, z3_2]
            BS ratios: [bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9] <- automatically contains 
            Extra distances: [z4, z5]
        fixed_params (jnp.array): parameters to maintain fixed during optimization [r, f, xout and yout]; that is radius and focal length of the objective lens.
         
        noise_distances (jnp.array): Misalignment (in microns): [noise_z1_1, noise_z1_2, ...]
        noise_slms (jnp.array): Noise (in radians): [noise_phase1_1, noise_phase1_2, ...]
        noise_wps (jnp.array): Noise (in radians): [noise_eta1, noise_theta1, ...]
        noise_amps (jnp.array): Noise (in AU): [noise_A1, noise_A2, ...]
        
    Returns vectorized version of detected light (intensity tensor): (# tables, (6, resolution, resolution))
    """
    # Noise shapes are: 
    # distance: (#tables, 1, 8); 
    # slms (amp and phase): (#tables, 6, (resolution,resolution)); 
    # wp: (#tables, 1, 6)
    # vmap in axes 0 -> across optical tables
    detected_intensities_z = vmap(robust_discovery, in_axes=(None, None, None, None, None, None, None, None, 
                                                             0, 0, 0, 0, None))(ls1, ls2, ls3, ls4, ls5, ls6, parameters, fixed_params, 
                                                                             noise_distances, noise_slms, noise_wps, noise_amps, distance_offset)
    return detected_intensities_z

2. Vectorize the loss function: `vmap` the computation of the loss function across different optical tables.(imported from `loss_functions.py` > `vectorized_loss_hybrid`).

In [7]:
def mean_batch_discover(detected_light):
    """ vmap loss over optical tables: """
    # detected_light with shape (#optical tables, (6, N, N))
    # vmap loss in axis #optical tables 
    return vmap(vectorized_loss_hybrid, in_axes=(0,))(detected_light)

3. Define the loss: first computes light from paralellized optical tables (`batch_robust_discovery`), later compute the loss function for each detector in each optical table. Finally, compute the mean loss value across the optical tables and get the minimum value using `softmin`.

In [8]:
@jit 
def loss_batch_discovery(parameters, noise_d, noise_slm, noise_wp, noise_amp):
    """
    Loss function. It computes L= Area/I_{epsilon} across detectors.
    
    Parameters:
        parameters (list): Optimized parameters.
        noise_distances (jnp.array): Misalignment (in microns): [noise_d1, noise_d2, ...]
        noise_slms (jnp.array): Noise (in radians): [noise_slm_1, noise_slm_2, ...]
        noise_wps (jnp.array): Noise (in radians): [noise_eta, noise_theta, ...]
        noise_amp (jnp.array): Noise (in AU): [noise_A1, noise_A2, ...]

    Returns the mean value of the loss computed for all the inputs. 
    """
    # Output from batch_robust_discovery is (#optical tables, (6, N, N)): for 6 detectors each
    detected_z_intensities = batch_robust_discovery(ls1, ls1, ls1, ls1, ls1, ls1, parameters, fixed_params, noise_d, noise_slm, noise_wp, noise_amp, distance_offset = 15) 

    # Get the minimum value within loss value array
    # output from mean_batch_discover is (#optical tables, (6, 1)). 
    # Compute the mean across #optical tables and get the minimum value using softmin.
    loss_val = softmin(jnp.mean(mean_batch_discover(detected_z_intensities), axis=0, keepdims=True))
    return loss_val # shape (#optical tables, 1)

## Optimizer settings:

In [9]:
# Global variable
shape = jnp.array([sensor_lateral_size, sensor_lateral_size])
# Define the loss function:
loss_function = loss_batch_discovery

In [10]:
# Optimization settings
OS = {'n_best': 500,
      'best_loss': 3*1e2,
      'num_iterations': 50000,
      'num_samples': 1,
      'WEIGHT_DECAY': 1e-3,
      'BASE_lr': 0.05,
      'END_lr': 0.001,
      'DECAY_STEPS': 4000
     }

## [!!]  define noise settings dictionary:

**NS (dict)**: 

    NS = {'n_tables': __, 
          'number of distances': __,
          'number of sSLM': __, 
          'number of wps': __,
          'noise_level': __, 
          'misalignment': (minval, maxval), 
          'phase_noise': (minval, maxval),
          'discretize': __}

where, 

**n_tables (int)**: number of optical tables to compute in parallel

**number of distances, number of sSLM, number of wps (str)**: number of distances, sSLM and wave plates in the optical setup.

**level (str)**:

    1. low: noise in SLMs and WPs $\pm$(0.01 to 0.05) rads and misalignment of $\pm$(0.01 to 0.05) mm 
    2. mild: noise in SLMs and WPs $\pm$(0.05 to 0.5) rads and misalignment of $\pm$(0.05 to 0.5) mm 
    3. high: noise in SLMs and WPs $\pm$(0.5 to 1) rads and misalignment of $\pm$(0.5 to 1) mm 
    4. all: noise in SLMs and WPs $\pm$(0.01 to 1) rads and misalignment of $\pm$(0.01 to 1) mm 
    5. tunable: tunable noise via NS dictionary

**discretize (bool)**: if true, discretize SLM noise to 8-bit.


In [11]:
# Noise settings:
NS = {'n_tables': 3, 
      'number of distances': 8,
      'number of sSLM': 3, 
      'number of wps': 3,
      'noise_level': 'tunable',
      'misalignment': (10, 100), 
      'phase_noise': (0.01, 0.1),
      'discretize': False}

1. Define keychain for the number of optical tables specified in NS.

In [12]:
# Keychain for optical tables
def keychain_optical_tables(seed, number_of_tables):
    """ 
    Generates keychain for # optical tables especified 
    """
    keychain = []
    for num in range(number_of_tables):
        key_table = random.PRNGKey(seed + num)
        keychain.append(key_table)
            
    return jnp.array(keychain)

2. Define `shake_setup()` and `batch_shake_setup()` as functions to include noise in the setup per iteration. 

    Two types of shaking functions are provided in `optical_elements.py`. 

    1. `shake_setup` takes noise settings NS:dict as argument. Thus, `batch_shake_setup` can't be used with @jit or @partial(jit). 

    2. However, if you want to @jit `batch_shake_setup`, copy-paste `shake_setup_jit` in your optimizer file, as it doesn't have NS:dict as an argument. 

Here we will copy-paste `shake_setup_jit` and use NS as global variable. 

In [13]:
def shake_setup_jit(key, resolution):
    """
    [THIS FUNCTION IS INTENDED TO BE PASTED IN THE OPTIMIZER FILE TO ENABLE @jit COMPILATION FOR `batch_shake_setup`]
    
    Creates noise for all the different optical variables on an optical table.
    
    Parameters:
        key (PRNGKey): JAX random key for reproducibility
        resolution (int): number of pixels for space
        
        global variable NS (dict): noise settings as

            NS = {'n_tables': __, 'number of distances': __,
                  'number of sSLM': __, 'number of wps': __,
                  'noise_level': __, 
                  'misalignment': (minval, maxval), 
                  'phase_noise': (minval, maxval),
                  'discretize': __}
    
    Returns:
        random_noise_distances, random_noise_slms, random_noise_wps, random_noise_amps, key0 (new key to split in the next iteration), key (old key)
    """
    num_physical_variables = 4 # of physical variables (e.g., distance, slm phase, ,...) to optimize.
    # split as many times as variables + 1 to renew the key0 each step
    key0, key1, key2, key3, key4 = random.split(key, num_physical_variables+1)
    d_type = 'int8'

    # NS is not an input to ensure vmap during optimization.
    level = NS['noise_level']
    discretize = NS['discretize']
    
    # level can be: 'low' == 0, 'mild' == 1, 'high' == 2, 'all' == 3 and 'tunable' == 4

    if level == 'low':
        # Misalignment (um)
        minval_d = 10*um  # 0.01 mm
        maxval_d = 50*um # 0.05 mm
        # SLM / WP phase and amplitude (rads and AU, respectively)
        minval_phase = 0.01 
        maxval_phase = 0.05

    if level == 'mild':
        # Misalignment (um)
        minval_d = 50*um  # 0.05 mm
        maxval_d = 500*um # 0.5 mm
        # SLM / WP phase and amplitude (rads and AU, respectively)
        minval_phase = 0.05 
        maxval_phase = 0.5

    if level == 'high':
        # Misalignment (um)
        minval_d = 500*um # 0.5 mm
        maxval_d = 1000*um # 1 mm
        # SLM / WP phase and amplitude (rads and AU, respectively)
        minval_phase = 0.5
        maxval_phase = 1

    if level == 'all':
        # Misalignment (um)
        minval_d = 10*um  # 0.01 mm
        maxval_d = 1000*um # 0.15 mm
        # SLM / WP phase and amplitude (rads and AU, respectively)
        minval_phase = 0.01 
        maxval_phase = 1

    if level == 'tunable':
        # Misalignment (um)
        minval_d, maxval_d  = NS['misalignment']  # in um
        # SLM / WP phase and amplitude (rads and AU, respectively)
        minval_phase, maxval_phase = NS['phase_noise']

    if discretize: 
        d_type = 'uint8'

    # noise for distances (d1 and d2): shape = (1, NS['number of distances'])
    random_noise_distances = jnp.squeeze(random.uniform(key1, shape=(1, NS['number of distances']), minval=minval_d, maxval=maxval_d), axis=0) 
    # noise for SLMs phases and amplitude (slm1 and slm2): shape = (2, (resolution, resolution))
    random_noise_amps = random.choice(key4, jnp.array([-1,1]), shape=(2*NS['number of sSLM'], resolution, resolution)).astype(d_type) * random.uniform(key2, shape=(2*NS['number of sSLM'], resolution, resolution), minval=minval_phase, maxval=maxval_phase) 
    random_noise_slms = random.choice(key2, jnp.array([-1,1]), shape=(2*NS['number of sSLM'], resolution, resolution)).astype(d_type) * random.uniform(key2, shape=(2*NS['number of sSLM'], resolution, resolution), minval=minval_phase, maxval=maxval_phase) 
    # noise for WP angles (eta and theta): shape = (1, 2)
    random_noise_wps = jnp.squeeze(random.choice(key3, jnp.array([-1,1]), shape=(1, 2*NS['number of wps'])).astype('int8'), axis=0) * jnp.squeeze(random.uniform(key3, shape=(1, 2*NS['number of wps']), minval=minval_phase, maxval=maxval_phase), axis=0) 
    
    return random_noise_distances, random_noise_slms, random_noise_wps, random_noise_amps, key0, key


@jit
def batch_shake_setup(key_array, array_for_shape):
    """
    Creates noise for all the different optimizable variables on multiple optical tables given by size(key_array).
    
    Parameters: 
        key_array (PRNGKey): Array with different keys -- will change for each step in the optimization. 
        The dimension of this array is decided by # of optical tables to compute in parallel.
        array_for_shape (jnp.array): array of shape [resolution, resolution] to make it jit. 
        
    Returns:
        random_noise_distances [with shape = (size(key_array), 1, NS['number of distances'])], 
        random_noise_amps [with shape = (size(key_array), 2*NS['number of sSLM'], resolution, resolution)], 
        random_noise_slms [with shape = (size(key_array), 2*NS['number of sSLM'], resolution, resolution)], 
        random_noise_wps [with shape = (size(key_array), 1, 2*NS['number of wps'])]
        key0 (PRNGKey): array with key0 to split in the next iteration step
    """
    return vmap(shake_setup_jit, in_axes = (0, None))(key_array, jnp.shape(array_for_shape)[0])

## Define optimizer (adamw with schedule)

In [14]:
def adamw_schedule(base_lr, end_lr, decay_steps, weight_decay) -> optax.GradientTransformation:
    """
    Custom optimizer - adamw: applies several transformations in sequence
    1) Apply ADAMW
    2) Apply lr schedule
    """
    lr_schedule = base_lr
    #lr_schedule = optax.linear_schedule(init_value= base_lr, end_value = end_lr, transition_steps = decay_steps, transition_begin = 500)                                           
    return optax.adamw(learning_rate=lr_schedule, weight_decay=weight_decay)

## Optimization loop: 

In [15]:
def fit(params: optax.Params, optimizer: optax.GradientTransformation, num_iterations, keys, x) -> optax.Params:
    
    # Init the optimizer with initial parameters
    opt_state = optimizer.init(params)

    @jit
    def update(parameters, opt_state, noise_d, noise_slm, noise_wp, noise_amp):
        # Define single update step - contains noise_array: 
        loss_value, grads = jax.value_and_grad(loss_function)(parameters, noise_d, noise_slm, noise_wp, noise_amp)
        # Update the state of the optimizer
        updates, state = optimizer.update(grads, opt_state, parameters)
        # Update the parameters
        new_params = optax.apply_updates(parameters, updates)
        
        return new_params, parameters, state, loss_value, updates

    # Initialize some parameters    
    n_best = OS['n_best']
    best_loss = OS['best_loss']
    best_params = None
    best_keys = None
    best_step = 0
    
    print('Starting Optimization', flush=True)
    
    for step in range(num_iterations):
        
        # Add noise: update noise and keys each iteration
        noise_d, noise_slm, noise_wp, noise_amp, keys, old_keys = batch_shake_setup(keys, x) # 'x' is the space variable from optical table
        
        # Apply update step
        params, old_params, opt_state, loss_value, grads = update(params, opt_state, noise_d, noise_slm, noise_wp, noise_amp)
        
        print(f"Step {step}")
        print(f"Loss {loss_value}")
        
        # Update the `best_loss` value:
        if loss_value < best_loss:
            # Best loss value
            best_loss = loss_value
            # Best optimized parameters
            best_params = old_params
            # Keys for best params
            best_keys = old_keys
            best_step = step
            print('Best loss value is updated')

        if step % 100 == 0:
            # Stopping criteria: if best_loss has not changed every 500 steps, stop.
            if step - best_step > n_best:
                print(f'Stopping criterion: no improvement in loss value for {n_best} steps')
                break
    
    print(f'Best loss: {best_loss} at step {best_step}')
    print(f'Best parameters: {best_params}')  
    return best_params, best_loss, best_keys

In [None]:
# Optimizer settings
num_iterations = OS['num_iterations']
num_samples = OS['num_samples']

for i in range(num_samples):
    tic = time.perf_counter()
    
    # seed1 to ensure randomness among samples
    seed1 = np.random.randint(9999)
    
    # Init keychain for noise -- as many init keys as optical tables to parallelize
    keys = keychain_optical_tables(seed1, NS['n_tables'])
        
    # Optimizer settings
    WEIGHT_DECAY = OS['WEIGHT_DECAY']
    BASE_lr = OS['BASE_lr']
    END_lr = OS['END_lr']
    DECAY_STEPS = OS['DECAY_STEPS']
    
    # Random init parameters:
    phase1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    phase1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    a1_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    a1_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    
    phase2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    phase2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    a2_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    a2_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    
    phase3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    phase3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    a3_1 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    a3_2 = jnp.array([np.random.uniform(0, 1, shape)], dtype=jnp.float64)[0]
    
    eta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
    theta1 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
    eta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
    theta2 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
    eta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
    theta3 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
    eta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
    theta4 = jnp.array([np.random.uniform(0, 1, 1)], dtype=jnp.float64)[0]
    
    z1_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    z1_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    z2_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    z2_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    z3_1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    z3_2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    z4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    z5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    
    bs1 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    bs2 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    bs3 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    bs4 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    bs5 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    bs6 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    bs7 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    bs8 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    bs9 = jnp.array([np.random.uniform(0, 1)], dtype=jnp.float64)
    
    # Init params for 3x3 robust discovery
    init_params = [phase1_1, phase1_2, a1_1, a1_2, eta1, theta1, z1_1, z1_2, 
                   phase2_1, phase2_2, a2_1, a2_2, eta2, theta2, z2_1, z2_2, 
                   phase3_1, phase3_2, a3_1, a3_2, eta3, theta3, z3_1, z3_2, 
                   bs1, bs2, bs3, bs4, bs5, bs6, bs7, bs8, bs9, 
                   z4, z5]
                   
    # Init optimizer:
    optimizer = adamw_schedule(BASE_lr, END_lr, DECAY_STEPS, WEIGHT_DECAY)

    # Apply fit function:
    best_params, best_loss, iteration_steps, loss_list, keys_noise = fit(init_params, optimizer, num_iterations, keys, x)
    
    print(f"Time taken to optimize one sample - in seconds {(time.perf_counter() - tic):.4f}")