This notebook walks you through the optimization procedure used to compare ML optimizers for use in invDes. The notebook is currently configured for optimization of a 100 x 100 two-way mode splitter, but the structure and cost function may be modified to target any device.

The parameters which determine structure, optimization phases, simulation, etc. can be found in the User Parameters section. Make sure to run each cell in sequence before performing an optimization in Run Optimization. See comments for details on what everything does.

## Imports

In [91]:
from functools import partial
from IPython import display
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.animation
from mpl_toolkits.axes_grid1 import make_axes_locatable
from PIL import Image, ImagePalette
import pickle
import optax
import datetime
import pjz
from copy import deepcopy
from jax.interpreters import xla
import os
import interpax
import jaxopt
jax.config.update("jax_debug_nans", False)
cpus = jax.devices("cpu")
gpus = jax.devices("gpu")
if len(gpus)==0:
    print('The fdtdz simulator requires a cuda-equiped gpu. If you are reading this, such a gpu was not detected by JAX. This means either a) you do not have a cuda-equiped gpu, or b) your gpu is not currently communicating with jax within this python environment. This will need to be rectified before proceeding.')
else:
    print('GPU recognized, all is well.')

GPU recognized, all is well.


## User Parameters

### The most important ones which you will frequently change

In [75]:
verbose=1 #Verbosity level. How often to print out current value of the cost function. Set to 0 for silent running.

optimizer='adam' #What optimizer to use. Should be an optax optimizer, or 'L-BFGS-B'.
learningRate=0.1 #Learning rate. Maximum learning rate if schedule is 'auto', static learning rate otherwise. Ignored for lbfgs
staticMetaparameters={} #Any non-default hyperparams BESIDES learning rate. Keys are parameter names, entries are values.

schedule='auto' #What scheduler to use. You can add your own, set it to 'auto' for independent cosine annealing w/ warm-up on each phase, or set it to None for static hyperparams. If 'auto' and metaOptimizer is not None, scheduling will NOT be used
timePerIteration=2.165 #Seconds per iteration. Necessary to determine cosine periods in schedule if 'auto' This entirely depends upon your gpu, the simulation domain, the simualtion time, and the launch_params in Simulation Parameters. Just run a few iterations and see how long they take.

metaOptimizer=None#'sgd' #What meta-optimizer to use. Set to None for no meta-optimization
parametersToMetaOptimize={'learning_rate':0.1} #Parameters to meta-optimize, with initial values. If you want to metaOpt learning rate, include 'learning_rate' here, with an init value which will override learningRate
metaLearningRate=1. #Learning rate for meta-optimizer
innerPerOuter=4 #Number of inner iterations per outer iteration
staticMetaMetaparameters={} #Any non-default hyperparams BESIDES learning rate for the metaoptimizer.

inputFile='inTest.txt' #Input file, storing the initial density distribution as a pickled jax.numpy.array. If None, a random init will be generated using parameters in Optimization Parameters. If not None, but pointing to a nonexistent file, this random init will be pickled to the requested file.
modeFile='modeTest.txt' #Input file, a pickled dictionary with mode wavenumbers as entry 'invariantBetas' and mode profiles as entry 'invariantModes'. If None, modes will be computed prior to computation. If not None, but pointing to a nonexistent file, these modes will be pickled to the requested file.
outputFile='outTest.txt' #Output file. If not None, the evo dictionary will be pickled to this location. Otherwise, data will not be automatically saved to file.

### PRNG key

In [76]:
prngkey=jax.random.key(2)

### Structure Parameters

In [77]:
nanometersPerCell=40 #nm per cell. Currently this is configured for cubic meshes only (250,220)
designRegionSize=jnp.array((200,200)) #Size of the xy plane of the design region, in cells. In both the 'PV' and 'TW' methods, this is the size of the array of control parameters
symmetryAxis=None #Axis about which design will be symmetric. Options are 'x', 'y', and 'xy' (or None for no symmetry)
designRegionThickness=6 #Size of the design region in the z-direction. The xy plane structures will be extruded in this dimension
xyPadWidth=100 #Thickness of the pad region in the xy plane, which surrounds the design region, in cells. This many cells will be added to each side of the design region
zPadWith=49 #Thickness of the pad region in the z-direction, above and below the design region, in cells.
xyAbsorptionWidth=50 #How many layers of adiabatic absorber BEYOND the xyPad to add to each side of the sim region
xyAbsorptionCoeff=1e-4 #Strength of absorption layer. Too small and waves will sail though the absorption region and reflect off the far wall. Too large and they will reflect off the front endge of the absorption region
zPMLWidth=12 #How many PMLs to add in the z-direction. This will be reduced to the nearest multiple of 4 for efficiency reasons
permitivityRange=jnp.array((2.1,12.1)) #Minimum/maximum permitivity. The binarized design will consist entirely of either permitivityRange[0] or permitivityRange[1]
filtRad=1 #Controls density filter, which "blurs" the design to correllate nearby pixels. This is done by convolution with a code filter of this radius.

