Skip to content

Commit

Permalink
Merge changes from thij's branch #455
Browse files Browse the repository at this point in the history
  • Loading branch information
Aypac committed Apr 13, 2018
1 parent d9ed43d commit d4b8878
Showing 1 changed file with 54 additions and 17 deletions.
71 changes: 54 additions & 17 deletions pycqed/analysis_v2/base_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@

import numbers
from matplotlib import pyplot as plt
from matplotlib import cm
from pycqed.analysis import analysis_toolbox as a_tools
from pycqed.utilities.general import NumpyJsonEncoder
from pycqed.analysis.analysis_toolbox import get_color_order as gco
from pycqed.analysis.analysis_toolbox import get_color_list
from pycqed.analysis.tools.plotting import set_xlabel, set_ylabel
from pycqed.analysis.tools.plotting import (
flex_colormesh_plot_vs_xy, flex_color_plot_vs_x)
# import pycqed.analysis_v2.default_figure_settings_analysis as def_fig
from . import default_figure_settings_analysis as def_fig
from mpl_toolkits.axes_grid1 import make_axes_locatable
import datetime
import json
Expand Down Expand Up @@ -191,7 +188,10 @@ def run_analysis(self):
self.prepare_plots() # specify default plots
if not self.extract_only:
self.plot(key_list='auto') # make the plots
self.save_figures(close_figs=self.options_dict['close_figs'])


if self.options_dict.get('save_figs', False):
self.save_figures(close_figs=self.options_dict.get('close_figs', False))

