The source code is in the public domain and not licensed or under
copyright. The information and software may be used freely by the public.
As required by 17 U.S.C. 403, third parties producing copyrighted works
consisting predominantly of the material produced by U.S. government
agencies must provide notice with such work(s) identifying the U.S.
Government material incorporated and stating that such material is not
subject to copyright protection.

Derived works shall not identify themselves in a manner that implies an
endorsement by or an affiliation with the Naval Research Laboratory.

RECIPIENT BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE
SOFTWARE AND ANY RELATED MATERIALS, AND AGREES TO INDEMNIFY THE NAVAL
RESEARCH LABORATORY FOR ALL THIRD-PARTY CLAIMS RESULTING FROM THE ACTIONS
OF RECIPIENT IN THE USE OF THE SOFTWARE.

In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import glob
import json
import ipywidgets
import warnings
from scipy import constants as C
import scipy.interpolate
import IPython.display
#%matplotlib widget
control_layout = 'left' # can be 'left' or 'bottom'
image_height_px = 700

undefine_list = ['ax','fig','lineout','image','bar','A','plot_data','real_field','data_ext','plot_ext','lbl5','qty_dict','init_fig']
for gv in undefine_list:
    if gv in globals():
        del globals()[gv]

# Find the data

base_diagnostic = 'out/'
sim_list = glob.glob(base_diagnostic+'*_sim.json')
uppe_list = glob.glob(base_diagnostic+'*_uppe_wave*.npy')
uppe_list += glob.glob(base_diagnostic+'*_uppe_j*.npy')
uppe_list += glob.glob(base_diagnostic+'*_uppe_chi*.npy')
uppe_list += glob.glob(base_diagnostic+'*_uppe_plasma*.npy')
para_list = glob.glob(base_diagnostic+'*_paraxial_wave*.npy')
para_list += glob.glob(base_diagnostic+'*_paraxial_j*.npy')
para_list += glob.glob(base_diagnostic+'*_paraxial_chi*.npy')
para_list += glob.glob(base_diagnostic+'*_paraxial_plasma*.npy')
total_list = uppe_list + para_list

# Units

if len(sim_list)==0:
    raise FileNotFoundError("no simulation metadata")
if len(sim_list)>1:
    warnings.warn("Multiple metadata files, normalization is based on " + sim_list[0])
with open(sim_list[0]) as f:
    sim_obj = json.load(f)
    mks_length = sim_obj['mks_length']

l1_mm = 1e3*mks_length
t1_ps = 1e12*mks_length/C.c
mpl.rcParams['font.size'] = 12

fig_pane = ipywidgets.Output()
file_w = ipywidgets.Dropdown(options=total_list,value=total_list[0],description='File')
zslice_w = ipywidgets.IntSlider(min=0,max=0,step=1,value=0,continuous_update=False)
xslice_w = ipywidgets.IntSlider(min=0,max=0,step=1,value=0,continuous_update=False)
yslice_w = ipywidgets.IntSlider(min=0,max=0,step=1,value=0,continuous_update=False)
tslice_w = ipywidgets.IntSlider(min=0,max=0,step=1,value=0,continuous_update=False)
hroll_w = ipywidgets.IntSlider(min=0,max=0,step=1,value=0,continuous_update=False)
vroll_w = ipywidgets.IntSlider(min=0,max=0,step=1,value=0,continuous_update=False)
ax_w = ipywidgets.Dropdown(options=['Falsecolor-tx','Falsecolor-xy','Falsecolor-zx','Lineout','Power'],value='Falsecolor-tx',description='Display')
rep_w = ipywidgets.Dropdown(options=['A(t)'],value='A(t)',description='Representation') # options generated on the fly
color_w = ipywidgets.Dropdown(options=['viridis','gray','jet','plasma','inferno','ocean','seismic','bwr','prism','flag',
                                      'nipy_spectral','gist_ncar','gist_earth'],value='viridis',description='Color',disabled=False)
