diff --git a/.cspell/custom-dictionary.txt b/.cspell/custom-dictionary.txt index e69de29..944bb8b 100644 --- a/.cspell/custom-dictionary.txt +++ b/.cspell/custom-dictionary.txt @@ -0,0 +1,14 @@ +ARPES +cmap +codemirror +ipython +kernelspec +matplotlib +mpes +nbconvert +nbformat +numpy +nxarray +pygments +pyplot +venv diff --git a/.gitignore b/.gitignore index 9b30dad..4e9e2d1 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -poetry.toml \ No newline at end of file +poetry.toml +*.nxs \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cb04945..1404c97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "numpy>=1.26.1,<2.0", "PyQt5>=5.0.0", "xarray>=0.20.2", + "nxarray>=0.4.4", + "superqt >=0.3.0", ] [project.urls] diff --git a/src/mpes_tools/Gui_3d.py b/src/mpes_tools/Gui_3d.py index 1c8e478..bec027b 100644 --- a/src/mpes_tools/Gui_3d.py +++ b/src/mpes_tools/Gui_3d.py @@ -1,50 +1,95 @@ -from PyQt5.QtWidgets import QMainWindow, QVBoxLayout, QWidget, QCheckBox, QAction, QSlider, QHBoxLayout, QLabel +import sys +from PyQt5.QtWidgets import QApplication,QMainWindow, QVBoxLayout, QWidget, QCheckBox, QAction, QSlider, QHBoxLayout, QLabel,QLineEdit,QPushButton from PyQt5.QtCore import Qt from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas import matplotlib.pyplot as plt import numpy as np from matplotlib.patches import Circle from matplotlib.lines import Line2D - -from mpes_tools.fit_panel import MainWindow - +import json +import pickle +from mpes_tools.fit_panel import fit_panel +from mpes_tools.fit_panel_single import fit_panel_single +from IPython.core.getipython import get_ipython import xarray as xr - -# %matplotlib qt - -class GraphWindow(QMainWindow): - def __init__(self,data_array: xr.DataArray,t,dt): - global t_final +from mpes_tools.right_click_handler import RightClickHandler +from PyQt5.QtWidgets import QMenu,QGridLayout,QHBoxLayout, QSizePolicy,QLabel +from PyQt5.QtGui import QCursor +from mpes_tools.cursor_handler import Cursor_handler +from mpes_tools.dot_handler import Dot_handler +from mpes_tools.colorscale_slider_handler import colorscale_slider +from matplotlib.figure import Figure +#graphic window showing a 2d map controllable with sliders for the third dimension, with cursors showing cuts along the x direction for MDC and y direction for EDC +# Two vertical cursors and two horizontal cursors are defined in the main graph with each same color for the cursors being horizontal and vertical intercept each other in a dot so one can move either each cursor or the dot itself which will move both cursors. +class Gui_3d(QMainWindow): + def __init__(self,data_array: xr.DataArray,t=None,dt=None): super().__init__() self.setWindowTitle("Graph Window") - self.setGeometry(100, 100, 800, 600) + self.setGeometry(100, 100, 1200, 1000) - # Create a central widget for the graph - central_widget = QWidget() - self.setCentralWidget(central_widget) + # Create a main widget for the graph + main_widget = QWidget() + self.setCentralWidget(main_widget) layout = QVBoxLayout() - central_widget.setLayout(layout) - - self.fig, self.axs = plt.subplots(2,2,figsize=(20,16)) - self.canvas = FigureCanvas(self.fig) - + main_widget.setLayout(layout) + + self.click_handlers = [] + self.handler_list = [] + # plt.ioff() + # add the checkboxes for EDC and MDC integration and the button to save the results self.checkbox_e = QCheckBox("Integrate_energy") self.checkbox_e.stateChanged.connect(self.checkbox_e_changed) self.checkbox_k = QCheckBox("Integrate_k") self.checkbox_k.stateChanged.connect(self.checkbox_k_changed) + + + #create the layout + h_layout = QHBoxLayout() + self.cursor_label=[] + self.cursor_inputs = [] + cursors_names=['yellow_vertical', 'yellow_horizontal','green_vertical', 'green_horizontal'] + for i in range(4): + sub_layout = QVBoxLayout() + # label = QLabel(f"Cursor {i+1}:") + label=QLabel(cursors_names[i]) + input_field = QLineEdit() + input_field.setPlaceholderText("Value") + input_field.setFixedWidth(80) + input_field.editingFinished.connect(lambda i=i: self.main_graph_cursor_changed(i)) + self.cursor_inputs.append(input_field) + self.cursor_label.append(label) + sub_layout.addWidget(label) + sub_layout.addWidget(input_field) + h_layout.addLayout(sub_layout) - self.checkbox_cursors = QCheckBox("energy_cursors") - self.checkbox_cursors.stateChanged.connect(self.checkbox_cursors_changed) + self.canvases = [] + self.axes = [] + + for i in range(4): + fig = Figure(figsize=(10, 8)) # optional: smaller size per plot + plt.close(fig) + canvas = FigureCanvas(fig) + ax = fig.add_subplot(111) + self.canvases.append(canvas) + self.axes.append(ax) + + canvas_layout = QGridLayout() + + canvas_layout.addWidget(self.canvases[0], 0, 0) + canvas_layout.addWidget(self.canvases[1], 0, 1) + canvas_layout.addWidget(self.canvases[2], 1, 0) + canvas_layout.addWidget(self.canvases[3], 1, 1) + checkbox_layout= QHBoxLayout() # Add the canvas to the layout checkbox_layout.addWidget(self.checkbox_e) checkbox_layout.addWidget(self.checkbox_k) layout.addLayout(checkbox_layout) - layout.addWidget(self.canvas) - layout.addWidget(self.checkbox_cursors) + layout.addLayout(h_layout) + layout.addLayout(canvas_layout) slider_layout= QHBoxLayout() self.slider1 = QSlider(Qt.Horizontal) @@ -57,414 +102,434 @@ def __init__(self,data_array: xr.DataArray,t,dt): self.slider2.setValue(0) self.slider2_label = QLabel("0") - self.slider1.setFixedSize(200, 12) # Change the width and height as needed - self.slider2.setFixedSize(200, 12) # Change the width and height as needed + # self.slider1.setFixedSize(200, 12) # Change the width and height as needed + # self.slider2.setFixedSize(200, 12) # Change the width and height as needed slider_layout.addWidget(self.slider1) slider_layout.addWidget(self.slider1_label) slider_layout.addWidget(self.slider2) slider_layout.addWidget(self.slider2_label) layout.addLayout(slider_layout) - # Create a layout for the central widget + + + for idx, ax in enumerate(self.axes): + handler = RightClickHandler(self.canvases[idx], ax,self.show_pupup_window) + self.canvases[idx].mpl_connect("button_press_event", handler.on_right_click) + self.handler_list.append(handler) + + + #define the data_array + self.data=data_array + self.axis=[data_array.coords[dim].data for dim in data_array.dims] + + + if t is not None and dt is not None: + self.t=t + self.dt=dt + else: + self.t=0 + self.dt=0 + # define the cut for the spectra of the main graph + self.data2D_plot=self.data.sel({self.data.dims[2]:slice(self.axis[2][self.t], self.axis[2][self.t + self.dt])}).mean(dim=self.data.dims[2]) + + #Initialize the relevant prameters self.active_cursor = None - self.cursorlinev1=1 - self.cursorlinev2=0 - # self.v1_pixel=None - # self.v2_pixel=None self.Line1=None self.Line2=None - self.square_artists = [] # To store the artists representing the dots - self.square_coords = [(0, 0), (0, 0)] # To store the coordinates of the dots - self.square_count = 0 # To keep track of the number of dots drawn - + self.cursor_vert1 = [] + self.cursor_horiz1 = [] + self.cursor_vert2 =[] + self.cursor_horiz2 = [] + self.integrated_edc=None + self.integrated_mdc=None - self.cid_press2= None - self.line_artists=[] - self.cid_press3 = None - self.cid_press4 = None - self.cid_press = None - - # Create a figure and canvas for the graph - - self.data_o=data_array.data - self.axis=[data_array.coords[dim].data for dim in data_array.dims] - self.dt=dt - self.datae=np.zeros((len(self.axis[0]),len(self.axis[1]))) - # Plot data - self.plot_graph(t,dt) - self.ssshow(t,dt) + # sliders for the delay self.slider1.setRange(0,len(self.axis[2])-1) - self.plot=np.zeros_like(self.data[1,:]) - + self.slider1.setValue(self.t) + self.slider2.setValue(self.dt) + self.slider1_label.setText(self.data.dims[2]+ f": {self.data[self.data.dims[2]][self.t].item():.2f}") + self.slider2_label.setText("Δ"+self.data.dims[2]+f": {self.dt}") + self.slider1.valueChanged.connect(self.slider1_changed) self.slider2.valueChanged.connect(self.slider2_changed) - t_final=self.axis[2].shape[0] - - fit_panel_action = QAction('Fit_Panel',self) - fit_panel_action.triggered.connect(self.fit_panel) + #create a menu for the fit panel menu_bar = self.menuBar() - - # Create a 'Graph' menu + fit_menu = menu_bar.addMenu("Fit Panel") - graph_menu1 = menu_bar.addMenu("Fit Panel") + energy_panel_action = QAction('EDC',self) + energy_panel_action.triggered.connect(self.fit_energy_panel) + fit_menu.addAction(energy_panel_action) - graph_menu1.addAction(fit_panel_action) - - # Add the actions to the menu + momentum_panel_action = QAction('MDC',self) + momentum_panel_action.triggered.connect(self.fit_momentum_panel) + fit_menu.addAction(momentum_panel_action) + + box_panel_action = QAction('box',self) + box_panel_action.triggered.connect(self.fit_box_panel) + fit_menu.addAction(box_panel_action) self.graph_windows=[] - self.t=t - def slider1_changed(self,value): - self.slider1_label.setText(str(value)) - self.plot_graph(self.slider1.value(),self.slider2.value()) - # print(self.slider1.value(),self.slider2.value()) - self.update_show(self.slider1.value(),self.slider2.value()) - self.t=self.slider1.value() - # self.us() - # update_show(self.slider1.value(),self.slider2.value()) - def slider2_changed(self,value): - self.slider2_label.setText(str(value)) - self.plot_graph(self.slider1.value(),self.slider2.value()) - self.update_show(self.slider1.value(),self.slider2.value()) - self.dt=self.slider2.value() - # self.ssshow(self.slider1.value(),self.slider2.value()).update_show() - # self.us() - # update_show(self.slider1.value(),self.slider2.value()) - def checkbox_e_changed(self, state): - if state == Qt.Checked: - # print("Checkbox is checked") - self.integrate_E() - else: - # print("Checkbox is unchecked") - self.update_show(self.slider1.value(),self.slider2.value()) - def checkbox_k_changed(self, state): - if state == Qt.Checked: - # print("Checkbox is checked") - self.integrate_k() - else: - # print("Checkbox is unchecked") - self.update_show(self.slider1.value(),self.slider2.value()) - def checkbox_cursors_changed(self, state): - if state == Qt.Checked: - self.put_cursors() - # self.integrate_k() - else: - # print("Checkbox is unchecked") - self.remove_cursors() - def plot_graph(self,t,dt): - # Plot on the graph - x = [1, 2, 3, 4, 5] - y = [2, 3, 5, 7, 11] - self.data=np.zeros((len(self.axis[0]),len(self.axis[1]))) - # self.ax.plot(x, y) - for i in range (t,t+dt+1): - self.data+= self.data_o[:,:,i] + # plot the main graph + self.im = self.data2D_plot.plot(ax=self.axes[0], cmap='terrain', add_colorbar=False) + self.axes[0].figure.colorbar(self.im, ax=self.axes[0]) + self.colorscale_2dplot=colorscale_slider(canvas_layout, self.im, self.axes[0].figure.canvas) - self.axs[0,0].imshow(self.data, extent=[self.axis[1][0], self.axis[1][-1], self.axis[0][0], self.axis[0][-1]], origin='lower', cmap='viridis',aspect='auto') - self.axs[0,0].set_title('Sample Graph') - self.axs[0,0].set_xlabel('X') - self.axs[0,0].set_ylabel('Y') - self.fig.tight_layout() - self.canvas.draw() - - def fit_panel(self,event): - print('forfit',len(self.plot),'axis',len(self.axis)) - graph_window= MainWindow( self.data_o, self.axis,self.square_coords[0][1], self.square_coords[1][1],self.t,self.dt) - graph_window.show() - self.graph_windows.append(graph_window) + # define the initial positions of the cursors in the main graph - def lz_fit(self, event): - two_lz_fit(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt,self.a).fit() - def fit(self, event): - fit_4d(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt).fit() - def fit_FD(self, event): - fit_FD(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt).fit() - def fit_FD_conv(self, event): - # print('ax0test=',self.ax[0]) - # print('ax1test=',self.ax[1]) + initial_x = self.data[self.data.dims[1]].values[int(len(self.data[self.data.dims[1]])/3)] + initial_y = self.data[self.data.dims[0]].values[int(len(self.data[self.data.dims[0]])/3)] + initial_x2 = self.data[self.data.dims[1]].values[int(2*len(self.data[self.data.dims[1]])/3)] + initial_y2 = self.data[self.data.dims[0]].values[int(2*len(self.data[self.data.dims[0]])/3)] + ax = self.axes[0] + # define the lines for the cursors + ymin, ymax = self.axes[0].get_ylim() + xmin, xmax = self.axes[0].get_ylim() + ymin, ymax = 5 * ymin, 5 * ymax + xmin, xmax = 5 * xmin, 5 * xmax + self.cursor_vert1 = Line2D([initial_x, initial_x], [ymin, ymax], color='yellow', linewidth=2, picker=10, linestyle='--') + self.cursor_horiz1 = Line2D([xmin, xmax], [initial_y, initial_y], color='yellow', linewidth=2, picker=10, linestyle='--') + self.cursor_vert2 = Line2D([initial_x2, initial_x2], [ymin, ymax], color='green', linewidth=2, picker=10, linestyle='--') + self.cursor_horiz2 = Line2D([xmin, xmax], [initial_y2, initial_y2], color='green', linewidth=2, picker=10, linestyle='--') - fit_FD_lor_conv(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt).fit() - def fit_FD_conv_2(self, event): + # show the initial values of the cursors + base = self.cursor_label[0].text().split(':')[0] + self.cursor_label[0].setText(f"{base}: {initial_x:.2f}") + base = self.cursor_label[1].text().split(':')[0] + self.cursor_label[1].setText(f"{base}: {initial_x:.2f}") + base = self.cursor_label[2].text().split(':')[0] + self.cursor_label[2].setText(f"{base}: {initial_x2:.2f}") + base = self.cursor_label[3].text().split(':')[0] + self.cursor_label[3].setText(f"{base}: {initial_y2:.2f}") - f=fit_FD_conv(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt) - f.show() - def ssshow(self,t,dt): - def test(self): - print('whatever test') - print('show is running') - c= self.data.shape[1]// 10 ** (len(str(self.data.shape[1])) - 1) + # define the dots that connect the cursors + self.dot1 = Circle((initial_x, initial_y), radius=0.05, color='yellow', picker=10) + self.dot2 = Circle((initial_x2, initial_y2), radius=0.05, color='green', picker=10) - def put_cursors(): - self.Line1=axe.axvline(x=self.cursorlinev1, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - self.Line2=axe.axvline(x=self.cursorlinev2, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - plt.draw() - self.fig.canvas.draw() - def remove_cursors(): - self.Line1.remove() - self.Line2.remove() - plt.draw() - self.fig.canvas.draw() - - - def integrate_E(): - self.plote=np.zeros_like(self.data[1,:]) - self.axs[1,0].clear() - plt.draw() - x_min = int(min(self.square_coords[1][1], self.square_coords[0][1])) - x_max = int(max(self.square_coords[1][1], self.square_coords[0][1])) + 1 - for i in range(x_min, x_max): - self.plote += self.data[i, :] - # if self.square_coords[1][1]self.square_coords[0][1]: - # for i in range(self.square_coords[0][1],self.square_coords[1][1]+1): - # self.plot+=self.data[i,:] - # else: - # self.plot+=self.data[self.square_coords[0][1],:] - - self.axs[1, 0].plot(self.axis[1][:],self.plote/abs(self.square_coords[0][1]-self.square_coords[1][1]),color='red') - - # save_data(self.axis[1], plot/abs(self.square_coords[0][1]-self.square_coords[1][1]),"EDC_time="+str(slider_t.val)+"_", [0.42,0.46],self.fig) - def integrate_k(): - self.plotk=np.zeros_like(self.data[:,1]) - self.axs[0,1].clear() - plt.draw() - x_min = int(min(self.square_coords[0][0], self.square_coords[1][0])) - x_max = int(max(self.square_coords[0][0], self.square_coords[1][0])) + 1 - for i in range(x_min, x_max): - self.plotk += self.data[:, i] - # if self.square_coords[0][0]0: - self.axs[2]=self.axs[2][:-self.dt] + self.axs=self.axs[:-self.dt] for pname, par in self.params.items(): self.fit_results.append(getattr(self, pname)[:-self.dt]) else: for pname, par in self.params.items(): self.fit_results.append(getattr(self, pname)) - print('fit_results',len(self.fit_results)) - print('thelengthis=',self.fit_results[0].shape) - sg=showgraphs(self.axs[2], self.fit_results) + # sg=showgraphs(self.axs[min_val:max_val-self.dt], self.fit_results) + sg=showgraphs(self.data[self.data.dims[1]][min_val:max_val-self.dt], self.fit_results) + sg.show() + self.graph_windows.append(sg) + + def fit_all(self): + # C=False + list_plot_fits=[] + + fixed_list=[] + names=[] + self.fit_results=[] + self.fit_results_err=[] + def zero(x): + return 0 + cursors= self.cursor_handler.cursors() + + self.mod= Model(zero) + j=0 + for f in self.function_list: + self.mod+=Model(f,prefix='f'+str(j)+'_') + j+=1 + if self.FD_state == True: + self.mod= self.mod* Model(self.fermi_dirac) + if self.CV_state == True: + self.mod = CompositeModel(self.mod, Model(self.centered_kernel), self.convolve) + if self.offset_state==True: + self.mod= self.mod+Model(self.offset_function) + m1=make_model(self.mod, self.table_widget) + self.mod=m1.current_model() + self.params=m1.current_params() + + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + if self.offset_state==True: + self.params['offset'].set(value=self.y_f.data.min()) + list_axis=[[self.y[self.dim]],[self.x_f]] + # print('the items',self.params.items()) + for pname, par in self.params.items(): + if not par.vary: # Check if vary is False + # print(f"Parameter '{pname}' is fixed at {par.value}") + fixed_list.append(pname) + # print('the paramsnames or',pname, par) + setattr(self, pname, np.zeros((len(self.axs)))) + + if self.t0_state==False: + for i in range(len(self.axs)-self.dt): + self.y=self.data.isel({self.data.dims[1]:slice(i, i+self.dt+1)}).sum(dim=self.data.dims[1]) + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + self.axis.clear() + out = self.mod.fit(self.y_f, self.params, x=self.x_f) + self.y.plot(ax=self.axis) + self.axis.plot(self.x_f,out.best_fit,color='red',label='fit') + list_plot_fits.append([[self.y],[out.best_fit]]) + for pname, par in self.params.items(): + array=getattr(self, pname) + array[i]=out.best_values[pname] + setattr(self, pname,array) + + err_array = getattr(self, f"{pname}_err",np.zeros_like(array)) + stderr = out.params[pname].stderr + err_array[i] = stderr + setattr(self, f"{pname}_err", err_array) + + else: + if self.mid_value_input.text() is not None: + mid_val = int(self.mid_value_input.text()) + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + + for i in range(0,mid_val-self.dt): + self.y=self.data.isel({self.data.dims[1]:slice(i, i+self.dt+1)}).sum(dim=self.data.dims[1]) + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + self.axis.clear() + out = self.mod.fit(self.y_f, self.params, x=self.x_f) + self.y.plot(ax=self.axis) + self.axis.plot(self.x_f,out.best_fit,color='red',label='fit') + list_plot_fits.append([[self.y],[out.best_fit]]) + for pname, par in self.params.items(): + array=getattr(self, pname) + array[i]=out.best_values[pname] + setattr(self, pname,array) + + err_array = getattr(self, f"{pname}_err",np.zeros_like(array)) + stderr = out.params[pname].stderr + err_array[i] = stderr + setattr(self, f"{pname}_err", err_array) + sigma_mean= getattr(self, 'sigma')[0:mid_val-self.dt].mean() + self.params['sigma'].set(value=sigma_mean, vary=False ) + # print(sigma_mean) + for p in fixed_list: + self.params[p].vary=True + # print(p) + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + + + for i in range(mid_val-self.dt,len(self.axs)-self.dt): + self.y=self.data.isel({self.data.dims[1]:slice(i, i+self.dt+1)}).sum(dim=self.data.dims[1]) + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + self.axis.clear() + out = self.mod.fit(self.y_f, self.params, x=self.x_f) + self.y.plot(ax=self.axis) + self.axis.plot(self.x_f,out.best_fit,color='red',label='fit') + list_plot_fits.append([[self.y],[out.best_fit]]) + for pname, par in self.params.items(): + array=getattr(self, pname) + array[i]=out.best_values[pname] + setattr(self, pname,array) + + err_array = getattr(self, f"{pname}_err",np.zeros_like(array)) + stderr = out.params[pname].stderr + err_array[i] = stderr + setattr(self, f"{pname}_err", err_array) + # print('second T',getattr(self, 'T')) + if self.dt>0: + # self.axs=self.axs[:-self.dt] + for pname, par in self.params.items(): + self.fit_results.append(getattr(self, pname)[:-self.dt]) + self.fit_results_err.append(getattr(self, f"{pname}_err")[:-self.dt]) + names.append(pname) + else: + for pname, par in self.params.items(): + self.fit_results.append(getattr(self, pname)) + self.fit_results_err.append(getattr(self, f"{pname}_err")) + names.append(pname) + sg=showgraphs(self.data[self.data.dims[1]][:len(self.data[self.data.dims[1]])-self.dt], self.fit_results,self.fit_results_err,names,list_axis,list_plot_fits) sg.show() self.graph_windows.append(sg) - # pname='T' - # print(getattr(self, pname)) - # out.best_values['A1'] - # self.axis.clear() + self.cursor_handler.redraw() if __name__ == "__main__": app = QApplication(sys.argv) - window = MainWindow() + window = fit_panel() window.show() sys.exit(app.exec_()) diff --git a/tests/fit_panel6.py b/src/mpes_tools/fit_panel_single.py similarity index 58% rename from tests/fit_panel6.py rename to src/mpes_tools/fit_panel_single.py index cbf2070..0041491 100644 --- a/tests/fit_panel6.py +++ b/src/mpes_tools/fit_panel_single.py @@ -1,701 +1,570 @@ -import sys -from PyQt5.QtGui import QBrush, QColor -from PyQt5.QtWidgets import QTextEdit, QApplication, QMainWindow, QVBoxLayout, QHBoxLayout, QWidget, QSlider, QLabel, QAction, QCheckBox, QPushButton, QListWidget, QTableWidget, QTableWidgetItem, QTableWidget, QCheckBox, QSplitter -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import QTableWidgetItem, QHBoxLayout, QCheckBox, QWidget -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -import matplotlib.pyplot as plt -from scipy.optimize import curve_fit -import numpy as np -from lmfit.models import ExpressionModel,Model -from lmfit import CompositeModel, Model -from lmfit.lineshapes import gaussian, step -import inspect -from numpy import loadtxt -from movable_vertical_cursors_graph import MovableCursors -from make_model import make_model -from graphs2 import showgraphs - - - -class MainWindow(QMainWindow): - def __init__(self,data,axis,c1,c2,t,dt): - super().__init__() - - self.setWindowTitle("Main Window") - self.setGeometry(100, 100, 1500, 800) - - # Create a menu bar - menu_bar = self.menuBar() - - # Create a 'View' menu - view_menu = menu_bar.addMenu("View") - - # Create actions for showing and hiding the graph window - show_graph_action = QAction("Show Graph", self) - show_graph_action.triggered.connect(self.show_graph_window) - view_menu.addAction(show_graph_action) - - # Store references to graph windows to prevent garbage collection - self.graph_windows = [] - - # Create a central widget - central_widget = QWidget() - self.setCentralWidget(central_widget) - - # Create a layout for the central widget - layout = QHBoxLayout() - central_widget.setLayout(layout) - - # Create a splitter for two panels - splitter = QSplitter(Qt.Horizontal) - - # Create a left panel widget and its layout - left_panel = QWidget() - left_layout = QVBoxLayout() - left_panel.setLayout(left_layout) - - # Create a right panel widget and its layout - right_panel = QWidget() - right_layout = QVBoxLayout() - right_panel.setLayout(right_layout) - - # Add the panels to the splitter - splitter.addWidget(left_panel) - splitter.addWidget(right_panel) - - self.figure, self.axis = plt.subplots() - self.canvas = FigureCanvas(self.figure) - # Create two checkboxes - self.checkbox0 = QCheckBox("Cursors") - self.checkbox0.stateChanged.connect(self.checkbox0_changed) - - self.slider = QSlider(Qt.Horizontal) - self.slider.setMinimum(0) - self.slider.setMaximum(len(axis[2])-1) - self.slider.setValue(t) - self.slider.valueChanged.connect(self.update_label) - self.slider2 = QSlider(Qt.Horizontal) - self.slider2.setMinimum(0) - self.slider2.setMaximum(10) - self.slider2.setValue(dt) - self.slider2.valueChanged.connect(self.update_label2) - - self.label = QLabel("Slider Value: {t}") - self.label2 = QLabel("Slider Value: {dt}") - - # Create two checkboxes - self.checkbox1 = QCheckBox("Multiply with Fermi Dirac") - self.checkbox1.stateChanged.connect(self.checkbox1_changed) - - self.checkbox2 = QCheckBox("Convolve with a Gaussian") - self.checkbox2.stateChanged.connect(self.checkbox2_changed) - - # Create a QListWidget - self.list_widget = QListWidget() - self.list_widget.addItems(["linear","Lorentz", "Gauss", "sinusoid","constant","jump"]) - self.list_widget.setMaximumSize(120,150) - self.list_widget.itemClicked.connect(self.item_selected) - - self.add_button = QPushButton("add") - self.add_button.clicked.connect(self.button_add_clicked) - - self.remove_button = QPushButton("remove") - self.remove_button.clicked.connect(self.button_remove_clicked) - - - self.graph_button = QPushButton("clear graph") - self.graph_button.clicked.connect(self.show_graph_window) - - self.fit_button = QPushButton("Fit") - self.fit_button.clicked.connect(self.fit) - - self.fitall_button = QPushButton("Fit all") - self.fitall_button.clicked.connect(self.fit_all) - - left_buttons=QVBoxLayout() - left_sublayout=QHBoxLayout() - - left_buttons.addWidget(self.add_button) - left_buttons.addWidget(self.remove_button) - left_buttons.addWidget(self.graph_button) - left_buttons.addWidget(self.fit_button) - left_buttons.addWidget(self.fitall_button) - - left_sublayout.addWidget(self.list_widget) - left_sublayout.addLayout(left_buttons) - - # Add widgets to the left layout - left_layout.addWidget(self.canvas) - left_layout.addWidget(self.checkbox0) - left_layout.addWidget(self.slider) - left_layout.addWidget(self.label) - left_layout.addWidget(self.slider2) - left_layout.addWidget(self.label2) - left_layout.addLayout(left_sublayout) - - # left_layout.addWidget(self.list_widget) - # left_layout.addWidget(self.add_button) - # left_layout.addWidget(self.remove_button) - # left_layout.addWidget(self.graph_button) - # left_layout.addWidget(self.fit_button) - # left_layout.addWidget(self.fitall_button) - - self.text_equation = QTextEdit() - # self.text_equation.setMinimumSize(50, 50) # Set minimum size - self.text_equation.setMaximumSize(500, 30) # Set maximum size - - # Create a table widget for the right panel - self.table_widget = QTableWidget(0, 4) # 6 rows and 4 columns (including the special row) - self.table_widget.setHorizontalHeaderLabels(['min', 'value', 'max', 'fix']) - # self.table_widget.setVerticalHeaderLabels(['Row 1', 'The ROW', 'Row 2', 'Row 3', 'Row 4', 'Row 5']) - self.table_widget.itemChanged.connect(self.table_item_changed) - self.table_widget.setMaximumSize(700,500) - # Add checkboxes to the last column of the table, except for the special row - for row in range(6): - if row != 1: # Skip 'The ROW' - checkbox_widget = QWidget() - checkbox_layout = QHBoxLayout() - checkbox_layout.setAlignment(Qt.AlignCenter) - checkbox = QCheckBox() - checkbox_layout.addWidget(checkbox) - checkbox_widget.setLayout(checkbox_layout) - self.table_widget.setCellWidget(row, 3, checkbox_widget) - - # Set 'The ROW' with uneditable empty cells - for col in range(4): - # if col == 3: # Skip the checkbox column for 'The ROW' - # continue - item = QTableWidgetItem('') - item.setFlags(Qt.ItemIsEnabled) # Make cell uneditable - self.table_widget.setItem(1, col, item) - - # Add the table to the right layout - checkboxes=QVBoxLayout() - top_lay = QHBoxLayout() - checkboxes.addWidget(self.checkbox1) - checkboxes.addWidget(self.checkbox2) - top_lay.addWidget(self.text_equation) - top_lay.addLayout(checkboxes) - right_layout.addLayout(top_lay) - right_layout.addWidget(self.table_widget) - - # Add the splitter to the main layout - layout.addWidget(splitter) - def zero(x): - return 0 - self.equation= None - self.mod= Model(zero) - self.total_function=zero - self.function_before_Fermi= zero - self.function_before_convoluted= zero - self.update_text_edit_boxes() - self.i=0 - - self.function_list=[] - self.function_names_list=[] - # Add a button to activate cursors - # self.add_cursors_button = QPushButton("Add Movable Cursors", self) - # self.add_cursors_button.clicked.connect(self.add_movable_cursors) - # right_layout.addWidget(self.add_cursors_button) - - # To hold the MovableCursors instance - self.cursor_handler = None - self.FD_state = False - self.CV_state = False - print('data=',data.shape,'axs',len(axis),'axis1',len(axis[0]),'axis2',len(axis[1]),'axis3',len(axis[2])) - self.axs=axis - self.data_t=np.zeros((data.shape[1],data.shape[2])) - x_min = int(min(c1, c2)) - x_max = int(max(c1, c2)) + 1 - print('xmin=',x_min,'xmax=',x_max) - for i in range(x_min, x_max): - self.data_t += data[i, :,:] - print('data_t',self.data_t.shape) - self.t=t - self.dt=dt - print(t,dt) - self.slider.setValue(self.t) - self.slider2.setValue(self.dt) - self.plot_graph(t,dt) - self.fit_results=[] - - # def add_movable_cursors(self): - # if self.cursor_handler is None: - # # Initialize and add the cursors to the existing plot - # self.cursor_handler = MovableCursors(self.axis) - # self.canvas.draw() - - def plot_graph(self,t,dt): - # Sample data - # self.x = np.linspace(-5,5,100) - # self.y = np.linspace(10,100,100) - self.y=np.zeros((self.data_t.shape[0])) - print('thecomp',self.y.shape,self.data_t[:,1].shape) - # data = loadtxt('C:/Users/admin-nisel131/Documents/CVS_TR_flatband_fig/EDC_time=0_2024-05-08_093401_6994.txt') - self.axis.clear() - # self.x= data[:,0] - # self.y= data[:,1] - self.x=self.axs[1][:] - for i in range(0,dt+1): - self.y +=self.data_t[:,t+i] - - - self.axis.plot(self.x, self.y, 'bo', label='Data') - - self.axis.set_title('Sample Graph') - self.axis.set_xlabel('X') - self.axis.set_ylabel('Y') - self.axis.legend() - self.figure.tight_layout() - self.canvas.draw() - print('sliders=',self.slider.value(),self.slider2.value()) - def update_text_edit_boxes(self): - self.text_equation.setPlaceholderText("Top Right Text Edit Box") - - def constant (self,x,b): - return 0*x+b - def linear (self,x,a,b): - return a*x+b - def lorentzian(self,x, A, x0, gamma): - c=0.0002 - return A / (1 + ((x - x0) / (gamma+c)) ** 2) - # def fermi_dirac(self,x, mu, T,off): - # kb = 8.617333262145 * 10**(-5) # Boltzmann constant in eV/K - # return 1 / (1 + np.exp((x - mu) / (kb * T)))+off - def fermi_dirac(self,x, mu, T): - kb = 8.617333262145 * 10**(-5) # Boltzmann constant in eV/K - return 1 / (1 + np.exp((x - mu) / (kb * T))) - def gaussian(self,x,A, mu, sigma): - return A* np.exp(-(x - mu)**2 / (2 * sigma**2)) - def gaussian_conv(self,x,sigma): - return np.exp(-(x)**2 / (2 * sigma**2)) - def jump(self,x, mid): - """Heaviside step function.""" - o = np.zeros(x.size) - imid = max(np.where(x <= mid)[0]) - o[imid:] = 1.0 - return o - def jump2(self,x, mid,Amp): - """Heaviside step function.""" - o = np.zeros(x.size) - imid = max(np.where(x <= mid)[0]) - o[:imid] = Amp - return o - - - def convolve(self, arr, kernel): - """Simple convolution of two arrays.""" - npts = min(arr.size, kernel.size) - pad = np.ones(npts) - tmp = np.concatenate((pad*arr[0], arr, pad*arr[-1])) - out = np.convolve(tmp, kernel, mode='valid') - noff = int((len(out) - npts) / 2) - return out[noff:noff+npts] - - def convolution(x, func, *args, sigma=1.0): - N = 20 # Assuming N is intended to be a local variable here - x_step = x[1] - x[0] - - # Create the shifted input signal 'y' for convolution - y = np.zeros(N + len(x)) - for i in range(N): - y[i] = x[0] - (N - i) * x_step - y[N:] = x # Append the original signal x to y - - # Create the Gaussian kernel - x_gauss = np.linspace(-0.5, 0.5, len(x)) - gaussian_values = np.exp(-0.5 * (x_gauss / sigma)**2) / (sigma * np.sqrt(2 * np.pi)) - - # Evaluate the function values with parameters - function_values = func(x, *args) - - # Perform convolution - convolution_result = np.convolve(function_values, gaussian_values, mode='same') - - return convolution_result[N-1:-1] - - # def convolution(self,x,sigma,f): - # global N - # xmax=x[-1] - # xmin=x[0] - # N=20 - # y=np.zeros(N+len(x)) - # x_step=x[1]-x[0] - # for i in range(0,N): - # y[i]=x[0]-(N-i)*x_step - # for i in range(0,len(x)): - # y[i+N]=x[i] - # x_gauss = np.linspace(-0.5, 0.5, len(self.ax)) - # gaussian_values = self.gaussian(x_gauss, 0, sigma) - # # function_values = - # convolution = np.convolve( function_values, gaussian_values, mode='same') - # return convolution[N-1:-1] - - def show_graph_window(self): - # Create a new graph window and show it - # graph_window = GraphWindow() - # graph_window.show() - - # # Store a reference to the window to prevent it from being garbage collected - # self.graph_windows.append(graph_window) - self.axis.clear() - self.plot_graph(self.t,self.dt) - - - def update_label(self, value): - self.label.setText(f"Slider Value: {value}") - self.t=self.slider.value() - self.plot_graph(self.t,self.dt) - def update_label2(self, value): - self.label2.setText(f"Slider Value: {value}") - self.dt=self.slider2.value() - self.plot_graph(self.t,self.dt) - - def checkbox0_changed(self, state): - # MovableCursors(self.axis) - if state == Qt.Checked: - if self.cursor_handler is None: - # Initialize and add the cursors to the existing plot - self.cursor_handler = MovableCursors(self.axis) - self.canvas.draw() - else: - self.cursor_handler.redraw() - else: - self.cursor_handler.remove() - # self.cursor_handler= None - - def checkbox1_changed(self, state): - if self.CV_state== True: - pos=2 - else: - pos=0 - if state == Qt.Checked: - self.FD_state = True - self.update_equation() - # pos=0 - - print("Checkbox 1 is checked") - # new_row_name = QTableWidgetItem('Fermi') - self.table_widget.insertRow(pos) - label_item = QTableWidgetItem("Fermi") - # label_item.setTextAlignment(0x0004 | 0x0080) # Align center - self.table_widget.setVerticalHeaderItem(pos, label_item) - # self.table_widget.setVerticalHeaderItem(0, new_row_name) - for col in range(4): - item = QTableWidgetItem('') - item.setFlags(Qt.ItemIsEnabled) # Make cell uneditable - self.table_widget.setItem(pos, col, item) - item.setBackground(QBrush(QColor('grey'))) - c=self.table_widget.rowCount() - self.table_widget.insertRow(pos+1) - label_item1 = QTableWidgetItem("Fermi level") - # label_item1.setTextAlignment(0x0004 | 0x0080) # Align center - checkbox_widget = QWidget() - checkbox_layout = QHBoxLayout() - checkbox_layout.setAlignment(Qt.AlignCenter) - checkbox = QCheckBox() - # checkbox.stateChanged.connect(self.handle_checkbox_state_change) - checkbox.stateChanged.connect(lambda state, row= pos+1: self.handle_checkbox_state_change(state, row)) - print('thecount',c+1) - checkbox_layout.addWidget(checkbox) - checkbox_widget.setLayout(checkbox_layout) - self.table_widget.setCellWidget(pos+1, 3, checkbox_widget) - self.table_widget.setVerticalHeaderItem(pos+1, label_item1) - - self.table_widget.insertRow(pos+2) - label_item2 = QTableWidgetItem("Temperature") - checkbox_widget = QWidget() - checkbox_layout = QHBoxLayout() - checkbox_layout.setAlignment(Qt.AlignCenter) - checkbox = QCheckBox() - # checkbox.stateChanged.connect(self.handle_checkbox_state_change) - checkbox.stateChanged.connect(lambda state, row= pos+2: self.handle_checkbox_state_change(state, row)) - checkbox_layout.addWidget(checkbox) - checkbox_widget.setLayout(checkbox_layout) - self.table_widget.setCellWidget(pos+2, 3, checkbox_widget) - # label_item2.setTextAlignment(0x0004 | 0x0080) # Align center - self.table_widget.setVerticalHeaderItem(pos+2, label_item2) - - - - else: - self.FD_state = False - self.update_equation() - print("Checkbox 1 is unchecked") - - self.table_widget.removeRow(pos) - self.table_widget.removeRow(pos) - self.table_widget.removeRow(pos) - - def checkbox2_changed(self, state): - if state == Qt.Checked: - self.CV_state = True - - self.update_equation() - - - print("Checkbox 1 is checked") - # new_row_name = QTableWidgetItem('Fermi') - self.table_widget.insertRow(0) - label_item = QTableWidgetItem("Convolution") - # label_item.setTextAlignment(0x0004 | 0x0080) # Align center - self.table_widget.setVerticalHeaderItem(0, label_item) - # self.table_widget.setVerticalHeaderItem(0, new_row_name) - for col in range(4): - item = QTableWidgetItem('') - item.setFlags(Qt.ItemIsEnabled) # Make cell uneditable - self.table_widget.setItem(0, col, item) - item.setBackground(QBrush(QColor('grey'))) - - self.table_widget.insertRow(1) - label_item1 = QTableWidgetItem("sigma") - checkbox_widget = QWidget() - checkbox_layout = QHBoxLayout() - checkbox_layout.setAlignment(Qt.AlignCenter) - checkbox = QCheckBox() - # checkbox.stateChanged.connect(self.handle_checkbox_state_change) - checkbox.stateChanged.connect(lambda state, row= 1: self.handle_checkbox_state_change(state, row)) - checkbox_layout.addWidget(checkbox) - checkbox_widget.setLayout(checkbox_layout) - self.table_widget.setCellWidget(1, 3, checkbox_widget) - # label_item1.setTextAlignment(0x0004 | 0x0080) # Align center - self.table_widget.setVerticalHeaderItem(1, label_item1) - - # self.table_widget.insertRow(2) - # label_item2 = QTableWidgetItem("Temperature") - # # label_item2.setTextAlignment(0x0004 | 0x0080) # Align center - # self.table_widget.setVerticalHeaderItem(2, label_item2) - - - - else: - self.CV_state = False - self.update_equation() - print("Checkbox 1 is unchecked") - - self.table_widget.removeRow(0) - self.table_widget.removeRow(0) - # self.table_widget.removeRow(0) - - def item_selected(self, item): - print(f"Selected: {item.text()}") - if item.text() == 'Lorentz': - self.function_selected = self.lorentzian - elif item.text() == 'Gauss': - self.function_selected = self.gaussian - elif item.text()=='linear': - self.function_selected =self.linear - elif item.text()=='constant': - self.function_selected =self.constant - elif item.text()=='jump': - self.function_selected =self.jump2 - # print(self.list_widget.currentItem().text()) - - def button_remove_clicked(self): - if self.i>0: - self.i-=1 - # self.mod= - # print('removal') - current_row_count = self.table_widget.rowCount() - print(current_row_count) - sig = inspect.signature(self.function_list[-1]) - params = sig.parameters - - for p in range(len(params)): - # print('p=',p) - # print('count=',current_row_count-1-p) - self.table_widget.removeRow(current_row_count-1-p) - - self.function_list.remove(self.function_list[-1]) - self.function_names_list.remove(self.function_names_list[-1]) - self.update_equation() - - def button_add_clicked(self): - # print(self.cursor_handler.cursors()) - def zero(x): - return 0 - - - self.i+=1 - self.function_list.append(self.function_selected) - self.function_names_list.append(self.list_widget.currentItem().text()) - - print('the list=',self.function_list,'iten',self.function_list[0]) - print('listlength=',len(self.function_list)) - j=0 - for p in self.function_list: - # j=0 - print('j==',j) - current_function=Model(p,prefix='f'+str(j)+'_') - j+=1 - - - current_row_count = self.table_widget.rowCount() - - self.table_widget.insertRow(current_row_count) - # self.table_widget.setVerticalHeaderLabels([self.list_widget.currentItem().text()]) - new_row_name = QTableWidgetItem(self.list_widget.currentItem().text()) - self.table_widget.setVerticalHeaderItem(current_row_count, new_row_name) - for col in range(4): - item = QTableWidgetItem('') - item.setFlags(Qt.ItemIsEnabled) # Make cell uneditable - self.table_widget.setItem(current_row_count, col, item) - item.setBackground(QBrush(QColor('grey'))) - c=current_row_count - # self.table_widget.insertRow(1) - # self.table_widget.insertRow(2) - for p in range(len(current_function.param_names)): - # c+=1 - # print(c+p+1) - self.table_widget.insertRow(c+p+1) - print(current_function.param_names[p]) - new_row_name = QTableWidgetItem(current_function.param_names[p]) - self.table_widget.setVerticalHeaderItem(c+p+1, new_row_name) - checkbox_widget = QWidget() - checkbox_layout = QHBoxLayout() - checkbox_layout.setAlignment(Qt.AlignCenter) - checkbox = QCheckBox() - # checkbox.stateChanged.connect(self.handle_checkbox_state_change) - checkbox.stateChanged.connect(lambda state, row=c + p + 1: self.handle_checkbox_state_change(state, row)) - checkbox_layout.addWidget(checkbox) - checkbox_widget.setLayout(checkbox_layout) - self.table_widget.setCellWidget(c+p+1, 3, checkbox_widget) - # self.table_widget.setVerticalHeaderLabels([Model(self.function_selected).param_names[p]]) - # print(self.Mod.param_names) - # params['A'].set(value=1.35, vary=True, expr='') - - self.update_equation() - # print(self.params) - - def update_equation(self): - self.equation='' - print('names',self.function_names_list) - for j,n in enumerate(self.function_names_list): - if len(self.function_names_list)==1: - self.equation= n - else: - if j==0: - self.equation= n - else: - self.equation+= '+' + n - if self.FD_state: - self.equation= '('+ self.equation+ ')* Fermi_Dirac' - self.text_equation.setPlainText(self.equation) - print('equation',self.equation) - - - def table_item_changed(self, item): - print(f"Table cell changed at ({item.row()}, {item.column()}): {item.text()}") - header_item = self.table_widget.verticalHeaderItem(item.row()) - # print(header_item.text()) - print('theeeeeeitem=',item.text()) - - def handle_checkbox_state_change(self,state,row): - if state == Qt.Checked: - print("Checkbox is checked") - print(row) - header_item = self.table_widget.verticalHeaderItem(row) - # self.params[header_item.text()].vary = False - - else: - print("Checkbox is unchecked") - header_item = self.table_widget.verticalHeaderItem(row) - # self.params[header_item.text()].vary = True - def fit(self): - - def zero(x): - return 0 - self.mod= Model(zero) - j=0 - for f in self.function_list: - self.mod+=Model(f,prefix='f'+str(j)+'_') - j+=1 - if self.FD_state == True: - self.mod= self.mod* Model(self.fermi_dirac) - if self.CV_state == True: - # self.mod=CompositeModel(self.mod, Model(self.gaussian_conv), self.convolve) - self.mod=CompositeModel(self.mod, Model(self.gaussian_conv), self.convolve) - # self.mod=CompositeModel(Model(self.jump), Model(gaussian), self.convolve) - - m1=make_model(self.mod, self.table_widget) - self.mod=m1.current_model() - # self.mod = CompositeModel(m1.current_model(), Model(gaussian), self.convolve) - self.params=m1.current_params() - # self.params=self.mod.make_params() - cursors= self.cursor_handler.cursors() - self.x_f=self.x[cursors[0]:cursors[1]] - self.y_f=self.y[cursors[0]:cursors[1]] - print(self.params) - # params['b'].set(value=0, vary=True, expr='') - # out = mod.fit(self.data[:,1], params, x=self.data[:,0],method='nelder') - out = self.mod.fit(self.y_f, self.params, x=self.x_f) - # dely = out.eval_uncertainty(sigma=3) - print(out.fit_report(min_correl=0.25)) - self.axis.plot(self.x_f,out.best_fit,color='red',label='fit') - # self.axis.plot(self.x_f,1e5*self.gaussian_conv(self.x_f,out.best_values['sigma'])) - def fit_all(self): - self.fit_results=[] - def zero(x): - return 0 - self.mod= Model(zero) - j=0 - for f in self.function_list: - self.mod+=Model(f,prefix='f'+str(j)+'_') - j+=1 - if self.FD_state == True: - self.mod= self.mod* Model(self.fermi_dirac) - if self.CV_state == True: - # self.mod=CompositeModel(self.mod, Model(self.gaussian_conv), self.convolve) - self.mod=CompositeModel(self.mod, Model(self.gaussian_conv), self.convolve) - m1=make_model(self.mod, self.table_widget) - self.mod=m1.current_model() - self.params=m1.current_params() - for pname, par in self.params.items(): - print('the paramsnames or',pname, par) - setattr(self, pname, np.zeros((len(self.axs[2])))) - # self.pname=np.zeros((len(self.axs[2]))) - cursors= self.cursor_handler.cursors() - for i in range(len(self.axs[2])-self.dt): - self.y=np.zeros((self.data_t.shape[0])) - for j in range (self.dt+1): - self.y+= self.data_t[:,i+j] - self.x_f=self.x[cursors[0]:cursors[1]] - self.y_f=self.y[cursors[0]:cursors[1]] - # print(self.params) - # params['b'].set(value=0, vary=True, expr='') - # out = mod.fit(self.data[:,1], params, x=self.data[:,0],method='nelder') - self.axis.clear() - out = self.mod.fit(self.y_f, self.params, x=self.x_f) - # dely = out.eval_uncertainty(sigma=3) - # print(out.fit_report(min_correl=0.25)) - self.axis.plot(self.x,self.y, 'bo', label='Data') - self.axis.plot(self.x_f,out.best_fit,color='red',label='fit') - for pname, par in self.params.items(): - array=getattr(self, pname) - array[i]=out.best_values[pname] - setattr(self, pname,array) - if self.dt>0: - self.axs[2]=self.axs[2][:-self.dt] - for pname, par in self.params.items(): - self.fit_results.append(getattr(self, pname)[:-self.dt]) - else: - for pname, par in self.params.items(): - self.fit_results.append(getattr(self, pname)) - print('fit_results',len(self.fit_results)) - print('thelengthis=',self.fit_results[0].shape) - - - sg=showgraphs(self.axs[2], self.fit_results) - sg.show() - self.graph_windows.append(sg) - # pname='T' - # print(getattr(self, pname)) - # out.best_values['A1'] - # self.axis.clear() - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = MainWindow() - window.show() - sys.exit(app.exec_()) +import sys +from PyQt5.QtGui import QBrush, QColor +from PyQt5.QtWidgets import QTextEdit, QLineEdit,QApplication, QMainWindow, QVBoxLayout, QHBoxLayout, QWidget, QSlider, QLabel, QAction, QCheckBox, QPushButton, QListWidget, QTableWidget, QTableWidgetItem, QTableWidget, QCheckBox, QSplitter +from PyQt5.QtCore import Qt +from PyQt5.QtWidgets import QTableWidgetItem, QHBoxLayout, QCheckBox, QWidget +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +import matplotlib.pyplot as plt +from scipy.optimize import curve_fit +import numpy as np +from lmfit.models import ExpressionModel,Model +from lmfit import CompositeModel, Model +from lmfit.lineshapes import gaussian, step +import inspect +from mpes_tools.movable_vertical_cursors_graph import MovableCursors +from mpes_tools.make_model import make_model +from mpes_tools.graphs import showgraphs + + + + +class fit_panel_single(QMainWindow): + def __init__(self,data): + super().__init__() + + self.setWindowTitle("Main Window") + self.setGeometry(100, 100, 1500, 800) + + # Create a menu bar + menu_bar = self.menuBar() + + # Create a 'View' menu + view_menu = menu_bar.addMenu("View") + + # Create actions for showing and hiding the graph window + clear_graph_action = QAction("Show Graph", self) + clear_graph_action.triggered.connect(self.clear_graph_window) + view_menu.addAction(clear_graph_action) + + # Store references to graph windows to prevent garbage collection + self.graph_windows = [] + + # Create a central widget + central_widget = QWidget() + self.setCentralWidget(central_widget) + + # Create a layout for the central widget + layout = QHBoxLayout() + central_widget.setLayout(layout) + + # Create a splitter for two panels + splitter = QSplitter(Qt.Horizontal) + + # Create a left panel widget and its layout + left_panel = QWidget() + left_layout = QVBoxLayout() + left_panel.setLayout(left_layout) + + # Create a right panel widget and its layout + right_panel = QWidget() + right_layout = QVBoxLayout() + right_panel.setLayout(right_layout) + + # Add the panels to the splitter + splitter.addWidget(left_panel) + splitter.addWidget(right_panel) + + self.figure, self.axis = plt.subplots() + plt.close(self.figure) + self.canvas = FigureCanvas(self.figure) + # Create two checkboxes + self.checkbox0 = QCheckBox("Cursors") + self.checkbox0.stateChanged.connect(self.checkbox0_changed) + + + # Create two checkboxes + self.checkbox1 = QCheckBox("Multiply with Fermi Dirac") + self.checkbox1.stateChanged.connect(self.checkbox1_changed) + + self.checkbox2 = QCheckBox("Convolve with a Gaussian") + self.checkbox2.stateChanged.connect(self.checkbox2_changed) + + self.checkbox3 = QCheckBox("add background offset") + self.checkbox3.stateChanged.connect(self.checkbox3_changed) + + + self.guess_button = QPushButton("Guess") + self.guess_button.clicked.connect(self.button_guess_clicked) + + bigger_layout = QVBoxLayout() + bigger_layout.addWidget(self.guess_button) + # Create a QListWidget + self.list_widget = QListWidget() + self.list_widget.addItems(["linear","Lorentz", "Gauss", "sinusoid","constant","jump"]) + self.list_widget.setMaximumSize(120,150) + self.list_widget.itemClicked.connect(self.item_selected) + + self.add_button = QPushButton("add") + self.add_button.clicked.connect(self.button_add_clicked) + + self.remove_button = QPushButton("remove") + self.remove_button.clicked.connect(self.button_remove_clicked) + + + self.graph_button = QPushButton("clear graph") + self.graph_button.clicked.connect(self.clear_graph_window) + + self.fit_button = QPushButton("Fit") + self.fit_button.clicked.connect(self.fit) + + + + left_buttons=QVBoxLayout() + left_sublayout=QHBoxLayout() + + left_buttons.addWidget(self.add_button) + left_buttons.addWidget(self.remove_button) + left_buttons.addWidget(self.graph_button) + left_buttons.addWidget(self.fit_button) + + + left_sublayout.addWidget(self.list_widget) + left_sublayout.addLayout(left_buttons) + + # Add widgets to the left layout + left_layout.addWidget(self.canvas) + left_layout.addWidget(self.checkbox0) + left_layout.addLayout(left_sublayout) + + + self.text_equation = QTextEdit() + # self.text_equation.setMinimumSize(50, 50) # Set minimum size + self.text_equation.setMaximumSize(500, 30) # Set maximum size + + # Create a table widget for the right panel + self.table_widget = QTableWidget(0, 4) # 6 rows and 4 columns (including the special row) + self.table_widget.setHorizontalHeaderLabels(['min', 'value', 'max', 'fix']) + # self.table_widget.setVerticalHeaderLabels(['Row 1', 'The ROW', 'Row 2', 'Row 3', 'Row 4', 'Row 5']) + self.table_widget.itemChanged.connect(self.table_item_changed) + self.table_widget.setMaximumSize(700,500) + # Add checkboxes to the last column of the table, except for the special row + for row in range(6): + if row != 1: # Skip 'The ROW' + checkbox_widget = QWidget() + checkbox_layout = QHBoxLayout() + checkbox_layout.setAlignment(Qt.AlignCenter) + checkbox = QCheckBox() + checkbox_layout.addWidget(checkbox) + checkbox_widget.setLayout(checkbox_layout) + self.table_widget.setCellWidget(row, 3, checkbox_widget) + + # Set 'The ROW' with uneditable empty cells + for col in range(4): + # if col == 3: # Skip the checkbox column for 'The ROW' + # continue + item = QTableWidgetItem('') + item.setFlags(Qt.ItemIsEnabled) # Make cell uneditable + self.table_widget.setItem(1, col, item) + + # Add the table to the right layout + checkboxes=QVBoxLayout() + top_lay = QHBoxLayout() + above_table=QVBoxLayout() + checkboxes.addWidget(self.checkbox1) + checkboxes.addWidget(self.checkbox2) + checkboxes.addWidget(self.checkbox3) + top_lay.addWidget(self.text_equation) + top_lay.addLayout(checkboxes) + above_table.addLayout(top_lay) + above_table.addLayout(bigger_layout) + right_layout.addLayout(above_table) + right_layout.addWidget(self.table_widget) + + # Add the splitter to the main layout + layout.addWidget(splitter) + def zero(x): + return 0 + self.equation= None + self.mod= Model(zero) + self.total_function=zero + self.function_before_Fermi= zero + self.function_before_convoluted= zero + self.update_text_edit_boxes() + self.i=0 + + self.function_list=[] + self.function_names_list=[] + self.cursor_handler = None + self.FD_state = False + self.CV_state = False + self.t0_state = False + self.offset_state = False + self.data=data + self.y=data + self.dim=data.dims[0] + self.plot_graph() + + def plot_graph(self): + self.axis.clear() + self.y.plot(ax=self.axis) + if self.checkbox0.isChecked(): + if self.cursor_handler is None: + self.cursor_handler = MovableCursors(self.axis) + else: + self.cursor_handler.redraw() + self.figure.tight_layout() + self.canvas.draw() + def update_text_edit_boxes(self): + self.text_equation.setPlaceholderText("Top Right Text Edit Box") + + def offset_function (self,x,offset): + return 0*x+offset + def constant (self,x,A): + return 0*x+A + def linear (self,x,a,b): + return a*x+b + def lorentzian(self,x, A, x0, gamma): + c=0.0000 + return A / (1 + ((x - x0) / (gamma+c)) ** 2) + def fermi_dirac(self,x, mu, T): + kb = 8.617333262145 * 10**(-5) # Boltzmann constant in eV/K + return 1 / (1 + np.exp((x - mu) / (kb * T))) + def gaussian(self,x,A, x0, gamma): + return A* np.exp(-(x -x0)**2 / (2 * gamma**2)) + def gaussian_conv(self,x,sigma): + return np.exp(-(x)**2 / (2 * sigma**2)) + def jump(self,x, mid): + """Heaviside step function.""" + o = np.zeros(x.size) + imid = max(np.where(x <= mid)[0]) + o[imid:] = 1.0 + return o + def jump2(self,x, mid,Amp): + """Heaviside step function.""" + o = np.zeros(x.size) + imid = max(np.where(x <= mid)[0]) + o[:imid] = Amp + return o + def sinusoid(self,x,A,omega,phi): + return A* np.sin(omega*x+phi) + + def centered_kernel(self,x, sigma): + mean = x.mean() + return np.exp(-(x-mean)**2/(2*sigma/2.3548200)**2) + + def convolve(self,arr, kernel): + """Simple convolution of two arrays.""" + npts = min(arr.size, kernel.size) + pad = np.ones(npts) + tmp = np.concatenate((pad*arr[0], arr, pad*arr[-1])) + out = np.convolve(tmp, kernel/kernel.sum(), mode='valid') + noff = int((len(out) - npts) / 2) + return out[noff:noff+npts] + + + def convolution(x, func, *args, sigma=1.0): + N = 20 # Assuming N is intended to be a local variable here + x_step = x[1] - x[0] + + # Create the shifted input signal 'y' for convolution + y = np.zeros(N + len(x)) + for i in range(N): + y[i] = x[0] - (N - i) * x_step + y[N:] = x # Append the original signal x to y + + # Create the Gaussian kernel + x_gauss = np.linspace(-0.5, 0.5, len(x)) + gaussian_values = np.exp(-0.5 * (x_gauss / sigma)**2) / (sigma * np.sqrt(2 * np.pi)) + + # Evaluate the function values with parameters + function_values = func(x, *args) + + # Perform convolution + convolution_result = np.convolve(function_values, gaussian_values, mode='same') + + return convolution_result[N-1:-1] + + + def clear_graph_window(self): + self.axis.clear() + self.plot_graph() + + def checkbox0_changed(self, state): + if state == Qt.Checked: + if self.cursor_handler is None: + self.cursor_handler = MovableCursors(self.axis) + self.canvas.draw() + else: + self.cursor_handler.redraw() + else: + self.cursor_handler.remove() + + def checkbox1_changed(self, state): + if self.CV_state== True: + pos=2 + else: + pos=0 + if state == Qt.Checked: + self.FD_state = True + self.update_equation() + self.table_widget.insertRow(pos) + label_item = QTableWidgetItem("Fermi") + self.table_widget.setVerticalHeaderItem(pos, label_item) + for col in range(4): + item = QTableWidgetItem('') + item.setFlags(Qt.ItemIsEnabled) # Make cell uneditable + self.table_widget.setItem(pos, col, item) + item.setBackground(QBrush(QColor('grey'))) + c=self.table_widget.rowCount() + self.table_widget.insertRow(pos+1) + label_item1 = QTableWidgetItem("Fermi level") + checkbox_widget = QWidget() + checkbox_layout = QHBoxLayout() + checkbox_layout.setAlignment(Qt.AlignCenter) + checkbox = QCheckBox() + checkbox.stateChanged.connect(lambda state, row= pos+1: self.handle_checkbox_state_change(state, row)) + # print('thecount',c+1) + checkbox_layout.addWidget(checkbox) + checkbox_widget.setLayout(checkbox_layout) + self.table_widget.setCellWidget(pos+1, 3, checkbox_widget) + self.table_widget.setVerticalHeaderItem(pos+1, label_item1) + + self.table_widget.insertRow(pos+2) + label_item2 = QTableWidgetItem("Temperature") + checkbox_widget = QWidget() + checkbox_layout = QHBoxLayout() + checkbox_layout.setAlignment(Qt.AlignCenter) + checkbox = QCheckBox() + checkbox.stateChanged.connect(lambda state, row= pos+2: self.handle_checkbox_state_change(state, row)) + checkbox_layout.addWidget(checkbox) + checkbox_widget.setLayout(checkbox_layout) + self.table_widget.setCellWidget(pos+2, 3, checkbox_widget) + self.table_widget.setVerticalHeaderItem(pos+2, label_item2) + else: + self.FD_state = False + self.update_equation() + # print("Checkbox 1 is unchecked") + + self.table_widget.removeRow(pos) + self.table_widget.removeRow(pos) + self.table_widget.removeRow(pos) + + def checkbox2_changed(self, state): + if state == Qt.Checked: + self.CV_state = True + + self.update_equation() + + self.table_widget.insertRow(0) + label_item = QTableWidgetItem("Convolution") + self.table_widget.setVerticalHeaderItem(0, label_item) + # self.table_widget.setVerticalHeaderItem(0, new_row_name) + for col in range(4): + item = QTableWidgetItem('') + item.setFlags(Qt.ItemIsEnabled) # Make cell uneditable + self.table_widget.setItem(0, col, item) + item.setBackground(QBrush(QColor('grey'))) + + self.table_widget.insertRow(1) + label_item1 = QTableWidgetItem("sigma") + checkbox_widget = QWidget() + checkbox_layout = QHBoxLayout() + checkbox_layout.setAlignment(Qt.AlignCenter) + checkbox = QCheckBox() + checkbox.stateChanged.connect(lambda state, row= 1: self.handle_checkbox_state_change(state, row)) + checkbox_layout.addWidget(checkbox) + checkbox_widget.setLayout(checkbox_layout) + self.table_widget.setCellWidget(1, 3, checkbox_widget) + self.table_widget.setVerticalHeaderItem(1, label_item1) + + else: + self.CV_state = False + self.update_equation() + # print("Checkbox 1 is unchecked") + + self.table_widget.removeRow(0) + self.table_widget.removeRow(0) + def checkbox3_changed(self, state): + if state == Qt.Checked: + self.offset_state=True + else: + self.offset_state=False + + def item_selected(self, item): + # print(f"Selected: {item.text()}") + if item.text() == 'Lorentz': + self.function_selected = self.lorentzian + elif item.text() == 'Gauss': + self.function_selected = self.gaussian + elif item.text()=='linear': + self.function_selected =self.linear + elif item.text()=='constant': + self.function_selected =self.constant + elif item.text()=='jump': + self.function_selected =self.jump2 + elif item.text()=='sinusoid': + self.function_selected =self.sinusoid + + def button_guess_clicked(self): + cursors= self.cursor_handler.cursors() + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + max_value= self.y_f.data.max() + min_value= self.y_f.data.min() + mean_value= self.y_f.data.mean() + max_arg=self.y_f.data.argmax() + # print(self.x_f[max_arg].item()) + for row in range(self.table_widget.rowCount()): + header_item = self.table_widget.verticalHeaderItem(row) + if "A" in header_item.text(): + self.params[header_item.text()].set(value=max_value) + item = QTableWidgetItem(str(max_value)) + self.table_widget.setItem(row, 1, item) + elif "x0" in header_item.text(): + self.params[header_item.text()].set(value=self.x_f[max_arg].item()) + item = QTableWidgetItem(str(self.x_f[max_arg].item())) + self.table_widget.setItem(row, 1, item) + elif "gamma" in header_item.text(): + self.params[header_item.text()].set(value=0.2) + item = QTableWidgetItem(str(0.2)) + self.table_widget.setItem(row, 1, item) + + + + def button_remove_clicked(self): + if self.i>0: + self.i-=1 + current_row_count = self.table_widget.rowCount() + sig = inspect.signature(self.function_list[-1]) + params = sig.parameters + + for p in range(len(params)): + self.table_widget.removeRow(current_row_count-1-p) + + self.function_list.remove(self.function_list[-1]) + self.function_names_list.remove(self.function_names_list[-1]) + self.update_equation() + self.create() + + def button_add_clicked(self): + def zero(x): + return 0 + + + self.i+=1 + self.function_list.append(self.function_selected) + self.function_names_list.append(self.list_widget.currentItem().text()) + j=0 + for p in self.function_list: + current_function=Model(p,prefix='f'+str(j)+'_') + j+=1 + + + current_row_count = self.table_widget.rowCount() + + self.table_widget.insertRow(current_row_count) + new_row_name = QTableWidgetItem(self.list_widget.currentItem().text()) + self.table_widget.setVerticalHeaderItem(current_row_count, new_row_name) + for col in range(4): + item = QTableWidgetItem('') + item.setFlags(Qt.ItemIsEnabled) # Make cell uneditable + self.table_widget.setItem(current_row_count, col, item) + item.setBackground(QBrush(QColor('grey'))) + c=current_row_count + for p in range(len(current_function.param_names)): + + self.table_widget.insertRow(c+p+1) + # print(current_function.param_names[p]) + new_row_name = QTableWidgetItem(current_function.param_names[p]) + self.table_widget.setVerticalHeaderItem(c+p+1, new_row_name) + checkbox_widget = QWidget() + checkbox_layout = QHBoxLayout() + checkbox_layout.setAlignment(Qt.AlignCenter) + checkbox = QCheckBox() + checkbox.stateChanged.connect(lambda state, row=c + p + 1: self.handle_checkbox_state_change(state, row)) + checkbox_layout.addWidget(checkbox) + checkbox_widget.setLayout(checkbox_layout) + self.table_widget.setCellWidget(c+p+1, 3, checkbox_widget) + + self.update_equation() + self.create() + + def update_equation(self): + self.equation='' + # print('names',self.function_names_list) + for j,n in enumerate(self.function_names_list): + if len(self.function_names_list)==1: + self.equation= n + else: + if j==0: + self.equation= n + else: + self.equation+= '+' + n + if self.FD_state: + self.equation= '('+ self.equation+ ')* Fermi_Dirac' + self.text_equation.setPlainText(self.equation) + # print('equation',self.equation) + + + def table_item_changed(self, item): + # print(f"Table cell changed at ({item.row()}, {item.column()}): {item.text()}") + header_item = self.table_widget.verticalHeaderItem(item.row()) + # print('theeeeeeitem=',item.text()) + + def handle_checkbox_state_change(self,state,row): + if state == Qt.Checked: + header_item = self.table_widget.verticalHeaderItem(row) + + else: + header_item = self.table_widget.verticalHeaderItem(row) + def create(self): + def zero(x): + return 0 + cursors= self.cursor_handler.cursors() + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + # print(self.y_f) + if self.offset_state==True: + self.params['offset'].set(value=self.y_f.data.min()) + list_axis=[[self.y[self.dim]],[self.x_f]] + self.mod= Model(zero) + j=0 + for f in self.function_list: + self.mod+=Model(f,prefix='f'+str(j)+'_') + j+=1 + if self.FD_state == True: + self.mod= self.mod* Model(self.fermi_dirac) + if self.CV_state == True: + self.mod = CompositeModel(self.mod, Model(self.centered_kernel), self.convolve) + if self.offset_state==True: + self.mod= self.mod+Model(self.offset_function) + m1=make_model(self.mod, self.table_widget) + self.mod=m1.current_model() + self.params=m1.current_params() + def fit(self): + + def zero(x): + return 0 + self.mod= Model(zero) + cursors= self.cursor_handler.cursors() + j=0 + for f in self.function_list: + self.mod+=Model(f,prefix='f'+str(j)+'_') + j+=1 + if self.FD_state == True: + self.mod= self.mod* Model(self.fermi_dirac) + if self.CV_state == True: + self.mod = CompositeModel(self.mod, Model(self.centered_kernel), self.convolve) + if self.offset_state==True: + self.mod= self.mod+Model(self.offset_function) + m1=make_model(self.mod, self.table_widget) + self.mod=m1.current_model() + self.params=m1.current_params() + self.y_f=self.y.isel({self.dim:slice(cursors[0], cursors[1])}) + self.x_f=self.y_f[self.dim] + if self.offset_state==True: + self.params['offset'].set(value=self.y_f.data.min()) + # print(self.params) + out = self.mod.fit(self.y_f, self.params, x=self.x_f) + print(out.fit_report(min_correl=0.25)) + self.axis.plot(self.x_f,out.best_fit,color='red',label='fit') + self.figure.tight_layout() + self.canvas.draw() + + + +if __name__ == "__main__": + app = QApplication(sys.argv) + window = fit_panel_single() + window.show() + sys.exit(app.exec_()) diff --git a/src/mpes_tools/graphs.py b/src/mpes_tools/graphs.py index 7b50216..f11a999 100644 --- a/src/mpes_tools/graphs.py +++ b/src/mpes_tools/graphs.py @@ -1,80 +1,241 @@ import sys import numpy as np -from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton, QGridLayout +from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton, QGridLayout,QSlider,QLabel,QCheckBox +from PyQt5.QtCore import Qt from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas import matplotlib.pyplot as plt - +from IPython.core.getipython import get_ipython +from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar +import xarray as xr +from mpes_tools.right_click_handler import RightClickHandler +from PyQt5.QtWidgets import QMenu +from PyQt5.QtGui import QCursor class showgraphs(QMainWindow): - def __init__(self, x, y_arrays): + def __init__(self, x, y_arrays,y_arrays_err,names,list_axis,list_plot_fits): super().__init__() self.setWindowTitle("Multiple Array Plots") self.setGeometry(100, 100, 800, 600) # Store x and y data - self.x = x + self.dim=x.dims[0] + self.x = x.data self.y_arrays = y_arrays + self.y_arrays_err = y_arrays_err self.num_plots = len(y_arrays) - + self.list_plot_fits=list_plot_fits + self.list_axis=list_axis # Create a central widget and layout central_widget = QWidget(self) self.setCentralWidget(central_widget) layout = QGridLayout(central_widget) - + + # print(len(x),len(list_plot_fits)) + # print(list_plot_fits[0]) + self.slider = QSlider() + self.slider.setOrientation(1) # 1 = Qt.Horizontal + self.slider.setMinimum(0) + self.slider.setMaximum(len(x)-1) # Adjust as needed + self.slider.setValue(0) # Default value + # Function to update parameter + + self.slider_label = QLabel(f"{x.dims[0]}:0") + + self.figure, self.axis = plt.subplots() + self.canvas = FigureCanvas(self.figure) + plt.close(self.figure) + + self.toolbar = NavigationToolbar(self.canvas, self) + + + layout_plot = QVBoxLayout() + layout_plot.addWidget(self.toolbar) # assuming `layout` is your QVBoxLayout or similar + layout_plot.addWidget(self.canvas) + + widget_plot = QWidget() + widget_plot.setLayout(layout_plot) + + vbox = QVBoxLayout() + vbox.addWidget(widget_plot) + vbox.addWidget(self.slider_label) + vbox.addWidget(self.slider) + + layout.addLayout(vbox, 0, 0) # Place in top-left + + self.click_handlers=[] + self.handler_list=[] + self.ax_list=[] + self.data_list=[] + self.cursor_list=[] # Create and add buttons and plots for each y array in a 3x3 layout for i, y in enumerate(y_arrays): # Create a button to show the plot in a new window - button = QPushButton(f"Show Plot {i+1}") - button.setFixedSize(80, 30) # Set a fixed size for the button - button.clicked.connect(lambda checked, y=y, index=i+1: self.show_plot(y, index)) - + + data_array = xr.DataArray( + data=y, + dims=[self.dim], # e.g., 'energy', 'time', etc. + coords={self.dim: self.x}, + name=names[i] # Optional: give it a name (like the plot title) + ) + self.data_list.append(data_array) + # Calculate grid position - row = (i // 3) * 2 # Each function will take 2 rows: one for the plot, one for the button - col = i % 3 - + row = ((i+1) // 3) * 2 # Each function will take 2 rows: one for the plot, one for the button + col = (i+1) % 3 + widget,canvas,ax=self.create_plot_widget(data_array,y_arrays_err[i], names[i]) # Add the plot canvas to the grid - layout.addWidget(self.create_plot_widget(y, f"Plot {i+1}"), row, col) # Plot in a 3x3 grid - layout.addWidget(button, row + 1, col) # Button directly below the corresponding plot + checkbox = QCheckBox(f"Show error bars {i+1}") + checkbox.setFixedSize(120, 30) # Adjust size if needed + checkbox.stateChanged.connect(lambda state, data_array=data_array, y_err=y_arrays_err[i], index=i: self.show_err(state, data_array, y_err, index)) + + layout.addWidget(widget, row, col) # Plot in a 3x3 grid + # layout.addWidget(self.create_plot_widget(y, f"Plot {i+1}_"+names[i]), row, col) # Plot in a 3x3 grid + layout.addWidget(checkbox, row + 1, col) # Button directly below the corresponding plot + handler = RightClickHandler(canvas, ax,self.show_pupup_window) + canvas.mpl_connect("button_press_event", handler.on_right_click) + self.handler_list.append(handler) + # handler = SubplotClickHandler(ax, self.external_callback) + # canvas.mpl_connect("button_press_event", handler.handle_double_click) + # self.click_handlers.append(handler) + self.ax_list.append(ax) + self.cursor=ax.axvline(x=self.x[0], color='r', linestyle='--') + self.cursor_list.append(self.cursor) + # self.update_parameter(0) + self.axis.plot(self.list_axis[0][0],self.list_plot_fits[0][0][0],'o', label='data') + self.axis.plot(self.list_axis[1][0],self.list_plot_fits[0][1][0],'r--', label='fit') + self.axis.legend() + self.figure.tight_layout() + self.canvas.draw() + self.slider.valueChanged.connect(self.update_parameter) + + def show_pupup_window(self,canvas,ax): + # print(f"External callback: clicked subplot ({i},{j})") + for i, ax_item in enumerate(self.ax_list): + if ax == ax_item: + data = self.data_list[i] + coords = {k: data.coords[k].values.tolist() for k in data.coords} + dims = data.dims + name = data.name if data.name else f"data_{i}" + menu = QMenu(canvas) + action1 = menu.addAction(f"{data.name} plot") + action = menu.exec_(QCursor.pos()) + + if action == action1: + print(f''' +import xarray as xr +import numpy as np + +data_array = xr.DataArray( + data=np.array({data.values.tolist()}), + dims={dims}, + coords={coords}, + name="{name}" +) +''') + + + + def external_callback(self,ax): + # print(f"External callback: clicked subplot ({i},{j})") + for i, ax_item in enumerate(self.ax_list): + if ax == ax_item: + data = self.data_list[i] + coords = {k: data.coords[k].values.tolist() for k in data.coords} + dims = data.dims + name = data.name if data.name else f"data_{i}" + content = f""" +import xarray as xr +import numpy as np - def create_plot_widget(self, y, title): +data_array = xr.DataArray( + data=np.array({data.values.tolist()}), + dims={dims}, + coords={coords}, + name="{name}" +) +""" + break + shell = get_ipython() + payload = dict( + source='set_next_input', + text=content, + replace=False, + ) + shell.payload_manager.write_payload(payload, single=False) + # shell.run_cell("%gui qt") + QApplication.processEvents() + print('results extracted!') + + def create_plot_widget(self, data_array, y_err , title): """Creates a plot widget for displaying a function.""" + figure, ax = plt.subplots() - ax.plot(self.x, y) + plt.close(figure) + + # ax.errorbar(data_array[data_array.dims[0]].values, data_array.values, yerr=y_err, fmt='o', capsize=3) + ax.plot(data_array[data_array.dims[0]].values, data_array.values,marker='o', linestyle='-') + # data_array.plot(ax=ax,fmt='o', capsize=3) ax.set_title(title) - ax.grid(True) - ax.set_xlabel('x') - ax.set_ylabel('y') - + # print('create_plot'+f"self.ax id: {id(ax)}") # Create a FigureCanvas to embed in the Qt layout canvas = FigureCanvas(figure) - return canvas # Return the canvas so it can be used in the layout + toolbar = NavigationToolbar(canvas, self) - def show_plot(self, y, index): - """Show the plot in a new window.""" - figure, ax = plt.subplots() - ax.plot(self.x, y) - ax.set_title(f"Plot {index}") - ax.grid(True) - ax.set_xlabel('x') - ax.set_ylabel('y') - plt.show() # Show the figure in a new window + # Wrap canvas and toolbar in a widget with a layout + widget = QWidget() + layout = QVBoxLayout() + widget.setLayout(layout) + + layout.addWidget(toolbar) + layout.addWidget(canvas) + return widget,canvas,ax # Return the canvas so it can be used in the layout + + def show_err(self,state,data_array,y_err,i): + self.ax_list[i].clear() + if state == Qt.Checked: + self.ax_list[i].errorbar(data_array[data_array.dims[0]].values, data_array.values, yerr=y_err, fmt='o', capsize=3) + else: + self.ax_list[i].plot(data_array[data_array.dims[0]].values, data_array.values,marker='o', linestyle='-') + # data_array.plot(ax=self.ax_list[i], fmt='o', capsize=3) + self.ax_list[i].set_title(data_array.name) + self.cursor_list[i]=self.ax_list[i].axvline(x=self.x[self.slider.value()], color='r', linestyle='--') + self.ax_list[i].figure.canvas.draw_idle() + + def update_parameter(self, value): + for i, c in enumerate(self.cursor_list): + if c is not None: + c.remove() + self.cursor_list[i]=self.ax_list[i].axvline(x=self.x[value], color='r', linestyle='--') + self.ax_list[i].figure.canvas.draw_idle() + base = self.slider_label.text().split(':')[0] + self.slider_label.setText(f"{base}: {self.x[value]:.2f}") + yscale = self.axis.get_yscale() + ylim = self.axis.get_ylim() + self.axis.clear() + + self.axis.plot(self.list_axis[0][0],self.list_plot_fits[value][0][0],'o', label='data') + self.axis.plot(self.list_axis[1][0],self.list_plot_fits[value][1][0],'r--', label='fit') + self.axis.set_yscale(yscale) + self.axis.set_ylim(ylim) + self.axis.legend() + self.figure.tight_layout() + self.canvas.draw() + # def create_plot_widget1(self,x_data, y_data, title, return_axes=False): + # from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas + # import matplotlib.pyplot as plt + + # fig, ax = plt.subplots() + # canvas = FigureCanvas(fig) + + # ax.plot(x_data,y_data) + # ax.set_title(title) + + # if return_axes: + # return canvas, ax # Allow updating later + # return canvas if __name__ == "__main__": app = QApplication(sys.argv) - # # Example data: Define x and multiple y arrays - # x = np.linspace(-10, 10, 400) - # y_arrays = [ - # np.sin(x), - # np.cos(x), - # np.tan(x), - # np.exp(x / 10), - # x**2, - # x**3, - # np.abs(x), - # np.log(x + 11), # Shift to avoid log(0) - # np.sqrt(x + 11) # Shift to avoid sqrt of negative - # ] - main_window = showgraphs() main_window.show() sys.exit(app.exec_()) diff --git a/src/mpes_tools/hdf5.py b/src/mpes_tools/hdf5.py index 5b133c9..5d4d45d 100644 --- a/src/mpes_tools/hdf5.py +++ b/src/mpes_tools/hdf5.py @@ -46,7 +46,7 @@ def recursive_write_metadata(h5group: h5py.Group, node: dict): try: h5group.create_dataset(key, data=str(item)) print(f"Saved {key} as string.") - except Exception as exc: + except BaseException as exc: raise ValueError( f"Unknown error occurred, cannot save {item} of type {type(item)}.", ) from exc diff --git a/src/mpes_tools/make_model.py b/src/mpes_tools/make_model.py index 940b1e2..d13d969 100644 --- a/src/mpes_tools/make_model.py +++ b/src/mpes_tools/make_model.py @@ -1,63 +1,53 @@ -import sys -from PyQt5.QtGui import QBrush, QColor -from PyQt5.QtWidgets import QTextEdit, QApplication, QMainWindow, QVBoxLayout, QHBoxLayout, QWidget, QSlider, QLabel, QAction, QCheckBox, QPushButton, QListWidget, QTableWidget, QTableWidgetItem, QTableWidget, QCheckBox, QSplitter -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import QTableWidgetItem, QHBoxLayout, QCheckBox, QWidget -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -import matplotlib.pyplot as plt - +from PyQt5.QtWidgets import QCheckBox class make_model: - # from matplotlib.widgets import CheckButtons, Button - # %matplotlib qt - def __init__(self,mod,table_widget): self.mod=mod self.params=mod.make_params() - print('otherpalce',self.params) - print('thefuuuuTable',table_widget) - print('count',table_widget.rowCount()) + # print('otherpalce',self.params) + # print('thefuuuuTable',table_widget) + # print('count',table_widget.rowCount()) for row in range(table_widget.rowCount()): item = table_widget.item(row, 1) checkbox_widget = table_widget.cellWidget(row, 3) - print('tableitenm=',item) + # print('tableitenm=',item) if item is not None and item.text().strip(): header_item = table_widget.verticalHeaderItem(item.row()) checkbox=checkbox_widget.findChild(QCheckBox) print(header_item.text(),item.text()) if header_item.text()== "Fermi level": self.params['mu'].set(value=float(item.text())) - if table_widget.item(row, 0) is not None: + if table_widget.item(row, 0) is not None and table_widget.item(row, 0).text().strip(): self.params['mu'].set(min=float(table_widget.item(row, 0).text())) - if table_widget.item(row, 2) is not None: + if table_widget.item(row, 2) is not None and table_widget.item(row, 2).text().strip(): self.params['mu'].set(max=float(table_widget.item(row, 2).text())) if checkbox.isChecked(): self.params['mu'].vary = False elif header_item.text()== "Temperature": self.params['T'].set(value=float(item.text())) - if table_widget.item(row, 0) is not None: + if table_widget.item(row, 0) is not None and table_widget.item(row, 0).text().strip(): self.params['T'].set(min=float(table_widget.item(row, 0).text())) - if table_widget.item(row, 2) is not None: + if table_widget.item(row, 2) is not None and table_widget.item(row, 2).text().strip(): self.params['T'].set(max=float(table_widget.item(row, 2).text())) if checkbox.isChecked(): self.params['T'].vary = False elif header_item.text()== "sigma": self.params['sigma'].set(value=float(item.text())) self.params['sigma'].set(min=0) - if table_widget.item(row, 0) is not None: + if table_widget.item(row, 0) is not None and table_widget.item(row, 0).text().strip(): self.params['sigma'].set(min=float(table_widget.item(row, 0).text())) - if table_widget.item(row, 2) is not None: + if table_widget.item(row, 2) is not None and table_widget.item(row, 2).text().strip(): self.params['sigma'].set(max=float(table_widget.item(row, 2).text())) if checkbox.isChecked(): self.params['sigma'].vary = False else: self.params[header_item.text()].set(value=float(item.text())) - if table_widget.item(row, 0) is not None: + if table_widget.item(row, 0) is not None and table_widget.item(row, 0).text().strip(): self.params[header_item.text()].set(min=float(table_widget.item(row, 0).text())) - if table_widget.item(row, 2) is not None: + if table_widget.item(row, 2) is not None and table_widget.item(row, 2).text().strip(): self.params[header_item.text()].set(max=float(table_widget.item(row, 2).text())) if checkbox.isChecked(): self.params[header_item.text()].vary = False diff --git a/src/mpes_tools/movable_vertical_cursors_graph.py b/src/mpes_tools/movable_vertical_cursors_graph.py index 580f4a8..74cb907 100644 --- a/src/mpes_tools/movable_vertical_cursors_graph.py +++ b/src/mpes_tools/movable_vertical_cursors_graph.py @@ -13,9 +13,7 @@ def __init__(self, ax): self.cursorlinev1=self.axis[int(len(self.axis)/4)] self.cursorlinev2=self.axis[int(3*len(self.axis)/4)] - # Create initial cursors (at the middle of the plot) - # self.v1_cursor = self.ax.axvline(x=5, color='r', linestyle='--', label='Cursor X') - # self.v2_cursor = self.ax.axhline(y=0, color='g', linestyle='--', label='Cursor Y') + self.Line1=self.ax.axvline(x=self.cursorlinev1, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) self.Line2=self.ax.axvline(x=self.cursorlinev2, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) @@ -41,8 +39,6 @@ def on_motion(self,event): elif self.active_cursor == self.Line2: self.Line2.set_xdata([event.xdata, event.xdata]) self.cursorlinev2= event.xdata - # print(dot1.center) - # print(self.cursorlinev1,self.cursorlinev2) self.ax.figure.canvas.draw() plt.draw() def find_nearest_index(array, value): @@ -51,10 +47,6 @@ def find_nearest_index(array, value): self.v1_pixel=find_nearest_index(self.axis, self.cursorlinev1) self.v2_pixel=find_nearest_index(self.axis, self.cursorlinev2) - # self.v1_pixel=int((self.cursorlinev1 - self.axis[0]) / (self.axis[-1] - self.axis[0]) * (self.axis.shape[0] - 1) + 0.5) - # self.v2_pixel=int((self.cursorlinev2 - self.axis[0]) / (self.axis[-1] - self.axis[0]) * (self.axis.shape[0] - 1) + 0.5) - print(self.v1_pixel,self.v2_pixel) - # print(self.v1_pixel,self.v2_pixel) def on_release(self,event): # global self.active_cursor @@ -64,14 +56,11 @@ def remove(self): self.cursorlinev2= self.Line2.get_xdata()[0] self.Line1.remove() self.Line2.remove() - # plt.draw() self.ax.figure.canvas.draw() - def redraw(self): # print(self.cursorlinev1,self.cursorlinev2) self.Line1=self.ax.axvline(x=self.cursorlinev1, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) self.Line2=self.ax.axvline(x=self.cursorlinev2, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - # plt.draw() self.ax.figure.canvas.draw() def cursors(self): return [self.v1_pixel,self.v2_pixel] \ No newline at end of file diff --git a/src/mpes_tools/right_click_handler.py b/src/mpes_tools/right_click_handler.py new file mode 100644 index 0000000..8b54ddd --- /dev/null +++ b/src/mpes_tools/right_click_handler.py @@ -0,0 +1,14 @@ +from PyQt5.QtWidgets import QMenu +from matplotlib.backend_bases import MouseButton +from PyQt5.QtGui import QCursor + +class RightClickHandler: + def __init__(self, canvas, ax, show_popup=None): + self.canvas = canvas + self.ax = ax + self.show_popup=show_popup + + def on_right_click(self, event): + if event.button == MouseButton.RIGHT and event.inaxes == self.ax: + if self.show_popup: + self.show_popup(self.canvas,self.ax) \ No newline at end of file diff --git a/src/mpes_tools/show_4d_window.py b/src/mpes_tools/show_4d_window.py index 6d1f355..3339e6d 100644 --- a/src/mpes_tools/show_4d_window.py +++ b/src/mpes_tools/show_4d_window.py @@ -5,13 +5,17 @@ import matplotlib.pyplot as plt import numpy as np import h5py -from mpes_tools.Gui_3d import GraphWindow +from mpes_tools.Gui_3d import Gui_3d import xarray as xr from mpes_tools.hdf5 import load_h5 - - -class MainWindow(QMainWindow): - def __init__(self): +from IPython.core.getipython import get_ipython +from mpes_tools.right_click_handler import RightClickHandler +from PyQt5.QtWidgets import QMenu +from PyQt5.QtGui import QCursor +from mpes_tools.colorscale_slider_handler import colorscale_slider + +class show_4d_window(QMainWindow): + def __init__(self,data_array: xr.DataArray): super().__init__() self.setWindowTitle("Main Window") @@ -20,11 +24,10 @@ def __init__(self): # Create a central widget for the graph and slider central_widget = QWidget() self.setCentralWidget(central_widget) - + # Create a layout for the central widget layout = QGridLayout() central_widget.setLayout(layout) - # Create four graphs and sliders self.graphs = [] self.slider1 = [] @@ -33,20 +36,35 @@ def __init__(self): self.slider4 = [] self.sliders = [] self.slider_labels = [] - + self.canvases = [] + self.click_handlers=[] + self.handler_list=[] + self.axis_list=[] + self.graph_layout_list=[] + self.color_graph_list=[] + self.list=[] plt.ioff() - for i in range(2): for j in range(2): - graph_window = QWidget() - graph_layout = QVBoxLayout() - graph_window.setLayout(graph_layout) + self.graph_window = QWidget() + self.graph_layout = QVBoxLayout() + self.graph_window.setLayout(self.graph_layout) # Create a figure and canvas for the graph - figure, axis = plt.subplots(figsize=(20, 20)) + figure, axis = plt.subplots(figsize=(10, 10)) + plt.close(figure) canvas = FigureCanvas(figure) - graph_layout.addWidget(canvas) + handler = RightClickHandler(canvas, axis,self.show_pupup_window) + canvas.mpl_connect("button_press_event", handler.on_right_click) + self.handler_list.append(handler) + + self.graph_layout.addWidget(canvas) + self.axis_list.append(axis) + self.canvases.append(canvas) + + + slider_layout= QHBoxLayout() slider_layout_2= QHBoxLayout() # Create a slider widget @@ -54,11 +72,6 @@ def __init__(self): slider1.setRange(0, 100) slider1.setValue(0) slider1_label = QLabel("0") - # slider.valueChanged.connect(self.slider_changed) - # Set the size of the slider - - # default_size = slider1.sizeHint() - # print(f"Default size of the slider: {default_size.width()}x{default_size.height()}") slider2 = QSlider(Qt.Horizontal) slider2.setRange(0, 10) @@ -94,12 +107,12 @@ def __init__(self): # slider2.valueChanged.connect(self.slider_changed) # Add the slider to the layout - graph_layout.addLayout(slider_layout) - graph_layout.addLayout(slider_layout_2) + self.graph_layout.addLayout(slider_layout) + self.graph_layout.addLayout(slider_layout_2) # graph_layout.addWidget(slider3) # graph_layout.addWidget(slider2) - layout.addWidget(graph_window, i, j) + layout.addWidget(self.graph_window, i, j) self.graphs.append(figure) self.slider1.append(slider1) self.slider2.append(slider2) @@ -107,6 +120,9 @@ def __init__(self): self.slider4.append(slider4) self.sliders.extend([slider1, slider2,slider3, slider4]) self.slider_labels.extend([slider1_label, slider2_label,slider3_label, slider4_label]) + self.graph_layout_list.append(self.graph_layout) + + for slider in self.slider1: slider.valueChanged.connect(self.slider_changed) for slider in self.slider2: @@ -115,30 +131,14 @@ def __init__(self): slider.valueChanged.connect(self.slider_changed) for slider in self.slider4: slider.valueChanged.connect(self.slider_changed) - - self.xv = None - self.yv = None - self.ev = None - self.eh = None - - # print(self.sliders) - # Create a menu bar - menu_bar = self.menuBar() - - # Create a 'File' menu - file_menu = menu_bar.addMenu("File") - # Create actions for opening a file and exiting - open_file_action = QAction("Open File", self) - open_file_action.triggered.connect(self.open_file_dialog) - file_menu.addAction(open_file_action) open_graphe_action = QAction("Energy", self) open_graphe_action.triggered.connect(self.open_graph_kxkydt) - open_graphy_action = QAction("ky_cut", self) - open_graphy_action.triggered.connect(self.open_graph_kyedt) open_graphx_action = QAction("kx_cut", self) - open_graphx_action.triggered.connect(self.open_graph_kxedt) + open_graphx_action.triggered.connect(self.open_graph_kyedt) + open_graphy_action = QAction("ky_cut", self) + open_graphy_action.triggered.connect(self.open_graph_kxedt) menu_bar = self.menuBar() @@ -151,15 +151,84 @@ def __init__(self): graph_menu.addAction(open_graphy_action) # file_menu.addAction(open_graph_action) self.graph_windows = [] - self.ce=None + self.colorscale_energy=[] + self.colorscale_ky=[] + self.colorscale_kx=[] + self.colorscale_dt=[] self.show() + self.load_data(data_array) + + def closeEvent(self, event): + # Remove references to graphs and canvases to prevent lingering objects + self.graphs = [] + self.canvases = [] + self.axis_list = [] + + # Update window state + self.window_open = False + event.accept() + def show_pupup_window(self,canvas,ax): + if ax==self.axis_list[0]: + menu = QMenu(canvas) + action1 = menu.addAction("energy plot") + action = menu.exec_(QCursor.pos()) + + if action == action1: + print(f"""# ENERGY plot +data.loc[{{ + '{self.axes[2]}': slice({self.data_array[self.axes[2]][self.slider1[0].value()].item()}, {self.data_array[self.axes[2]][self.slider1[0].value() + self.slider2[0].value()].item()}), + '{self.axes[3]}': slice({self.data_array[self.axes[3]][self.slider3[0].value()].item()}, {self.data_array[self.axes[3]][self.slider3[0].value() + self.slider4[0].value()].item()}) +}}].mean(dim=('{self.axes[2]}', '{self.axes[3]}')).T +""") + + elif ax==self.axis_list[1]: + menu = QMenu(canvas) + action1 = menu.addAction("ky plot") + action = menu.exec_(QCursor.pos()) + + if action == action1: + print(f"""# KY plot +data.loc[{{ + '{self.axes[1]}': slice({self.data_array[self.axes[1]][self.slider1[1].value()].item()}, {self.data_array[self.axes[1]][self.slider1[1].value() + self.slider2[1].value()].item()}), + '{self.axes[3]}': slice({self.data_array[self.axes[3]][self.slider3[1].value()].item()}, {self.data_array[self.axes[3]][self.slider3[1].value() + self.slider4[1].value()].item()}) +}}].mean(dim=('{self.axes[1]}', '{self.axes[3]}')).T +""") + + + elif ax==self.axis_list[2]: + menu = QMenu(canvas) + action1 = menu.addAction("kx plot") + action = menu.exec_(QCursor.pos()) + + if action == action1: + print(f"""# KX plot +data.loc[{{ + '{self.axes[0]}': slice({self.data_array[self.axes[0]][self.slider1[2].value()].item()}, {self.data_array[self.axes[0]][self.slider1[2].value() + self.slider2[2].value()].item()}), + '{self.axes[3]}': slice({self.data_array[self.axes[3]][self.slider3[2].value()].item()}, {self.data_array[self.axes[3]][self.slider3[2].value() + self.slider4[2].value()].item()}) +}}].mean(dim=('{self.axes[0]}', '{self.axes[3]}')).T +""") + + + elif ax==self.axis_list[3]: + menu = QMenu(canvas) + action1 = menu.addAction("kx ky plot") + action = menu.exec_(QCursor.pos()) + + if action == action1: + print(f"""# KX-KY plot +data.loc[{{ + '{self.axes[1]}': slice({self.data_array[self.axes[1]][self.slider1[3].value()].item()}, {self.data_array[self.axes[1]][self.slider1[3].value() + self.slider2[3].value()].item()}), + '{self.axes[0]}': slice({self.data_array[self.axes[0]][self.slider3[3].value()].item()}, {self.data_array[self.axes[0]][self.slider3[3].value() + self.slider4[3].value()].item()}) +}}].mean(dim=('{self.axes[1]}', '{self.axes[0]}')) +""") + def open_graph_kxkydt(self): E1=self.data_array[self.axes[2]][self.slider1[0].value()].item() E2=self.data_array[self.axes[2]][self.slider1[0].value()+self.slider2[0].value()+1].item() data_kxkydt = self.data_array.loc[{self.axes[2]:slice(E1,E2)}].mean(dim=(self.axes[2])) - graph_window=GraphWindow(data_kxkydt, self.slider3[0].value(), self.slider4[0].value()) + graph_window=Gui_3d(data_kxkydt, self.slider3[0].value(), self.slider4[0].value()) # Show the graph window graph_window.show() self.graph_windows.append(graph_window) @@ -168,7 +237,7 @@ def open_graph_kxedt(self): ky1=self.data_array[self.axes[1]][self.slider1[1].value()].item() ky2=self.data_array[self.axes[1]][self.slider1[1].value()+self.slider2[1].value()+1].item() data_kxedt = self.data_array.loc[{self.axes[1]:slice(ky1,ky2)}].mean(dim=(self.axes[1])) - graph_window = GraphWindow(data_kxedt, self.slider3[1].value(), self.slider4[1].value()) + graph_window = Gui_3d(data_kxedt, self.slider3[1].value(), self.slider4[1].value()) # Show the graph window graph_window.show() self.graph_windows.append(graph_window) @@ -177,24 +246,15 @@ def open_graph_kyedt(self): kx1=self.data_array[self.axes[0]][self.slider1[2].value()].item() kx2=self.data_array[self.axes[0]][self.slider1[2].value()+self.slider2[2].value()+1].item() data_kyedt = self.data_array.loc[{self.axes[0]:slice(kx1,kx2)}].mean(dim=(self.axes[0])) - graph_window = GraphWindow(data_kyedt, self.slider3[2].value(), self.slider4[2].value()) + graph_window = Gui_3d(data_kyedt, self.slider3[2].value(), self.slider4[2].value()) # Show the graph window + graph_window.show() self.graph_windows.append(graph_window) - - def open_file_dialog(self): - # Open file dialog to select a .h5 file - file_path, _ = QFileDialog.getOpenFileName(self, "Open hdf5", "", "h5 Files (*.h5)") - print(file_path) - if file_path: - data_array = load_h5(file_path) - - self.load_data(data_array) - + def load_data(self, data_array: xr.DataArray): self.data_array = data_array self.axes = data_array.dims - self.slider1[0].setRange(0,len(self.data_array.coords[self.axes[2]])-1) self.slider1[1].setRange(0,len(self.data_array.coords[self.axes[0]])-1) self.slider1[2].setRange(0,len(self.data_array.coords[self.axes[1]])-1) @@ -203,35 +263,110 @@ def load_data(self, data_array: xr.DataArray): self.slider3[0].setRange(0,len(self.data_array.coords[self.axes[3]])-1) self.slider3[1].setRange(0,len(self.data_array.coords[self.axes[3]])-1) self.slider3[2].setRange(0,len(self.data_array.coords[self.axes[3]])-1) - - self.update_energy(self.slider1[0].value(),self.slider2[0].value() , self.slider1[1].value(), self.slider2[1].value()) - # self.ce= update_color(self.im,self.graphs[0],self.graphs[0].gca()) - # self.ce.slider_plot.on_changed(self.ce.update) + self.slider_labels[0].setText(self.axes[2]) + self.slider_labels[1].setText("Δ"+self.axes[2]) + self.slider_labels[2].setText(self.axes[3]) + self.slider_labels[3].setText("Δ"+self.axes[3]) + + self.slider_labels[4].setText(self.axes[1]) + self.slider_labels[5].setText("Δ"+self.axes[1]) + self.slider_labels[6].setText(self.axes[3]) + self.slider_labels[7].setText("Δ"+self.axes[3]) + + self.slider_labels[8].setText(self.axes[0]) + self.slider_labels[9].setText("Δ"+self.axes[0]) + self.slider_labels[10].setText(self.axes[3]) + self.slider_labels[11].setText("Δ"+self.axes[3]) + + self.slider_labels[12].setText(self.axes[1]) + self.slider_labels[13].setText("Δ"+self.axes[1]) + self.slider_labels[14].setText(self.axes[0]) + self.slider_labels[15].setText("Δ"+self.axes[0]) - self.update_ky(self.slider1[2].value(), self.slider2[2].value(), self.slider3[0].value(), self.slider4[0].value()) + + + self.initialize_plots() + self.initialize_cursors() + + self.update_energy(self.slider1[0].value(),self.slider2[0].value(),self.slider3[0].value(), self.slider4[0].value()) + + self.update_ky(self.slider1[1].value(), self.slider2[1].value(),self.slider3[1].value(), self.slider4[1].value()) - self.update_kx(self.slider3[1].value(), self.slider4[1].value(), self.slider3[2].value(), self.slider4[2].value()) + self.update_kx(self.slider1[2].value(), self.slider2[2].value(),self.slider3[2].value(), self.slider4[2].value()) + + self.update_dt(self.slider1[3].value(), self.slider2[3].value(), self.slider3[3].value(), self.slider4[3].value()) + + + + def initialize_plots(self): + data_avg=self.data_array.isel({self.axes[2]:slice(0,0), self.axes[3]:slice(0,0)}).mean(dim=(self.axes[2], self.axes[3])) + self.im0=data_avg.T.plot(ax=self.graphs[0].gca(),cmap='terrain', add_colorbar=False) + + data_avg=self.data_array.isel({self.axes[1]:slice(0,0), self.axes[3]:slice(0,0)}).mean(dim=(self.axes[1], self.axes[3])) + self.im1=data_avg.T.plot(ax=self.graphs[1].gca(),cmap='terrain', add_colorbar=False) + + data_avg=self.data_array.isel({self.axes[0]:slice(0,0), self.axes[3]:slice(0,0)}).mean(dim=(self.axes[0], self.axes[3])) + self.im2=data_avg.T.plot(ax=self.graphs[2].gca(),cmap='terrain', add_colorbar=False) - self.update_dt(self.slider1[3].value(), self.slider3[3].value(), self.slider2[3].value(), self.slider4[3].value()) + data_avg=self.data_array.isel({self.axes[1]:slice(0,0), self.axes[0]:slice(0,0)}).mean(dim=(self.axes[1], self.axes[0])) + self.im3=data_avg.plot(ax=self.graphs[3].gca(),cmap='terrain', add_colorbar=False) + + self.graphs[0].gca().figure.colorbar(self.im0, ax=self.graphs[0].gca()) + self.graphs[1].gca().figure.colorbar(self.im1, ax=self.graphs[1].gca()) + self.graphs[2].gca().figure.colorbar(self.im2, ax=self.graphs[2].gca()) + self.graphs[3].gca().figure.colorbar(self.im3, ax=self.graphs[3].gca()) + + self.im0.set_clim([self.data_array.min(),self.data_array.max()]) + self.im1.set_clim([self.data_array.min(),self.data_array.max()]) + self.im2.set_clim([self.data_array.min(),self.data_array.max()]) + self.im3.set_clim([self.data_array.min(),self.data_array.max()]) + self.colorscale_energy=colorscale_slider(self.graph_layout_list[0], self.im0, self.canvases[0], [self.data_array.min(),self.data_array.max()]) + self.colorscale_ky=colorscale_slider(self.graph_layout_list[1], self.im1, self.canvases[1], [self.data_array.min(),self.data_array.max()]) + self.colorscale_kx=colorscale_slider(self.graph_layout_list[2], self.im2, self.canvases[2], [self.data_array.min(),self.data_array.max()]) + self.colorscale_dt=colorscale_slider(self.graph_layout_list[3], self.im3, self.canvases[3], [self.data_array.min(),self.data_array.max()]) + + def initialize_cursors(self): + + ax=self.graphs[0].gca() + self.energy_kx_cursor = ax.axvline(x=self.data_array.coords[self.axes[1]][self.slider1[2].value()].item(), color='r', linestyle='--') + self.energy_ky_cursor = ax.axhline(y=self.data_array.coords[self.axes[1]][self.slider1[1].value()].item(), color='r', linestyle='--') + self.energy_kxky_x = ax.axvline(x=self.data_array.coords[self.axes[1]][self.slider1[3].value()].item(), color='b', linestyle='--') + self.energy_kxky_y = ax.axhline(y=self.data_array.coords[self.axes[0]][self.slider3[3].value()].item(), color='b', linestyle='--') + self.energy_delta_kx_cursor = self.graphs[0].gca().axvline(x=self.data_array.coords[self.axes[1]][self.slider1[2].value()+self.slider2[2].value()].item(), color='r', linestyle='--') + self.energy_delta_ky_cursor = self.graphs[0].gca().axhline(y=self.data_array.coords[self.axes[0]][self.slider1[1].value()+self.slider2[1].value()].item(), color='r', linestyle='--') + self.energy_delta_kxky_y = self.graphs[0].gca().axhline(y=self.data_array.coords[self.axes[1]][self.slider1[3].value()+self.slider2[3].value()].item(), color='b', linestyle='--') + self.energy_delta_kxky_x = self.graphs[0].gca().axvline(x=self.data_array.coords[self.axes[0]][self.slider3[3].value()+self.slider4[3].value()].item(), color='b', linestyle='--') + ax=self.graphs[1].gca() + self.ky_energy_cursor = ax.axhline(y=self.data_array.coords[self.axes[2]][self.slider1[0].value()].item(), color='r', linestyle='--') + self.ky_delta_energy_cursor = ax.axhline(y=self.data_array.coords[self.axes[2]][self.slider1[0].value()+self.slider2[0].value()].item(), color='r', linestyle='--') + ax=self.graphs[2].gca() + self.kx_energy_cursor = ax.axhline(y=self.data_array.coords[self.axes[2]][self.slider1[0].value()].item(), color='r', linestyle='--') + self.kx_delta_energy_cursor = ax.axhline(y=self.data_array.coords[self.axes[2]][self.slider1[0].value()+self.slider2[0].value()].item(), color='r', linestyle='--') + ax=self.graphs[3].gca() + self.kx_ky_energy_cursor = ax.axhline(y=self.data_array.coords[self.axes[2]][self.slider1[0].value()].item(), color='r', linestyle='--') + self.kx_ky_delta_energy_cursor = ax.axhline(y=self.data_array.coords[self.axes[2]][self.slider1[0].value()+self.slider2[0].value()].item(), color='r', linestyle='--') + self.energy_time_cursor = self.graphs[3].gca().axvline(x=self.data_array.coords[self.axes[3]][self.slider3[0].value()].item(), color='r', linestyle='--') + self.delta_energy_time_cursor = self.graphs[3].gca().axvline(x=self.data_array.coords[self.axes[3]][self.slider3[0].value()+self.slider4[0].value()].item(), color='r', linestyle='--') + self.ky_time_cursor = self.graphs[3].gca().axvline(x=self.data_array.coords[self.axes[3]][self.slider3[1].value()].item(), color='b', linestyle='--') + self.delta_ky_time_cursor = self.graphs[3].gca().axvline(x=self.data_array.coords[self.axes[3]][self.slider3[1].value()+self.slider4[1].value()].item(), color='b', linestyle='--') + self.kx_time_cursor = self.graphs[3].gca().axvline(x=self.data_array.coords[self.axes[3]][self.slider3[2].value()].item(), color='g', linestyle='--') + self.delta_kx_time_cursor = self.graphs[3].gca().axvline(x=self.data_array.coords[self.axes[3]][self.slider3[2].value()+self.slider4[2].value()].item(), color='g', linestyle='--') + def update_energy(self,Energy,dE,te,dte): - self.ce_state=True E1=self.data_array[self.axes[2]][Energy].item() E2=self.data_array[self.axes[2]][Energy+dE].item() te1=self.data_array[self.axes[3]][te].item() te2=self.data_array[self.axes[3]][te+dte].item() - - self.graphs[0].clear() ax=self.graphs[0].gca() - - self.im=self.data_array.loc[{self.axes[2]:slice(E1,E2), self.axes[3]:slice(te1,te2)}].mean(dim=(self.axes[2], self.axes[3])).T.plot(ax=ax) - - self.ev = ax.axvline(x=self.data_array.coords[self.axes[0]][self.slider1[1].value()], color='r', linestyle='--') - self.eh = ax.axhline(y=self.data_array.coords[self.axes[1]][self.slider1[2].value()], color='r', linestyle='--') + data_avg=self.data_array.loc[{self.axes[2]:slice(E1,E2), self.axes[3]:slice(te1,te2)}].mean(dim=(self.axes[2], self.axes[3])) + self.im0.set_array(data_avg.T.values) + ax.set_aspect('auto') + ax.set_title(f'energy: {E1:.2f}, E+dE: {E2:.2f} , t: {te1:.2f}, t+dt: {te2:.2f}') self.graphs[0].tight_layout() - self.graphs[0].canvas.draw() + self.graphs[0].canvas.draw_idle() def update_ky(self,ypos,dy,ty,dty): @@ -240,14 +375,13 @@ def update_ky(self,ypos,dy,ty,dty): ty1=self.data_array[self.axes[3]][ty].item() ty2=self.data_array[self.axes[3]][ty+dty].item() - self.graphs[1].clear() ax=self.graphs[1].gca() - self.data_array.loc[{self.axes[1]:slice(y1,y2), self.axes[3]:slice(ty1,ty2)}].mean(dim=(self.axes[1], self.axes[3])).T.plot(ax=ax) - - self.yv = ax.axvline(x=self.data_array.coords[self.axes[2]][self.slider1[0].value()], color='r', linestyle='--') - + data_avg=self.data_array.loc[{self.axes[1]:slice(y1,y2), self.axes[3]:slice(ty1,ty2)}].mean(dim=(self.axes[1], self.axes[3])) + self.im1.set_array(data_avg.T.values) + ax.set_aspect('auto') + ax.set_title(f'ky: {y1:.2f}, ky+dky: {y2:.2f} , t: {ty1:.2f}, t+dt: {ty2:.2f}') self.graphs[1].tight_layout() - self.graphs[1].canvas.draw() + self.graphs[1].canvas.draw_idle() def update_kx(self,xpos,dx,tx,dtx): @@ -256,49 +390,89 @@ def update_kx(self,xpos,dx,tx,dtx): tx1=self.data_array[self.axes[3]][tx].item() tx2=self.data_array[self.axes[3]][tx+dtx].item() - self.graphs[2].clear() ax=self.graphs[2].gca() - self.data_array.loc[{self.axes[0]:slice(x1,x2), self.axes[3]:slice(tx1,tx2)}].mean(dim=(self.axes[0], self.axes[3])).T.plot(ax=ax) - - self.xv = ax.axvline(x=self.data_array.coords[self.axes[2]][self.slider1[0].value()], color='r', linestyle='--') - + data_avg=self.data_array.loc[{self.axes[0]:slice(x1,x2), self.axes[3]:slice(tx1,tx2)}].mean(dim=(self.axes[0], self.axes[3])) + self.im2.set_array(data_avg.T.values) + ax.set_aspect('auto') + ax.set_title(f'kx: {x1:.2f}, kx+dkx: {x2:.2f} , t: {tx1:.2f}, t+dt: {tx2:.2f}') self.graphs[2].tight_layout() - self.graphs[2].canvas.draw() + self.graphs[2].canvas.draw_idle() - def update_dt(self,yt,xt,dyt,dxt): + def update_dt(self,yt,dyt,xt,dxt): yt1=self.data_array[self.axes[1]][yt].item() yt2=self.data_array[self.axes[1]][yt+dyt].item() xt1=self.data_array[self.axes[0]][xt].item() xt2=self.data_array[self.axes[0]][xt+dxt].item() - self.graphs[3].clear() ax=self.graphs[3].gca() - self.data_array.loc[{self.axes[1]:slice(yt1,yt2), self.axes[0]:slice(xt1,xt2)}].mean(dim=(self.axes[1], self.axes[0])).plot(ax=ax) - + + data_avg=self.data_array.loc[{self.axes[1]:slice(yt1,yt2), self.axes[0]:slice(xt1,xt2)}].mean(dim=(self.axes[1], self.axes[0])) + self.im3.set_array(data_avg.values) + ax.set_aspect('auto') + ax.set_title(f'ky: {yt1:.2f}, ky+dky: {yt2:.2f} , kx: {xt1:.2f}, kx+dkx: {xt2:.2f}') self.graphs[3].tight_layout() - self.graphs[3].canvas.draw() + self.graphs[3].canvas.draw_idle() + def slider_changed(self, value): sender = self.sender() # Get the slider that emitted the signal index = self.sliders.index(sender) # Find the index of the slider - self.slider_labels[index].setText(str(value)) # Update the corresponding label text + # self.slider_labels[index].setText(str(value)) # Update the corresponding label text + base = self.slider_labels[index].text().split(':')[0] + self.slider_labels[index].setText(f"{base}: {value}") if index in range(0,4): - # self.ce.slider_plot.on_changed(self.ce.update) + self.kx_energy_cursor.set_ydata([self.data_array.coords[self.axes[2]][self.slider1[0].value()].item(),self.data_array.coords[self.axes[2]][self.slider1[0].value()].item()]) + self.ky_energy_cursor.set_ydata([self.data_array.coords[self.axes[2]][self.slider1[0].value()].item(),self.data_array.coords[self.axes[2]][self.slider1[0].value()].item()]) + self.kx_ky_energy_cursor.set_ydata([self.data_array.coords[self.axes[2]][self.slider1[0].value()].item(),self.data_array.coords[self.axes[2]][self.slider1[0].value()].item()]) + + self.kx_delta_energy_cursor.set_ydata([self.data_array.coords[self.axes[2]][self.slider1[0].value() + self.slider2[0].value()].item(),self.data_array.coords[self.axes[2]][self.slider1[0].value() + self.slider2[0].value()].item()]) + self.ky_delta_energy_cursor.set_ydata([self.data_array.coords[self.axes[2]][self.slider1[0].value() + self.slider2[0].value()].item(),self.data_array.coords[self.axes[2]][self.slider1[0].value() + self.slider2[0].value()].item()]) + self.kx_ky_delta_energy_cursor.set_ydata([self.data_array.coords[self.axes[2]][self.slider1[0].value() + self.slider2[0].value()].item(),self.data_array.coords[self.axes[2]][self.slider1[0].value() + self.slider2[0].value()].item()]) + + self.energy_time_cursor.set_xdata([self.data_array.coords[self.axes[3]][self.slider3[0].value()].item(),self.data_array.coords[self.axes[3]][self.slider3[0].value()].item()]) + self.delta_energy_time_cursor.set_xdata([self.data_array.coords[self.axes[3]][self.slider3[0].value() + self.slider4[0].value()].item(),self.data_array.coords[self.axes[3]][self.slider3[0].value() + self.slider4[0].value()].item()]) + + + self.graphs[2].canvas.draw_idle() + self.graphs[1].canvas.draw_idle() + self.graphs[3].canvas.draw_idle() self.update_energy(self.slider1[0].value(),self.slider2[0].value(),self.slider3[0].value(), self.slider4[0].value()) - # self.update_line() + elif index in range(4,8): + + self.energy_ky_cursor.set_ydata([self.data_array.coords[self.axes[0]][self.slider1[1].value()].item(),self.data_array.coords[self.axes[0]][self.slider1[1].value()].item()]) + self.energy_delta_ky_cursor.set_ydata([self.data_array.coords[self.axes[0]][self.slider1[1].value() + self.slider2[1].value()].item(),self.data_array.coords[self.axes[0]][self.slider1[1].value() + self.slider2[1].value()].item()]) + self.ky_time_cursor.set_xdata([self.data_array.coords[self.axes[3]][self.slider3[1].value()].item(),self.data_array.coords[self.axes[3]][self.slider3[1].value()].item()]) + self.delta_ky_time_cursor.set_xdata([self.data_array.coords[self.axes[3]][self.slider3[1].value() + self.slider4[1].value()].item(),self.data_array.coords[self.axes[3]][self.slider3[1].value() + self.slider4[1].value()].item()]) + + + self.graphs[0].canvas.draw_idle() + self.graphs[3].canvas.draw_idle() self.update_ky(self.slider1[1].value(), self.slider2[1].value(),self.slider3[1].value(), self.slider4[1].value()) elif index in range (8,12): + self.energy_kx_cursor.set_xdata([self.data_array.coords[self.axes[1]][self.slider1[2].value()].item(),self.data_array.coords[self.axes[1]][self.slider1[2].value()].item()]) + self.energy_delta_kx_cursor.set_xdata([self.data_array.coords[self.axes[1]][self.slider1[2].value() + self.slider2[2].value()].item(),self.data_array.coords[self.axes[1]][self.slider1[2].value() + self.slider2[2].value()].item()]) + self.kx_time_cursor.set_xdata([self.data_array.coords[self.axes[3]][self.slider3[2].value()].item(),self.data_array.coords[self.axes[3]][self.slider3[2].value()].item()]) + self.delta_kx_time_cursor.set_xdata([self.data_array.coords[self.axes[3]][self.slider3[2].value() + self.slider4[2].value()].item(),self.data_array.coords[self.axes[3]][self.slider3[2].value() + self.slider4[2].value()].item()]) + + self.graphs[3].canvas.draw_idle() + self.graphs[0].canvas.draw_idle() self.update_kx(self.slider1[2].value(), self.slider2[2].value(),self.slider3[2].value(), self.slider4[2].value()) elif index in range (12,16): - self.update_dt(self.slider1[3].value(), self.slider3[3].value(), self.slider2[3].value(), self.slider4[3].value()) + + self.energy_kxky_y.set_ydata([self.data_array.coords[self.axes[1]][self.slider1[3].value()].item(),self.data_array.coords[self.axes[1]][self.slider1[3].value()].item()]) + self.energy_kxky_x.set_xdata([self.data_array.coords[self.axes[0]][self.slider3[3].value()].item(),self.data_array.coords[self.axes[0]][self.slider3[3].value()].item()]) + self.energy_delta_kxky_y.set_ydata([self.data_array.coords[self.axes[1]][self.slider1[3].value() + self.slider2[3].value()].item(),self.data_array.coords[self.axes[1]][self.slider1[3].value() + self.slider2[3].value()].item()]) + self.energy_delta_kxky_x.set_xdata([self.data_array.coords[self.axes[0]][self.slider3[3].value() + self.slider4[3].value()].item(),self.data_array.coords[self.axes[0]][self.slider3[3].value() + self.slider4[3].value()].item()]) + self.graphs[0].canvas.draw_idle() + self.update_dt(self.slider1[3].value(), self.slider2[3].value(), self.slider3[3].value(), self.slider4[3].value()) if __name__ == "__main__": app = QApplication(sys.argv) - window = MainWindow() + window = show_4d_window() window.show() sys.exit(app.exec_()) diff --git a/tests/Arpes_gui.py b/tests/Arpes_gui.py deleted file mode 100644 index 0270868..0000000 --- a/tests/Arpes_gui.py +++ /dev/null @@ -1,392 +0,0 @@ -import sys -from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QAction, QFileDialog, QSlider, QGridLayout,QHBoxLayout, QSizePolicy,QLabel -from PyQt5.QtCore import Qt -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -import matplotlib.pyplot as plt -import numpy as np -import h5py -from matplotlib.widgets import CheckButtons, Button -from matplotlib.patches import Circle -from matplotlib.lines import Line2D -from additional_window import GraphWindow -from color_scale import update_color -import xarray as xr -from Drawwindow import DrawWindow -from h5toxarray import h5toxarray_loader -# from k_path_4d_4 import drawKpath - -class MainWindow(QMainWindow): - def __init__(self): - super().__init__() - - self.setWindowTitle("Main Window") - self.setGeometry(100, 100, 800, 600) - - # Create a central widget for the graph and slider - central_widget = QWidget() - self.setCentralWidget(central_widget) - - # Create a layout for the central widget - layout = QGridLayout() - central_widget.setLayout(layout) - - # Create four graphs and sliders - self.graphs = [] - self.slider1 = [] - self.slider2 = [] - self.slider3 = [] - self.slider4 = [] - self.sliders = [] - self.slider_labels = [] - - for i in range(2): - for j in range(2): - graph_window = QWidget() - graph_layout = QVBoxLayout() - graph_window.setLayout(graph_layout) - - # Create a figure and canvas for the graph - figure, axis = plt.subplots(figsize=(20, 20)) - canvas = FigureCanvas(figure) - graph_layout.addWidget(canvas) - - slider_layout= QHBoxLayout() - slider_layout_2= QHBoxLayout() - # Create a slider widget - slider1 = QSlider(Qt.Horizontal) - slider1.setRange(0, 100) - slider1.setValue(0) - slider1_label = QLabel("0") - # slider.valueChanged.connect(self.slider_changed) - # Set the size of the slider - - # default_size = slider1.sizeHint() - # print(f"Default size of the slider: {default_size.width()}x{default_size.height()}") - - slider2 = QSlider(Qt.Horizontal) - slider2.setRange(0, 10) - slider2.setValue(0) - slider2_label = QLabel("0") - - - - slider3 = QSlider(Qt.Horizontal) - slider3.setRange(0, 100) - slider3.setValue(0) - slider3_label = QLabel("0") - - slider4 = QSlider(Qt.Horizontal) - slider4.setRange(0, 10) - slider4.setValue(0) - slider4_label = QLabel("0") - - slider1.setFixedSize(200, 12) # Change the width and height as needed - slider2.setFixedSize(200, 12) # Change the width and height as needed - slider3.setFixedSize(200, 12) # Change the width and height as needed - slider4.setFixedSize(155, 10) # Change the width and height as needed - - slider_layout.addWidget(slider1) - slider_layout.addWidget(slider1_label) - slider_layout.addWidget(slider2) - slider_layout.addWidget(slider2_label) - - slider_layout_2.addWidget(slider3) - slider_layout_2.addWidget(slider3_label) - slider_layout_2.addWidget(slider4) - slider_layout_2.addWidget(slider4_label) - # slider2.valueChanged.connect(self.slider_changed) - - # Add the slider to the layout - graph_layout.addLayout(slider_layout) - graph_layout.addLayout(slider_layout_2) - # graph_layout.addWidget(slider3) - # graph_layout.addWidget(slider2) - - layout.addWidget(graph_window, i, j) - self.graphs.append(figure) - self.slider1.append(slider1) - self.slider2.append(slider2) - self.slider3.append(slider3) - self.slider4.append(slider4) - self.sliders.extend([slider1, slider2,slider3, slider4]) - self.slider_labels.extend([slider1_label, slider2_label,slider3_label, slider4_label]) - for slider in self.slider1: - slider.valueChanged.connect(self.slider_changed) - for slider in self.slider2: - slider.valueChanged.connect(self.slider_changed) - for slider in self.slider3: - slider.valueChanged.connect(self.slider_changed) - for slider in self.slider4: - slider.valueChanged.connect(self.slider_changed) - - self.xv = None - self.yv = None - self.ev = None - self.eh = None - self.ph= None - self.pxv=None - self.pyh=None - self.axis=[] - # print(self.sliders) - # Create a menu bar - menu_bar = self.menuBar() - - # Create a 'File' menu - file_menu = menu_bar.addMenu("File") - - # Create actions for opening a file and exiting - open_file_action = QAction("Open File", self) - open_file_action.triggered.connect(self.open_file) - file_menu.addAction(open_file_action) - - open_graphe_action = QAction("Energy", self) - open_graphe_action.triggered.connect(self.open_graph_energy) - open_graphy_action = QAction("kx_cut", self) - open_graphy_action.triggered.connect(self.open_graph_y_cut) - open_graphx_action = QAction("ky_cut", self) - open_graphx_action.triggered.connect(self.open_graph_x_cut) - - menu_bar = self.menuBar() - - # Create a 'Graph' menu - graph_menu = menu_bar.addMenu("Graph") - - # Add the actions to the menu - graph_menu.addAction(open_graphe_action) - graph_menu.addAction(open_graphx_action) - graph_menu.addAction(open_graphy_action) - - open_draw_action = QAction("k-path", self) - open_draw_action.triggered.connect(self.open_draw_k_path) - - draw_menu= menu_bar.addMenu("Draw path") - draw_menu.addAction(open_draw_action) - # file_menu.addAction(open_graph_action) - self.graph_windows = [] - self.ce=None - - def open_draw_k_path(self): - D=DrawWindow(self.data_array,self.slider1[0].value(),self.slider2[0].value() , self.slider1[1].value(), self.slider2[1].value()) - D.show() - self.graph_windows.append(D) - - def open_graph_energy(self): - print('energy') - self.dataet=np.zeros((len(self.axis[0]),len(self.axis[1]),len(self.axis[3]))) - self.axet=[self.axis[0],self.axis[1],self.axis[3]] - - for i in range(self.slider1[0].value(),self.slider1[0].value()+self.slider2[0].value()+1): - self.dataet += self.data_updated[:, :, i,:] - graph_window= GraphWindow(self.dataet,self.axet,self.slider3[0].value(),self.slider4[0].value()) - - graph_window.show() - self.graph_windows.append(graph_window) - def open_graph_x_cut(self): - self.dataxt=np.zeros((len(self.axis[0]),len(self.axis[2]),len(self.axis[3]))) - self.axxt=[self.axis[0],self.axis[2],self.axis[3]] - for i in range(self.slider1[1].value(),self.slider1[1].value()+self.slider2[1].value()+1): - self.dataxt += self.data_updated[:, i, :,:] - graph_window = GraphWindow(self.dataxt,self.axxt,self.slider3[1].value(),self.slider4[1].value()) - # Show the graph window - graph_window.show() - self.graph_windows.append(graph_window) - def open_graph_y_cut(self): - self.datayt=np.zeros((len(self.axis[1]),len(self.axis[2]),len(self.axis[3]))) - self.axyt=[self.axis[1],self.axis[2],self.axis[3]] - - for i in range(self.slider1[2].value(),self.slider1[2].value()+self.slider2[2].value()+1): - self.datayt += self.data_updated[i, :, :,:] - graph_window = GraphWindow(self.datayt,self.axyt,self.slider3[2].value(),self.slider4[2].value()) - # Show the graph window - graph_window.show() - self.graph_windows.append(graph_window) - def open_graph_xy_cut(self): - self.datapt=np.zeros((len(self.axis[0]),len(self.axis[1]),len(self.axis[3]))) - self.axpt=[self.axis[0],self.axis[1],self.axis[3]] - - for i in range(self.slider1[2].value(),self.slider1[2].value()+self.slider2[2].value()+1): - self.datayt += self.data_updated[i, :, :,:] - graph_window = GraphWindow(self.datayt,self.axyt,self.slider3[2].value(),self.slider4[2].value()) - # Show the graph window - graph_window.show() - self.graph_windows.append(graph_window) - def open_file(self): - # Open file dialog to select a .txt file - # file_path, _ = QFileDialog.getOpenFileName(self, "Open Text File", "", "Text Files (*.txt)") - file_path, _ = QFileDialog.getOpenFileName(self, "Open Text File", "", "Text Files (*.h5)") - print(file_path) - if file_path: - # Load data from the file - # x, y = self.load_data(file_path) - # self.axis,self.data_updated = self.load_data2(file_path) - # Convert to an xarray.DataArray with named dimensions - df = h5py.File(file_path, 'r') - loader= h5toxarray_loader(df) - self.data_array= loader.get_data_array() - self.data_updated= loader.get_original_array() - self.axis=[self.data_array['kx'].data,self.data_array['ky'].data,self.data_array['E'].data,self.data_array['dt'].data] - - # print(self.axis[2]) - self.slider1[0].setRange(0,len(self.data_array['E'].data)-1) - self.slider1[1].setRange(0,len(self.data_array['kx'].data)-1) - self.slider1[2].setRange(0,len(self.data_array['ky'].data)-1) - self.slider1[3].setRange(0,len(self.data_array['kx'].data)-1) - self.slider3[3].setRange(0,len(self.data_array['ky'].data)-1) - self.slider3[0].setRange(0,len(self.data_array['dt'].data)-1) - self.slider3[1].setRange(0,len(self.data_array['dt'].data)-1) - self.slider3[2].setRange(0,len(self.data_array['dt'].data)-1) - - - - # self.update_plot(self.slider1[0].value(),self.slider2[0].value() , self.slider1[1].value(), self.slider2[0].value(), self.slider1[2].value(), self.slider2[2].value(), self.slider3[0].value(), self.slider4[0].value(),self.slider3[1].value(), self.slider4[1].value(), self.slider3[2].value(), self.slider4[2].value(), self.slider1[3].value(), self.slider3[3].value(), self.slider2[3].value(), self.slider4[3].value()) - self.update_energy(self.slider1[0].value(),self.slider2[0].value() , self.slider1[1].value(), self.slider2[1].value()) - - # self.ce= update_color(self.im,self.graphs[0],self.graphs[0].gca()) - # self.ce.slider_plot.on_changed(self.ce.update) - - self.update_y(self.slider1[2].value(), self.slider2[2].value(), self.slider3[0].value(), self.slider4[0].value()) - - self.update_x(self.slider3[1].value(), self.slider4[1].value(), self.slider3[2].value(), self.slider4[2].value()) - - self.update_point(self.slider1[3].value(), self.slider3[3].value(), self.slider2[3].value(), self.slider4[3].value()) - - def update_energy(self,Energy,dE,te,dte): - E1=self.data_array['E'][Energy].item() - # print(Energy,E1) - E2=self.data_array['E'][Energy+dE].item() - te1=self.data_array['dt'][te].item() - te2=self.data_array['dt'][te+dte].item() - - self.graphs[0].clear() - ax=self.graphs[0].gca() - self.im=self.data_array.sel(E=slice(E1,E2), dt=slice(te1,te2)).mean(dim=("E", "dt")).T.plot(ax=ax) - # ax.set_title('Loaded Data') - ax.set_xlabel('kx') - ax.set_ylabel('ky') - ax.set_title(f'Energy: {E1:.2f}, dE: {E2-E1}') - # self.graphs[0].tight_layout() - self.graphs[0].canvas.draw() - self.ev = self.graphs[0].gca().axvline(x=self.axis[1][self.slider1[2].value()], color='r', linestyle='--') - self.eh = self.graphs[0].gca().axhline(y=self.axis[0][self.slider1[1].value()], color='r', linestyle='--') - self.pxv = self.graphs[0].gca().axvline(x=self.axis[1][self.slider1[3].value()], color='b', linestyle='--') - self.pyh = self.graphs[0].gca().axhline(y=self.axis[0][self.slider3[3].value()], color='b', linestyle='--') - # if self.ce is not None: - # self.ce.slider_plot.on_changed(self.ce.update) - - def update_y(self,ypos,dy,ty,dty): - - y1=self.data_array['ky'][ypos].item() - y2=self.data_array['ky'][ypos+dy].item() - ty1=self.data_array['dt'][ty].item() - ty2=self.data_array['dt'][ty+dty].item() - - self.graphs[1].clear() - ax=self.graphs[1].gca() - self.data_array.sel(ky=slice(y1,y2), dt=slice(ty1,ty2)).mean(dim=("ky", "dt")).plot(ax=ax) - # ax.set_title('Loaded Data') - - ax.set_xlabel('Energy (eV)') - ax.set_ylabel('kx (1/A)') - ax.set_title(f'ky_pos: {y1:.2f}, dky: {y2-y1}') - self.graphs[1].tight_layout() - self.graphs[1].canvas.draw() - self.yv = ax.axvline(x=self.axis[2][self.slider1[0].value()], color='r', linestyle='--') - - def update_x(self,xpos,dx,tx,dtx): - x1=self.data_array['kx'][xpos].item() - x2=self.data_array['kx'][xpos+dx].item() - tx1=self.data_array['dt'][tx].item() - tx2=self.data_array['dt'][tx+dtx].item() - - self.graphs[2].clear() - ax=self.graphs[2].gca() - self.data_array.sel(kx=slice(x1,x2), dt=slice(tx1,tx2)).mean(dim=("kx", "dt")).plot(ax=ax) - # ax.set_title('Loaded Data') - ax.set_xlabel('Energy (eV)') - ax.set_ylabel('ky (1/A)') - ax.set_title(f'kx_pos: {x1:.2f}, dkx: {x2-x1}') - self.graphs[2].tight_layout() - self.graphs[2].canvas.draw() - self.xv = ax.axvline(x=self.axis[2][self.slider1[0].value()], color='r', linestyle='--') - def update_point(self,xt,yt,dxt,dyt): - yt1=self.data_array['ky'][yt].item() - yt2=self.data_array['ky'][yt+dyt].item() - xt1=self.data_array['kx'][xt].item() - xt2=self.data_array['kx'][xt+dxt].item() - - self.graphs[3].clear() - ax=self.graphs[3].gca() - self.data_array.sel(kx=slice(xt1,xt2), ky=slice(yt1,yt2)).mean(dim=("kx", "ky")).plot(ax=ax) - # ax.set_title('Loaded Data') - ax.set_xlabel('time (fs)') - ax.set_ylabel('Energy (eV)') - ax.set_title(f'kx_pos: {xt1:.2f}, dkx: {xt2-xt1},ky_pos: {yt1:.2f}, dky: {yt2-yt1}') - self.graphs[3].tight_layout() - self.graphs[3].canvas.draw() - self.ph = ax.axhline(y=self.axis[2][self.slider1[0].value()], color='r', linestyle='--') - - - def load_data2(self, file_path): - # Load data from the text file - # r'C:\Users\admin-nisel131\Documents\\' - # 'Scan130_scan130_Amine_100x100x300x50_spacecharge4_gamma850_amp_3p3.h5', 'r') - df = h5py.File(file_path, 'r') - # print(df.keys()) - print(df['axes'].keys()) - - axis=[df['axes/ax0'][: ],df['axes/ax1'][: ],df['axes/ax2'][: ],df['axes/ax3'][: ]] - # print(df['binned/BinnedData'].keys()) - data=df['binned/BinnedData'] - - - return axis,data - - def slider_changed(self, value): - sender = self.sender() # Get the slider that emitted the signal - index = self.sliders.index(sender) # Find the index of the slider - # print(index) - - self.slider_labels[index].setText(str(value)) # Update the corresponding label text - - if index in range(0,4): - - self.update_energy(self.slider1[0].value(),self.slider2[0].value(),self.slider3[0].value(), self.slider4[0].value()) - # self.update_line() - if self.xv is not None: - self.xv.remove() - if self.yv is not None: - self.yv.remove() - if self.ph is not None: - self.ph.remove() - - self.xv = self.graphs[1].gca().axvline(x=self.axis[2][self.slider1[0].value()], color='r', linestyle='--') - self.yv = self.graphs[2].gca().axvline(x=self.axis[2][self.slider1[0].value()], color='r', linestyle='--') - self.ph = self.graphs[3].gca().axhline(y=self.axis[2][self.slider1[0].value()], color='r', linestyle='--') - elif index in range(4,8): - - if self.eh is not None: - self.eh.remove() - - self.eh = self.graphs[0].gca().axhline(y=self.axis[0][self.slider1[1].value()], color='r', linestyle='--') - - self.update_y(self.slider1[1].value(), self.slider2[1].value(),self.slider3[1].value(), self.slider4[1].value()) - print('here') - elif index in range (8,12): - if self.ev is not None: - self.ev.remove() - self.ev = self.graphs[0].gca().axvline(x=self.axis[1][self.slider1[2].value()], color='r', linestyle='--') - self.update_x(self.slider1[2].value(), self.slider2[2].value(),self.slider3[2].value(), self.slider4[2].value()) - else: - if self.pxv is not None: - self.pxv.remove() - if self.pyh is not None: - self.pyh.remove() - self.update_point(self.slider1[3].value(), self.slider3[3].value(), self.slider2[3].value(), self.slider4[3].value()) - self.pxv = self.graphs[0].gca().axvline(x=self.axis[1][self.slider1[3].value()], color='b', linestyle='--') - self.pyh = self.graphs[0].gca().axhline(y=self.axis[0][self.slider3[3].value()], color='b', linestyle='--') - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = MainWindow() - window.show() - sys.exit(app.exec_()) diff --git a/tests/Drawwindow.py b/tests/Drawwindow.py deleted file mode 100644 index 77f70fe..0000000 --- a/tests/Drawwindow.py +++ /dev/null @@ -1,173 +0,0 @@ -import sys -import numpy as np -import matplotlib.pyplot as plt -from PyQt5.QtCore import Qt -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QTextEdit, \ - QHBoxLayout, QSizePolicy,QSlider,QLabel -# from k_path_4d_4 import drawKpath - -class DrawWindow(QMainWindow): - def __init__(self,data,s1,s2,s3,s4): - super().__init__() - - # Set the title and size of the main window - self.setWindowTitle("PyQt5 Matplotlib Example") - self.setGeometry(100, 100, 800, 600) - self.data_array=data - print(data['E'][0]) - # Create the main layout - main_layout = QVBoxLayout() - - # Create a widget to hold the layout - widget = QWidget() - widget.setLayout(main_layout) - self.setCentralWidget(widget) - - # Create a horizontal layout for the top row - top_row_layout = QHBoxLayout() - - - # Create top left graph - self.figure1, self.axis1 = plt.subplots() - self.canvas1 = FigureCanvas(self.figure1) - top_row_layout.addWidget(self.canvas1) - - # Create bottom right graph - self.figure2, self.axis2 = plt.subplots() - self.canvas2 = FigureCanvas(self.figure2) - top_row_layout.addWidget(self.canvas2) - - layout = QVBoxLayout() - - slider_layout= QHBoxLayout() - self.slider1 = QSlider(Qt.Horizontal) - self.slider1.setRange(0, len(data['E'].data)) - self.slider1.setValue(s1) - self.slider1_label = QLabel("0") - - self.slider2 = QSlider(Qt.Horizontal) - self.slider2.setRange(0, 10) - self.slider2.setValue(s2) - self.slider2_label = QLabel("0") - - self.slider1.setFixedSize(200, 12) # Change the width and height as needed - self.slider2.setFixedSize(200, 12) # Change the width and height as needed - - slider_layout.addWidget(self.slider1) - slider_layout.addWidget(self.slider1_label) - slider_layout.addWidget(self.slider2) - slider_layout.addWidget(self.slider2_label) - # layout.addLayout(slider_layout) - slider_layout2= QHBoxLayout() - self.slider3 = QSlider(Qt.Horizontal) - self.slider3.setRange(0, 100) - self.slider3.setValue(s3) - self.slider3_label = QLabel("0") - - self.slider4 = QSlider(Qt.Horizontal) - self.slider4.setRange(0, 10) - self.slider4.setValue(s4) - self.slider4_label = QLabel("0") - - self.slider3.setFixedSize(200, 12) # Change the width and height as needed - self.slider4.setFixedSize(200, 12) # Change the width and height as needed - - slider_layout2.addWidget(self.slider3) - slider_layout2.addWidget(self.slider3_label) - slider_layout2.addWidget(self.slider4) - slider_layout2.addWidget(self.slider4_label) - - # layout.addLayout(slider_layout2) - - self.slider1.valueChanged.connect(self.slider1_changed) - self.slider2.valueChanged.connect(self.slider2_changed) - self.slider3.valueChanged.connect(self.slider3_changed) - self.slider4.valueChanged.connect(self.slider4_changed) - - - main_layout.addLayout(top_row_layout) - main_layout.addLayout(slider_layout) - main_layout.addLayout(slider_layout2) - - - # Set size policy for the graph widgets - self.canvas1.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - self.canvas2.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - - self.update_energy(s1, s2, s3, s4) - # self.d=drawKpath(data, axis, fig, ax, ax2, linewidth, slider, N) - - # Plot data - # self.plot_graphs() - # self.update_text_edit_boxes() - - def slider1_changed(self,value): - self.slider1_label.setText(str(value)) - print(value) - self.update_energy(self.slider1.value(),self.slider2.value() , self.slider3.value(), self.slider4.value()) - def slider2_changed(self,value): - self.slider2_label.setText(str(value)) - self.update_energy(self.slider1.value(),self.slider2.value() , self.slider3.value(), self.slider4.value()) - def slider3_changed(self,value): - self.slider3_label.setText(str(value)) - self.update_energy(self.slider1.value(),self.slider2.value() , self.slider3.value(), self.slider4.value()) - def slider4_changed(self,value): - self.slider4_label.setText(str(value)) - # self.plot_graph(self.slider1.value(),self.slider2.value()) - # print(self.slider1.value(),self.slider2.value()) - # self.update_show(self.slider1.value(),self.slider2.value()) - self.update_energy(self.slider1.value(),self.slider2.value() , self.slider3.value(), self.slider4.value()) - - def update_energy(self,Energy,dE,te,dte): - - # self.ce_state=True - E1=self.data_array['E'][Energy].item() - # print(Energy,E1) - E2=self.data_array['E'][Energy+dE].item() - te1=self.data_array['dt'][te].item() - te2=self.data_array['dt'][te+dte].item() - # print(E1,E2,te1) - self.figure1.clear() - ax = self.figure1.add_subplot(111) # Recreate the axis on the figure - self.im=self.data_array.sel(E=slice(E1,E2), dt=slice(te1,te2)).mean(dim=("E", "dt")).plot(ax=ax) - # ax.set_title('Loaded Data') - ax.set_xlabel('X') - ax.set_ylabel('Y') - # self.graphs[0].tight_layout() - self.figure1.canvas.draw() - # self.ev = self.graphs[0].gca().axvline(x=self.axis[0][self.slider1[1].value()], color='r', linestyle='--') - # self.eh = self.graphs[0].gca().axhline(y=self.axis[1][self.slider1[2].value()], color='r', linestyle='--') - - - def plot_graphs(self): - # Plot on the top left graph - x1 = np.linspace(0, 10, 100) - y1 = np.sin(x1) - self.axis1.plot(x1, y1) - self.axis1.set_title('Top Left Graph') - self.axis1.set_xlabel('X') - self.axis1.set_ylabel('Y') - - # Plot on the bottom right graph - x2 = np.linspace(0, 10, 100) - y2 = np.cos(x2) - self.axis2.plot(x2, y2) - self.axis2.set_title('Bottom Right Graph') - self.axis2.set_xlabel('X') - self.axis2.set_ylabel('Y') - - # Update the canvas - self.canvas1.draw() - self.canvas2.draw() - - # def update_text_edit_boxes(self): - # # self.text_edit_top_right.setPlaceholderText("Top Right Text Edit Box") - # self.text_edit_bottom_left.setPlaceholderText("Bottom Left Text Edit Box") - - -if __name__ == "__main__": - app = QApplication(sys.argv) - window = DrawWindow() - window.show() - sys.exit(app.exec_()) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/additional_window.py b/tests/additional_window.py deleted file mode 100644 index c63c2a1..0000000 --- a/tests/additional_window.py +++ /dev/null @@ -1,473 +0,0 @@ -import sys -from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QCheckBox, QAction, QFileDialog, QSlider, QGridLayout,QHBoxLayout, QSizePolicy,QLabel -from PyQt5.QtCore import Qt -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -import matplotlib.pyplot as plt -import numpy as np -import h5py -from matplotlib.widgets import CheckButtons, Button -from matplotlib.patches import Circle -from matplotlib.lines import Line2D - - -from fit_panel6 import MainWindow - -# %matplotlib qt - -class GraphWindow(QMainWindow): - def __init__(self,data,axis,t,dt): - global t_final - super().__init__() - - self.setWindowTitle("Graph Window") - self.setGeometry(100, 100, 800, 600) - - # Create a central widget for the graph - central_widget = QWidget() - self.setCentralWidget(central_widget) - - layout = QVBoxLayout() - central_widget.setLayout(layout) - - self.fig, self.axs = plt.subplots(2,2,figsize=(20,16)) - self.canvas = FigureCanvas(self.fig) - - self.checkbox_e = QCheckBox("Integrate_energy") - self.checkbox_e.stateChanged.connect(self.checkbox_e_changed) - - self.checkbox_k = QCheckBox("Integrate_k") - self.checkbox_k.stateChanged.connect(self.checkbox_k_changed) - - self.checkbox_cursors = QCheckBox("energy_cursors") - self.checkbox_cursors.stateChanged.connect(self.checkbox_cursors_changed) - checkbox_layout= QHBoxLayout() - # Add the canvas to the layout - checkbox_layout.addWidget(self.checkbox_e) - checkbox_layout.addWidget(self.checkbox_k) - layout.addLayout(checkbox_layout) - layout.addWidget(self.canvas) - layout.addWidget(self.checkbox_cursors) - - slider_layout= QHBoxLayout() - self.slider1 = QSlider(Qt.Horizontal) - self.slider1.setRange(0, 100) - self.slider1.setValue(0) - self.slider1_label = QLabel("0") - - self.slider2 = QSlider(Qt.Horizontal) - self.slider2.setRange(0, 10) - self.slider2.setValue(0) - self.slider2_label = QLabel("0") - - self.slider1.setFixedSize(200, 12) # Change the width and height as needed - self.slider2.setFixedSize(200, 12) # Change the width and height as needed - - slider_layout.addWidget(self.slider1) - slider_layout.addWidget(self.slider1_label) - slider_layout.addWidget(self.slider2) - slider_layout.addWidget(self.slider2_label) - layout.addLayout(slider_layout) - # Create a layout for the central widget - self.active_cursor = None - self.cursorlinev1=1 - self.cursorlinev2=0 - # self.v1_pixel=None - # self.v2_pixel=None - self.Line1=None - self.Line2=None - self.square_artists = [] # To store the artists representing the dots - self.square_coords = [(0, 0), (0, 0)] # To store the coordinates of the dots - self.square_count = 0 # To keep track of the number of dots drawn - - - self.cid_press2= None - self.line_artists=[] - self.cid_press3 = None - self.cid_press4 = None - self.cid_press = None - - # Create a figure and canvas for the graph - - self.data_o=data - self.axis=axis - self.dt=dt - self.datae=np.zeros((len(self.axis[0]),len(self.axis[1]))) - # Plot data - self.plot_graph(t,dt) - self.ssshow(t,dt) - self.slider1.setRange(0,len(self.axis[2])-1) - self.plot=np.zeros_like(self.data[1,:]) - - self.slider1.valueChanged.connect(self.slider1_changed) - self.slider2.valueChanged.connect(self.slider2_changed) - t_final=self.axis[2].shape[0] - - - fit_panel_action = QAction('Fit_Panel',self) - fit_panel_action.triggered.connect(self.fit_panel) - - menu_bar = self.menuBar() - - # Create a 'Graph' menu - - graph_menu1 = menu_bar.addMenu("Fit Panel") - - graph_menu1.addAction(fit_panel_action) - - # Add the actions to the menu - - self.graph_windows=[] - self.t=t - - def slider1_changed(self,value): - self.slider1_label.setText(str(value)) - self.plot_graph(self.slider1.value(),self.slider2.value()) - # print(self.slider1.value(),self.slider2.value()) - self.update_show(self.slider1.value(),self.slider2.value()) - self.t=self.slider1.value() - # self.us() - # update_show(self.slider1.value(),self.slider2.value()) - def slider2_changed(self,value): - self.slider2_label.setText(str(value)) - self.plot_graph(self.slider1.value(),self.slider2.value()) - self.update_show(self.slider1.value(),self.slider2.value()) - self.dt=self.slider2.value() - # self.ssshow(self.slider1.value(),self.slider2.value()).update_show() - # self.us() - # update_show(self.slider1.value(),self.slider2.value()) - def checkbox_e_changed(self, state): - if state == Qt.Checked: - # print("Checkbox is checked") - self.integrate_E() - else: - # print("Checkbox is unchecked") - self.update_show(self.slider1.value(),self.slider2.value()) - def checkbox_k_changed(self, state): - if state == Qt.Checked: - # print("Checkbox is checked") - self.integrate_k() - else: - # print("Checkbox is unchecked") - self.update_show(self.slider1.value(),self.slider2.value()) - def checkbox_cursors_changed(self, state): - if state == Qt.Checked: - self.put_cursors() - # self.integrate_k() - else: - # print("Checkbox is unchecked") - self.remove_cursors() - def plot_graph(self,t,dt): - # Plot on the graph - x = [1, 2, 3, 4, 5] - y = [2, 3, 5, 7, 11] - self.data=np.zeros((len(self.axis[0]),len(self.axis[1]))) - # self.ax.plot(x, y) - for i in range (t,t+dt+1): - self.data+= self.data_o[:,:,i] - - self.axs[0,0].imshow(self.data, extent=[self.axis[1][0], self.axis[1][-1], self.axis[0][0], self.axis[0][-1]], origin='lower', cmap='viridis',aspect='auto') - self.axs[0,0].set_title('Sample Graph') - self.axs[0,0].set_xlabel('X') - self.axs[0,0].set_ylabel('Y') - self.fig.tight_layout() - self.canvas.draw() - - def fit_panel(self,event): - print('forfit',len(self.plot),'axis',len(self.axis)) - graph_window= MainWindow( self.data_o, self.axis,self.square_coords[0][1], self.square_coords[1][1],self.t,self.dt) - graph_window.show() - self.graph_windows.append(graph_window) - - def lz_fit(self, event): - two_lz_fit(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt,self.a).fit() - def fit(self, event): - fit_4d(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt).fit() - def fit_FD(self, event): - fit_FD(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt).fit() - def fit_FD_conv(self, event): - # print('ax0test=',self.ax[0]) - # print('ax1test=',self.ax[1]) - - fit_FD_lor_conv(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt).fit() - def fit_FD_conv_2(self, event): - - f=fit_FD_conv(self.data_o, self.axis, self.square_coords[0][1], self.square_coords[1][1], 0, t_final, self.v1_pixel, self.v2_pixel,self.dt) - f.show() - def ssshow(self,t,dt): - def test(self): - print('whatever test') - print('show is running') - c= self.data.shape[1]// 10 ** (len(str(self.data.shape[1])) - 1) - - def put_cursors(): - self.Line1=axe.axvline(x=self.cursorlinev1, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - self.Line2=axe.axvline(x=self.cursorlinev2, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - plt.draw() - self.fig.canvas.draw() - def remove_cursors(): - self.Line1.remove() - self.Line2.remove() - plt.draw() - self.fig.canvas.draw() - - - def integrate_E(): - self.plote=np.zeros_like(self.data[1,:]) - self.axs[1,0].clear() - plt.draw() - x_min = int(min(self.square_coords[1][1], self.square_coords[0][1])) - x_max = int(max(self.square_coords[1][1], self.square_coords[0][1])) + 1 - for i in range(x_min, x_max): - self.plote += self.data[i, :] - # if self.square_coords[1][1]self.square_coords[0][1]: - # for i in range(self.square_coords[0][1],self.square_coords[1][1]+1): - # self.plot+=self.data[i,:] - # else: - # self.plot+=self.data[self.square_coords[0][1],:] - - self.axs[1, 0].plot(self.axis[1][:],self.plote/abs(self.square_coords[0][1]-self.square_coords[1][1]),color='red') - - # save_data(self.axis[1], plot/abs(self.square_coords[0][1]-self.square_coords[1][1]),"EDC_time="+str(slider_t.val)+"_", [0.42,0.46],self.fig) - def integrate_k(): - self.plotk=np.zeros_like(self.data[:,1]) - self.axs[0,1].clear() - plt.draw() - x_min = int(min(self.square_coords[0][0], self.square_coords[1][0])) - x_max = int(max(self.square_coords[0][0], self.square_coords[1][0])) + 1 - for i in range(x_min, x_max): - self.plotk += self.data[:, i] - # if self.square_coords[0][0]1: - first_key = sorted(df['binned'].keys(), key=lambda x: int(x[1:]))[0] - data_shape = df['binned/' + first_key][:].shape - self.M = np.empty((data_shape[0], data_shape[1], data_shape[2], len(df['binned']))) - axis=[] - for idx, v in enumerate(sorted(df['binned'].keys(), key=lambda x: int(x[1:]))): - self.M[:, :, :, idx] = df['binned/' + v][:] - else: - self.M= df['binned/' + list(df['binned'].keys())[0]][:] - - - # Define the desired order lists - desired_orders = [ - ['ax0', 'ax1', 'ax2', 'ax3'], - ['kx', 'ky', 'E', 'delay'], - ['kx', 'ky', 'E', 'ADC'] - ] - - axes_list = [] - - matched_order = None - for i, order in enumerate(desired_orders): - # Check if all keys in the current order exist in df['axes'] - if all(f'axes/{axis}' in df for axis in order): - # If match is found, generate axes_list based on this order - axes_list = [df[f'axes/{axis}'] for axis in order] - matched_order = i + 1 # Store which list worked (1-based index) - break # Stop once the first matching list is found - - if matched_order: - print(f"Matched desired list {matched_order}: {desired_orders[matched_order - 1]}") - else: - print("No matching desired list found.") - - # print("Axes list:", axes_list) - # print(M[12,50,4,20]) - self.data_array = xr.DataArray( - self.M, - coords={"kx": axes_list[0], "ky": axes_list[1], "E": axes_list[2], "dt": axes_list[3]}, - dims=["kx", "ky", "E","dt"] - ) - def get_data_array(self): - return self.data_array - def get_original_array(self): - return self.M -# df =h5py.File(r'C:\Users\admin-nisel131\Documents\Scan130_scan130_Amine_100x100x300x50_spacecharge4_gamma850_amp_3p3.h5', 'r') -# test=h5toxarray_loader(df) diff --git a/tests/k_path_4d_4.py b/tests/k_path_4d_4.py deleted file mode 100644 index 13876c7..0000000 --- a/tests/k_path_4d_4.py +++ /dev/null @@ -1,422 +0,0 @@ -import numpy as np -import h5py -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.widgets import CheckButtons, Button -from scipy.ndimage import rotate -import h5py -# import mplcursors -from matplotlib.widgets import Slider, Cursor, SpanSelector -from matplotlib.gridspec import GridSpec -from matplotlib.lines import Line2D -from matplotlib.patches import Circle -from AdditionalInterface import AdditionalInterface -from AxisInteractor import AxisInteractor -from LinePixelGetter import LinePixelGetter -from update_plot_cut_4d import update_plot_cut -import json -import csv -from datetime import datetime - -class drawKpath: - # print(True) - def __init__(self, data,axis,fig, ax,ax2,linewidth,slider,N): - self.active_cursor=None - self.dots_count=0 - self.ax=ax - self.fig=fig - self.dot1_x=0 - self.do1_y=0 - self.dot2_x=0 - self.do2_y=0 - self.cid_press=None - self.linewidth=1 - self.line_artist=None - self.cb_line=None - self.button_update=None - self.dot1=None - self.dot2=None - self.method_running = False - self.pixels_along_line=[] - self.number=N - self.og_number=N - self.dots_list=[] - self.line_artist_list=[None]*N - self.pixels_along_path=[None]*N - # self.number=N - self.is_drawn= False - self.is_loaded= False - self.new=False - self.add_pressed=False - self.lw=linewidth - self.ax2=ax2 - self.data=data[:,:,:,slider] - self.axis=axis - self.pixels=[] - self.slider=slider - self.data2=data - self.slider_ax7 = plt.axes([0.55, 0.14, 0.02, 0.3]) - self.slider_dcut= Slider(self.slider_ax7, 'dcut_kpath', 0, 15, valinit=1, valstep=1, orientation='vertical') - # def update_plot_cut(self): - # update_plot_cut.update_plot_cut(self.data2[:,:,:,self.slider],self.ax2,self.pixels,self.lw) - def isdrawn(self): - return self.is_drawn - - - def get_pixels(self): - if self.pixels is not None: - return self.pixels - def get_pixels_along_line(self): - if self.pixels_along_line is not None: - return self.pixels_along_line - - def get_status(self): - if self.cb_line is not None: - return self.cb_line.get_status()[0] - else: - return False - - def draw(self): - # print('beginning') - def add_path(event): - self.add_pressed= True - - for i in range (self.number): - self.line_artist_list.append(None) - self.pixels_along_path.append(None) - # self.dots_list - print('line list=', len(self.line_artist_list)) - self.number=self.number+self.og_number - self.is_drawn=False - self.dots_count=0 - self.cid_press = self.fig.canvas.mpl_connect('button_press_event', drawdots) - - def drawline(dot1,dot2,pos): - self.pixels=[] - if self.dots_count ==0 and self.line_artist_list[len(self.dots_list)-2] is not None : - if not self.loaded: - self.line_artist_list[len(self.dots_list)-2].remove() # Remove the previous line if exists - print('test,code') - # if self.dots_count==2: - # line = Line2D([self.dots_list[len(self.dots_list)].center[0], self.dots_list[len(self.dots_list)-1].center[0]], [self.dots_list[len(self.dots_list)].center[1], self.dots_list[len(self.dots_list)-1].center[1]], linewidth=self.linewidth, color='red') - if self.dots_count==2 : - line = Line2D([dot1.center[0], dot2.center[0]], [dot1.center[1], dot2.center[1]], linewidth=self.linewidth, color='red') - - self.ax.add_line(line) - # print('movement',len(self.line_artist_list)) - print('currentline=',self.line_artist_list[pos]) - if self.line_artist_list[pos] is not None: - # print(pos,self.line_artist_list[pos].get_data()) - self.line_artist_list[pos].remove() - # if self.line_artist is not None: - # self.line_artist.remove() # Remove the previous line if exists - - self.line_artist = line - # self.line_artist_list.append(line) - self.line_artist_list[pos]=line - # print(pos,self.line_artist_list[pos].get_data()) - # x1=self.line_artist_list[pos].get_data()[0][1] - # y1=self.line_artist_list[pos].get_data()[1][1] - # x2= self.line_artist_list[pos].get_data()[0][0] - # y2=self.line_artist_list[pos].get_data()[1][0] - x1_pixel=int((self.line_artist_list[pos].get_data()[0][1] - self.axis[0][0]) / (self.axis[0][-1] - self.axis[0][0]) * (self.axis[0].shape[0] - 1) + 0.5) - y1_pixel=int((self.line_artist_list[pos].get_data()[1][1] - self.axis[1][0]) / (self.axis[1][-1] - self.axis[1][0]) * (self.axis[1].shape[0] - 1) + 0.5) - x2_pixel=int((self.line_artist_list[pos].get_data()[0][0] - self.axis[0][0]) / (self.axis[0][-1] - self.axis[0][0]) * (self.axis[0].shape[0] - 1) + 0.5) - y2_pixel=int((self.line_artist_list[pos].get_data()[1][0] - self.axis[1][0]) / (self.axis[1][-1] - self.axis[1][0]) * (self.axis[1].shape[0] - 1) + 0.5) - - self.pixels_along_path[pos] = LinePixelGetter.get_pixels_along_line(x1_pixel, y1_pixel, x2_pixel, y2_pixel, self.linewidth) - # print(x1_pixel,y1_pixel) - # self.pixels_along_path[pos]=LinePixelGetter.get_pixels_along_line(self.line_artist_list[pos].get_data()[0][1], self.line_artist_list[pos].get_data()[1][1], self.line_artist_list[pos].get_data()[0][0], self.line_artist_list[pos].get_data()[1][0], self.linewidth) - # for i in self.pixels_along_path: - for i in range (0,self.number): - if self.pixels_along_path[i] is not None: - self.pixels+=self.pixels_along_path[i] - - def drawdots(event): - # if self.add_pressed: - - - if self.cb_line.get_status()[0] and len(self.dots_list) < self.number and (self.new or not self.is_drawn): - x = event.xdata # Round the x-coordinate to the nearest integer - y = event.ydata # Round the y-coordinate to the nearest integer - print('you hereeee') - print(self.number) - # print('line list=', len(self.line_artist_list)) - if self.dots_count==0: - self.dot= Circle((x, y), radius=0.1, color='yellow', picker=20) - self.ax.add_patch(self.dot) - # self.dot_coords[self.dots_count] = (x, y) - self.dots_list.append(self.dot) - self.dots_count += 1 - self.fig.canvas.draw() - else: - self.dot= Circle((x, y), radius=0.1, color='yellow', picker=20) - self.ax.add_patch(self.dot) - # self.dot_coords[self.dots_count] = (x, y) - self.dots_count += 1 - self.dots_list.append(self.dot) - print('dots list=',len(self.dots_list)) - drawline(self.dots_list[len(self.dots_list)-1],self.dots_list[len(self.dots_list)-2],len(self.dots_list)-2) - self.dots_count -= 1 - update_plot_cut.update_plot_cut(self.data,self.ax2,self.pixels,self.slider_dcut.val) - - self.fig.canvas.draw() - if len(self.dots_list)== self.number: - self.is_drawn=True - # print(self.is_drawn) - def on_checkbox_change(label): - if self.cb_line.get_status()[0]: - if self.is_loaded: - for i in range(len(self.dots_list)): - self.ax.add_patch(self.dots_list[i]) - plt.draw() - for i in range(len(self.line_artist_list)): - if self.line_artist_list[i] is not None: - self.ax.add_line(self.line_artist_list[i]) - plt.draw() - elif self.is_drawn: - for i in range(len(self.dots_list)): - self.ax.add_patch(self.dots_list[i]) - plt.draw() - for i in range(len(self.line_artist_list)): - if self.line_artist_list[i] is not None: - self.ax.add_line(self.line_artist_list[i]) - plt.draw() - - else: - self.cid_press = self.fig.canvas.mpl_connect('button_press_event', drawdots) - - else: - # Disconnect the click event - self.is_loaded= False - self.fig.canvas.mpl_disconnect(self.cid_press) - for i in range(len(self.dots_list)): - self.dots_list[i].remove() - plt.draw() - for i in range(len(self.line_artist_list)): - if self.line_artist_list[i] is not None: - self.line_artist_list[i].remove() - plt.draw() - def new(event): - - for i in range(len(self.dots_list)): - print(i) - self.dots_list[i].remove() - plt.draw() - for i in range(len(self.line_artist_list)): - print(i) - if self.line_artist_list[i] is not None: - self.line_artist_list[i].remove() - plt.draw() - self.new=True - self.is_drawn= False - self.is_loaded= False - self.dots_list=[] - self.line_artist_list=[None]*self.number - self.pixels_along_path=[None]*self.number - self.dots_count=0 - self.cid_press = self.fig.canvas.mpl_connect('button_press_event', drawdots) - def on_pick(event): - for i in range(len(self.dots_list)): - if event.artist == self.dots_list[i]: - self.active_cursor = self.dots_list[i] - def on_motion(event): - # global dot1,dot2 - if self.active_cursor is not None and event.inaxes == self.ax: - # Initialize a list of dictionaries to store dot information - dot_info_list = [{"dot": dot, "center": dot.center} for dot in self.dots_list] - self.dots_count=2 - - # Iterate through the list to find the selected dot - selected_dot_index = None - for i, dot_info in enumerate(dot_info_list): - dot = dot_info["dot"] - contains, _ = dot.contains(event) - if contains: - selected_dot_index = i - break # Exit the loop when a matching dot is found - - if selected_dot_index is not None: - selected_dot_info = dot_info_list[selected_dot_index] - selected_dot = selected_dot_info["dot"] - # print(f"Selected dot index: {selected_dot_index}") - # print(f"Selected dot center: {selected_dot_info['center']}") - selected_dot.center = (event.xdata, event.ydata) - plt.draw() - i=selected_dot_index - if i==len(self.dots_list)-1: - # self.line_artist_list[i-1].remove() - drawline(self.dots_list[i],self.dots_list[i-1],i-1) - update_plot_cut.update_plot_cut(self.data,self.ax2,self.pixels,self.slider_dcut.val) - elif i==0: - drawline(self.dots_list[i+1],self.dots_list[i],i) - update_plot_cut.update_plot_cut(self.data,self.ax2,self.pixels,self.slider_dcut.val) - else: - # self.line_artist_list[i-1].remove() - # self.line_artist_list[i].remove() - drawline(self.dots_list[i+1],self.dots_list[i],i) - update_plot_cut.update_plot_cut(self.data,self.ax2,self.pixels,self.slider_dcut.val) - drawline(self.dots_list[i],self.dots_list[i-1],i-1) - update_plot_cut.update_plot_cut(self.data,self.ax2,self.pixels,self.slider_dcut.val) - plt.draw() - - - def on_release(event): - self.active_cursor = None - def get_status(): - return self.cb_line.get_status()[0] - def dots_coord(): - return [self.dot1.center, self.dot2.center] - - def update_dcut(val): - self.linewidth=self.slider_dcut.val - self.pixels=[] - for position, line in enumerate(self.line_artist_list): - if line is not None: - line.set_linewidth(self.linewidth+1) - x1_pixel=int((line.get_data()[0][1] - self.axis[0][0]) / (self.axis[0][-1] - self.axis[0][0]) * (self.axis[0].shape[0] - 1) + 0.5) - y1_pixel=int((line.get_data()[1][1] - self.axis[1][0]) / (self.axis[1][-1] - self.axis[1][0]) * (self.axis[1].shape[0] - 1) + 0.5) - x2_pixel=int((line.get_data()[0][0] - self.axis[0][0]) / (self.axis[0][-1] - self.axis[0][0]) * (self.axis[0].shape[0] - 1) + 0.5) - y2_pixel=int((line.get_data()[1][0] - self.axis[1][0]) / (self.axis[1][-1] - self.axis[1][0]) * (self.axis[1].shape[0] - 1) + 0.5) - # print(x1_pixel,y1_pixel,x2_pixel,y2_pixel) - self.pixels_along_path[position] = LinePixelGetter.get_pixels_along_line(x1_pixel, y1_pixel, x2_pixel, y2_pixel, self.linewidth) - self.pixels+=self.pixels_along_path[position] - - print('before before line') - # for pos in range(0,self.number): - # print('before line') - # if self.line_artist_list[pos] is not None: - # x1_pixel=int((self.line_artist_list[pos].get_data()[0][1] - self.axis[0][0]) / (self.axis[0][-1] - self.axis[0][0]) * (self.axis[0].shape[0] - 1) + 0.5) - # y1_pixel=int((self.line_artist_list[pos].get_data()[1][1] - self.axis[1][0]) / (self.axis[1][-1] - self.axis[1][0]) * (self.axis[1].shape[0] - 1) + 0.5) - # x2_pixel=int((self.line_artist_list[pos].get_data()[0][0] - self.axis[0][0]) / (self.axis[0][-1] - self.axis[0][0]) * (self.axis[0].shape[0] - 1) + 0.5) - # y2_pixel=int((self.line_artist_list[pos].get_data()[1][0] - self.axis[1][0]) / (self.axis[1][-1] - self.axis[1][0]) * (self.axis[1].shape[0] - 1) + 0.5) - # print(x1_pixel,y1_pixel,x2_pixel,y2_pixel) - # self.pixels_along_path[pos] = LinePixelGetter.get_pixels_along_line(x1_pixel, y1_pixel, x2_pixel, y2_pixel, self.linewidth) - # self.pixels+=self.pixels_along_path[pos] - - # self.pixels_along_line = LinePixelGetter.get_pixels_along_line(self.dot1_x_pixel, self.dot1_y_pixel, self.dot2_x_pixel, self.dot2_y_pixel, self.linewidth) - # update_plot_cut.update_plot_cut(self.data,self.ax2,self.pixels_along_line,self.slider_dcut.val) - update_plot_cut.update_plot_cut(self.data,self.ax2,self.pixels,self.slider_dcut.val) - def draw_load(): - if self.is_loaded: - for i in range(len(self.dots_list)): - self.ax.add_patch(self.dots_list[i]) - plt.draw() - for i in range(len(self.line_artist_list)): - if self.line_artist_list[i] is not None: - self.ax.add_line(self.line_artist_list[i]) - plt.draw() - def save_path(event): - def c_to_string(circle): - return f"{circle.center[0]}, {circle.center[1]}, {circle.radius}" - def l_to_string(line): - x_data, y_data = line.get_data() - linewidth = line.get_linewidth() - return f"{x_data[0]}, {y_data[0]}, {x_data[1]},{y_data[1]},{linewidth}" - # self.positions= np.array([[0]*4]*len(self.line_artist_list)) - # for position, line in enumerate(self.line_artist_list): - # if line is not None: - # line.set_linewidth(self.linewidth+1) - # x1_pixel=int((line.get_data()[0][1] - self.axis[0][0]) / (self.axis[0][-1] - self.axis[0][0]) * (self.axis[0].shape[0] - 1) + 0.5) - # y1_pixel=int((line.get_data()[1][1] - self.axis[1][0]) / (self.axis[1][-1] - self.axis[1][0]) * (self.axis[1].shape[0] - 1) + 0.5) - # x2_pixel=int((line.get_data()[0][0] - self.axis[0][0]) / (self.axis[0][-1] - self.axis[0][0]) * (self.axis[0].shape[0] - 1) + 0.5) - # y2_pixel=int((line.get_data()[1][0] - self.axis[1][0]) / (self.axis[1][-1] - self.axis[1][0]) * (self.axis[1].shape[0] - 1) + 0.5) - # self.positions[position][0] - output_directory = "C:/Users/admin-nisel131/Documents/CVS_TR_flatband_fig/" - current_time = datetime.now() - current_time_str = current_time.strftime("%Y-%m-%d_%H%M%S") - file_name = "k_path" - output_path = f"{output_directory}/{file_name}_{current_time_str}.txt" - with open(output_path, "w",newline="") as file: - file.write(f"{self.number}" + "\n") - for circle in self.dots_list: - file.write(c_to_string(circle) + "\n") - for line in self.line_artist_list: - if line is not None: - file.write(l_to_string(line) + "\n") - def load_path(event): - self.fig.canvas.mpl_disconnect(self.cid_press) - circle_list=[] - line_list=[] - file_path1="C:/Users/admin-nisel131/Documents/CVS_TR_flatband_fig/" - # file="k_path_2023-10-06_153243.txt" - # file= "k_path_2023-10-10_221437.txt" - # file= "k_path_2024-04-03_125248.txt" - file= "k_path_2024-04-03_140548.txt" - - - file_path=file_path1+file - with open(file_path, "r") as file: - lines=file.readlines() - # print(lines) - # for line_number, line in enumerate(file, start=1): - for line_number, line in enumerate(lines, start =1): - # if line_number==2: - # a,b,c= map(float, line.strip().split(', ')) - # print(a,b,c) - # print(map(float, line.strip().split(', '))) - # print('linenumber=',line_number) - # Split the line into individual values - # if line_number==1: - # print('firstline',line_number) - # number=7 - # first_line = file.readline().strip() # Read and strip whitespace - # print(line) - # first_line = file.readline() - - # number= int(first_line) - # print(number) - # print(first_line) - # print() - if line_number==1: - number= int(line) - # print(number) - elif line_number>=2 and line_number<=number+1: - x, y, radius = map(float, line.strip().split(', ')) - # print(x,y,radius) - circle = Circle((x, y), radius=radius, color='yellow', picker=20) - circle_list.append(circle) - self.dots_list=circle_list - else: - x0,y0,x1,y1,lw=map(float, line.strip().split(',')) - line1=Line2D([x0,x1], [y0, y1], linewidth=lw, color='red') - line_list.append(line1) - self.line_artist_list=line_list - self.is_loaded= True - self.dots_count=2 - # draw_load() - # print(len(self.line_artist_list),len(self.dots_list)) - - # print(x0,y0,x1,y1,lw) - # on_checkbox_change('K path') - - - self.slider_dcut.on_changed(update_dcut) - self.fig.canvas.mpl_connect('pick_event', on_pick) - self.fig.canvas.mpl_connect('motion_notify_event', on_motion) - self.fig.canvas.mpl_connect('button_release_event', on_release) - - rax_line = self.fig.add_axes([0.45, 0.02, 0.06, 0.03]) # [left, bottom, width, height] - self.cb_line = CheckButtons(rax_line, ['K path'], [False]) - self.cb_line.on_clicked(on_checkbox_change) - - rax_button = self.fig.add_axes([0.52, 0.02, 0.06, 0.03]) - self.button_update = Button(rax_button, 'new k') - self.button_update.on_clicked(new) - - savepath_button = self.fig.add_axes([0.52, 0.06, 0.06, 0.03]) - self.button_save = Button(savepath_button, 'save k-path') - self.button_save.on_clicked(save_path) - - loadpath_button = self.fig.add_axes([0.45, 0.06, 0.06, 0.03]) - self.button_load = Button(loadpath_button, 'load k-path') - self.button_load.on_clicked(load_path) - - addpath_button = self.fig.add_axes([0.37, 0.06, 0.06, 0.03]) - self.button_add = Button(addpath_button, 'add k-path') - self.button_add.on_clicked(add_path) - - plt.show() - self.fig.canvas.draw() - \ No newline at end of file diff --git a/tests/make_model.py b/tests/make_model.py deleted file mode 100644 index 940b1e2..0000000 --- a/tests/make_model.py +++ /dev/null @@ -1,70 +0,0 @@ -import sys -from PyQt5.QtGui import QBrush, QColor -from PyQt5.QtWidgets import QTextEdit, QApplication, QMainWindow, QVBoxLayout, QHBoxLayout, QWidget, QSlider, QLabel, QAction, QCheckBox, QPushButton, QListWidget, QTableWidget, QTableWidgetItem, QTableWidget, QCheckBox, QSplitter -from PyQt5.QtCore import Qt -from PyQt5.QtWidgets import QTableWidgetItem, QHBoxLayout, QCheckBox, QWidget -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -import matplotlib.pyplot as plt - - - -class make_model: - # from matplotlib.widgets import CheckButtons, Button - # %matplotlib qt - - def __init__(self,mod,table_widget): - - self.mod=mod - self.params=mod.make_params() - print('otherpalce',self.params) - print('thefuuuuTable',table_widget) - print('count',table_widget.rowCount()) - for row in range(table_widget.rowCount()): - item = table_widget.item(row, 1) - checkbox_widget = table_widget.cellWidget(row, 3) - print('tableitenm=',item) - if item is not None and item.text().strip(): - header_item = table_widget.verticalHeaderItem(item.row()) - checkbox=checkbox_widget.findChild(QCheckBox) - print(header_item.text(),item.text()) - if header_item.text()== "Fermi level": - self.params['mu'].set(value=float(item.text())) - if table_widget.item(row, 0) is not None: - self.params['mu'].set(min=float(table_widget.item(row, 0).text())) - if table_widget.item(row, 2) is not None: - self.params['mu'].set(max=float(table_widget.item(row, 2).text())) - if checkbox.isChecked(): - self.params['mu'].vary = False - - elif header_item.text()== "Temperature": - self.params['T'].set(value=float(item.text())) - if table_widget.item(row, 0) is not None: - self.params['T'].set(min=float(table_widget.item(row, 0).text())) - if table_widget.item(row, 2) is not None: - self.params['T'].set(max=float(table_widget.item(row, 2).text())) - if checkbox.isChecked(): - self.params['T'].vary = False - elif header_item.text()== "sigma": - self.params['sigma'].set(value=float(item.text())) - self.params['sigma'].set(min=0) - if table_widget.item(row, 0) is not None: - self.params['sigma'].set(min=float(table_widget.item(row, 0).text())) - if table_widget.item(row, 2) is not None: - self.params['sigma'].set(max=float(table_widget.item(row, 2).text())) - if checkbox.isChecked(): - self.params['sigma'].vary = False - else: - self.params[header_item.text()].set(value=float(item.text())) - if table_widget.item(row, 0) is not None: - self.params[header_item.text()].set(min=float(table_widget.item(row, 0).text())) - if table_widget.item(row, 2) is not None: - self.params[header_item.text()].set(max=float(table_widget.item(row, 2).text())) - if checkbox.isChecked(): - self.params[header_item.text()].vary = False - - - def current_model(self): - return self.mod - def current_params(self): - return self.params - \ No newline at end of file diff --git a/tests/movable_vertical_cursors_graph.py b/tests/movable_vertical_cursors_graph.py deleted file mode 100644 index 580f4a8..0000000 --- a/tests/movable_vertical_cursors_graph.py +++ /dev/null @@ -1,77 +0,0 @@ -# movable_cursors.py - -import numpy as np -import matplotlib.pyplot as plt - -class MovableCursors: - def __init__(self, ax): - self.ax = ax - line = self.ax.lines[0] - self.active_cursor=None - - self.axis = line.get_xdata() - - self.cursorlinev1=self.axis[int(len(self.axis)/4)] - self.cursorlinev2=self.axis[int(3*len(self.axis)/4)] - # Create initial cursors (at the middle of the plot) - # self.v1_cursor = self.ax.axvline(x=5, color='r', linestyle='--', label='Cursor X') - # self.v2_cursor = self.ax.axhline(y=0, color='g', linestyle='--', label='Cursor Y') - - self.Line1=self.ax.axvline(x=self.cursorlinev1, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - self.Line2=self.ax.axvline(x=self.cursorlinev2, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - - # Connect mouse events for the canvas of the axes - self.ax.figure.canvas.mpl_connect('pick_event', self.on_pick) - self.ax.figure.canvas.mpl_connect('motion_notify_event', self.on_motion) - self.ax.figure.canvas.mpl_connect('button_release_event', self.on_release) - - def on_pick(self,event): - - if event.artist == self.Line1: - self.active_cursor =self.Line1 - elif event.artist == self.Line2: - self.active_cursor =self.Line2 - # self.active_cursor=None - def on_motion(self,event): - - if self.active_cursor is not None and event.inaxes == self.ax: - if self.active_cursor == self.Line1: - self.Line1.set_xdata([event.xdata, event.xdata]) - self.cursorlinev1= event.xdata - elif self.active_cursor == self.Line2: - self.Line2.set_xdata([event.xdata, event.xdata]) - self.cursorlinev2= event.xdata - # print(dot1.center) - # print(self.cursorlinev1,self.cursorlinev2) - self.ax.figure.canvas.draw() - plt.draw() - def find_nearest_index(array, value): - idx = (np.abs(array - value)).argmin() - return idx - self.v1_pixel=find_nearest_index(self.axis, self.cursorlinev1) - self.v2_pixel=find_nearest_index(self.axis, self.cursorlinev2) - - # self.v1_pixel=int((self.cursorlinev1 - self.axis[0]) / (self.axis[-1] - self.axis[0]) * (self.axis.shape[0] - 1) + 0.5) - # self.v2_pixel=int((self.cursorlinev2 - self.axis[0]) / (self.axis[-1] - self.axis[0]) * (self.axis.shape[0] - 1) + 0.5) - print(self.v1_pixel,self.v2_pixel) - - # print(self.v1_pixel,self.v2_pixel) - def on_release(self,event): - # global self.active_cursor - self.active_cursor = None - def remove(self): - self.cursorlinev1= self.Line1.get_xdata()[0] - self.cursorlinev2= self.Line2.get_xdata()[0] - self.Line1.remove() - self.Line2.remove() - # plt.draw() - self.ax.figure.canvas.draw() - - def redraw(self): - # print(self.cursorlinev1,self.cursorlinev2) - self.Line1=self.ax.axvline(x=self.cursorlinev1, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - self.Line2=self.ax.axvline(x=self.cursorlinev2, color='red', linestyle='--',linewidth=2, label='Vertical Line',picker=10) - # plt.draw() - self.ax.figure.canvas.draw() - def cursors(self): - return [self.v1_pixel,self.v2_pixel] \ No newline at end of file diff --git a/tutorials/template.ipynb b/tutorials/template.ipynb new file mode 100644 index 0000000..f3b6d87 --- /dev/null +++ b/tutorials/template.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6d2e0046", + "metadata": {}, + "outputs": [], + "source": [ + "# import the 4D data\n", + "import numpy as np\n", + "import nxarray\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "from mpes_tools import show_4d_window\n", + "from mpes_tools import Gui_3d\n", + "from mpes_tools import fit_panel\n", + "from mpes_tools import ARPES_Analyser\n", + "\n", + "%gui qt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7a3e58c", + "metadata": {}, + "outputs": [], + "source": [ + "# Loading panel\n", + "ARPES_Analyser()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e923734", + "metadata": {}, + "outputs": [], + "source": [ + "# get data from the NOMAD repository\n", + "if not os.path.exists(\"Scan49_binned.nxs\"):\n", + " ! curl -o Scan49_binned.nxs https://nomad-lab.eu/prod/v1/oasis-b/api/v1/entries/MehgoizphpnxG_t-J0WGgbUZKlTp/raw/Scan49_binned.nxs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a046f68", + "metadata": {}, + "outputs": [], + "source": [ + "data_array = nxarray.load(\"Scan49_binned.nxs\").data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5aeb6fe2", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the 4D Gui\n", + "graph_4d = show_4d_window(data_array)\n", + "graph_4d.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f416bc6e", + "metadata": {}, + "outputs": [], + "source": [ + "data = data_array\n", + "# the energy plot\n", + "en = (\n", + " data.loc[\n", + " {\n", + " \"energy\": slice(-0.04999999999999982, 0.0),\n", + " \"delay\": slice(-0.00399999999999999, 0.014000000000000012),\n", + " }\n", + " ]\n", + " .mean(dim=(\"energy\", \"delay\"))\n", + " .T\n", + ")\n", + "fig, ax = plt.subplots(1, 1)\n", + "en.plot(ax=ax, cmap=\"terrain\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6a92293", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the 3D Gui\n", + "# select the 3D data\n", + "data = data_array.loc[\n", + " {\n", + " \"kx\": slice(0.48, 0.6800000000000002),\n", + " }\n", + "].mean(dim=(\"kx\"))\n", + "\n", + "# print(data.dims)\n", + "graph_window = Gui_3d(data)\n", + "graph_window.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c14ca2d1", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(1, 1)\n", + "data.loc[{data.dims[0]: slice(-0.97, -0.60), data.dims[1]: slice(0.99, 1.56)}].mean(\n", + " dim=(data.dims[0], data.dims[1])\n", + ").plot(ax=ax) # Box integration\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c78b3de", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the fit panel on the extracted data\n", + "data_edc = data.sel({data.dims[0]: slice(0.86, 1.08)}).mean(dim=data.dims[0])\n", + "graph_window = fit_panel(data_edc, 0, 5, \"\")\n", + "graph_window.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f46368fe", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5a0b003", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}