Skip to content

Commit

Permalink
live plotting for MC with interpolation WIP do not merge
Browse files Browse the repository at this point in the history
  • Loading branch information
AdriaanRol committed May 14, 2018
1 parent b7cea14 commit c5136d1
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 29 deletions.
67 changes: 67 additions & 0 deletions pycqed/analysis/tools/plot_interpolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
import logging
from scipy import interpolate

def areas(ip):
p = ip.tri.points[ip.tri.vertices]
q = p[:, :-1, :] - p[:, -1, None, :]
areas = abs(q[:, 0, 0] * q[:, 1, 1] - q[:, 0, 1] * q[:, 1, 0]) / 2
return areas


def scale(points, xy_mean, xy_scale):
points = np.asarray(points, dtype=float)
return (points - xy_mean) / xy_scale


def unscale(points, xy_mean, xy_scale):
points = np.asarray(points, dtype=float)
return points * xy_scale + xy_mean

def interpolate_heatmap(x, y, z, n=None):
"""
Args:
Returns:
x_grid : N*1 array of x-values of the interpolated grid
y_grid : N*1 array of x-values of the interpolated grid
z_grid : N*N array of z-values that form a grid.
The output of this method can directly be used for
plt.imshow(z_grid, extent=extent, aspect='auto')
where the extent is determined by the min and max of the x_grid and
y_grid
"""

points = list(zip(x, y))
lbrt = np.min(points, axis=0), np.max(points, axis=0)
lbrt = lbrt[0][0], lbrt[0][1], lbrt[1][0], lbrt[1][1]

xy_mean = np.mean([lbrt[0], lbrt[2]]), np.mean([lbrt[1], lbrt[3]])
xy_scale = np.ptp([lbrt[0], lbrt[2]]), np.ptp([lbrt[1], lbrt[3]])

# interpolation needs to happen on a rescaled grid, this is somewhat akin to an
# assumption in the interpolation that the scale of the experiment is chosen sensibly.
ip = interpolate.LinearNDInterpolator(scale(points, xy_mean=xy_mean, xy_scale=xy_scale),
z)

if n is None:
# Calculate how many grid points are needed.
# factor from A=√3/4 * a² (equilateral triangle)
n = int(0.658 / np.sqrt(areas(ip).min()))
n = max(n, 10)
if n > 500:
logging.warning('n: {} larger than 500'.format(n))
n=500

x_lin = y_lin = np.linspace(-0.5, 0.5, n)
# Interpolation is evaulated linearly in the domain for interpolation
z_grid = ip(x_lin[:, None], y_lin[None, :]).squeeze()

# x and y grid points need to be rescaled from the linearly chosen points
points_grid = unscale(list(zip(x_lin, y_lin)), xy_mean=xy_mean, xy_scale=xy_scale)
x_grid = points_grid[:, 0]
y_grid = points_grid[:, 1]


return x_grid, y_grid, (z_grid).T
7 changes: 6 additions & 1 deletion pycqed/instrument_drivers/virtual_instruments/mock_device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import time
from pycqed.analysis.fitting_models import hanger_func_complex_SI
from qcodes.instrument.base import Instrument
from qcodes.utils.validators import Numbers, Enum, Ints
Expand Down Expand Up @@ -54,13 +55,17 @@ def __init__(self, name, **kw):
'magn_phase', 'magn'),
parameter_class=ManualParameter)

self.add_parameter('acq_delay', initial_value=0,
unit='s',
parameter_class=ManualParameter)

self.add_parameter('cw_noise_level', initial_value=0,
parameter_class=ManualParameter)

self.add_parameter('S21', unit='V', get_cmd=self.measure_transmission)

def measure_transmission(self):

time.sleep(self.acq_delay())
# TODO: add attenuation and gain
transmission = dBm_to_Vpeak(self.mw_pow())

Expand Down
117 changes: 89 additions & 28 deletions pycqed/measurement/measurement_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from pycqed.measurement.mc_parameter_wrapper import wrap_par_to_det
from pycqed.analysis.tools.data_manipulation import get_generation_means

from pycqed.analysis.tools.plot_interpolation import interpolate_heatmap

from qcodes.instrument.base import Instrument
from qcodes.instrument.parameter import ManualParameter
from qcodes.utils import validators as vals
Expand All @@ -30,7 +32,7 @@
print('Could not import msvcrt (used for detecting keystrokes)')