shape_w = ipywidgets.Textarea(value='no data',description='Shape')
bounds_w = ipywidgets.Textarea(value='no data',description='Bounds')
lim1_w = ipywidgets.FloatText(value=-1.0,description='Min')
lim2_w = ipywidgets.FloatText(value=1.0,description='Max')
dpi_w = ipywidgets.FloatText(value=100.0,description='DPI')
fig_size_w = ipywidgets.Text(value='7,5',description='size (w,h)')
bar_orient_w = ipywidgets.Dropdown(options=['horizontal','vertical'],value='vertical',description='Colorbar')
tex_w = ipywidgets.Checkbox(value=False,description='Use TeX')
transform_list = ['abs(x)','Log_e(abs(x))','Log_10(abs(x))','x^2','-x','const*x']
transform_select_w = ipywidgets.Dropdown(options=transform_list,value=transform_list[0],layout=ipywidgets.Layout(width='100px'))
transform_w = ipywidgets.Button(description='Transform')
autoscale_frame_w = ipywidgets.Button(description='Autoscale Frame')
autoscale_all_w = ipywidgets.Button(description='Autoscale All')

def load_data(file):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    init_fig = True
    real_field = file in uppe_list
    # if a field is inherently real, but went through complex FFT, call it complexified
    complexified_field = not real_field and ('_chi.npy' in file or '_plasma.npy' in file)
    A = np.load(file)
    s = file.split('_')
    data_ext_name = s[0]
    for word in s[1:-1]:
        data_ext_name += '_' + word
    data_ext_name += '_plot_ext.npy'
    data_ext = np.load(data_ext_name)
    data_ext[2:8] *= l1_mm
    bounds_w.value = str(data_ext)
    shape_w.value = str(A.shape)
    tslice_w.max = A.shape[0]-1
    xslice_w.max = A.shape[1]-1
    yslice_w.max = A.shape[2]-1
    zslice_w.max = A.shape[3]-1
    tslice_w.value = int(A.shape[0]/2)
    xslice_w.value = int(A.shape[1]/2)
    yslice_w.value = int(A.shape[2]/2)
    zslice_w.value = 0
    rep_w.options = ('A(t)','|A(t)|','I(t)','|E(w)|^2',r'arg{E(w)}','|E(lambda)|^2')
    if 'wave' in file:
        temp = [ r'$eA(t)/mc^2$' , r'$e|A(t)|/mc^2$' , r'$I(t)$ (W$/$cm$^2$)' , r'$|E|^2(\omega)$' , r'arg$\{E(\omega)\}$' , r'$|E|^2(\lambda)$' ]
    if '_j' in file:
        temp = [ r'$J(t)/n_cec$' ,  r'$|J(t)|/n_cec$' , r'?' , r'$\omega^2|J|^2(\omega)$' , r'arg$\{i\omega J(\omega)\}$' , r'?' ]
    if 'chi' in file:
        temp = [ r'$\chi(t)$' , r'?' , r'?' , r'$\omega^2|\chi|^2(\omega)$' , r'arg$\{i\omega J(\omega)\}$' , r'?' ]
    if 'plasma' in file:
        temp = [ r'$n_e(t)/n_c$' , r'?' , r'?' , r'$\omega^2|n_e|^2(\omega)$' , r'arg$\{i\omega n_e(\omega)\}$' , r'?' ]
    if not complexified_field:
        temp[0] = temp[0][0] + r'\Re ' + temp[0][1:]
    qty_dict = {}
    for i,key in enumerate(rep_w.options):
        qty_dict[key] = temp[i]
    if real_field or complexified_field:
        rep_w.value = 'A(t)'
    else:
        rep_w.value = '|A(t)|'

def load(file):
    load_data(file)
    change_rep_data(rep_w.value)
    disp_frame(tslice_w.value,xslice_w.value,yslice_w.value,zslice_w.value,hroll_w.value,vroll_w.value)

def get_freq_time_data(A,ext):
    if real_field:
        # Nodes are like [0,1,2,3] and walls are like [-0.5,0.5,1.5,2.5,3.5]
        # In these examples the user's requested upper bound would be 4 (it is thrown out)
        # Therefore element N/2+1 should be regarded as the central frequency
        dw = (ext[1] - ext[0])/A.shape[0]
        wc = 0.5*ext[0] + 0.5*(ext[1] + dw)
        tmax = 2*np.pi/dw
        w_nodes = np.linspace(ext[0]+dw/2,ext[1]-dw/2,A.shape[0])
        return dw,wc,tmax,w_nodes
    else:
        # Nodes are like [-2,-1,0,1] and walls are like [-2.5,-1.5,-.5,.5,1.5]
        dw = (ext[1] - ext[0])/A.shape[0]
        wc = 0.5*ext[0] + 0.5*(ext[1] + dw)
        tmax = 2*np.pi/dw
        w_nodes = np.linspace(ext[0]+dw/2,ext[1]-dw/2,A.shape[0])
        return dw,wc,tmax,w_nodes

