In [1]:
%matplotlib qt
from tkinter import *
from tkinter import filedialog
from scipy import interpolate
import scipy.constants as const
from scipy.optimize import curve_fit
from scipy.signal import savgol_filter
from scipy.ndimage import gaussian_filter1d
import h5py
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText

cwd = os.getcwd()

from Helpful_functions import *

In [3]:
def open_file():
    global sample
    global data, angles, energies, sapolar
    
    global state
    global offset
    global file_no
    
    global polar_slider
    global Fermi_slider
    global Ef
    
    global peak_no
    global EDC_entry
    global MDC_entry
    global MDC_save_cbox
    global EDC_save_cbox
    
    global add_parab_btn, WaterFall_btn, fit_parab_btn, FWHMvsE_btn
    global WF_upper_e, WF_lower_e, WF_step_e, WF_check_e, add_max
    
    global low_k, upp_k, low_E, upp_E
    global Gold_var
    global Curv_sldr, Curv_a
    global Fermi_range, Gamma_cor_lbl
    
    global var1, var2, red_dict, align_room
    
    plt.close()
    
    offset = 0
    directory = '\\\\cmdaq3.physics.ox.ac.uk\\CMDAQ_RestOfCMGroups\\ColdeaAGroup\\Data2024\\ARPES_Jan24\\AAH76_B'
    root.filename = filedialog.askopenfilename(initialdir=directory, 
                                               title='Select file', 
                                               filetypes=(('nexus','*.nxs'), ('all files', '*.*')))
    #Load the data
    with h5py.File(root.filename, 'r') as I05:
        '''entry1 is the first and only group name for a static scan
        Almost all data is within analyser, for details you'll have to have to read the instrument or sample groups
        '''
        state = '2D'
        encoding = 'utf-8'
        mode = str(I05['entry1/instrument/analyser/acquisition_mode'][0], encoding)
        Pol = str(I05['entry1/instrument/insertion_device/beam/final_polarisation_label'][0], encoding)
        phE = sigfig(np.squeeze(I05['entry1/instrument/monochromator/energy']), 4)
        T = sigfig(np.squeeze(I05['entry1/sample/cryostat_temperature'][:]), 3)
        data = np.squeeze(I05['entry1/analyser/data'][:])
        angles = np.squeeze(I05['entry1/analyser/angles'][:])
        energies = np.squeeze(I05['entry1/analyser/energies'][:])
        if energies.ndim != 1:
            energies = np.squeeze(I05['entry1/analyser/energies'][:])[0]
        analyser_keys = [*I05['entry1/analyser'].keys()]
        
        label3D = set(['sapolar', 'deflector_x'])
        myset = set(analyser_keys)
        common = myset.intersection(label3D)
        if len(common) == 0:
            state = '2D'
            sapolar = None
        else:
            state = '3D'
            sapolar_key = list(common)[0]
            sapolar_raw = I05['entry1/analyser/'+sapolar_key][:]
            sapolar = np.linspace(sapolar_raw[0], sapolar_raw[-1], len(sapolar_raw))
            # May need to add functionality for when deflector output is dodgy    
        
        if data.ndim == 4:
            energies = energies[0]
        
        analyser_keys.remove('data')
        analyser_keys.remove('angles')
        analyser_keys.remove('energies')
        if state == '3D':
            analyser_keys.remove(sapolar_key)
        perplong = set(['saperp', 'salong'])
        myset = set(analyser_keys)
        if len(myset.intersection(perplong)) != 0:
            analyser_keys.remove('sax')
            analyser_keys.remove('say')
        red_dict = {}
        for i, key in enumerate(analyser_keys):
            if data.ndim == 4:
                red_dict[key] = np.mean(np.squeeze(I05['entry1/analyser/'+key][:]), axis=1-i)
            else:
                red_dict[key] = np.squeeze(I05['entry1/analyser/'+key][:])
                
    '''remove dead pixels'''
    data = dead_pixels(data, 20)
    
    #Print the file number
    index = root.filename.find('i05-')
    file_no = root.filename[index:-4]
    Label(root, text=file_no).grid(row=0, column=1)
    
    Gold_btn = Button(root, text='Fermi level', command=Gold_plot).grid(row=0,column=2)
    
    #Print the dimensions of the data
    description = 'Data is '+state+': '+mode+' mode - '+str(phE)+'eV - '+Pol+' - '+str(T)+'K'
    state_label = Label(root, text=description).grid(row=1,column=0, columnspan=3)
    
    #Create a display button that plots the raw data
    First_plot_btn = Button(root, text='Initial Plot', command=lambda: int_plot(0)).grid(row=2,column=0)
    
    #Create a display button that plots the raw data
    Center_plot_btn = Button(root, text='Centering', command=Center).grid(row=2,column=1)
    
    #Create a display button that plots the data in k space
    K_plot_btn = Button(root, text='full K plot', command=K_plot).grid(row=2,column=2)
    
    #Create a slider
    if state == '3D':
        '''The sapolar values are not evenly spaced therefore use an index setting
        For this case the value above the slider must be supressed than manually made'''
        Label(root, text='sapolar slice').grid(row=4, column=0, rowspan=2)
        dp = sapolar[1]-sapolar[0]
        polar_slider = Scale(root, from_=0, to=len(sapolar), resolution=1, orient=HORIZONTAL, length=200, showvalue=0)
        polar_slider.set(int(len(sapolar)/2))
        polar_slider.bind("<ButtonRelease-1>", slider_label)
        polar_slider.bind("<ButtonRelease-1>", chg_slice, add='+')
        polar_slider.grid(row=5, column=1)
        polar_slice = sapolar[int(polar_slider.get())]
        Label(root, text='polar = {}'.format(round(polar_slice,3))).grid(row=4, column=1)

    #Create a Fermi slider
    Label(root, text='Fermi E').grid(row=6, column=0)
    de = energies[1]-energies[0]
    Fermi_slider = Scale(root, from_=energies[0], to=energies[-1], resolution=de, orient=HORIZONTAL, length=200)
    try:
        Fermi_slider.set(Ef)
    except:
        Ef = energies[int(len(energies)/2)]
        Fermi_slider.set(Ef)
    Fermi_slider.bind("<ButtonRelease-1>", add_Fermi)
    Fermi_slider.grid(row=6, column=1)
    
    #Create the features for a reduce K plot
    k_range_lbl = Label(root, text='k range (\u212B\u207B\u00B9)').grid(row=7, column=0)
    E_range_lbl = Label(root, text='E range (meV)').grid(row=8, column=0)
    low_k = Entry(root, width=10)
    low_k.insert(0, -0.5)
    low_k.grid(row=7, column=1)
    upp_k = Entry(root, width=10)
    upp_k.insert(0, 0.5)
    upp_k.grid(row=7, column=2)
    low_E = Entry(root, width=10)
    low_E.insert(0, -1000)
    low_E.grid(row=8, column=1)
    upp_E = Entry(root, width=10)
    upp_E.insert(0, 150)
    upp_E.grid(row=8, column=2)
    
    #Buttons for the reduced k_plot
    K2_plot_btn = Button(root, text='1. refined K plot', command=K2_plot).grid(row=9,column=0)
    add_kf_plot_btn = Button(root, text='2. add MDC', command=MDC_plot).grid(row=9,column=1)
    EDC_btn = Button(root, text='3. plot EDC', command=EDC_plot).grid(row=9,column=2)
    
    #Create Curvature box
    curv_room = LabelFrame(root, text='Curvature', padx=5, pady=5)
    curv_room.grid(row=10, column=0, columnspan=2)
    
    powers = np.linspace(-4, 1, 11)
    dpowers = powers[1]-powers[0]
    Curv_a = StringVar()
    Curv_a.set('a = 1')
    
    Curv_lbl = Label(curv_room, textvariable = Curv_a)
    Curv_lbl.grid(row=0, column=0, columnspan=3)
    Curv_sldr = Scale(curv_room, from_=powers[0], to=powers[-1], resolution=dpowers, orient=HORIZONTAL, length=200, showvalue=0)
    Curv_sldr.set(0)
    Curv_sldr.bind("<ButtonRelease-1>", chCurv_a)
    Curv_sldr.grid(row=1, column=0, columnspan=3)
    
    Button(curv_room, text='Calculate Curvature', command=Curvature).grid(row=2, column=0, columnspan=3)
    Button(curv_room, text='1D in k', padx=8, command=Curv_1D_k).grid(row=3, column=0)
    Button(curv_room, text='1D in E', padx=8, command=Curv_1D_E).grid(row=3, column=1)
    Button(curv_room, text='2D', padx=8, command=Curv_2D).grid(row=3, column=2)
    
    #Create saz, salong, saperp thang
    if len(analyser_keys) != 0:
        align_room = LabelFrame(root, text='Align room', padx=5, pady=5)
        align_room.grid(row=11, column=0, columnspan=3)
        var1 = StringVar()
        var1.set(analyser_keys[0])
        drop1 = OptionMenu(align_room, var1, *analyser_keys)
        drop1.grid(row=0, column=0)
        var2 = StringVar()
        var2.set('None')
        drop2 = OptionMenu(align_room, var2, 'None', *analyser_keys)
        drop2.grid(row=0, column=1)
        Button(align_room, text='Plot', padx=8, command=alignment_plot).grid(row=1, column=0)
        Button(align_room, text='2D sweep', padx=8, command=alignment_sweep).grid(row=2, column=0)
    
    #Create the Button the plots the fermi surface
    if state == '3D':
        fermi_room = LabelFrame(root, text='Fermi-surface', padx=5, pady=5)
        fermi_room.grid(row=10, column=2)
        Raw_Fermi_plot_btn = Button(fermi_room, text='Raw Fermi plot', command=Raw_Fermi_plot).grid(row=0, column=0)
        Fermi_plot_btn = Button(fermi_room, text='Fermi plot', command=Fermi_plot).grid(row=0,column=1)
        Axis_cut_btn = Button(fermi_room, text='Axis Cuts', command=axis_cuts).grid(row=1,column=0)
        Label(fermi_room, text='\u0393 correction (deg)').grid(row=3, column=0)
        Gamma_cor_lbl = Entry(fermi_room, width=10)
        Gamma_cor_lbl.insert(0, 0)
        Gamma_cor_lbl.grid(row=3, column=1)
        Label(fermi_room, text='k range (\u212B\u207B\u00B9)').grid(row=4, column=0)
        Fermi_range = Entry(fermi_room, width=10)
        Fermi_range.insert(0, 0.25)
        Fermi_range.grid(row=4, column=1)
        