#This is how input/output is defined. Each waveguide gets a dictionary in this list. Waveguides will be added in the order in which they appear in this list (so the first
#dictionary defines waveguide 0, the second waveguide 1, etc.)
waveguideDictionary=[{'wgWidth':36, #Width of waveguide. If >=1, width in cells. If <1, width in % design region length (so 0.1 = 10% of that face's length)
                      'wgPerm':12.1, #Permitivity of the waveguide. Must be within permitivityRange
                      'wgFace':'w', #Which face to place the waveguide on. Options are 'n', 's', 'e', 'w'.
                      'wgLocation':0.5, #How far along face to place the waveguide. If >=1, distance in cells from start of DESIGN REGION (no pad). If <1, width in % face length
                      'wgMode':0, #Which mode to inject (if Input) or search for (if Output) in this waveguide. 0 is (usually) TE00, 1 is (usually) TE10, etc. Modes are sorted by descending wavenumber
                      'wgInjectionLocation':2, #How far along the the waveguide to place the source plane, in cells. 0 means source plane will be at the boarder of the sim domain.
                      'wgInput':True}, #Whether this waveguide is an Input or Output
                     {'wgWidth':36,'wgPerm':12.1,'wgFace':'w', 'wgLocation':0.5,'wgMode':1,'wgInjectionLocation':2,'wgInput':True},
                     {'wgWidth':20,'wgPerm':12.1,'wgFace':'e', 'wgLocation':0.2,'wgMode':0,'wgInjectionLocation':1,'wgInput':False},
                     {'wgWidth':20,'wgPerm':12.1,'wgFace':'e', 'wgLocation':0.8,'wgMode':0,'wgInjectionLocation':1,'wgInput':False},]

### Simulation Parameters

In [78]:
picosecondsPerStep=0.6 #Scaled time per timestep
totalTimeSteps='auto' #Total number of time steps each simualtion. If it's too small, we won't reach steady state. If it's too large, numerical error will make the sim explode. Set to 'auto' to base it on design region size, assuming 1550nm light and a timestep of 0.6.
inputWavelengths=jnp.array([1550]) #Wavelengths of inputs, in nanometers
outputWavelengthRange=(1450,1650) #Wavelength range of outputs, for use in the FFT. Ideal number of output snapshots will be recorded based on this
sourceRamp=8. #How many periods to spin up the input to max strength
sourceDelay=4. #How many periods to wait before starting to spin up the input. Should be nonzero for stability reasons.

#The fdtdz simulator is very fast because it is designed to maximize GPU utilization by efficiently partitioning the spatio-temporal space to minimize the amount of memory
#that needs to be cached at any given time. How to do this partioning optimally depends upon your GPU, so you will need to enter some specs. If using Colab's cloud hardware,
#set launchParams to None and the code will automatically choose optimal settings. Descriptions come from fdtdz docstring
launchParams=((2, 4), #determines the layout of warps in the u- and v-directions within a block and should be ``(2, 4)`` or ``(4, 2)``
               (7,6), #specify the layout of blocks on the GPU and must be equal to or less than the number of streaming multiprocessors on the GPU because of the need for grid-wide synchronization
               1, #controls the number of buffers used between each block and its downstream neighbor and should be tuned to balance between reducing grid synchronization overhead and staying within the limits of the L2 cache
               (8, 0)) #major and minor compute capability of the device. Used to determine which precompiled kernel to use. Currently allowed values are ``(3, 7)``, ``(6, 0)``, ``(7, 0)``, ``(7, 5)``, and ``(8, 0)``.

### Optimization Parameters

In [80]:
maxIterations=500 #Maximum total iterations before forcible termination.
maxTime=5*60#Maximum time before forcible termination, in seconds
ftol=1e-12 #Maximum realtive change in objective before termination. Will terminate when abs(L[-2]-L[-1])/L[-1]<=ftol
corrLen=80 #Correllation length if autogenerating init_theta. Shouldn't have to say this, but do make sure it is not bigger than the simulation domain size. Has no effect if init_theta is not None
minv=-0.05 #min initial value if autogenerating init_theta. Theta is unconstrained, varying in theory from -inf to inf. In practice, rarely leaves [-3,3]. It is passed through a sigmoid to project to [0,1]
maxv=0.05 #max initial value if autogenerating init_theta.