def change_rep_data(rep):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    init_fig = True
    plot_ext = np.copy(data_ext)
    #lbl3 = [r'$x_1$ (mm)',r'$x_2$ (mm)',r'$x_3$ (mm)']
    lbl3 = [r'$\varrho$ (mm)',r'$\varphi$',r'$z$ (mm)']
    dw,wc,tmax,wn = get_freq_time_data(A,data_ext)
    wnx = wn[...,np.newaxis,np.newaxis,np.newaxis]
    plot_ext[0:2] = [0.0,t1_ps*tmax]
    qty_label_now = [qty_dict[rep_w.value]]
    lbl5 = [r'$t-z/c$ (ps)'] + lbl3 + qty_label_now
    if rep=='A(t)' and real_field:
        plot_data = np.fft.irfft(A,axis=0)[::-1,...]
    if rep=='A(t)' and not real_field:
        plot_data = np.real(np.fft.ifft(np.fft.ifftshift(A,axes=0),axis=0))[::-1,...]
    if rep=='I(t)' and real_field:
        plot_data = (np.fft.irfft(A*wnx,axis=0)[::-1,...])**2
        plot_data *= (C.m_e*C.c*(C.c/mks_length)/C.e)**2 / 377 / 1e4
    if rep=='I(t)' and not real_field:
        plot_data = (np.abs(np.fft.ifft(np.fft.ifftshift(A*wnx,axes=0),axis=0))[::-1,...])**2
        plot_data *= (C.m_e*C.c*(C.c/mks_length)/C.e)**2 / (2*377) / 1e4
    if rep=='|A(t)|' and real_field:
        idx = np.argmax(np.abs(A[:,0,0]))
        plot_data = np.abs(np.fft.ifft(np.roll(A,-idx,axis=0),axis=0))[::-1,...]
    if rep=='|A(t)|' and not real_field:
        plot_data = np.abs(np.fft.ifft(np.fft.ifftshift(A,axes=0),axis=0))[::-1,...]
    if rep=='|E(w)|^2':
        plot_ext[0:2] = [wn[0],wn[-1]]
        lbl5 = [r'$\omega/\omega_0$'] + lbl3 + qty_label_now
        plot_data = np.abs(A*wnx)**2
    if rep==r'arg{E(w)}':
        plot_ext[0:2] = [wn[0],wn[-1]]
        lbl5 = [r'$\omega/\omega_0$'] + lbl3 + qty_label_now
        plot_data = np.angle(A*wnx)
    if rep=='|E(lambda)|^2':
        if real_field:
            max_microns = 20
            lbl5 = [r'$\lambda$ ($\mu$m)'] + lbl3 + qty_label_now
            plot_data = (np.abs(A*wnx)**2)[1:][::-1]
            l_array_nu = (1000*l1_mm*2*np.pi/wn[1:])[::-1]
            l_array = np.linspace(l_array_nu[0],max_microns,plot_data.shape[0])
            for i in range(plot_data.shape[1]):
                for j in range(plot_data.shape[2]):
                    for k in range(plot_data.shape[3]):
                        fi = scipy.interpolate.interp1d(l_array_nu,plot_data[:,i,j,k])
                        plot_data[:,i,j,k] = fi(l_array)/l_array**2
            plot_ext[0:2] = [l_array[0],l_array[-1]]
        else:
            raise ValueError("not supported for envelope fields")

def change_rep(rep):
    change_rep_data(rep)
    disp_frame(tslice_w.value,xslice_w.value,yslice_w.value,zslice_w.value,hroll_w.value,vroll_w.value)