In [4]:
'''Remove dead pixels'''
def dead_pixels(data, m):
    if data.ndim == 3:
        for i in range(data.shape[0]):
            data_slice = data[i]
            a, b = data_slice.shape
            flat = data_slice.flatten()
            indexing = np.argsort(flat)
            sort = np.sort(flat)
            Max = sort[-m-1]
            for j in range(m):
                flat[indexing[-j-1]] = Max
            data_slice = flat.reshape((a,b))
            data[i] = data_slice
    elif data.ndim == 2:
        a, b = data.shape
        flat = data.flatten()
        indexing = np.argsort(flat)
        sort = np.sort(flat)
        Max = sort[-m-1]
        for j in range(m):
            flat[indexing[-j-1]] = Max
        data = flat.reshape((a,b))
    else:
        print('I havent got this far')
    return data

In [5]:
def fermi_func(x, m, c, e, T, sigma, bkg_m, bkg_c):
    linear = m*x + c
    a = const.e/const.k
    fermi = 1/(np.exp((x-e)*a/T)+1)
    dx = x[1]-x[0]
    sigma_p = sigma*1e-3/dx
    conv = gaussian_filter1d(fermi*linear, sigma, mode='nearest')
    return conv + bkg_m*x + bkg_c

def Find_Fermi(X, Y):
    global ax
    '''X should be in units of eV'''
    fig, ax = plt.subplots()
    ax.plot(X, Y)
    plt.title('select 2 points to form\linear plot before step')
    plt.pause(0.01)
    P = np.asarray(plt.ginput(2, timeout=-1))
    m = (P[1,1] - P[0,1])/(P[1,0] - P[0,0])
    c = P[1,1] - P[1,0]*m
    plt.title('select a rough guess of\nFermi level')
    plt.pause(0.01)
    P = np.asarray(plt.ginput(1, timeout=-1))
    e = P[0,0]
    plt.title('select 2 points to form\linear plot before step')
    plt.pause(0.01)
    P = np.asarray(plt.ginput(2, timeout=-1))
    bkg_m = (P[1,1] - P[0,1])/(P[1,0] - P[0,0])
    bkg_c = P[1,1] - P[1,0]*bkg_m
    m, c = m - bkg_m, c - bkg_c
    plt.pause(0.01)
    guess = [m,c,e,10,2,bkg_m,bkg_c]

    ax.plot(X, fermi_func(X, *guess))
    plt.title('This is guess')
    plt.pause(1)

    plt.close()
    fig, ax = plt.subplots()
    ax.plot(X, Y)
    bounds = ([-np.inf, -np.inf, -np.inf, 0, 0, -np.inf, -np.inf], np.inf)
    XX, YY = np.array(X, dtype=np.float64), np.array(Y, dtype=np.float64)
    fermi_popt, pcov = curve_fit(fermi_func, XX, YY, p0 = guess, bounds=bounds, maxfev=10000)
    fermi_perr = np.sqrt(np.diag(pcov))

    ax.plot(X, fermi_func(X,*fermi_popt), color='tab:red')
    plt.title('')
    plt.pause(0.01)
    return fermi_popt, fermi_perr