#Optimization is governed by a scheduler (different from the "schedule" above). The scheduler is called every iteration, and returns one or more values as a dictionary to change the course of optimization. Iteration=0 is called before optimization begins. You do not need to return use all arguments in this dictionary on
#any given iteration. Parameters not updated in this dictionary retain their values from the last iteration or, in the case of iteration=0, adopt default values.
#In my scheduler, I break optimization into four phases: continuum, discretization, settling, and DRC
t1=6*60*(5/20) #Time at which continuum ends/ discretization begins
t2=12*60*(5/20) #Time at which discretization ends / settling begins
t3=16*60*(5/20) #Time at which settline ends / DRC begins
def scheduler(iteration, #Current iteration. Iteration=0 is before optimization begins.
              params, #Current set of PROJECTED optimization parameters (the density, in [0,1], 1 being permitivityRange[1], 0 being permitivityRange[0])
              previousL, #Last iteration's cost function
              previousPreviousL,
              time): #Iteration before last's cost function (np.inf for iteration=0)
    global currentEpoch,t1,t2,t3
    returnDictionary={'alpha':0, #Controls discretization. Must be in [0,1]. Density sent to permitivity computer is [continuum density]*(1-alpha)+[discretized density]*alpha
                      'eta':0.5, #Discretization threshold. Any density below this number is mapped to permitivityRange[0] in [discretized density], and density above is mapped to permitivityRange[1]
                      'gamma':0, #Controls fabrication constraint penalty. Full objective function is [performance objective]+gamma*[fabrication objective]. For TW method.
                      'c':1e3, #Inflection detection sensitivity for fab constraint computation. For TW method.
                      'eta_lo':0.25, #Minimum feature size for density=0 features. For TW method.
                      'eta_hi':0.75, #Minimum feature size for density=1 features. For TW method.
                      'noiseInjection':None, #Whether to inject artificaly noise into the density parameters. If not None, [new params]=[old params]*noiseInjection*jax.random.uniform(key=prngkey,shape=params.shape(),minval=-1+noiseInjection,maxval=1-noiseInjection). For TW method.
                      't':1.0, #Gaussian kernal length scale. For PV method.
                      'd':0.005, #Minimum gap size, in % total length. For PV method.
                      'r':0.005, #Minimum radius of curvature, in % total length. For PV method.
                      'beta':0.33, #Gap detection parameter. For PV method.
                      'gapPenalty':0, #Scale factor on gapLoss. For PV method.
                      'curvePenalty':0} #Scale factor on curveLoss. For PV method.
    dt=t2-t1
    currentEpoch=0
    if time>=t1:
        returnDictionary['alpha']=(time-t1)/dt #alpha increases linearly from 0 to 1 during discretization
        currentEpoch=1
    if time>=t2:
        returnDictionary['alpha']=1
        currentEpoch=2
    if time>=t3:
        returnDictionary['gamma']=2#Gamma is just a static blast as soon as DRC starts. A dynamic gamma would be more elegent.
        currentEpoch=3
    return returnDictionary

#This is where you define your performance objctive. THIS FUNCTION MUST BE JAX-TRACEABLE (use jnp or jax.scipy, no normal numpy, scipy, torch, etc.)
#An additional component of fabriaction loss, from three-wave inflection detection, will be added external to this function. Use floss if you have some other fab constraint beyond this
@jax.jit
def performanceObjective(svals, #The scattering matrix. svals[i][j][k]=mode overlap from waveguide i to waveguide j at frequency k
                         frequencies, #The angular frequencies, in mesh units
                         params): #The current set of RAW unconstrained optimization parameters, before projection

    ploss,floss=-jnp.linalg.norm(svals[0][2][0])-jnp.linalg.norm(svals[1][3][0]),0
    return floss+ploss, (ploss,floss) #Return two values: the total loss, and a tuple with seperate fab and performance loss to record later
performanceObjectiveRecord='-jnp.linalg.norm(svals[0][0][1][0])-jnp.linalg.norm(svals[1][0][2][0])-jnp.linalg.norm(svals[2][0][3][0])' #I cannot save the performance objective function, so write out something that will let you reconstruct it later to be added to the pickled results file

## Parameter Processing

### init generation

In [82]:
def correllatedRandom(size,corrLen=corrLen,minv=minv,maxv=maxv,prngkey=prngkey):
    x,y=jnp.arange(-corrLen,corrLen),jnp.arange(-corrLen,corrLen)
    X,Y=jnp.meshgrid(x,y)
    dist=jnp.sqrt(X**2+Y**2)
    filt=jnp.exp(-dist**2/(2*corrLen))
    noise=jax.random.uniform(key=prngkey,shape=size,minval=minv,maxval=maxv)
    return jax.scipy.signal.convolve(noise,filt,mode='same')
if inputFile is not None:
    if os.path.isfile(inputFile):
         f=open(inputFile,'rb')
         init_theta=pickle.load(f)
         f.close()
    else:
        init_theta=correllatedRandom(designRegionSize)
        write_file=open(inputFile,'wb')
        pickle.dump(init_theta,write_file)
        write_file.close()
else:
    init_theta=correllatedRandom(designRegionSize)

### deal with autos

In [83]:
if totalTimeSteps=='auto' or totalTimeSteps is None:
    if jnp.max(designRegionSize)//2<=50:
        totalTimeSteps,timeFactor=2500,2.4
    elif jnp.max(designRegionSize)//2<=100:
        totalTimeSteps,timeFactor=3000,3.3
    elif jnp.max(designRegionSize)//2<=150:
        totalTimeSteps,timeFactor=4000,4.1
    elif jnp.max(designRegionSize)//2<=200:
        totalTimeSteps,timeFactor=5000,5.8
    elif jnp.max(designRegionSize)//2<=250:
        totalTimeSteps,timeFactor=5500,7.36
    elif jnp.max(designRegionSize)//2<=300:
        totalTimeSteps,timeFactor=6000,9.6
    else:
        totalTimeSteps,timeFactor=7000,13.1

if schedule.lower()=='auto':
    periods=np.array([t1,t2-t1,t3-t2,maxTime-t3])/timePerIteration #1.055 seconds per iteration for 180 x 180 wl trimux with 3000steps, 3.3sec factor, so this is #iterations per epoch (1.453 for 230 x 230 with 4000steps, 4.1sec, 1.535 for 280 x 280 with 5000steps, 5.8sec)
    runningPeriods=np.array([t1,t2,t3])/timePerIteration
    scheduleLis=[]
    for period in periods:
        scheduleLis.append(optax.warmup_cosine_decay_schedule(init_value= learningRate/10, peak_value= learningRate, warmup_steps= period*0.1, decay_steps= period, end_value = learningRate/10, exponent = 1.0))
    if metaOptimizer is None:
        schedule=optax.join_schedules(scheduleLis,runningPeriods)
    else:
        schedule=None