def change_axes_data(ax):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    init_fig = True
    hroll_w.value = 0
    vroll_w.value = 0
    if ax=='Falsecolor-tx':
        tslice_w.disabled = True
        xslice_w.disabled = True
        yslice_w.disabled = False
        zslice_w.disabled = False
        yslice_w.max = plot_data.shape[2]-1
        yslice_w.value = int(plot_data.shape[2]/2)
        hroll_w.max = plot_data.shape[0]-1
        vroll_w.max = plot_data.shape[1]-1
    if ax=='Falsecolor-xy':
        tslice_w.disabled = False
        xslice_w.disabled = True
        yslice_w.disabled = True
        zslice_w.disabled = False
        tslice_w.max = plot_data.shape[0]-1
        tslice_w.value = int(plot_data.shape[0]/2)
        hroll_w.max = plot_data.shape[1]-1
        vroll_w.max = plot_data.shape[2]-1
    if ax=='Falsecolor-zx':
        tslice_w.disabled = False
        xslice_w.disabled = True
        yslice_w.disabled = False
        zslice_w.disabled = True
        tslice_w.max = plot_data.shape[0]-1
        tslice_w.value = int(plot_data.shape[0]/2)
        hroll_w.max = plot_data.shape[3]-1
        vroll_w.max = plot_data.shape[1]-1
    if ax=='Lineout':
        tslice_w.disabled = True
        xslice_w.disabled = False
        yslice_w.disabled = False
        zslice_w.disabled = False
        xslice_w.max = plot_data.shape[1]-1
        xslice_w.value = int(plot_data.shape[1]/2)
        yslice_w.max = plot_data.shape[2]-1
        yslice_w.value = int(plot_data.shape[2]/2)
        hroll_w.max = plot_data.shape[0]-1
        vroll_w.max = 0
    if ax=='Power':
        tslice_w.disabled = True
        xslice_w.disabled = True
        yslice_w.disabled = True
        zslice_w.disabled = False
        hroll_w.max = plot_data.shape[0]-1
        vroll_w.max = 0

def change_axes(ax):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    init_fig = True
    change_axes_data(ax)
    disp_frame(tslice_w.value,xslice_w.value,yslice_w.value,zslice_w.value,hroll_w.value,vroll_w.value)

def transform_data(the_button):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    init_fig = True
    smallest_num = 1e-25
    if transform_select_w.value=='Log_e(abs(x))':
        lbl5[4] = r'$\log_{e}$' + lbl5[4]
        plot_data = np.log(smallest_num+np.abs(plot_data))
    if transform_select_w.value=='Log_10(abs(x))':
        lbl5[4] = r'$\log_{10}$' + lbl5[4]
        plot_data = np.log10(smallest_num+np.abs(plot_data))
    if transform_select_w.value=='x^2':
        plot_data = plot_data**2
    if transform_select_w.value=='-x':
        plot_data = -plot_data
    if transform_select_w.value=='abs(x)':
        plot_data = np.abs(plot_data)
    if transform_select_w.value=='const*x':
        plot_data = plot_data * 23900
    disp_frame(tslice_w.value,xslice_w.value,yslice_w.value,zslice_w.value,hroll_w.value,vroll_w.value)

def auto_scale_frame(the_button):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    init_fig = True
    t = tslice_w.value
    x = xslice_w.value
    y = yslice_w.value
    z = zslice_w.value
    if ax_w.value=='Falsecolor-tx':
        lim1_w.value = np.min(plot_data[:,:,y,z])
        lim2_w.value = np.max(plot_data[:,:,y,z])
    if ax_w.value=='Falsecolor-xy':
        lim1_w.value = np.min(plot_data[t,:,:,z])
        lim2_w.value = np.max(plot_data[t,:,:,z])
    if ax_w.value=='Falsecolor-zx':
        lim1_w.value = np.min(plot_data[t,:,y,:])
        lim2_w.value = np.max(plot_data[t,:,y,:])
    if ax_w.value=='Lineout':
        lim1_w.value = np.min(plot_data[:,x,y,z])
        lim2_w.value = np.max(plot_data[:,x,y,z])

def auto_scale_all(the_button):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    init_fig = True
    lim1_w.value = np.min(plot_data)
    lim2_w.value = np.max(plot_data)

def force_update(color,lim1,lim2,tex,dpi,fig_size,bar_orient):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    init_fig = True
    disp_frame(tslice_w.value,xslice_w.value,yslice_w.value,zslice_w.value,hroll_w.value,vroll_w.value)
    
