From bb2e377c4e6ac24e3872511aaf0a29e613be057e Mon Sep 17 00:00:00 2001 From: Graham Rowlands Date: Tue, 15 Aug 2017 23:43:28 -0400 Subject: [PATCH] Deferred eval working with, e.g., RabiWidth --- QGL/ChannelLibrary.py | 14 ++++++++++++++ QGL/Compiler.py | 13 +++++++------ QGL/PulsePrimitives.py | 41 +++++++++++++++++++++++++++++------------ 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/QGL/ChannelLibrary.py b/QGL/ChannelLibrary.py index bf365d86..936bd77d 100644 --- a/QGL/ChannelLibrary.py +++ b/QGL/ChannelLibrary.py @@ -32,6 +32,7 @@ import re import traceback import importlib +from functools import wraps from atom.api import Atom, Str, Int, Typed import networkx as nx import yaml @@ -454,6 +455,15 @@ def on_awg_change(self, oldName, newName): print("Changing {0} to {1}".format(chName, newLabel)) self.physicalChannelManager.name_changed(chName, newLabel) +def _defer_factory(factFunc): + @wraps(factFunc) + def defer(*args, **kwargs): + if kwargs: + return lambda: factFunc(*args, **kwargs) + return lambda: factFunc(*args) + return defer + +@_defer_factory def MarkerFactory(label, **kwargs): '''Return a marker channel by name. Must be defined under top-level `markers` keyword in measurement configuration YAML. @@ -464,8 +474,10 @@ def MarkerFactory(label, **kwargs): else: raise ValueError("Marker channel {} not found in channel library.".format(label)) +@_defer_factory def QubitFactory(label, **kwargs): ''' Return a saved qubit channel or create a new one. ''' + print("Running the factory") if channelLib and label in channelLib and isinstance(channelLib[label], Channels.Qubit): return channelLib[label] @@ -473,6 +485,7 @@ def QubitFactory(label, **kwargs): return Channels.Qubit(label=label, **kwargs) +@_defer_factory def MeasFactory(label, meas_type='autodyne', **kwargs): ''' Return a saved measurement channel or create a new one. ''' if channelLib and label in channelLib and isinstance(channelLib[label], @@ -482,6 +495,7 @@ def MeasFactory(label, meas_type='autodyne', **kwargs): return Channels.Measurement(label=label, meas_type=meas_type, **kwargs) +@_defer_factory def EdgeFactory(source, target): if not channelLib: raise ValueError('Connectivity graph not found') diff --git a/QGL/Compiler.py b/QGL/Compiler.py index ed5fb2ab..0771092e 100644 --- a/QGL/Compiler.py +++ b/QGL/Compiler.py @@ -317,7 +317,8 @@ def compile_to_hardware(seqs, ''' logger.debug("Compiling %d sequence(s)", len(seqs)) - # Expand the + # Evaluate any objects that have been deferred + seqs = eval_deferred_pulses(seqs) # save input code to file save_code(seqs, fileName + suffix) @@ -627,7 +628,7 @@ def normalize(seq, channels=None): blocklen = block.length emptyChannels = channels - set(block.pulses.keys()) for ch in emptyChannels: - block.pulses[ch] = Id(ch, blocklen) + block.pulses[ch] = Id(ch, blocklen)() return seq class Waveform(object): @@ -713,12 +714,12 @@ def flatten_to_pulses(obj): # no padding element required return pulses elif alignment == "left": - return pulses + [Id(channel, padLength)] + return pulses + [Id(channel, padLength)()] elif alignment == "right": - return [Id(channel, padLength)] + pulses + return [Id(channel, padLength)()] + pulses else: # center - return [Id(channel, padLength / 2)] + pulses + [Id(channel, padLength / - 2)] + return [Id(channel, padLength / 2)()] + pulses + [Id(channel, padLength / + 2)()] def validate_linklist_channels(linklistChannels): diff --git a/QGL/PulsePrimitives.py b/QGL/PulsePrimitives.py index 62f264f2..b57163ea 100644 --- a/QGL/PulsePrimitives.py +++ b/QGL/PulsePrimitives.py @@ -41,6 +41,7 @@ def _memoize(pulseFunc): @wraps(pulseFunc) def cacheWrap(*args, **kwargs): + # import pdb; pdb.set_trace() if kwargs: return pulseFunc(*args, **kwargs) key = (pulseFunc, args) @@ -53,6 +54,16 @@ def cacheWrap(*args, **kwargs): def clear_pulse_cache(): _memoize.cache = {} +def _defer(pulseFunc): + @wraps(pulseFunc) + def defer(*args, **kwargs): + # import pdb; pdb.set_trace() + if kwargs: + return lambda: pulseFunc(*((args[0]() if callable(args[0]) else args[0],) + args[1:]), **kwargs) + return lambda: pulseFunc(*((args[0]() if callable(args[0]) else args[0],) + args[1:])) + return defer + +@_defer @_memoize def Id(channel, *args, **kwargs): ''' @@ -65,7 +76,7 @@ def Id(channel, *args, **kwargs): if len(args) > 0 and isinstance(args[0], (int, float)): params['length'] = args[0] - return lambda: TAPulse("Id", + return TAPulse("Id", channel, params['length'], 0, @@ -73,6 +84,7 @@ def Id(channel, *args, **kwargs): # the most generic pulse is Utheta +@_defer def Utheta(qubit, angle=0, phase=0, @@ -109,17 +121,18 @@ def Utheta(qubit, else: # linearly scale based upon the 'pi/2' amplitude amp = (angle / pi/2) * qubit.pulse_params['pi2Amp'] - return lambda: Pulse(label, qubit, params, amp, phase, 0.0, ignoredStrParams) + return Pulse(label, qubit, params, amp, phase, 0.0, ignoredStrParams) # generic pulses around X, Y, and Z axes +@_defer def Xtheta(qubit, angle=0, label='Xtheta', ignoredStrParams=None, **kwargs): ''' A generic X rotation with a variable rotation angle ''' if ignoredStrParams is None: ignoredStrParams = ['phase', 'frameChange'] else: ignoredStrParams += ['phase', 'frameChange'] - return Utheta(qubit, angle, 0, label, ignoredStrParams, **kwargs) + return Utheta(qubit, angle, 0, label, ignoredStrParams, **kwargs)() def Ytheta(qubit, angle=0, label='Ytheta', ignoredStrParams=None, **kwargs): @@ -130,14 +143,14 @@ def Ytheta(qubit, angle=0, label='Ytheta', ignoredStrParams=None, **kwargs): ignoredStrParams += ['phase', 'frameChange'] return Utheta(qubit, angle, pi/2, label, ignoredStrParams, **kwargs) - +@_defer def Ztheta(qubit, angle=0, label='Ztheta', ignoredStrParams=['amp', 'phase', 'length'], **kwargs): # special cased because it can be done with a frame update - return lambda: TAPulse(label, + return TAPulse(label, qubit, length=0, amp=0, @@ -147,13 +160,14 @@ def Ztheta(qubit, #Setup the default 90/180 rotations +@_defer @_memoize def X90(qubit, **kwargs): return Xtheta(qubit, pi/2, label="X90", ignoredStrParams=['amp'], - **kwargs) + **kwargs)() @_memoize def X90m(qubit, **kwargs): @@ -276,7 +290,7 @@ def Z90(qubit, **kwargs): def Z90m(qubit, **kwargs): return Ztheta(qubit, -pi / 2, label="Z90m", **kwargs) - +@_defer def arb_axis_drag(qubit, nutFreq, rotAngle=0, @@ -333,7 +347,7 @@ def arb_axis_drag(qubit, params['rotAngle'] = rotAngle params['polarAngle'] = polarAngle params['shape_fun'] = PulseShapes.arb_axis_drag - return lambda: Pulse(kwargs["label"] if "label" in kwargs else "ArbAxis", qubit, + return Pulse(kwargs["label"] if "label" in kwargs else "ArbAxis", qubit, params, 1.0, aziAngle, frameChange) @@ -726,13 +740,14 @@ def CNOT(source, target, **kwargs): return cnot_impl(source, target, **kwargs) ## Measurement operators +@_defer @_memoize def MEAS(qubit, **kwargs): ''' MEAS(q1) measures a qubit. Applies to the pulse with the label M-q1 ''' channelName = "M-" + qubit.label - measChan = ChannelLibrary.MeasFactory(channelName) + measChan = ChannelLibrary.MeasFactory(channelName)() params = overrideDefaults(measChan, kwargs) if measChan.meas_type == 'autodyne': params['frequency'] = measChan.autodyne_freq @@ -742,7 +757,7 @@ def MEAS(qubit, **kwargs): ignoredStrParams = ['phase', 'frameChange'] if 'amp' not in kwargs: ignoredStrParams.append('amp') - return lambda: Pulse("MEAS", measChan, params, amp, 0.0, 0.0, ignoredStrParams) + return Pulse("MEAS", measChan, params, amp, 0.0, 0.0, ignoredStrParams) #MEAS and ring-down time on one qubit, echo on every other @@ -775,13 +790,15 @@ def MeasEcho(qM, qD, delay, piShift=None, phase=0): return measEcho # Gating/blanking pulse primitives +@_defer def BLANK(chan, length): - return lambda: TAPulse("BLANK", chan.gate_chan, length, 1, 0, 0) + return TAPulse("BLANK", chan.gate_chan, length, 1, 0, 0) +@_defer def TRIG(marker_chan, length): '''TRIG(marker_chan, length) generates a trigger output of amplitude 1 on a LogicalMarkerChannel. ''' if not isinstance(marker_chan, Channels.LogicalMarkerChannel): raise ValueError("TRIG pulses can only be generated on LogicalMarkerChannels.") - return lambda: TAPulse("TRIG", marker_chan, length, 1.0, 0., 0.) + return TAPulse("TRIG", marker_chan, length, 1.0, 0., 0.)