try:
from qcodes.plots.pyqtgraph import QtPlot
from qcodes.plots.pyqtgraph import QtPlot, TransformState
except Exception:
print('pyqtgraph plotting not supported, '
'try "from qcodes.plots.pyqtgraph import QtPlot" '
Expand Down Expand Up @@ -264,7 +266,6 @@ def measure_soft_adaptive(self, method=None):
self.save_optimization_settings()
self.adaptive_function = self.af_pars.pop('adaptive_function')
if self.live_plot_enabled():
# self.initialize_plot_monitor()
self.initialize_plot_monitor_adaptive()
for sweep_function in self.sweep_functions:
sweep_function.prepare()
Expand Down Expand Up @@ -703,6 +704,63 @@ def update_plotmon_2D(self, force_update=False):
except Exception as e:
logging.warning(e)

def initialize_plot_monitor_2D_interp(self):
"""
Initialize a 2D plot monitor for interpolated (adaptive) plots
"""
if self.live_plot_enabled() and len(self.sweep_function_names) ==2:
self.time_last_2Dplot_update = time.time()

self.secondary_QtPlot.clear()
slabels = self.sweep_par_names
sunits = self.sweep_par_units
zlabels = self.detector_function.value_names
zunits = self.detector_function.value_units

for j in range(len(self.detector_function.value_names)):
self.secondary_QtPlot.add(x=[0, 1],
y=[0, 1],
z=np.zeros([2,2]),
xlabel=slabels[0], xunit=sunits[0],
ylabel=slabels[1], yunit=sunits[1],
zlabel=zlabels[j], zunit=zunits[j],
subplot=j+1,
cmap='viridis')

def update_plotmon_2D_interp(self, force_update=False):
'''
Updates the interpolated 2D heatmap
'''
if self.live_plot_enabled() and len(self.sweep_function_names) ==2:
# try:
if (time.time() - self.time_last_2Dplot_update >
self.plotting_interval() or force_update):
# exists to force reset the x- and y-axis scale
new_sc = TransformState(0, 1, True)

x_vals = self.dset[:, 0]
y_vals = self.dset[:, 1]
for j in range(len(self.detector_function.value_names)):
z_ind = len(self.sweep_functions) + j
z_vals = self.dset[:, z_ind]

# Interpolate points
x_grid, y_grid, z_grid = interpolate_heatmap(
x_vals, y_vals, z_vals)
trace = self.secondary_QtPlot.traces[j]
trace['config']['x'] = x_grid
trace['config']['y'] = y_grid
trace['config']['z'] = z_grid
# force rescale the axes
trace['plot_object']['scales']['x'] = new_sc
trace['plot_object']['scales']['y'] = new_sc

self.time_last_2Dplot_update = time.time()
self.secondary_QtPlot.update_plot()
# except Exception as e:
# logging.warning(e)


def initialize_plot_monitor_adaptive(self):
'''
Uses the Qcodes plotting windows for plotting adaptive plot updates
Expand All @@ -711,41 +769,44 @@ def initialize_plot_monitor_adaptive(self):
return self.initialize_plot_monitor_adaptive_cma()
else:
self.initialize_plot_monitor()
# init plotmon 2d interp checks if there are 2 sweep funcs
self.initialize_plot_monitor_2D_interp()
self.time_last_ad_plot_update = time.time()
self.secondary_QtPlot.clear()
# self.secondary_QtPlot.clear()

zlabels = self.detector_function.value_names
zunits = self.detector_function.value_units
# zlabels = self.detector_function.value_names
# zunits = self.detector_function.value_units

for j in range(len(self.detector_function.value_names)):
self.secondary_QtPlot.add(x=[0],
y=[0],
xlabel='iteration',
ylabel=zlabels[j],
yunit=zunits[j],
subplot=j+1,
symbol='o', symbolSize=5)
# for j in range(len(self.detector_function.value_names)):
# self.secondary_QtPlot.add(x=[0],
# y=[0],
# xlabel='iteration',
# ylabel=zlabels[j],
# yunit=zunits[j],
# subplot=j+1,
# symbol='o', symbolSize=5)

def update_plotmon_adaptive(self, force_update=False):
if self.adaptive_function.__module__ == 'cma.evolution_strategy':
return self.update_plotmon_adaptive_cma(force_update=force_update)
else:
self.update_plotmon(force_update=force_update)

if self.live_plot_enabled():
try:
if (time.time() - self.time_last_ad_plot_update >
self.plotting_interval() or force_update):
for j in range(len(self.detector_function.value_names)):
y_ind = len(self.sweep_functions) + j
y = self.dset[:, y_ind]
x = range(len(y))
self.secondary_QtPlot.traces[j]['config']['x'] = x
self.secondary_QtPlot.traces[j]['config']['y'] = y
self.time_last_ad_plot_update = time.time()
self.secondary_QtPlot.update_plot()
except Exception as e:
logging.warning(e)
self.update_plotmon_2D_interp(force_update=force_update)

# if self.live_plot_enabled():
# try:
# if (time.time() - self.time_last_ad_plot_update >
# self.plotting_interval() or force_update):
# for j in range(len(self.detector_function.value_names)):
# y_ind = len(self.sweep_functions) + j
# y = self.dset[:, y_ind]
# x = range(len(y))
# self.secondary_QtPlot.traces[j]['config']['x'] = x
# self.secondary_QtPlot.traces[j]['config']['y'] = y
# self.time_last_ad_plot_update = time.time()
# self.secondary_QtPlot.update_plot()
# except Exception as e:
# logging.warning(e)

def initialize_plot_monitor_adaptive_cma(self):
'''
Expand Down

0 comments on commit c5136d1

Please sign in to comment.