def disp_frame(t,x,y,z,hroll,vroll):
    global A, plot_data, real_field, data_ext, plot_ext, lbl5, qty_dict, init_fig
    global ax, fig, lineout, image, bar
    fig_pane.clear_output(wait=True)
    mpl.rcParams['text.usetex'] = tex_w.value
    plt.close('all')
    with fig_pane:
        fig,ax = plt.subplots(constrained_layout=True,dpi=dpi_w.value)
        fig.set_figwidth(float(fig_size_w.value.split(',')[0]))
        fig.set_figheight(float(fig_size_w.value.split(',')[1]))
        aspect_ratio = 'auto'
        if ax_w.value=='Lineout' or ax_w.value=='Power':
            if ax_w.value=='Lineout':
                slice_data = np.roll(plot_data[:,x,y,z],hroll)
                ylab = lbl5[4]
            else:
                slice_data = np.roll(np.sum(plot_data[...,z],axis=(1,2)),hroll)
                ylab = 'integrated'
            if init_fig:
                hor = np.linspace(plot_ext[0],plot_ext[1],slice_data.shape[0])
                lineout, = ax.plot(hor,slice_data)
                ax.set_ylim(lim1_w.value,lim2_w.value)
                ax.set_xlabel(lbl5[0],size=12)
                ax.set_ylabel(ylab,size=12)
            else:
                lineout.set_ydata(slice_data)
        if ax_w.value[:10]=='Falsecolor':
            axis_map = { 't' : 0 , 'x' : 1 , 'y' : 2 , 'z' : 3 }
            slicing = [slice(t,t+1),slice(x,x+1),slice(y,y+1),slice(z,z+1)]
            h = axis_map[ax_w.value[-2]]
            v = axis_map[ax_w.value[-1]]
            slicing[h] = slice(None)
            slicing[v] = slice(None)
            slice_data = np.squeeze(plot_data[tuple(slicing)])
            slice_data = np.roll(slice_data,hroll,axis=0)
            slice_data = np.roll(slice_data,vroll,axis=1)
            this_plot_ext = tuple(plot_ext[h*2:h*2+2]) + tuple(plot_ext[v*2:v*2+2])
            if h<v:
                slice_data = slice_data.swapaxes(0,1)
            if init_fig:
                image = ax.imshow(slice_data,vmin=lim1_w.value,vmax=lim2_w.value,origin='lower',cmap=color_w.value,aspect=aspect_ratio,extent=this_plot_ext)
                bar = fig.colorbar(image,ax=ax,orientation=bar_orient_w.value)
                ax.set_xlabel(lbl5[h],size=12)
                ax.set_ylabel(lbl5[v],size=12)
                bar.set_label(lbl5[4],size=12)
            else:
                image.set_data(slice_data)
        plt.show()

rep_iw = ipywidgets.interactive(change_rep,rep=rep_w)
ax_iw = ipywidgets.interactive(change_axes,ax=ax_w)
load_iw = ipywidgets.interactive(load,file=file_w)
update_iw = ipywidgets.interactive(force_update,color=color_w,lim1=lim1_w,lim2=lim2_w,tex=tex_w,dpi=dpi_w,fig_size=fig_size_w,bar_orient=bar_orient_w)
advance_iw = ipywidgets.interactive(disp_frame,t=tslice_w,x=xslice_w,y=yslice_w,z=zslice_w,hroll=hroll_w,vroll=vroll_w)

transform_w.on_click(transform_data)
autoscale_frame_w.on_click(auto_scale_frame)
autoscale_all_w.on_click(auto_scale_all)

autoscale_box = ipywidgets.HBox([autoscale_frame_w,autoscale_all_w])
transform_box = ipywidgets.HBox([transform_select_w,transform_w])
left_box = ipywidgets.VBox([file_w,tslice_w,xslice_w,yslice_w,zslice_w,hroll_w,vroll_w,ax_w,rep_w,color_w,shape_w,bounds_w,lim1_w,lim2_w,
                            dpi_w,fig_size_w,bar_orient_w,tex_w,transform_box,autoscale_box])

if control_layout=='left':
    main_view = ipywidgets.HBox([left_box, fig_pane])
else:
    main_view = ipywidgets.VBox([fig_pane, left_box])

load_data(file_w.value)
change_rep_data(rep_w.value)
change_axes_data(ax_w.value)
auto_scale_all(autoscale_all_w)
IPython.display.display(main_view)