def Gold_plot():
    global angles, energies, data
    global fermi_popt, Ef
    global Fermi_slider, polar_slider
    global file_no
    global ax
    
    if state == '2D':
        data_2D = data
    else:
        index = int(polar_slider.get())
        data_2D = data[index,:,:] 
    
    plt.close()
    
    fig, ax = plt.subplots()
    ready = False
    while ready == False:
        ax.clear()
        ax.contourf(angles, energies, data_2D.T, 50, cmap='Greys')
        plt.title('When ready Press Space')
        plt.show()
        plt.pause(0.01)
        ready = plt.waitforbuttonpress()
    an_bot, an_top = ax.get_xlim()
    en_bot, en_top = ax.get_ylim()
    plt.close()

    k_mask = (an_bot<angles)&(angles<an_top)
    data_1D = np.mean(data_2D[k_mask,:], axis=0)
    e_mask = (en_bot<energies)&(energies<en_top)
    X, Y = energies[e_mask], data_1D[e_mask]

    fig, ax = plt.subplots()
    plt.title('select point2 which we\'ll fit between')
    ax.plot(X, Y)
    plt.show()
    P = np.asarray(plt.ginput(2, timeout=-1))
    x_mask = (P[0,0]<X)&(X<P[1,0])
    XX, YY = X[x_mask], Y[x_mask]
    plt.close()
    
    fermi_popt, fermi_perr = Find_Fermi(XX, YY)

    Ef_er = fermi_perr[2]
    m, c, Ef, T, sigma, bkg_m, bkg_c = fermi_popt
    ax.axvline(Ef, c='0', ls='--', lw=0.8)
    dX = X[1] - X[0]
    res = sigma*dX*1e3
    
    '''Coordinates for label'''
    xbot, xtop = ax.get_xlim()
    ybot, ytop = ax.get_ylim()
    x = xbot + 0.55*(xtop-xbot)
    y = ybot + 0.6*(ytop-ybot)
    text1 = f'{sigfig(an_bot, 3)} < angles < {sigfig(an_top, 3)}\n'
    text2 = f'Fermi E = \n{np.round(Ef, 4)} \u00B1 {round(Ef_er, 4)}\nT = {round(T, 1)}\nsigma={round(res, 3)}meV'
    plt.text(x,y, text1+text2, fontsize=15, bbox=dict(facecolor='1'))
    plt.pause(0.01)
    
    Fermi_slider.set(Ef)

In [6]:
def chCurv_a(var):
    global Curv_a, Curv_sldr
    power = Curv_sldr.get()
    a = 10**power
    Curv_a.set('a = '+str(sigfig(a, 2)))

In [7]:
def int_plot(var):
    '''Plots the data just as it is in terms of angles
    For 3D data we just take the middle index of sapolar'''
    global data
    global angles
    global energies
    global sapolar
    global state
    global offset
    
    global polar_slider
    global Fermi_slider
    global ax1
    global ax2
    
    #close any open plots
    plt.close()
    
    if state == '2D':
        data_2D = data
    else:
        
        index = int(polar_slider.get())
        data_2D = data[index,:,:]
        
    X, Y = np.meshgrid(angles-offset, energies)
    
    Fermi_E = Fermi_slider.get()
    fig, ax1 = plt.subplots()
    ax1.contourf(X, Y, data_2D.T, 50, cmap='Greys')
    
    ax2 = ax1.twinx()
    ax2.set_ylim(ax1.get_ylim())
    ax2.set_yticks([])
    ax2.axhline(Fermi_E, linestyle='dashed')
    plt.show()

In [8]:
def add_Fermi(var):
    global Fermi_slider
    global ax1
    global ax2
    ax2.clear()
    ax2.set_ylim(ax1.get_ylim())
    ax2.set_yticks([])
    ax2.axhline(Fermi_slider.get(), linestyle='dashed')
    plt.pause(0.001)

In [9]:
def slider_label(var):
    global polar_slider
    polar_slice = sapolar[int(polar_slider.get())]
    Label(root, text='polar = {}'.format(round(polar_slice,3)), padx=5).grid(row=3, column=1)
    
def chg_slice(var):
    global data, angles, energies, sapolar
    global state
    global offset
    
    global polar_slider
    global ax1, ax2
    
    #close any open plots
    ax1.clear()
    
    if state == '2D':
        data_2D = data
    else:
        
        index = int(polar_slider.get())
        data_2D = data[index,:,:]
        
    X, Y = np.meshgrid(angles-offset, energies)
    
    ax1.contourf(X, Y, data_2D.T, 50, cmap='Greys')
    plt.pause(0.001)