def get_timestamps(self):
"""
Expand Down Expand Up @@ -306,6 +306,23 @@ def extract_data(self):
self.raw_data_dict[
'measured_values_ord_dict'] = measured_values_dict

def extract_data_json(self):
file_name = self.t_start
with open(file_name, 'r') as f:
raw_data_dict = json.load(f)
# print [[key, type(val[0]), len(val)] for key, val in
# raw_data_dict.items()]
self.raw_data_dict = {}
for key, val in list(raw_data_dict.items()):
if type(val[0]) is dict:
self.raw_data_dict[key] = val[0]
else:
self.raw_data_dict[key] = np.double(val)
# print [[key, type(val), len(val)] for key, val in
# self.raw_data_dict.items()]
self.raw_data_dict['timestamps'] = [self.t_start]


def process_data(self):
"""
process_data: overloaded in child classes,
Expand Down Expand Up @@ -347,14 +364,14 @@ def save_figures(self, savedir: str=None, savebase: str =None,
for key in key_list:
if self.presentation_mode:
savename = os.path.join(savedir, savebase+key+tstag+'presentation'+'.'+fmt)
self.axs[key].figure.savefig(savename, bbox_inches='tight', fmt=fmt)
self.figs[key].savefig(savename, bbox_inches='tight', fmt=fmt)
savename = os.path.join(savedir, savebase+key+tstag+'presentation'+'.svg')
self.axs[key].figure.savefig(savename, bbox_inches='tight', fmt='svg')
self.figs[key].savefig(savename, bbox_inches='tight', fmt='svg')
else:
savename = os.path.join(savedir, savebase+key+tstag+'.'+fmt)
self.axs[key].figure.savefig(savename, bbox_inches='tight', fmt=fmt)
self.figs[key].savefig(savename, bbox_inches='tight', fmt=fmt)
if close_figs:
plt.close(self.axs[key].figure)
plt.close(self.figs[key])

def save_data(self, savedir: str=None, savebase: str=None,
tag_tstamp: bool=True,
Expand Down Expand Up @@ -481,7 +498,7 @@ def save_fit_results(self):
# Delete the old group and create a new group (overwrite).
del analysis_group[fr_key]
fr_group = analysis_group.create_group(fr_key)

# TODO: convert the params object to a simple dict
# write_dict_to_hdf5(fit_res.params, entry_point=fr_group)
write_dict_to_hdf5(fit_res.best_values, entry_point=fr_group)
Expand Down Expand Up @@ -521,7 +538,8 @@ def plot(self, key_list=None, axs_dict=None,
pdict.get('numplotsy', 1), pdict.get('numplotsx', 1),
sharex=pdict.get('sharex', False),
sharey=pdict.get('sharey', False),
figsize=pdict.get('plotsize', None)) # (8, 6)))
figsize=pdict.get('plotsize', None) #plotsize None uses .rc_default of matplotlib
)

# transparent background around axes for presenting data
self.figs[pdict['ax_id']].patch.set_alpha(0)
Expand All @@ -531,10 +549,20 @@ def plot(self, key_list=None, axs_dict=None,
else:
for key in key_list:
pdict = self.plot_dicts[key]

plot_id_y = pdict.get('plot_id_y', None)
plot_id_x = pdict.get('plot_id_x', None)
plot_touching = pdict.get('touching', False)

if type(pdict['plotfn']) is str:
plotfn = getattr(self, pdict['plotfn'])
else:
plotfn = pdict['plotfn']

# used to ensure axes are touching
if plot_touching:
self.axs[pdict['ax_id']].figure.subplots_adjust(wspace=0, hspace=0)

# ensures the argument convention is preserved
if hasattr(self, plotfn.__name__):
plotfn(pdict, axs=self.axs[pdict['ax_id']])
Expand Down Expand Up @@ -584,6 +612,7 @@ def plot_bar(self, pdict, axs):
dataset_desc = pdict.get('setdesc', '')
dataset_label = pdict.get('setlabel', list(range(len(plot_yvals))))
do_legend = pdict.get('do_legend', False)
plot_touching = pdict.get('touching', False)

plot_xwidth = (plot_xedges[1:]-plot_xedges[:-1])
# center is left edge + widht /2
Expand All @@ -594,8 +623,7 @@ def plot_bar(self, pdict, axs):
for ii, this_yvals in enumerate(plot_yvals):
p_out.append(pfunc(plot_centers, this_yvals, width=plot_xwidth,
color=gco(ii, len(plot_yvals)-1),
label='%s%s' % (
dataset_desc, dataset_label[ii]),
label='%s%s' % (dataset_desc, dataset_label[ii]),
**plot_barkws))

else:
Expand Down Expand Up @@ -624,6 +652,9 @@ def plot_bar(self, pdict, axs):
legend_pos = pdict.get('legend_pos', 'best')
axs.legend(title=legend_title, loc=legend_pos, ncol=legend_ncol)

if plot_touching:
axs.figure.subplots_adjust(wspace=0, hspace=0)

if self.tight_fig:
axs.figure.tight_layout()

Expand Down Expand Up @@ -924,7 +955,7 @@ def plot_color2D(self, pfunc, pdict, axs):
clim=fig_clim, cmap=plot_cmap,
xvals=traces['xvals'][tt],
yvals=traces['yvals'][tt],
zvals=traces['zvals'][tt], # .transpose(),
zvals=traces['zvals'][tt],
transpose=plot_transpose,
normalize=plot_normalize)

Expand All @@ -946,10 +977,16 @@ def plot_color2D(self, pfunc, pdict, axs):

if plot_yrange is None:
if plot_xwidth is not None:
ymin, ymax = min([min(yvals[0])
for tt, yvals in enumerate(plot_yvals)]), \
max([max(yvals[0])
for tt, yvals in enumerate(plot_yvals)])
ymin_list, ymax_list = [], []
for ytraces in block['yvals']:
ymin_trace, ymax_trace = [], []
for yvals in ytraces:
ymin_trace.append(min(yvals))
ymax_trace.append(max(yvals))
ymin_list.append(min(ymin_trace))
ymax_list.append(max(ymax_trace))
ymin = min(ymin_list)
ymax = max(ymax_list)
else:
ymin = np.min(plot_yvals) - plot_yvals_step / 2.
ymax = np.max(plot_yvals) + plot_yvals_step/2.
Expand Down

0 comments on commit d4b8878

Please sign in to comment.