In [21]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import glob
import ipywidgets
from scipy import constants as C
import scipy.interpolate
import IPython.display

mks_length = 0.8e-6/(2*np.pi)
#mks_length = 10.6e-6/(2*np.pi)
base_diagnostic = 'out/test'
uppe_list = glob.glob(base_diagnostic+'*_uppe_wave*.npy')
uppe_list += glob.glob(base_diagnostic+'*_uppe_source*.npy')
uppe_list += glob.glob(base_diagnostic+'*_uppe_plasma*.npy')
para_list = glob.glob(base_diagnostic+'*_paraxial_wave*.npy')
total_list = uppe_list + para_list

mpl.rcParams['text.usetex'] = False
mpl.rcParams['font.size'] = 24
l1_mm = 1e3*mks_length
t1_ps = 1e12*mks_length/C.c

file_w = ipywidgets.Dropdown(options=total_list,value=total_list[0],description='File')
frame_w = ipywidgets.IntSlider(min=0,max=0,step=1,value=0,continuous_update=False)
disp_w = ipywidgets.Dropdown(options=['Falsecolor2D','Lineout'],value='Falsecolor2D',description='Display')
rep_w = ipywidgets.Dropdown(options=['Time','Frequency','Wavelength'],value='Time',description='Representation')
color_w = ipywidgets.Dropdown(options=['viridis','gray','jet','plasma','inferno','ocean','seismic','bwr','prism',
                                      'nipy_spectral'],value='viridis',description='Color')
shape_w = ipywidgets.Textarea(value='no data',description='Shape')
bounds_w = ipywidgets.Textarea(value='no data',description='Bounds')
lim1_w = ipywidgets.FloatText(value=-1.0,description='Min')
lim2_w = ipywidgets.FloatText(value=1.0,description='Max')
transform_list = ['abs(x)','Log_e(abs(x))','Log_10(abs(x))','x^2','-x']
transform_select_w = ipywidgets.Dropdown(options=transform_list,value=transform_list[0],layout=ipywidgets.Layout(width='100px'))
transform_w = ipywidgets.Button(description='Transform')
autoscale_frame_w = ipywidgets.Button(description='Autoscale Frame')
autoscale_all_w = ipywidgets.Button(description='Autoscale All')
logo = open('docs/logo.png','rb')
image_w = ipywidgets.Image(value=logo.read(),format='png')

def load_data(file):
    global A,real_field,data_ext,qty_label
    real_field = file in uppe_list
    A = np.load(file)
    s = file.split('_')
    data_ext_name = s[0] + '_' + s[1] + '_' + s[2] + '_plot_ext.npy'
    data_ext = np.load(data_ext_name)
    data_ext = np.concatenate((data_ext[0:2],l1_mm*data_ext[2:4]))
    bounds_w.value = str(data_ext)
    shape_w.value = str(A.shape)
    frame_w.value = 0
    frame_w.max = A.shape[3]-1
    if 'wave' in file:
        qty_label = ( r'$eA(t)/mc^2$' , r'$|E|^2(\omega)$' , r'$|E|^2(\lambda)$' )
    if 'source' in file:
        qty_label = ( r'$J(t)$' , r'$|J|(\omega)$' , r'$|J|(\lambda)$' )
    if 'plasma' in file:
        qty_label = ( r'$n_e(t)/n_c$' , r'$|n_e|(\omega)/n_c$' , r'$|n_e|(\lambda)/n_c$' )

def load(file):
    load_data(file)
    change_rep_data(rep_w.value)
    disp_frame(file,frame_w.value,disp_w.value,rep_w.value,color_w.value)

def get_freq_time_data(A,ext):
    if real_field:
        # Nodes are like [0,1,2,3] and walls are like [-0.5,0.5,1.5,2.5,3.5]
        # In these examples the user's requested upper bound would be 4 (it is thrown out)
        # Therefore element N/2+1 should be regarded as the central frequency
        dw = (ext[1] - ext[0])/A.shape[0]
        wc = 0.5*ext[0] + 0.5*(ext[1] + dw)
        tmax = 2*np.pi/dw
        w_nodes = np.linspace(ext[0]+dw/2,ext[1]-dw/2,A.shape[0])
        return dw,wc,tmax,w_nodes
    else:
        # Nodes are like [-2,-1,0,1] and walls are like [-2.5,-1.5,-.5,.5,1.5]
        dw = (ext[1] - ext[0])/A.shape[0]
        wc = 0.5*ext[0] + 0.5*(ext[1] + dw)
        tmax = 2*np.pi/dw
        w_nodes = np.linspace(ext[0]+dw/2,ext[1]-dw/2,A.shape[0])
        return dw,wc,tmax,w_nodes