In [10]:
def center_click(event):
    global ax1_c, ax2_c
    global offset
    ax2_c.clear()
    offset = event.xdata
    ax2_c.axvline(offset, color='0', ls='--', lw=0.8)
    ax2_c.set_ylim(ax1_c.get_ylim())
    ax2_c.set_yticks([])
    plt.title('Press Space to redo Contrast')
    plt.pause(0.01)
    
def re_contrast(event):
    global ax1_c
    global angles, energies, data_2D
    Xbot, Xtop = ax1_c.get_xlim()
    Ybot, Ytop = ax1_c.get_ylim()

    X_mask = (Xbot<=angles)&(angles<=Xtop)
    Y_mask = (Ybot<=energies)&(energies<=Ytop)

    data_red = data_2D[X_mask][:,Y_mask]
    X, Y = angles[X_mask], energies[Y_mask]
    
    ax1_c.clear()
    plt.title('Press Space to redo Contrast')
    ax1_c.contourf(X, Y, data_red.T, 100, cmap='Greys')
    ax1_c.set_xlim(Xbot, Xtop)
    ax1_c.set_ylim(Ybot, Ytop)
    plt.pause(0.01)
    
def Center():
    '''Plots the data just as it is in terms of angles
    For 3D data we just take the middle index of sapolar'''
    global data, angles, energies, sapolar
    global data_2D
    global state, polar_slider
    global ax1_c, ax2_c
    
    #close any open plots
    plt.close()
    
    if state == '2D':
        data_2D = data
    else:
        index = int(polar_slider.get())
        data_2D = data[index,:,:]
    
    fig, ax1_c = plt.subplots()
    ax1_c.clear()
    ax1_c.contourf(angles, energies, data_2D.T, 100, cmap='Greys')
    ax2_c = ax1_c.twinx()
    ax2_c.set_ylim(ax1_c.get_ylim())
    ax2_c.set_yticks([])

    fig.canvas.mpl_connect('button_press_event', center_click)
    fig.canvas.mpl_connect('key_press_event', re_contrast)

In [11]:
def K_plot():
    global data, angles, energies
    global k_full, E_full, k_data_full
    global state
    global offset
    global Fermi_slider
    global polar_slider
    
    #close any open plots
    plt.close()
    
    if state == '2D':
        data_2D = data
    else:
        index = int(polar_slider.get())
        data_2D = data[index,:,:] 
    
    new_angles = (angles - offset)*np.pi/180
    vec_y = np.sin(new_angles)

    my_interp = interpolate.RectBivariateSpline(vec_y, energies, data_2D)

    E_min = np.min(energies)
    k_range = np.sqrt(E_min)*vec_y

    new_data = np.zeros(data_2D.shape)
    for i, E in enumerate(energies):
        new_ang_vec = k_range/np.sqrt(E)
        new_data[:,i] = my_interp(new_ang_vec, E)[:,0]

    E_full = (energies - Fermi_slider.get())*1000 #center on Fermi energy and convert to meV
    
    f = np.sqrt(2*const.m_e*const.e)/const.hbar *1e-10 #convert to inverse angstrom
    k_full = k_range*f
    
    k_data_full = new_data.T
    
    plt.figure()
    plt.contourf(k_full, E_full, k_data_full, 300, cmap='Greys')
    plt.axhline(0, linestyle='dashed', color='0', linewidth=1)
    plt.xlabel(r'$k$'+' '+r'$(\AA^{-1})$', fontsize=14)
    plt.ylabel(r'$E - E_f$'+' (meV)', fontsize=14)
    plt.tick_params('both', labelsize=11)
    plt.tight_layout()
    plt.show()

In [12]:
def K2_plot():    
    global data, angles, energies
    global k_data_red, k_red, E_red
    global state
    global offset
    global Fermi_slider
    global polar_slider
    global file_no
    
    global low_k, upp_k, low_E, upp_E
    
    global ax1, ax2
    
    global MDC_data, MDC_data_red, EDC_data
    
    #close any open plots
    plt.close()
    
    if state == '2D':
        data_2D = data
    else:
        index = int(polar_slider.get())
        data_2D = data[index,:,:] 
    
    new_angles = (angles - offset)*np.pi/180
    vec_y = np.sin(new_angles)

    my_interp = interpolate.RectBivariateSpline(vec_y, energies, data_2D)

    E_min = np.min(energies)
    k_range = np.sqrt(E_min)*vec_y

    new_data = np.zeros(data_2D.shape)
    for i, E in enumerate(energies):
        new_ang_vec = k_range/np.sqrt(E)
        new_data[:,i] = my_interp(new_ang_vec, E)[:,0]
    
    Ef = Fermi_slider.get()
    Energy = (energies - Ef)*1000 #center on Fermi energy and convert to meV
    
    f = np.sqrt(2*const.m_e*const.e)/const.hbar *1e-10 #convert to inverse angstrom
    k_range = k_range*f
    
    k_data = new_data.T
    
    k_mask = (float(low_k.get())<k_range)&(k_range<float(upp_k.get()))
    E_mask = (float(low_E.get())<Energy)&(Energy<float(upp_E.get()))
    
    k_red, E_red = k_range[k_mask], Energy[E_mask]
    k_data_red = new_data[k_mask][:,E_mask].T
    
    fig, ax1 = plt.subplots()
    ax1.contourf(k_red, E_red, k_data_red, 300, cmap='Greys')
    ax1.axhline(0, linestyle='dashed', color='0', linewidth=1)
    ax2 = ax1.twinx()
    ax2.set_ylim(ax1.get_ylim())
    ax2.set_yticks([])
    ax1.set_xlabel(r'$k$'+' '+r'$(\AA^{-1})$', fontsize=14)
    ax1.set_ylabel(r'$E - E_f$'+' (meV)', fontsize=14)
    xbot, xtop = ax1.get_xlim()
    ybot, ytop = ax1.get_ylim()
    x = xbot + 0.025*(xtop-xbot)
    y = ybot + 1.025*(ytop-ybot)
    label = file_no+' - Ef='+str(round(Ef, 4))+'eV'
    ax1.text(x,y,label)
    plt.tick_params('both', labelsize=11)
    plt.tight_layout()
    plt.show()
    
    #Define MDC
    dE = 2.5
    MDC_mask = (-dE<Energy)&(Energy<dE)
    MDC = np.mean(k_data[MDC_mask, :], axis=0)
    
    MDC_data = [k_range, MDC]
    MDC_data_red = [k_red, MDC[k_mask]]
    
    #Define EDC
    dk = 0.02
    EDC_mask = (-dk<k_range)&(k_range<dk)
    EDC = np.mean(k_data[:, EDC_mask], axis=1)
    
    EDC_data = [Energy, EDC]

