If we assign U to any electric or magnetic field, the equation to solved is called the Helmholtz equation:

$$
\nabla^2 U + n^2(r)k_0^2 U=0
$$

These are the solutions:

$U(r,\phi,z)=u(r)e^{-jl\phi}e^{-j\beta z}$  for  $l=0,\pm 1,\pm 2...$

$$
u(r) = \left\{
\begin{array}{ll}
        J_l(k_T r) & r<a  \\
        K_l(\gamma r) & r>a
\end{array}
\right.
$$

where J and K are the Bessel functions of the first and second kind respectively.



In [15]:
import bqplot
from bqplot import pyplot as plt
import scipy.special as spe
import numpy as np
import ipywidgets as widgets
from IPython.display import display
import bqplot
from bqplot import pyplot as bqplt

pi = np.pi

class Mode:
    pass

def zero_func(X,V,L):
    Y=np.sqrt(V**2-X**2)
    return(X*spe.jv(L+1,X)/(spe.jv(L,X))-Y*spe.kv(L+1,Y)/(spe.kv(L,Y)))

def find_zeros_exact(X,Y,V,L):
    f=X*spe.jv(L+1,X)/(spe.jv(L,X))-Y*spe.kv(L+1,Y)/(spe.kv(L,Y))
                                     
    tt=len(X)
    zeros = []
    brackets=[]
                                                                                                          
    for ii in range(tt-1):
        if f[ii]*f[ii+1]<0: #change of sign
            if ii!=0 and ii!= tt-2: #not at an extreme
                if abs(f[ii-1]-f[ii+2])>abs(f[ii]-f[ii+1]): #not an asymptote
                    brackets += [[X[ii],X[ii+1]]]
            else:
                brackets += [[X[ii],X[ii+1]]]
    sols = []
    from scipy import optimize
    for br in brackets:
        sols.append(optimize.root_scalar(zero_func, args=(V,L), bracket=br, method='brentq'))
    
    return [a.root for a in sols]
                             
   

def find_modes(a=8.2/2, Na = 0.12, n_cladding = 1.444, w = 1.55):
    """Calculates all the modes in the fiber, and puts them in the list of modes.
    It also returns the total number of modes, which is higher than len(modes) because some modes
    have degeneracy 2 (L=0) and some have degeneracy 4 (L>0)"""
    

    k0=2*pi/w
    n_core = np.sqrt(Na**2+n_cladding**2)
    
    r=np.linspace(0,3*a,num=300)
    i_rad=round(np.interp(a,r,range(len(r)))) #i-th element at the radius

    V=k0*a*Na
    phi = np.linspace(1E-10, pi/2-1E-10, 5000)
    X=V*np.sin(phi)
    Y=V*np.cos(phi)

    solutions = True
    L=0
    M=1
    modes=[]
    tot_modes = 0
    while solutions==True:
        with np.errstate(invalid='ignore'):
            Lhs_1=X*spe.jv(L+1,X)/(spe.jv(L,X))
            Rhs_1=Y*spe.kv(L+1,Y)/(spe.kv(L,Y))
        #sols=CO.FindZerosInterp(Lhs_1-Rhs_1, X)
        #my_func = lambda x: (X*spe.jv(L+1,X)/(spe.jv(L,X))-
        #                    np.sqrt(V**2-X**2)*spe.kv(L+1,np.sqrt(V**2-X**2))/(spe.kv(L,np.sqrt(V**2-X**2))))
        #sols = CO.FindZerosExact(zero_func,x_min = 0, x_max=V, points_grid = 5000)
        sols = find_zeros_exact(X,Y,V,L)
        for sol in sols:
            kt=sol/a
            gamma=np.sqrt(V**2-sol**2)/a
            Er=spe.jv(L,kt*r)
            Er[i_rad::]=spe.kv(L,gamma*r[i_rad::])/spe.kv(L,gamma*r[i_rad])*spe.jv(L,kt*r[i_rad])
            Er = Er/max(abs(Er))
            modes.append(Mode())
            modes[-1].X=sol
            modes[-1].L = L
            modes[-1].M = M
            modes[-1].Er = Er[:]
            modes[-1].neff = np.sqrt(n_core**2-(kt/k0)**2)
            modes[-1].label="LP({0},{1})".format(L,M)
            modes[-1].degeneracy = 2 if L==0 else 4
            modes[-1].V = V
            modes[-1].a = a
            tot_modes += modes[-1].degeneracy
            M+=1

        M=1
        L += 1
        if len(sols)==0:
            solutions=False
    return modes, r, tot_modes

#these are for mode profile plotting
xx = np.linspace(-1.7, 1.7, 60)
yy = np.linspace(-1.7, 1.7, 60)
x_mesh, y_mesh = np.meshgrid(xx, yy)
r_mesh = np.sqrt(x_mesh**2+y_mesh**2)
phi_mesh = np.arctan2(y_mesh,x_mesh)

ones_mesh = np.ones((len(xx),len(yy)))
zeros_mesh = np.zeros((len(xx),len(yy)))
in_core_mesh = ones_mesh.copy()
in_core_mesh[r_mesh>1] = zeros_mesh[r_mesh>1] #mask core with ones
in_clad_mesh = ones_mesh.copy()
in_clad_mesh[r_mesh<=1] = zeros_mesh[r_mesh<=1] #mask cladding with ones

#this is to plot the core prerimeter later
phi_core_shape = np.linspace(0,2*pi,60)
x_core_shape = 1*np.cos(phi_core_shape)
y_core_shape = 1*np.sin(phi_core_shape)




slider_diam=widgets.FloatSlider(min=0.1, max = 80, value = 8.2)
slider_Na = widgets.FloatSlider(min=0.0001, max = 0.5, step=0.01, value = 0.12)
slider_lambda = widgets.FloatSlider(min=0.5, max = 2.0, step = 0.01, value = 1.55)
text_n_core = widgets.Text()
text_results = widgets.Text(value = "")
btn_calc = widgets.Button(description='Calculate modes')
btn_calc.style.button_color = 'lightgray'


n_cladding = 1.444

def update_text(change):
    V=2*pi/slider_lambda.value*slider_Na.value*slider_diam.value/2
    text_n_core.value = 'n cladding = 1.444, n core = {0:.3f}, V = {1:.3f}'.format(
                        np.sqrt(slider_Na.value**2+n_cladding**2), V) 

update_text(0)

slider_Na.observe(update_text, names = 'value')    
slider_lambda.observe(update_text, names = 'value')
slider_diam.observe(update_text, names = 'value')


a=slider_diam.value/2.0
Na = slider_Na.value

w=slider_lambda.value
V=2*pi/w*a*Na

    #diam_core = 50
    #Na=0.22 #typ. MM fiber
    
    
def fig_mode_profile(mode):
    
    X=mode.X
    L=mode.L
    V=mode.V
    a=mode.a
    kt=X/a
    gamma=np.sqrt(V**2-X**2)/a
    
    E_core=spe.jv(L,kt*a*r_mesh)*np.cos(L*phi_mesh)
    E_clad = spe.kv(L,gamma*a*r_mesh)*np.cos(L*phi_mesh)/spe.kv(L,gamma*a)*spe.jv(L,kt*a)
    E=E_core*in_core_mesh+E_clad*in_clad_mesh
    

    fig2=bqplt.figure(fig_margin = dict(top=10, bottom=10, left=10, right=10))
    fig2.layout.height = '140px'
    fig2.layout.width = '140px'
    bqplt.heatmap(E,x=xx,y=yy, cmap='RdBu')
    line=bqplt.plot(x_core_shape,y_core_shape,'k',stroke_width=.5)
    max_E = np.amax(abs(E))
    fig2.axes[0].scale.min = -max_E
    fig2.axes[0].scale.max = max_E
    for aa in fig2.axes:
        aa.visible = False

    return fig2


def fig_rad_plot(mode,r):
    
    min_y = min(mode.Er)
    max_y = max(mode.Er)
    xs = bqplot.LinearScale(min=0, max = 3*mode.a)
    ys = bqplot.LinearScale(min= min_y, max = max_y)

    line1 = bqplt.Lines(x=r,y=mode.Er, scales={'x': xs, 'y': ys})
    grid_line = bqplt.Lines(x=[mode.a, mode.a],y=[min_y,max_y],
                            scales={'x': xs, 'y': ys},stroke_width = 0.5,colors = ['black'])
    zero_line = bqplt.Lines(x=[0, 3*mode.a],y=[0,0],scales={'x': xs, 'y': ys}, stroke_width = 0.5,colors = ['black'])
    x_ax = bqplt.Axis(orientation = 'horizontal', scale = xs, tick_values=[], num_ticks = 0, visible = False)
    y_ax = bqplt.Axis(orientation = 'vertical', scale = ys,
                      tick_values = [], num_ticks = 0)

    new_fig = bqplt.figure(axes=[x_ax, y_ax], marks=[line1,grid_line,zero_line])
    new_fig.fig_margin = dict(top=10, bottom=10, left=10, right=10)
    new_fig.layout.height = '140px'
    new_fig.layout.width = '140px'
    
    
    #new_fig = bqplt.figure(fig_margin = dict(top=10, bottom=10, left=10, right=10))
    #new_fig.layout.height = '100px'
    #new_fig.layout.width = '100px'
    #bqplt.plot(r,mode.Er)
    #bqplt.plot(r,0*r,'k')
    #bqplt.plot([mode.a,mode.a],[min(mode.Er),max(mode.Er)],'k:')
    #for aa in fig.axes:
    #    aa.visible = False
    return new_fig
   
   
def update_table(modes,r):
    global num_modes_show
    for ii in range(num_modes_show):
        #mode_label.value = modes[0].label
        show_text = 'Mode {0}:\n{1}\nDegeneracy {2}\nneff = {3:0.6f}'.format(ii+1,
                                                                             modes[ii].label,
                                                                             modes[ii].degeneracy,
                                                                             modes[ii].neff)
        text_label = widgets.Textarea(value=show_text)
        text_label.layout.display = 'flex'
        text_label.layout.height = '140px'
        mode_row_list[ii].children = [text_label,fig_rad_plot(modes[ii],r), fig_mode_profile(modes[ii])]
    box_modes.children = mode_row_list[:len(modes)]    

 
    return

num_modes_show = 0
    
def btn_eventhandler(obj):
    global num_modes_show
    box_modes.children = []
    for ii in range(num_modes_show):
        bqplt.close(mode_row_list[ii].children[1])
        bqplt.close(mode_row_list[ii].children[2])
    
    w=slider_lambda.value
    diam_core = slider_diam.value #um, typ. SMF28
    Na = slider_Na.value #SMF-28
    a=diam_core/2
    
    modes, r, tot_modes = find_modes(a=a, Na=Na, w = w, n_cladding = n_cladding)
    modes.sort(key=lambda x: x.neff, reverse=True)
    
    num_modes = len(modes)
    text_results.value = "Distinct modes: {}. Total modes: {}".format(num_modes, tot_modes)
    num_modes_show = num_modes
    if num_modes_show >100:
        num_modes_show = 100
    update_table(modes,r)


    return





grid = widgets.GridspecLayout(4,20)

grid[0,:4]=widgets.Label('Core diameter ($\mu$m)')
grid[1,0:4]=widgets.Label('Num. Aperture (NA)')
grid[2,:4]=widgets.Label('Wavelength ($\mu$m)')
grid[0,4:10] = slider_diam
grid[1,4:10] = slider_Na
grid[2,4:10] = slider_lambda
grid[3,:4] = btn_calc
grid[3,4:10] = text_results
grid[1,11:] = text_n_core



mode_label = widgets.Text()

mode_row_layout = widgets.Layout(width = '600px', height = '150px', border = 'solid 1px')
mode_row_list = []
for ii in range(100):
    mode_row_list += [widgets.HBox([mode_label, mode_label, mode_label], layout = mode_row_layout)]
    

mode_list_layout = widgets.Layout(width='600px', height = '', flex_flow = 'column', display = 'flex')
box_modes = widgets.Box(children = [], layout = mode_list_layout)


display(grid)
display(box_modes)

btn_calc.on_click(btn_eventhandler)




GridspecLayout(children=(Label(value='Core diameter ($\\mu$m)', layout=Layout(grid_area='widget001')), Label(v…

Box(layout=Layout(display='flex', flex_flow='column', height='', width='600px'))