### set globals

In [84]:
#I've gone to great trouble to avoid having to register a custom class as a tree with jax. This involves a lot of global variables
def process_parameters():
    global omegas,optimizers
    omegas=2*jnp.pi/(inputWavelengths/nanometersPerCell)
    optimizers={'adabelief':optax.adabelief,'adafactor':optax.adafactor,'adagrad':optax.adagrad,'adam':optax.adam,'adamw':optax.adamw,'adamax':optax.adamax,
                'adamaxw':optax.adamaxw,'amsgrad':optax.amsgrad,'fromage':optax.fromage,'lamb':optax.lamb,'lars':optax.lars,'lion':optax.lion,'nadam':optax.nadam,'nadamw':optax.nadamw,
                'noisy_sgd':optax.noisy_sgd,'novograd':optax.novograd,'optimistic_gradient_descent':optax.optimistic_gradient_descent,'polyak_sgd':optax.polyak_sgd,'radam':optax.radam,
                'sgd':optax.sgd,'sm3':optax.sm3,'yogi':optax.yogi,'rprop':optax.rprop,'rmsprop':optax.rmsprop,'adadelta':optax.adadelta}
    return
process_parameters()

def get_globalParamDict():
    global prngkey,nanometersPerCell,designRegionSize,schedule,timePerIteration,inputFile,modeFile,symmetryAxis,designRegionThickness,xyPadWidth,zPadWith,xyAbsorptionWidth,xyAbsorptionCoeff,zPMLWidth,permitivityRange,filtRad,waveguideDictionary,picosecondsPerStep,totalTimeSteps,inputWavelengths,outputWavelengthRange,sourceRamp,sourceDelay,launchParams,optimizer,learningRate,optimizerParameters,parameterizationMethod,verbose,maxIterations,maxTime,ftol,corrLen,minv,maxv,performanceObjectiveRecord
    return{'prngkey':prngkey,'nanometersPerCell':nanometersPerCell,'designRegionSize':designRegionSize,'timePerIteration':timePerIteration,'inputFile':inputFile,'modeFile':modeFile,
           'symmetryAxis':symmetryAxis,'designRegionThickness':designRegionThickness,'xyPadWidth':xyPadWidth,'zPadWith':zPadWith,'xyAbsorptionWidth':xyAbsorptionWidth,'xyAbsorptionCoeff':xyAbsorptionCoeff,'zPMLWidth':zPMLWidth,
           'permitivityRange':permitivityRange,'filtRad':filtRad,'waveguideDictionary':waveguideDictionary,'picosecondsPerStep':picosecondsPerStep,'totalTimeSteps':totalTimeSteps,'inputWavelengths':inputWavelengths,
           'outputWavelengthRange':outputWavelengthRange,'sourceRamp':sourceRamp,'sourceDelay':sourceDelay,'launchParams':launchParams,'optimizer':optimizer,'learningRate':learningRate,'staticMetaparameters':staticMetaparameters,
           'verbose':verbose,'maxIterations':maxIterations,'maxTime':maxTime,'ftol':ftol,'corrLen':corrLen,'minv':minv,'maxv':maxv,'performanceObjectiveRecord':performanceObjectiveRecord,
           'metaOptimzer':metaOptimizer,'metaLearningRate':metaLearningRate,'staticMetaMetaparameters':staticMetaMetaparameters,'parametersToMetaOptimize':parametersToMetaOptimize,'innerPerOuter':innerPerOuter}
def set_globalParamDict(gpd):
    global prngkey,nanometersPerCell,designRegionSize,schedule,timePerIteration,inputFile,modeFile,symmetryAxis,designRegionThickness,xyPadWidth,zPadWith,xyAbsorptionWidth,xyAbsorptionCoeff,zPMLWidth,permitivityRange,filtRad,waveguideDictionary,picosecondsPerStep,totalTimeSteps,inputWavelengths,outputWavelengthRange,sourceRamp,sourceDelay,launchParams,optimizer,learningRate,optimizerParameters,parameterizationMethod,verbose,maxIterations,maxTime,ftol,corrLen,minv,maxv,performanceObjectiveRecord
    prngkey=gpd['prngkey']
    nanometersPerCell=gpd['nanometersPerCell']
    designRegionSize=gpd['designRegionSize']
    timePerIteration=gpd['timePerIteration']
    inputFile=gpd['inputFile']
    modeFile=gpd['modeFile']
    symmetryAxis=gpd['symmetryAxis']
    designRegionThickness=gpd['designRegionThickness']
    xyPadWidth=gpd['xyPadWidth']
    zPadWith=gpd['zPadWith']
    xyAbsorptionWidth=gpd['xyAbsorptionWidth']
    xyAbsorptionCoeff=gpd['xyAbsorptionCoeff']
    zPMLWidth=gpd['zPMLWidth']
    permitivityRange=gpd['permitivityRange']
    filtRad=gpd['filtRad']
    waveguideDictionary=gpd['waveguideDictionary']
    picosecondsPerStep=gpd['picosecondsPerStep']
    totalTimeSteps=gpd['totalTimeSteps']
    inputWavelengths=gpd['inputWavelengths']
    outputWavelengthRange=gpd['outputWavelengthRange']
    sourceRamp=gpd['sourceRamp']
    sourceDelay=gpd['sourceDelay']
    launchParams=gpd['launchParams']
    optimizer=gpd['optimizer']
    learningRate=gpd['learningRate']
    optimizerParameters=gpd['optimizerParameters']
    parameterizationMethod=gpd['parameterizationMethod']
    verbose=gpd['verbose']
    maxIterations=gpd['maxIterations']
    maxTime=gpd['maxTime']
    ftol=gpd['ftol']
    corrLen=gpd['corrLen']
    minv=gpd['minv']
    maxv=gpd['maxv']
    performanceObjectiveRecord=gpd['performanceObjectiveRecord']
    metaOptimizer=gpd['metaOptimizer']
    metaLearningRate=gpd['metaLearningRate']
    metaOptimizerParameters=gpd['metaOptimizerParameters']
    parametersToMetaOptimize=gpd['parametersToMetaOptimize']
    innerPerOuter=gpd['innerPerOuter']
    return
