diff --git a/pycqed/analysis_v2/base_analysis.py b/pycqed/analysis_v2/base_analysis.py index e7a7fcaaf1..7b1afb7e71 100644 --- a/pycqed/analysis_v2/base_analysis.py +++ b/pycqed/analysis_v2/base_analysis.py @@ -8,7 +8,6 @@ import numbers from matplotlib import pyplot as plt -from matplotlib import cm from pycqed.analysis import analysis_toolbox as a_tools from pycqed.utilities.general import NumpyJsonEncoder from pycqed.analysis.analysis_toolbox import get_color_order as gco @@ -16,8 +15,6 @@ from pycqed.analysis.tools.plotting import set_xlabel, set_ylabel from pycqed.analysis.tools.plotting import ( flex_colormesh_plot_vs_xy, flex_color_plot_vs_x) -# import pycqed.analysis_v2.default_figure_settings_analysis as def_fig -from . import default_figure_settings_analysis as def_fig from mpl_toolkits.axes_grid1 import make_axes_locatable import datetime import json @@ -191,7 +188,10 @@ def run_analysis(self): self.prepare_plots() # specify default plots if not self.extract_only: self.plot(key_list='auto') # make the plots - self.save_figures(close_figs=self.options_dict['close_figs']) + + + if self.options_dict.get('save_figs', False): + self.save_figures(close_figs=self.options_dict.get('close_figs', False)) def get_timestamps(self): """ @@ -306,6 +306,23 @@ def extract_data(self): self.raw_data_dict[ 'measured_values_ord_dict'] = measured_values_dict + def extract_data_json(self): + file_name = self.t_start + with open(file_name, 'r') as f: + raw_data_dict = json.load(f) + # print [[key, type(val[0]), len(val)] for key, val in + # raw_data_dict.items()] + self.raw_data_dict = {} + for key, val in list(raw_data_dict.items()): + if type(val[0]) is dict: + self.raw_data_dict[key] = val[0] + else: + self.raw_data_dict[key] = np.double(val) + # print [[key, type(val), len(val)] for key, val in + # self.raw_data_dict.items()] + self.raw_data_dict['timestamps'] = [self.t_start] + + def process_data(self): """ process_data: overloaded in child classes, @@ -347,14 +364,14 @@ def save_figures(self, savedir: str=None, savebase: str =None, for key in key_list: if self.presentation_mode: savename = os.path.join(savedir, savebase+key+tstag+'presentation'+'.'+fmt) - self.axs[key].figure.savefig(savename, bbox_inches='tight', fmt=fmt) + self.figs[key].savefig(savename, bbox_inches='tight', fmt=fmt) savename = os.path.join(savedir, savebase+key+tstag+'presentation'+'.svg') - self.axs[key].figure.savefig(savename, bbox_inches='tight', fmt='svg') + self.figs[key].savefig(savename, bbox_inches='tight', fmt='svg') else: savename = os.path.join(savedir, savebase+key+tstag+'.'+fmt) - self.axs[key].figure.savefig(savename, bbox_inches='tight', fmt=fmt) + self.figs[key].savefig(savename, bbox_inches='tight', fmt=fmt) if close_figs: - plt.close(self.axs[key].figure) + plt.close(self.figs[key]) def save_data(self, savedir: str=None, savebase: str=None, tag_tstamp: bool=True, @@ -481,7 +498,7 @@ def save_fit_results(self): # Delete the old group and create a new group (overwrite). del analysis_group[fr_key] fr_group = analysis_group.create_group(fr_key) - + # TODO: convert the params object to a simple dict # write_dict_to_hdf5(fit_res.params, entry_point=fr_group) write_dict_to_hdf5(fit_res.best_values, entry_point=fr_group) @@ -521,7 +538,8 @@ def plot(self, key_list=None, axs_dict=None, pdict.get('numplotsy', 1), pdict.get('numplotsx', 1), sharex=pdict.get('sharex', False), sharey=pdict.get('sharey', False), - figsize=pdict.get('plotsize', None)) # (8, 6))) + figsize=pdict.get('plotsize', None) #plotsize None uses .rc_default of matplotlib + ) # transparent background around axes for presenting data self.figs[pdict['ax_id']].patch.set_alpha(0) @@ -531,10 +549,20 @@ def plot(self, key_list=None, axs_dict=None, else: for key in key_list: pdict = self.plot_dicts[key] + + plot_id_y = pdict.get('plot_id_y', None) + plot_id_x = pdict.get('plot_id_x', None) + plot_touching = pdict.get('touching', False) + if type(pdict['plotfn']) is str: plotfn = getattr(self, pdict['plotfn']) else: plotfn = pdict['plotfn'] + + # used to ensure axes are touching + if plot_touching: + self.axs[pdict['ax_id']].figure.subplots_adjust(wspace=0, hspace=0) + # ensures the argument convention is preserved if hasattr(self, plotfn.__name__): plotfn(pdict, axs=self.axs[pdict['ax_id']]) @@ -584,6 +612,7 @@ def plot_bar(self, pdict, axs): dataset_desc = pdict.get('setdesc', '') dataset_label = pdict.get('setlabel', list(range(len(plot_yvals)))) do_legend = pdict.get('do_legend', False) + plot_touching = pdict.get('touching', False) plot_xwidth = (plot_xedges[1:]-plot_xedges[:-1]) # center is left edge + widht /2 @@ -594,8 +623,7 @@ def plot_bar(self, pdict, axs): for ii, this_yvals in enumerate(plot_yvals): p_out.append(pfunc(plot_centers, this_yvals, width=plot_xwidth, color=gco(ii, len(plot_yvals)-1), - label='%s%s' % ( - dataset_desc, dataset_label[ii]), + label='%s%s' % (dataset_desc, dataset_label[ii]), **plot_barkws)) else: @@ -624,6 +652,9 @@ def plot_bar(self, pdict, axs): legend_pos = pdict.get('legend_pos', 'best') axs.legend(title=legend_title, loc=legend_pos, ncol=legend_ncol) + if plot_touching: + axs.figure.subplots_adjust(wspace=0, hspace=0) + if self.tight_fig: axs.figure.tight_layout() @@ -924,7 +955,7 @@ def plot_color2D(self, pfunc, pdict, axs): clim=fig_clim, cmap=plot_cmap, xvals=traces['xvals'][tt], yvals=traces['yvals'][tt], - zvals=traces['zvals'][tt], # .transpose(), + zvals=traces['zvals'][tt], transpose=plot_transpose, normalize=plot_normalize) @@ -946,10 +977,16 @@ def plot_color2D(self, pfunc, pdict, axs): if plot_yrange is None: if plot_xwidth is not None: - ymin, ymax = min([min(yvals[0]) - for tt, yvals in enumerate(plot_yvals)]), \ - max([max(yvals[0]) - for tt, yvals in enumerate(plot_yvals)]) + ymin_list, ymax_list = [], [] + for ytraces in block['yvals']: + ymin_trace, ymax_trace = [], [] + for yvals in ytraces: + ymin_trace.append(min(yvals)) + ymax_trace.append(max(yvals)) + ymin_list.append(min(ymin_trace)) + ymax_list.append(max(ymax_trace)) + ymin = min(ymin_list) + ymax = max(ymax_list) else: ymin = np.min(plot_yvals) - plot_yvals_step / 2. ymax = np.max(plot_yvals) + plot_yvals_step/2.