In [13]:
def MDC_plot():
    global ax1, ax2
    global k_data_red, k_red, E_red
    global Ef, MDC, f
    
    '''Find Fermi index
    then plot the kf plot on top of the k_plot'''
    xbot, xtop = ax1.get_xlim()
    ybot, ytop = ax1.get_ylim()
    
    dE = 2    
    E_mask = (-dE<E_red)&(E_red<dE)
    MDC = np.sum(k_data_red[E_mask,:], axis=0)
    f = 0.95*ytop/np.max(MDC)
    
    ax2.plot(k_red, MDC*f)
    ax2.set_xlim(xbot, xtop)
    ax2.set_ylim(ybot, ytop)
    ax2.set_yticks([])
    plt.pause(0.01)

In [14]:
def EDC_plot():
    global k_full, E_full, k_data_full
    
    plt.close()
    dk = 0.025
    k_mask = (-dk<k_full)&(k_full<dk)
    EDC = np.sum(k_data_full[:,k_mask], axis=1)
    
    plt.figure()
    plt.plot(E_full, EDC)
    plt.show()

In [15]:
def Curvature():
    global k_data_red, k_red, E_red
    global C1D_M, C1D_E, C2D
    global indx, Curv_sldr
    '''Curvature
    As defined in https://doi.org/10.1063/1.3585113'''
    
    a = sigfig(10**(Curv_sldr.get()), 2)
    window = 31
    indx = int((window-1)/2)

    dx = k_red[1]-k_red[0]
    dy = E_red[1]-E_red[0]

    #slice along constant E
    M_2nd = np.zeros(k_data_red.shape)
    M_1st = np.zeros(k_data_red.shape)
    for i in range(E_red.size):
        M_1st[i,:] = savgol_filter(k_data_red[i,:], window, 3, deriv=1)
        M_2nd[i,:] = savgol_filter(k_data_red[i,:], window, 3, deriv=2)
    M_1st = M_1st/dx
    M_2nd = M_2nd/(dx*dx)

    #slice along constant k
    E_2nd = np.zeros(k_data_red.shape)
    E_1st = np.zeros(k_data_red.shape)
    for i in range(k_red.size):
        E_1st[:,i] = savgol_filter(k_data_red[:,i], window, 3, deriv=1)
        E_2nd[:,i] = savgol_filter(k_data_red[:,i], window, 3, deriv=2)
    E_1st = E_1st/dy
    E_2nd = E_2nd/(dy*dy)

    EM_1st = np.zeros(k_data_red.shape)
    for i in range(k_red.size):
        EM_1st[:,i] = savgol_filter(M_1st[:,i], 51, 3, deriv=1)
    EM_1st = EM_1st/dy

    # 1D curvature at constant E
    Co_M = a*np.max(M_1st**2)
    C1D_M = M_2nd/(Co_M + M_1st**2)**1.5
    # 1D curvature at constant k
    Co_E = a*np.max(E_1st**2)
    C1D_E = E_2nd/(Co_E + E_1st**2)**1.5
    
    Cx = 1/Co_M
    Cy = Cx*(dy/dx)**2

    # 2D curvature
    Top_a = (1 + Cx*M_1st**2)*Cy*E_2nd
    Top_b = -2*Cy*Cx*M_1st*E_1st*EM_1st
    Top_c = (1 + Cy*E_1st**2)*Cx*M_2nd
    Bot = (1 + Cx*M_1st**2 + Cy*E_1st**2)**1.5
    C2D = (Top_a+Top_b+Top_c)/Bot

def Curv_1D_k():
    global k_data_red, k_red, E_red
    global C1D_M, indx
    global ax1, ax2
    plt.close()
    
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax2.set_ylim(ax1.get_ylim())
    ax2.set_yticks([])
    plt.title('Curvature 1D - MDC')
    ax1.contourf(k_red[indx:-indx], E_red[indx:-indx], C1D_M[indx:-indx,indx:-indx], 150, cmap='Greys')
    ax1.set_xlabel(r'$k$'+' '+r'$(\AA^{-1})$', fontsize=14)
    ax1.set_ylabel(r'$E - E_f$'+' (meV)', fontsize=14)
    plt.tight_layout()
    plt.show()

def Curv_1D_E():
    global k_data_red, k_red, E_red
    global C1D_E, indx
    global ax1, ax2
    plt.close()
    
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax2.set_ylim(ax1.get_ylim())
    ax2.set_yticks([])
    plt.title('Curvature 1D - EDC')
    ax1.contourf(k_red[indx:-indx], E_red[indx:-indx], C1D_E[indx:-indx,indx:-indx], 150, cmap='Greys')
    ax1.set_xlabel(r'$k$'+' '+r'$(\AA^{-1})$', fontsize=14)
    ax1.set_ylabel(r'$E - E_f$'+' (meV)', fontsize=14)
    plt.tight_layout()
    plt.show()