def change_rep_data(rep):
    global plot_data,hstr,vstr,cstr,plot_ext,qty_label
    plot_ext = np.copy(data_ext)
    if rep=='Time' and real_field:
        dw,wc,tmax,wn = get_freq_time_data(A,plot_ext)
        plot_ext[0] = 0.0
        plot_ext[1] = 1000*l1_mm*tmax
        #plot_ext[1] = t1_ps*tmax
        hstr = r'$ct-z$ (um)'
        #hstr = r'$t-z/c$ (ps)'
        vstr = r'$\rho$ (mm)'
        cstr = qty_label[0]
        plot_data = np.fft.irfft(A,axis=0)[::-1,...]
    if rep=='Time' and not real_field:
        dw,wc,tmax,wn = get_freq_time_data(A,plot_ext)
        plot_ext[0] = 0.0
        plot_ext[1] = t1_ps*tmax
        hstr = r'$t-z/c$ (ps)'
        vstr = r'$\rho$ (mm)'
        cstr = qty_label[0]
        plot_data = np.abs(np.fft.ifft(np.fft.ifftshift(A,axes=0),axis=0))[::-1,...]
    if rep=='Frequency':
        dw,wc,tmax,wn = get_freq_time_data(A,plot_ext)
        hstr = r'$\omega/\omega_0$'
        vstr = r'$\rho$ (mm)'
        cstr = qty_label[1]
        plot_data = np.abs(A*wn[...,np.newaxis,np.newaxis,np.newaxis])**2
        #plot_data = np.log10(plot_data+1e-20)
    if rep=='Wavelength':
        dw,wc,tmax,wn = get_freq_time_data(A,plot_ext)
        plot_data = (np.abs(A*wn[...,np.newaxis,np.newaxis,np.newaxis])**2)[1:][::-1]
        l_array_nu = (1000*l1_mm*2*np.pi/wn[1:])[::-1]
        l_array = np.linspace(l_array_nu[0],50,plot_data.shape[0])
        for i in range(plot_data.shape[1]):
            for j in range(plot_data.shape[2]):
                for k in range(plot_data.shape[3]):
                    fi = scipy.interpolate.interp1d(l_array_nu,plot_data[:,i,j,k])
                    plot_data[:,i,j,k] = fi(l_array)/l_array
        hstr = r'$\lambda$ (um)'
        vstr = r'$\rho$ (mm)'
        cstr = qty_label[2]
        plot_ext[0] = l_array[0]
        plot_ext[1] = l_array[-1]

def change_rep(rep):
    change_rep_data(rep)
    disp_frame(file_w.value,frame_w.value,disp_w.value,rep,color_w.value)

def transform_data(the_button):
    global plot_data,cstr
    smallest_num = 1e-25
    if transform_select_w.value=='Log_e(abs(x))':
        plot_data = np.log(smallest_num+np.abs(plot_data))
    if transform_select_w.value=='Log_10(abs(x))':
        cstr = r'$\log_{10}$' + cstr
        plot_data = np.log10(smallest_num+np.abs(plot_data))
    if transform_select_w.value=='x^2':
        plot_data = plot_data**2
    if transform_select_w.value=='-x':
        plot_data = -plot_data
    if transform_select_w.value=='abs(x)':
        plot_data = np.abs(plot_data)

def auto_scale_frame(the_button):
    yslice = np.int(plot_data.shape[2]/2)
    lim1_w.value = np.min(plot_data[:,:,yslice,frame_w.value])
    lim2_w.value = np.max(plot_data[:,:,yslice,frame_w.value])

def auto_scale_all(the_button):
    lim1_w.value = np.min(plot_data)
    lim2_w.value = np.max(plot_data)

def disp_frame(file,frame,disp,rep,color,lim2):
    tslice = np.int(A.shape[0]/2)
    yslice = np.int(A.shape[2]/2)
    fig_size = (21,10)
    plt.figure(1,figsize=fig_size)
    if disp=='Lineout':
        x = np.linspace(plot_ext[0],plot_ext[1],plot_data.shape[0])
        plt.plot(x,plot_data[:,yslice,yslice,frame])
        #x = np.linspace(plot_ext[2],plot_ext[3],plot_data.shape[1])
        #plt.plot(x,plot_data[tslice+4,:,yslice,frame])
    else:
        aspect_ratio_base = (plot_ext[1]-plot_ext[0]) / (plot_ext[3]-plot_ext[2])
        aspect_ratio = 0.3*(fig_size[1]/fig_size[0])*aspect_ratio_base
        plt.imshow(plot_data[:,:,yslice,frame].swapaxes(0,1),vmin=lim1_w.value,vmax=lim2_w.value,origin='lower',cmap=color,aspect=aspect_ratio,extent=plot_ext)
        b=plt.colorbar()
        b.set_label(cstr,size=24)
    plt.xlim(130,225)
    plt.xlabel(hstr,size=24)
    plt.ylabel(vstr,size=24)
    plt.tight_layout()
    plt.show()

rep_iw = ipywidgets.interactive(change_rep,rep=rep_w)
load_iw = ipywidgets.interactive(load,file=file_w)
update_iw = ipywidgets.interactive(disp_frame,file=file_w,frame=frame_w,disp=disp_w,rep=rep_w,color=color_w,lim2=lim2_w)
transform_w.on_click(transform_data)
autoscale_frame_w.on_click(auto_scale_frame)
autoscale_all_w.on_click(auto_scale_all)

autoscale_box = ipywidgets.HBox([autoscale_frame_w,autoscale_all_w])
transform_box = ipywidgets.HBox([transform_select_w,transform_w])
left_box = ipywidgets.VBox([file_w,frame_w,disp_w,rep_w,color_w,shape_w,bounds_w,lim1_w,lim2_w,transform_box,autoscale_box],
                            layout=ipywidgets.Layout(justify_content='flex-start',flex='1 3 auto'))

update_iw.children[-1].layout.flex = '4 2 auto'
main_view = ipywidgets.HBox([left_box, update_iw.children[-1]],
                            layout=ipywidgets.Layout(display='inline-flex',
                                                     align_items='stretch',
                                                     align_content='stretch',
                                                     justify_content='flex-start'))

load_data(file_w.value)
change_rep_data('Time')
IPython.display.display(main_view)

HBox(children=(VBox(children=(Dropdown(description='File', options=('out/test_air_uppe_wave.npy', 'out/test_ai…