defaultSchedDict={'alpha':0,'eta':0.5,'gamma':0,'c':1e3,'eta_lo':0.25,'eta_hi':0.75,'radius':10,'noiseInjection':None,'t':1.0,'d':0.005,'r':0.005,'beta':0.33,'gapPenalty':0,'curvePenalty':0}

## Simulation code

### Params -> permitivity

In [85]:
def addWaveguides(u,theta,schedDict):
    """
    Adds waveguide to u.

    Arguments:
    u: jax array. the padded density without waveguides
    schedDict: dictionary. The scheduler dictionary.

    Returns:
    u: jax array. the padded density with waveguides
    wgDs: list of lists. Waveguide slices, such that wgDs[i]=[x-direction left extent, x-direction right extent, y-direction left extent, y-direction right extent] for waveguide i
    """
    dS=jnp.shape(theta)
    wgDs,wgFaceMap=[],{'n':0,'s':0,'e':1,'w':1} #The wgFaceMap maps cardinal direction strings to the axis normal to the waveguide
    for wgd in waveguideDictionary: #For each input and output waveguide
        wgw,wgl=int(wgd['wgWidth']*(dS[wgFaceMap[wgd['wgFace'].lower()]])**(wgd['wgWidth']<1)),xyPadWidth+int(wgd['wgLocation']*(dS[wgFaceMap[wgd['wgFace'].lower()]])**(wgd['wgLocation']<1)) #Determine the width and length of this waveguide, depending upon whether provided dimensions are in cells or in % total length
        if wgd['wgFace']=='n': #For each cardinal direction, make a note of which cells lie inside the waveguide
            wgDs.append([wgl-wgw//2,wgl+wgw//2,-1*xyPadWidth,None])
        elif wgd['wgFace']=='s':
            wgDs.append([wgl-wgw//2,wgl+wgw//2,None,xyPadWidth])
        elif wgd['wgFace']=='e':
            wgDs.append([-int(xyPadWidth),None,wgl-wgw//2,wgl+wgw//2])
        elif wgd['wgFace']=='w':
            wgDs.append([None,xyPadWidth,wgl-wgw//2,wgl+wgw//2])
        u=u.at[wgDs[-1][0]:wgDs[-1][1],wgDs[-1][2]:wgDs[-1][3]].set((wgd['wgPerm']-permitivityRange[0])/(permitivityRange[1]-permitivityRange[0]))#Add the waveguide to the density.
    return u,wgDs

def density(params,schedDict,modeComp=False):
    """
    Transforms data into a smooth function across the simulation space using the Three Wave method. All distances are measured as a % of the design region length (so 15% of design region means t=0.15,d=0.15, etc.)

    Arguments:
    params: jax array, sig (Dx,Dy). The RAW control parameter array
    schedDict: dictionary. The scheduler dictionary. For use with params2densityTW(), schedDict should have entries 'c' (the inflection detection threshold, try 1e3), 'eta' (level-set threshold, try 0.5), 'eta_lo' (minimum feature size for permitivity0, try 0.25), 'eta_hi' (minimum feature size for permitivity1, try 0.25), 'alpha' (% discretization), 'gamma' (fabrication constraint penalty)
    modeComp: bool. Whether this function has been called in order to compute modes, or else to process design prameter
    Returns:
    if modeComp:
        u: jax array, the padded density with waveguides
        ubase: jax array, the padded density without waveguides. Jax trace is broken at ubase creation
        wgDs: list of lists. Waveguide slices, such that wgDs[i]=[x-direction left extent, x-direction right extent, y-direction left extent, y-direction right extent] for waveguide i
    else:
        density: jax array. The padded density with waveguides and appropriate continuum-discrete ratio
        fabLoss: saclar. The total fabrication loss, scaled by schedDict['gamma']
    """
    sad={'x':1,'y':0,'xy':(0,1)}#The symmetry axis dictionary. The symmetry axis dictionary is perennially depressed.
    if symmetryAxis is not None: #if the user desires symmetry, make it so
        params=(params+ jnp.flip(params.T, axis=sad[symmetryAxis.lower()])) / 2
    thetaProj=jax.nn.sigmoid(params)#The raw parameters are unconstrained. map (-inf,inf) to (0,1)
    u = jnp.pad(thetaProj, xyPadWidth) #Pad the density
    if modeComp:
        ubase=jax.lax.stop_gradient(jnp.copy(u)) # retrieve the waveguide-free density
    u,wgDs=addWaveguides(u,params,schedDict)
    a = jnp.pad(schedDict['alpha'] * jnp.ones_like(params), xyPadWidth, constant_values=1)#The pad region should always be binary, so make the discretization parameter 1 across this region
    density, density_loss = pjz.density(u, filtRad, a, c=schedDict['c'],eta=schedDict['eta'],eta_lo=schedDict['eta_lo'],eta_hi=schedDict['eta_hi'])
    density,wgDs=addWaveguides(density,params,schedDict) #add the waveguides
    if modeComp: #return what the user asked for based on modeComp
        return density,ubase,wgDs
    else:
        return density,schedDict['gamma']*density_loss

def permittivity(density):
  layers = jnp.pad(density[None, :, :], ((1, 1), (0, 0), (0, 0)))
  layers = layers * (permitivityRange[1] - permitivityRange[0]) + permitivityRange[0]
  thicknesses = [zPadWith, designRegionThickness, zPadWith]
  return pjz.epsilon(layers,interface_positions=jnp.cumsum(jnp.array(thicknesses[:-1])),magnification=1,zz=sum(thicknesses))

### Modes

In [86]:
def modes(theta,schedDict,):
    _,ubase,wgDs=density(theta,schedDict,modeComp=True)
    betas,profiles=[],[]
    ppMapDict={'n':[None,None,-1,None],'s':[None,None,0,1],'e':[-1,None,None,None],'w':[0,1,None,None]}
    for wgd,wgD in zip(waveguideDictionary, wgDs):
        uu=ubase.at[wgD[0]:wgD[1],wgD[2]:wgD[3]].set((wgd['wgPerm']-permitivityRange[0])/(permitivityRange[1]-permitivityRange[0]))
        pp=permittivity(uu)
        plt.imshow(uu)
        plt.show()
        ppM=ppMapDict[wgd['wgFace'].lower()]
        b,m,_,_=pjz.mode(pp[:, ppM[0]:ppM[1], ppM[2]:ppM[3], :],omegas,num_modes=wgd['wgMode']+1)
        betas.append(jax.lax.stop_gradient(b[..., wgd['wgMode']]))
        profiles.append(jax.lax.stop_gradient(m[..., wgd['wgMode']]))
    return jnp.array(betas),profiles
if modeFile is not None:
    if os.path.isfile(modeFile):
        f=open(modeFile,'rb')
        modesDict=pickle.load(f)
        f.close()
        invariantBetas,invariantModes=modesDict['invariantBetas'],modesDict['invariantModes']
    else:
        invariantBetas,invariantModes=modes(jnp.ones(designRegionSize)*1000,scheduler(0,jnp.ones(designRegionSize),jnp.inf,jnp.inf,0))
        modesDict={'invariantBetas':invariantBetas,'invariantModes':invariantModes}
        write_file=open(modeFile,'wb')
        pickle.dump(modesDict,write_file)
        write_file.close()
else:
    invariantBetas,invariantModes=modes(jnp.ones(designRegionSize)*1000,scheduler(0,jnp.ones(designRegionSize),jnp.inf,jnp.inf,0))

### Params -> loss

In [87]:
def convertInjectLoc(p):
    pS=jnp.shape(p)
    rl=[]
    for wgd in waveguideDictionary:
        if wgd['wgFace'].lower()=='w' or wgd['wgFace'].lower()=='s':
            rl.append(wgd['wgInjectionLocation'])
        elif wgd['wgFace'].lower()=='e':
            rl.append(pS[1]-wgd['wgInjectionLocation'])
        elif wgd['wgFace'].lower()=='n':
            rl.append(pS[2]-wgd['wgInjectionLocation'])
    return tuple(rl)

def build(theta, schedDict):
    d, d_loss = density(theta,schedDict)
    p = permittivity(d)
    sim_params = pjz.SimParams(omega_range=(2*np.pi/(outputWavelengthRange[1]/nanometersPerCell),2*np.pi/(outputWavelengthRange[0]/nanometersPerCell)),dt=picosecondsPerStep,tt=totalTimeSteps,
                                pml_sigma_lnr=0.1,pml_sigma_m=1.0,pml_widths=(zPMLWidth,zPMLWidth),launch_params=launchParams,
                                source_ramp=sourceRamp,source_delay=sourceDelay,absorption_padding=xyAbsorptionWidth,
                                absorption_coeff=xyAbsorptionCoeff)
    return p, d, d_loss, convertInjectLoc(p), tuple([wgd['wgInput'] for wgd in waveguideDictionary]), sim_params

def loss(theta,schedDict):
  p, d, d_loss, mpos, is_fwd, sim_params = build(theta, schedDict)
  svals=jnp.array(pjz.scatter(p, omegas, tuple(invariantModes), tuple(invariantBetas), tuple(mpos), tuple(is_fwd), sim_params))
  performance_loss,(ploss,floss) = performanceObjective(svals, omegas,d)
  fabrication_loss = jnp.sum(jnp.square(d_loss))
  total_loss = performance_loss + schedDict['gamma'] * jnp.nan_to_num(fabrication_loss,0)
  return total_loss, (ploss, fabrication_loss+floss)

##Optimization code

In [88]:
def optimize(optimizerParameters,debug=None):
    if optimizer in ['Nelder-Mead','Powell','CG','BFGS','Newton-CG','L-BFGS-B','TNC','COBYLA','SLSQP','trust-constr','dogleg','trust-ncg','trust-exact','trust-krylov']:
        evo=optimize_jaxopt(optimizerParameters,debug=debug)
    else:
        evo=optimize_optax(optimizerParameters,debug=debug)
    return evo

def optimize_jaxopt(optimizerParameters,debug=None):
    schedDict=scheduler(0,init_theta,np.inf,np.inf,0)
    evo,it=[{'loss':0,'ploss':0,'floss':0,'params':None,'time':0,'schedDict':deepcopy(schedDict)}],0
    ct=0
    gradient=None
    lsit=0
    def dummyLoss(params):
        nonlocal evo,gradient,lsit
        (value, (ploss, floss)), gradient = jax.value_and_grad(loss, has_aux=True)(params, schedDict)
        evo[-1]['loss']=float(jax.device_put(value,cpus[0]))
        evo[-1]['ploss']=float(jax.device_put(ploss,cpus[0]))
        evo[-1]['floss']=float(jax.device_put(floss,cpus[0]))
        lsit+=1
        if lsit>=9:
            raise ValueError("CF not converging")
        elif it>=maxIterations or ct>=maxTime or lsit>=6:
            return 1000,jnp.zeros_like(gradient)
        return value,gradient

    def cback(params):
        nonlocal evo,schedDict,it,ct,lsit,debug
        evo[-1]['params']=np.array(jax.device_put(params,cpus[0]))
        ct=datetime.datetime.now()-startt
        ct=ct.total_seconds()
        evo[-1]['time']=ct
        if verbose>0 and it%verbose==0:
            print('Iteration '+str(it)+', time '+str(evo[-1]['time'])+': '+str(evo[-1]['loss']))
        schedDict=scheduler(it,params,evo[-1]['loss'],evo[-1]['ploss'],ct)
        evo.append({'loss':0,'ploss':0,'floss':0,'params':None,'time':0,'schedDict':deepcopy(schedDict)})
        it+=1
        lsit=0
        if debug is not None and it==debug:
            del evo[-1]
            raise ValueError("debug exit")
        return
    jsm=jaxopt.ScipyMinimize(method=optimizer,has_aux=False,value_and_grad=True,jit=False,callback=cback,fun=dummyLoss,options=deepcopy(optimizerParameters),tol=0)
    if verbose>0:
        print('Beginning optimization with '+optimizer)
    startt=datetime.datetime.now()
    params=init_theta
    while it<maxIterations and ct<maxTime:
        try:
            params,info=jsm.run(params)
        except ValueError:
            ct=datetime.datetime.now()-startt
            ct=ct.total_seconds()
            schedDict=scheduler(it,params,evo[-1]['loss'],evo[-1]['ploss'],ct)
            pass
        if debug is not None and it==debug:
            del evo[-1]
            return evo
    del evo[-1]
    ct=datetime.datetime.now()-startt
    if verbose>0:
        print('Optimization done; final cost function '+str(evo[-1]['loss'])+', time '+str(ct)+', iteration '+str(it))
    if outputFile is not None:
        f=open(outputFile,'wb')
        po={'evo':evo,'globalParamDict':get_globalParamDict()}
        pickle.dump(po,f)
        f.close()
    return evo

def innerStep(theta,state,schedDict,opt):
    (v, (ploss, floss)),g=jax.value_and_grad(loss,has_aux=True)(jax.lax.stop_gradient(theta.copy()),schedDict) #Compute the value and grad of the user-defined performance loss and the fabrication loss
    updates, state = opt.update(len(omegas)*g,state,theta) #Get an update for theta and an updated state from the optimizer
    theta = optax.apply_updates(theta, updates) #Apply the update to get the new theta
    return theta,state,v,ploss,floss

def outerLoss(eta, theta, state,schedDict,etaOrder,opt):
    state.hyperparams.update(zip(etaOrder,jax.nn.sigmoid(eta)))
    if 'learning_rate' in etaOrder:
        state.hyperparams['learning_rate']=learningRate*1.*10**(2*(state.hyperparams['learning_rate']-0.5))
    plosses,flosses,losses=[],[],[]
    for i in range(innerPerOuter):
        theta,state,v,ploss,floss=innerStep(theta,state,schedDict,opt)
        plosses.append(ploss)
        flosses.append(floss)
        losses.append(v)
    finalLoss,(finalPloss,finalFloss)=loss(theta,schedDict)
    plosses.append(finalPloss)
    flosses.append(finalFloss)
    losses.append(finalLoss)
    return finalLoss,(theta,state,plosses,flosses,losses)

def outerStep(eta, theta, meta_state, state,schedDict,etaOrder,opt,metaOpt):
    ((v,(theta,state,plosses,flosses,losses)),g) = jax.value_and_grad(outerLoss, has_aux=True)(eta, theta, state,schedDict,etaOrder,opt)
    meta_updates, meta_state = metaOpt.update(g, meta_state)
    eta = optax.apply_updates(eta, meta_updates)
    return eta, theta, v, meta_state, state,plosses,flosses,losses

def optimize_optax(optimizerParameters,debug=None):
    global gradFac
    evo=[]
    if schedule is not None and metaOptimizer is None:
        opt=optax.inject_hyperparams(optimizers[optimizer.lower()])(learning_rate=schedule,**optimizerParameters)
    else:
        optimizerParameters['learning_rate']=learningRate
        opt=optax.inject_hyperparams(optimizers[optimizer.lower()])(**optimizerParameters)
    if metaOptimizer is not None:
        staticMetaMetaparameters['learning_rate']=metaLearningRate
        metaOpt=optimizers[metaOptimizer.lower()](**staticMetaMetaparameters)
        eta,etaOrder=[],[]
    params=init_theta.copy()
    opt_state = opt.init(params)
    eta=[]
    for kk in parametersToMetaOptimize:
        if kk=='learning_rate':
            eta.append(0.)# The LR can vary logarithmically from base/10 to base*10. So, init must be eta=0
        else:
            eta.append(-np.log(0.2 / (opt_state.hyperparams[kk]-0.8) - 1))
    eta=jnp.array(eta)
    if metaOptimizer is not None:
        meta_state = metaOpt.init(eta)
    schedDict=scheduler(0,params,np.inf,np.inf,0)
    (value, (ploss, floss)),g=jax.value_and_grad(loss,has_aux=True)(jax.lax.stop_gradient(params.copy()),schedDict)
    evo.append({'loss':float(jax.device_put(value,cpus[0])),'ploss':float(jax.device_put(ploss,cpus[0])),'floss':float(jax.device_put(floss,cpus[0])),
                        'params':np.array(jax.device_put(params,cpus[0])),'time': 0,'schedDict':deepcopy(schedDict)})
    if verbose>0:
        print('Beginning optimization with '+optimizer)
    startt=datetime.datetime.now()
    ct,it,ft,pL=datetime.datetime.now()-startt,0,jnp.inf,jnp.inf
    while ct.total_seconds()<=maxTime and it<=maxIterations:# and ft>ftol:
        if metaOptimizer is not None:
            eta, params, value, meta_state, opt_state,plosses,flosses,losses = outerStep(eta, params, meta_state, opt_state,schedDict,parametersToMetaOptimize,opt,metaOpt)
            ct=datetime.datetime.now()-startt
            evo.append({'loss':tuple(jax.device_put(losses,cpus[0])),'ploss':tuple(jax.device_put(plosses,cpus[0])),'floss':tuple(jax.device_put(flosses,cpus[0])),
                        'params':np.array(jax.device_put(params,cpus[0])),'time':ct.total_seconds(),'schedDict':deepcopy(schedDict),'eta':tuple(jax.device_put(eta,cpus[0])),'learningRate':float(opt_state.hyperparams['learning_rate'])})
        else:
            params,opt_state,value,ploss,floss=innerStep(params,opt_state,schedDict,opt)
            ct=datetime.datetime.now()-startt
            evo.append({'loss':float(jax.device_put(value,cpus[0])),'ploss':float(jax.device_put(ploss,cpus[0])),'floss':float(jax.device_put(floss,cpus[0])),
                        'params':np.array(jax.device_put(params,cpus[0])),'time':ct.total_seconds(),'schedDict':deepcopy(schedDict),'learningRate':float(opt_state.hyperparams['learning_rate'])})
        if verbose>0 and it%verbose==0:
            print('Iteration '+str(it)+', time '+str(ct)+': '+str(value)+', LR: '+str(opt_state.hyperparams['learning_rate']))
        if debug is not None and it==debug:
            return evo
        ft=jnp.abs(pL-value)/abs(value)
        it+=1
        schedDict=scheduler(it,params,value,pL,ct.total_seconds())
        pL=value
    if verbose>0:
        print('Optimization done; final cost function '+str(value)+', time '+str(ct)+', iteration '+str(it-1)+', ftol '+str(ft))
    if outputFile is not None:
        f=open(outputFile,'wb')
        po={'evo':evo,'globalParamDict':get_globalParamDict()}
        pickle.dump(po,f)
        f.close()
    return evo

## Run Optimization

In [73]:
ev=optimize(staticMetaparameters,debug=None)

Beginning optimization with adam
Iteration 0, time 0:00:02.642730: -0.37069046, LR: 0.009999998
Iteration 1, time 0:00:05.296836: -0.40233153, LR: 0.064125
Iteration 2, time 0:00:07.960540: -0.6037177, LR: 0.09988732
Iteration 3, time 0:00:10.627572: -0.8826083, LR: 0.09823869
Iteration 4, time 0:00:13.296966: -1.0881629, LR: 0.0946916
Iteration 5, time 0:00:15.967820: -1.228209, LR: 0.089401774
Iteration 6, time 0:00:18.635639: -1.3159552, LR: 0.0826015
Iteration 7, time 0:00:21.298765: -1.3755691, LR: 0.07458933
Iteration 8, time 0:00:23.968054: -1.4227599, LR: 0.06571706
Iteration 9, time 0:00:26.635339: -1.4611628, LR: 0.056374222
Iteration 10, time 0:00:29.302450: -1.4950337, LR: 0.046971064
Iteration 11, time 0:00:31.965456: -1.5228448, LR: 0.037920427
Iteration 12, time 0:00:34.624141: -1.5484986, LR: 0.029619694
Iteration 13, time 0:00:37.273452: -1.566438, LR: 0.022433331
Iteration 14, time 0:00:39.923851: -1.571307, LR: 0.016676862
Iteration 15, time 0:00:42.573465: -1.550591