def Curv_2D():
    global k_data_red, k_red, E_red
    global C2D, indx
    global ax1, ax2
    plt.close()
    
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax2.set_ylim(ax1.get_ylim())
    ax2.set_yticks([])
    plt.title('Curvature 2D')
    ax1.contourf(k_red[indx:-indx], E_red[indx:-indx], C2D[indx:-indx,indx:-indx], 150, cmap='Greys')
    ax1.set_xlabel(r'$k$'+' '+r'$(\AA^{-1})$', fontsize=14)
    ax1.set_ylabel(r'$E - E_f$'+' (meV)', fontsize=14)
    plt.tight_layout()
    plt.show()

In [16]:
def cursor_center(event):
    global l1, l2, m1, m2, sapolar_o, angles_o, az_slider, az
    sapolar_o, angles_o = event.xdata, event.ydata
    tan = np.tan(az*np.pi/180)
    sine = size*np.sin(az*np.pi/180)/3
    cosine = size*np.cos(az*np.pi/180)/3
    #line1 - horizontal
    l1.set_data([sapolar_o-2*size, sapolar_o+2*size], [angles_o-2*size*tan, angles_o+2*size*tan])
    m1.set_offsets(np.array([[sapolar_o-2*cosine, sapolar_o-cosine, sapolar_o+cosine, sapolar_o+2*cosine],
                             [angles_o-2*sine, angles_o-sine, angles_o+sine, angles_o+2*sine]]).T)
    #line2 - vertical
    l2.set_data([sapolar_o+2*size*tan, sapolar_o-2*size*tan], [angles_o-2*size, angles_o+2*size])
    m2.set_offsets(np.array([[sapolar_o-2*sine, sapolar_o-sine, sapolar_o+sine, sapolar_o+2*sine],
                             [angles_o+2*cosine, angles_o+cosine, angles_o-cosine, angles_o-2*cosine]]).T)
    plt.pause(0.001)

def cursor_rotate(event):
    global l1, l2, m1, m2, sapolar_o, sapolar_o, angles_o, az
    if event.key == 'left':
        az -= 0.2
    if event.key == 'right':
        az += 0.2
    tan = np.tan(az*np.pi/180)
    sine = size*np.sin(az*np.pi/180)/3
    cosine = size*np.cos(az*np.pi/180)/3
    #line1 - horizontal
    l1.set_data([sapolar_o-2*size, sapolar_o+2*size], [angles_o-2*size*tan, angles_o+2*size*tan])
    m1.set_offsets(np.array([[sapolar_o-2*cosine, sapolar_o-cosine, sapolar_o+cosine, sapolar_o+2*cosine],
                             [angles_o-2*sine, angles_o-sine, angles_o+sine, angles_o+2*sine]]).T)
    #line2 - vertical
    l2.set_data([sapolar_o+2*size*tan, sapolar_o-2*size*tan], [angles_o-2*size, angles_o+2*size])
    m2.set_offsets(np.array([[sapolar_o-2*sine, sapolar_o-sine, sapolar_o+sine, sapolar_o+2*sine],
                             [angles_o+2*cosine, angles_o+cosine, angles_o-cosine, angles_o-2*cosine]]).T)
    plt.title(f'azimuth = {round(az, 2)}')
    plt.pause(0.001)

def Raw_Fermi_plot():
    global data, angles, sapolar, energies
    global l1, l2, m1, m2, sapolar_o, angles_o, az, size
    global Fermi_slider
    
    size = max([abs(sapolar[-1]), abs(sapolar[0]), abs(angles[-1]), abs(angles[0])])
    
    E_fermi = Fermi_slider.get()
    dE = 10e-3
    mask = (E_fermi-dE<energies)&(energies<E_fermi+dE)
    Fermi_data = np.mean(data[:,:,mask], axis=2)
    
    sapolar_o, angles_o = np.mean(sapolar), np.mean(angles)
    az = 0

    fig, ax = plt.subplots()
    ax.contourf(sapolar, angles, Fermi_data.T, 300, cmap='Greys')
    ax.set_xlabel(r'$\Theta$', fontsize=14)
    ax.set_ylabel(r'$\alpha$', fontsize=14)
#     ax.set_xlim(sapolar_o-size, sapolar_o+size)
#     ax.set_ylim(angles_o-size, angles_o+size)
    ax.set_aspect('equal', adjustable='box')
    l1, = ax.plot([sapolar_o-size, sapolar_o+size], [angles_o, angles_o], color='r', lw=0.8, ls='--')
    l2, = ax.plot([sapolar_o, sapolar_o], [angles_o-size, angles_o+size], color='r', lw=0.8, ls='--')
    m1 = ax.scatter([sapolar_o-2*size/3, sapolar_o-size/3, sapolar_o+size/3, sapolar_o+2*size/3], 
                    [angles_o, angles_o, angles_o, angles_o], 
                    color='r', marker='.', s=10)
    m2 = ax.scatter([sapolar_o, sapolar_o, sapolar_o, sapolar_o],
                    [angles_o-2*size/3, angles_o-size/3, angles_o+size/3, angles_o+2*size/3],
                    color='r', marker='.', s=10)

    fig.canvas.mpl_connect('button_press_event', cursor_center)
    fig.canvas.mpl_connect('key_press_event', cursor_rotate)
    plt.show()
    
def FindIndex(element, array):
    temp = abs(array-element)
    indices = np.arange(array.size)
    index = indices[temp == min(temp)][0]
    return index

def recontrast_axis_cut(event):
    global data, angles, sapolar, energies
    global sapolar_o, angles_o
    global axs, lh0, lv0, l1, l2
    global ang_w, sap_w
    left1, right1 = axs[1].get_xlim()
    bot1, top1 = axs[1].get_ylim()
    kmask1 = (left1<angles)&(angles<right1)
    Emask1 = (bot1<energies)&(energies<top1)
    left2, right2 = axs[2].get_xlim()
    bot2, top2 = axs[2].get_ylim()
    kmask2 = (left2<sapolar)&(sapolar<right2)
    Emask2 = (bot2<energies)&(energies<top2)
    
    axs[1].clear()
    sap_i = FindIndex(sapolar_o, sapolar)
    Z = np.mean(data[sap_i-sap_w:sap_i+1+sap_w], axis=0).T
    axs[1].contourf(angles[kmask1], energies[Emask1], Z[Emask1][:,kmask1], 100, cmap='Greys')
    l1 = axs[1].axvline(angles_o, color='0.5', ls='--')
    axs[1].set_yticks([])
    axs[2].clear()
    ang_i = FindIndex(angles_o, angles)
    Z = np.mean(data[:,ang_i-ang_w:ang_i+1+ang_w], axis=1).T
    axs[2].contourf(sapolar[kmask2], energies[Emask2], Z[Emask2][:,kmask2], 100, cmap='Greys')
    l2 = axs[2].axvline(sapolar_o, color='0.5', ls='--')
    axs[2].set_yticks([])
    plt.pause(0.01)

def update_axis_cut(event):
    global data, angles, sapolar, energies
    global sapolar_o, angles_o
    global axs, lh0, lv0, l1, l2
    global ang_w, sap_w
    if event.inaxes.get_label() == '0':
        sapolar_o, angles_o = event.xdata, event.ydata
        lh0.set_ydata(angles_o)
        lv0.set_xdata(sapolar_o)
        axs[1].clear()
        sap_i = FindIndex(sapolar_o, sapolar)
        Z = np.mean(data[sap_i-sap_w:sap_i+1+sap_w], axis=0).T
        axs[1].contourf(angles, energies, Z, 100, cmap='Greys')
        l1 = axs[1].axvline(angles_o, color='0.5', ls='--')
        axs[1].set_yticks([])
        axs[2].clear()
        ang_i = FindIndex(angles_o, angles)
        Z = np.mean(data[:,ang_i-ang_w:ang_i+1+ang_w], axis=1).T
        axs[2].contourf(sapolar, energies, Z, 100, cmap='Greys')
        l2 = axs[2].axvline(sapolar_o, color='0.5', ls='--')
        axs[2].set_yticks([])
    if event.inaxes.get_label() == '1':
        angles_o = event.xdata
        lh0.set_ydata(angles_o)
        l1.set_xdata(angles_o)
    if event.inaxes.get_label() == '2':
        sapolar_o = event.xdata
        lv0.set_xdata(sapolar_o)
        l2.set_xdata(sapolar_o)
    plt.pause(0.01)

def axis_cuts():
    global data, angles, sapolar, energies
    global sapolar_o, angles_o
    global axs, lh0, lv0, l1, l2
    global ang_w, sap_w, Fermi_slider
    plt.close()
    
    E_fermi = Fermi_slider.get()
    dE = 10e-3
    mask = (E_fermi-dE<energies)&(energies<E_fermi+dE)
    Fermi_data = np.mean(data[:,:,mask], axis=2)
    
    fig, axs = plt.subplots(1, 3)

    axs[0].set_label('0')
    axs[1].set_label('1')
    axs[2].set_label('2')

    axs[0].contourf(sapolar, angles, Fermi_data.T, 100, cmap='Greys')
    sapolar_o, angles_o = 0, 0
    lh0 = axs[0].axhline(angles_o, color='0.5', ls='--')
    lv0 = axs[0].axvline(sapolar_o, color='0.5', ls='--')
    axs[0].set_aspect('equal')

    sap_i = FindIndex(sapolar_o, sapolar)
    sap_w = 4
    Z = np.mean(data[sap_i-sap_w:sap_i+1+sap_w], axis=0).T
    axs[1].contourf(angles, energies, Z, 100, cmap='Greys')
    l1 = axs[1].axvline(angles_o, color='0.5', ls='--')
    axs[1].set_yticks([])

    ang_i = FindIndex(angles_o, angles)
    ang_w = 40
    Z = np.mean(data[:,ang_i-ang_w:ang_i+1+ang_w], axis=1).T
    axs[2].contourf(sapolar, energies, Z, 100, cmap='Greys')
    l2 = axs[2].axvline(sapolar_o, color='0.5', ls='--')
    axs[2].set_yticks([])
    
    axs[0].set_xlabel('sapolar')
    axs[0].set_ylabel('alpha')
    axs[1].set_xlabel('alpha')
    axs[2].set_xlabel('sapolar')

    fig.canvas.mpl_connect('button_press_event', update_axis_cut)
    fig.canvas.mpl_connect('key_press_event', recontrast_axis_cut)
    plt.show()
    
def Fermi_plot():
    global data, angles, energies, sapolar
    global file_no
    global sapolar_o, angles_o
    global Gamma_cor_lbl, Fermi_range, Fermi_slider
    
    plt.close()
    
    E_fermi = Fermi_slider.get()
#     energies is in eV, intergrate over a 5meV window
    dE = 10e-3
    mask = (E_fermi-dE<energies)&(energies<E_fermi+dE)
    Fermi_data = np.mean(data[:,:,mask], axis=2)
    
    Gamma_cor = float(Gamma_cor_lbl.get())
    alpha, theta = np.meshgrid(angles-angles_o, sapolar-sapolar_o+Gamma_cor)

    f = np.sqrt(2*const.m_e*const.e)/const.hbar *1e-10 #convert to inverse angstrom
    d2r = np.pi/180 #to convert from degrees to radians
    Fermi_kx = np.sin(theta*d2r) * np.cos(alpha*d2r) * np.sqrt(E_fermi) * f
    kx_o = np.sin((Gamma_cor)*d2r) * np.sqrt(E_fermi) * f
    Fermi_ky = np.sin(alpha*d2r) * np.sqrt(E_fermi) * f
    
    fig, ax = plt.subplots()
    ax.contourf(Fermi_kx, Fermi_ky, Fermi_data, 300, cmap='Greys')
    ax.set_xlabel(r'$k_x$'+' '+r'$(\AA^{-1})$', fontsize=14)
    ax.set_ylabel(r'$k_y$'+' '+r'$(\AA^{-1})$', fontsize=14)
    k_lim = float(Fermi_range.get())
    ax.set_xlim(kx_o-k_lim, kx_o+k_lim)
    ax.set_ylim(-k_lim, k_lim)
    ax.set_aspect('equal', adjustable='box')
    xbot, xtop = ax.get_xlim()
    ybot, ytop = ax.get_ylim()
    x = xbot + 0.025*(xtop-xbot)
    y = ybot + 1.025*(ytop-ybot)
    ax.text(x,y,file_no)
    plt.tight_layout()
    plt.show()

In [17]:
def alignment_plot():
    global var1, var2
    global Var1, Var2
    
    plt.close()
    
    Var1 = var1.get()
    Var2 = var2.get()
    if Var2 == 'None':
        alignment_plot_1D()
    else:
        alignment_plot_2D()
    
def alignment_plot_1D():
    global Var1, red_dict
    global data
    
    axes = []
    for i in range(data.ndim):
        if data.shape[i] != len(red_dict[Var1]):
            axes.append(i)
    axes = tuple(axes)
    red_data = np.sum(data, axis=axes)
    
    # What if data is 4D and we've done a sax and say scan and both dimensions are equal
    if data.ndim == 2:
        # Want to find out which label is before which in the keys
        keys = [*red_dict.keys()]
        index1 = keys.index(Var1)
        if index1 == 0:
            red_data = np.sum(data, axis=1)
        else:
            red_data = np.sum(data, axis=0)

    fig, ax = plt.subplots()
    ax.plot(red_dict[Var1], red_data)
    ax.set_xlabel(Var1)
    plt.show()
    
def alignment_plot_2D():
    global Var1, Var2, red_dict
    global data

    lens = (len(red_dict[Var1]), len(red_dict[Var2]))
    axes = []
    for i in range(data.ndim):
        if data.shape[i] not in lens:
            axes.append(i)
    axes = tuple(axes)

    red_data = np.sum(data, axis=axes)
    
    keys = [*red_dict.keys()]
    index1 = keys.index(Var1)
    if index1 == 0:
        red_data = red_data.T

    fig, ax = plt.subplots()
    ax.contourf(red_dict[Var1], red_dict[Var2], red_data, 100)
    ax.set_xlabel(Var1)
    ax.set_ylabel(Var2)
    plt.show()

In [18]:
def align_sweep_update(event):
    global red_data, angles, energies, integrated
    global axs, X, index
    
    axs[0].clear()
    axs[1].clear()
    
    if event.key == 'left':
        index -= 1
    if event.key == 'right':
        index += 1
    
    mask = (-20<angles)*(angles<20)
    Z = red_data[index][mask]

    axs[0].contourf(X, angles[mask], integrated[:,mask].T, 25, cmap='Greys')
    axs[0].axvline(X[index], color='tab:red', ls='--')
    
    axs[1].contourf(angles[mask], energies, red_data[index][mask].T, 25, cmap='Greys')
    axs[1].set_title(X[index])
    plt.tight_layout()
    plt.pause(0.01)
    
def alignment_sweep():
    global var1, red_dict, align_room
    global data, angles, energies, red_data, integrated
    global align_slider, axs, X, dX, index
    
    plt.close()
    
    Var1 = var1.get()
    
    lens = (len(red_dict[Var1]), len(angles), len(energies))
    axes = []
    for i in range(data.ndim):
        if data.shape[i] not in lens:
            axes.append(i)
    axes = tuple(axes)
    red_data = np.sum(data, axis=axes)
    
    X = red_dict[Var1]    
    mask = (-20<angles)*(angles<20)

    # Define window that we want to integrate between
    fig, ax = plt.subplots()
    full_sum = np.sum(red_data, axis=0)
    ready = False
    while ready == False:
        ax.clear()
        ax.contourf(angles[mask], energies, full_sum[mask].T, 50, cmap='Greys')
        plt.title('When ready Press Space')
        plt.show()
        plt.pause(0.01)
        ready = plt.waitforbuttonpress()
    an_bot, an_top = ax.get_xlim()
    en_bot, en_top = ax.get_ylim()
    plt.close()

    # Now integrate and plot the parameter vs k
    mask = (en_bot<energies)&(energies<en_top)
    integrated = np.sum(red_data[:,:,mask], axis=2)

    mask = (-20<angles)*(angles<20)
    index = int(len(X)/2)

    fig, axs = plt.subplots(1,2, figsize=(9, 5))
    axs[0].contourf(X, angles[mask], integrated[:,mask].T, 25, cmap='Greys')
    axs[0].axvline(X[index], color='tab:red', ls='--')
    axs[1].contourf(angles[mask], energies, red_data[index][mask].T, 25, cmap='Greys')
    axs[1].set_title(X[index])
    fig.canvas.mpl_connect('key_press_event', align_sweep_update)
    plt.tight_layout()
    plt.show() 

In [19]:
'''This is the final function to run'''

root = Tk()
root.title('ARchPES')
# root.geometry('510x700')

sample = 'None'

open_btn = Button(root, text='Open File', command=open_file).grid(row=0,column=0)
    
root.mainloop()

Exception in Tkinter callback
Traceback (most recent call last):
  File "C:\Users\amorf\anaconda3\lib\tkinter\__init__.py", line 1892, in __call__
    return self.func(*args)
  File "C:\Users\amorf\AppData\Local\Temp\ipykernel_17132\525349747.py", line 69, in Gold_plot
    ax.contourf(angles, energies, data_2D.T, 50, cmap='Greys')
  File "C:\Users\amorf\anaconda3\lib\site-packages\matplotlib\__init__.py", line 1412, in inner
    return func(ax, *map(sanitize_sequence, args), **kwargs)
  File "C:\Users\amorf\anaconda3\lib\site-packages\matplotlib\axes\_axes.py", line 6313, in contourf
    contours = mcontour.QuadContourSet(self, *args, **kwargs)
  File "C:\Users\amorf\anaconda3\lib\site-packages\matplotlib\contour.py", line 812, in __init__
    kwargs = self._process_args(*args, **kwargs)
  File "C:\Users\amorf\anaconda3\lib\site-packages\matplotlib\contour.py", line 1446, in _process_args
    x, y, z = self._contour_args(args, kwargs)
  File "C:\Users\amorf\anaconda3\lib\site-packages\