Skip to content

Commit

Permalink
Adding save/load functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
grahamrow committed Jun 12, 2018
1 parent a85ad0a commit bb01929
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 80 deletions.
174 changes: 104 additions & 70 deletions QGL/ChannelLibraries.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,49 +54,60 @@ def set_from_dict(obj, settings):
except Exception as e:
print(f"{obj.label}: Error loading {prop_name} from config")

def copy_objs(chans, srcs):
def copy_objs(chans, srcs, new_channel_db):
new_chans = []
new_srcs = []
old_to_new_chan = {}
old_to_new_src = {}

for chan in chans:
c = copy_entity(chan)
print("Copied", chan, "to", c)
c = copy_entity(chan, new_channel_db)
new_chans.append(c)
old_to_new_chan[chan] = c

for src in srcs:
c = copy_entity(src)
print("Copied", src, "to", c)
c = copy_entity(src, new_channel_db)
new_srcs.append(c)
old_to_new_src[src] = c

# # Fix links... pony updates the relationships symmetriacally so we get some for free
# for thing in new_chans + new_srcs:
# print(f"Fixing {thing}")
# for attr in thing._attrs_:
# print(f"\t{attr}")
# if attr:
# if isinstance(getattr(thing, attr.name), Channels.Channel):
# if getattr(thing, attr.name) in old_to_new_chan.keys():
# print(f"setting {thing} {attr.name} to {old_to_new_chan[getattr(thing, attr.name)]}")
# setattr(thing, attr.name, old_to_new_chan[getattr(thing, attr.name)])
# elif isinstance(getattr(thing, attr.name), Channels.MicrowaveSource):
# if getattr(thing, attr.name) in old_to_new_src.keys():
# print(f"setting {thing} {attr.name} to {old_to_new_src[getattr(thing, attr.name)]}")
# setattr(thing, attr.name, old_to_new_src[getattr(thing, attr.name)])
# Fix links... pony updates the relationships symmetriacally so we get some for free
for thing in new_chans + new_srcs:
for attr in thing._attrs_:
if attr:
if isinstance(getattr(thing, attr.name), Channels.Channel):
if getattr(thing, attr.name) in old_to_new_chan.keys():
setattr(thing, attr.name, old_to_new_chan[getattr(thing, attr.name)])
elif isinstance(getattr(thing, attr.name), Channels.MicrowaveSource):
if getattr(thing, attr.name) in old_to_new_src.keys():
setattr(thing, attr.name, old_to_new_src[getattr(thing, attr.name)])

return new_chans, new_srcs

def copy_entity(obj):
def copy_entity(obj, new_channel_db):
"""Copy a pony entity instance"""
kwargs = {a.name: getattr(obj, a.name) for a in obj._attrs_}
kwargs.pop("id")
# kwargs.pop("classtype")
kwargs.pop("channel_db")
kwargs = {a.name: getattr(obj, a.name) for a in obj.__class__._attrs_ if a.name not in ["id", "classtype", "pulse_params"]}
# if "pulse_params" in kwargs.keys():
# kwargs["pulse_params"] = dict(kwargs["pulse_params"])
kwargs["channel_db"] = new_channel_db
return obj.__class__(**kwargs)

# def copy_entity(obj, new_channel_db):
# """Copy a pony entity instance"""
# attr_names = [a.name for a in obj.__class__._attrs_]
# skip = ["id", "classtype", "channel_db"]
# print(obj)
# kwargs = {"channel_db": new_channel_db}
# for a in attr_names: #["label", "source_type"]:
# print("\t", a)
# if a in dir(obj):
# val = getattr(obj, a)
# if a == "pulse_params":
# val = dict(val)
# if val is not None and a not in skip:
# print("\t\t", getattr(obj, a))
# kwargs[a] = val
# return obj.__class__(**kwargs)

class ChannelLibrary(object):

def __init__(self, channel_db_name=None, database_file=None, channelDict={}, **kwargs):
Expand Down Expand Up @@ -125,20 +136,20 @@ def __init__(self, channel_db_name=None, database_file=None, channelDict={}, **k
self.channelDict = {}
self.channels = []
self.sources = []
self.channelDatabase = None
self.channelDatabase = Channels.ChannelDatabase(label="__temp__", time=datetime.datetime.now())
self.channel_db_name = channel_db_name if channel_db_name else "temp"

config.load_config()

self.load_most_recent()
# self.load_most_recent()
# config.load_config()

# Update the global reference
global channelLib
channelLib = self

def get_current_channels(self):
return list(select(c for c in Channels.Channel if (c.channel_db == self.channelDatabase) or (c.channel_db is None)))
return list(select(c for c in Channels.Channel if c.channel_db is None)) + list(select(c for c in Channels.MicrowaveSource if c.channel_db is None))

def update_channelDict(self):
self.channelDict = {c.label: c for c in self.get_current_channels()}
Expand All @@ -150,45 +161,68 @@ def load_by_id(self, id_num):
obj = select(c for c in Channels.ChannelDatabase if c.id==id_num).first()
self.load(obj)

self.channels = list(obj.channels)
self.sources = list(obj.sources)
self.channel_db_name = obj.label
def clear(self):
select(c for c in Channels.Channel if c.channel_db == self.channelDatabase).delete(bulk=True)
select(c for c in Channels.MicrowaveSource if c.channel_db == self.channelDatabase).delete(bulk=True)
self.channelDatabase.time = datetime.datetime.now()

def load(self, obj): #, delete=True):
self.clear()

def load(self, obj):
self.channels = list(obj.channels)
self.sources = list(obj.sources)
chans = list(obj.channels)
srcs = list(obj.sources)

# self.channelDatabase = Channels.ChannelDatabase(label="__temp__", time=datetime.datetime.now())
new_chans, new_srcs = copy_objs(chans, srcs, self.channelDatabase)

self.channels = new_chans
self.sources = new_srcs
self.channel_db_name = obj.label
self.channelDatabase = obj

def load_most_recent(self, name=None):
if name is None:
name = self.channel_db_name
mrcd = Channels.ChannelDatabase.select(lambda d: d.label==name).order_by(desc(Channels.ChannelDatabase.time)).first()
if mrcd:
self.load(mrcd)

def new(self, name):
self.channelDatabase = None
self.channel_db_name = name
self.channels = []
self.sources = []
# self.channelDatabase = None

# def load_most_recent(self, name=None):
# if name is None:
# name = self.channel_db_name
# mrcd = Channels.ChannelDatabase.select(lambda d: d.label==name).order_by(desc(Channels.ChannelDatabase.time)).first()
# if mrcd:
# self.load(mrcd)

# def new(self, name):
# # self.channelDatabase.delete()
# self.clear()
# commit()

# # self.channelDatabase = Channels.ChannelDatabase(label="__temp__", time=datetime.datetime.now())

# # self.channelDatabase = None
# self.channel_db_name = name
# self.channels = []
# self.sources = []

def save(self):
self.save_as(self.channel_db_name)

def save_as(self, name):
# Get channels that are part of the currently active db, or find those that aren't yet part of a db
chans = list(select(c for c in Channels.Channel if (c.channel_db == self.channelDatabase) or (c.channel_db is None)))
srcs = list(select(c for c in Channels.MicrowaveSource if (c.channel_db == self.channelDatabase) or (c.channel_db is None)))
# self.channelDatabase.label = name
# self.channelDatabase.time = datetime.datetime.now()
# cd = self.channelDatabase
# self.channelDatabase = None
# commit()
# self.load(cd, delete=False)

# chans, src = copy_objs(chans, srcs)

cd = Channels.ChannelDatabase(label=name, time=datetime.datetime.now(), channels=chans, sources=srcs)
self.channels = chans
self.sources = srcs
self.channelDatabase = cd
# Get channels that are part of the currently active db
# chans = list(select(c for c in Channels.Channel if c.channel_db is None))
# srcs = list(select(c for c in Channels.MicrowaveSource if c.channel_db is None))
chans = list(self.channelDatabase.channels)
srcs = list(self.channelDatabase.sources)
cd = Channels.ChannelDatabase(label=name, time=datetime.datetime.now(), channels=chans, sources=srcs)
new_chans, new_srcs = copy_objs(chans, srcs, cd)

# self.channels = new_chans
# self.sources = new_srcs
# self.channelDatabase = None
commit()
self.channel_db_name = name
# self.channel_db_name = name

#Dictionary methods
def __getitem__(self, key):
Expand Down Expand Up @@ -222,11 +256,11 @@ def build_connectivity_graph(self):

class APS2(object):
def __init__(self, label, address=None, delay=0.0):
self.chan12 = Channels.PhysicalQuadratureChannel(label=f"{label}-12", instrument=label, translator="APS2Pattern")
self.m1 = Channels.PhysicalMarkerChannel(label=f"{label}-12m1", instrument=label, translator="APS2Pattern")
self.m2 = Channels.PhysicalMarkerChannel(label=f"{label}-12m2", instrument=label, translator="APS2Pattern")
self.m3 = Channels.PhysicalMarkerChannel(label=f"{label}-12m3", instrument=label, translator="APS2Pattern")
self.m4 = Channels.PhysicalMarkerChannel(label=f"{label}-12m4", instrument=label, translator="APS2Pattern")
self.chan12 = Channels.PhysicalQuadratureChannel(label=f"{label}-12", instrument=label, translator="APS2Pattern", channel_db=channelLib.channelDatabase)
self.m1 = Channels.PhysicalMarkerChannel(label=f"{label}-12m1", instrument=label, translator="APS2Pattern", channel_db=channelLib.channelDatabase)
self.m2 = Channels.PhysicalMarkerChannel(label=f"{label}-12m2", instrument=label, translator="APS2Pattern", channel_db=channelLib.channelDatabase)
self.m3 = Channels.PhysicalMarkerChannel(label=f"{label}-12m3", instrument=label, translator="APS2Pattern", channel_db=channelLib.channelDatabase)
self.m4 = Channels.PhysicalMarkerChannel(label=f"{label}-12m4", instrument=label, translator="APS2Pattern", channel_db=channelLib.channelDatabase)

self.trigger_interval = None
self.trigger_source = "External"
Expand All @@ -236,8 +270,8 @@ def __init__(self, label, address=None, delay=0.0):

class X6(object):
def __init__(self, label, address=None):
self.chan1 = Channels.ReceiverChannel(label=f"RecvChan-{label}-1")
self.chan2 = Channels.ReceiverChannel(label=f"RecvChan-{label}-2")
self.chan1 = Channels.ReceiverChannel(label=f"RecvChan-{label}-1", channel_db=channelLib.channelDatabase)
self.chan2 = Channels.ReceiverChannel(label=f"RecvChan-{label}-2", channel_db=channelLib.channelDatabase)

self.address = address
self.reference = "external"
Expand All @@ -246,21 +280,21 @@ def __init__(self, label, address=None):
self.acquire_mode = "digitizer"

def new_qubit(label):
return Channels.Qubit(label=label)
return Channels.Qubit(label=label, channel_db=channelLib.channelDatabase)

def new_source(label, source_type, address, power=-30.0):
return Channels.MicrowaveSource(label=label, source_type=source_type, address=address, power=power)
return Channels.MicrowaveSource(label=label, source_type=source_type, address=address, power=power, channel_db=channelLib.channelDatabase)

def set_control(qubit, awg, generator=None):
qubit.phys_chan = awg.chan12
if generator:
qubit.phys_chan.generator = generator

def set_measure(qubit, awg, dig, generator=None, dig_channel=1, trig_channel=1, gate=False, gate_channel=2, trigger_length=1e-7):
meas = Channels.Measurement(label=f"M-{qubit.label}")
meas = Channels.Measurement(label=f"M-{qubit.label}", channel_db=channelLib.channelDatabase)
meas.phys_chan = awg.chan12

meas.trig_chan = Channels.LogicalMarkerChannel(label=f"digTrig-{qubit.label}")
meas.trig_chan = Channels.LogicalMarkerChannel(label=f"digTrig-{qubit.label}", channel_db=channelLib.channelDatabase)
meas.trig_chan.phys_chan = getattr(awg, f"m{trig_channel}")
meas.trig_chan.pulse_params = {"length": trigger_length, "shape_fun": "constant"}
meas.receiver_chan = getattr(dig, f"chan{dig_channel}")
Expand All @@ -269,11 +303,11 @@ def set_measure(qubit, awg, dig, generator=None, dig_channel=1, trig_channel=1,
meas.phys_chan.generator = generator

if gate:
meas.gate_chan = Channels.LogicalMarkerChannel(label=f"M-{qubit.label}-gate")
meas.gate_chan = Channels.LogicalMarkerChannel(label=f"M-{qubit.label}-gate", channel_db=channelLib.channelDatabase)
meas.gate_chan.phys_chan = getattr(awg, f"m{gate_channel}")

def set_master(awg, trig_channel=2, pulse_length=1e-7):
st = Channels.LogicalMarkerChannel(label="slave_trig")
st = Channels.LogicalMarkerChannel(label="slave_trig", channel_db=channelLib.channelDatabase)
st.phys_chan = getattr(awg, f"m{trig_channel}")
st.pulse_params = {"length": pulse_length, "shape_fun": "constant"}
awg.master = True
Expand Down
20 changes: 10 additions & 10 deletions QGL/Channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def define_entities(db):

class ChannelDatabase(db.Entity):
label = Required(str)
channels = Set("Channel")
sources = Set("MicrowaveSource")
channels = Set("Channel", cascade_delete=True)
sources = Set("MicrowaveSource", cascade_delete=True)
time = Optional(datetime.datetime)

class MicrowaveSource(db.Entity):
Expand All @@ -57,14 +57,14 @@ class MicrowaveSource(db.Entity):
address = Optional(str)
power = Optional(float)
logical_channel = Optional("PhysicalChannel")
channel_db = Optional("ChannelDatabase")
channel_db = Required("ChannelDatabase")

class Channel(db.Entity):
'''
Every channel has a label and some printers.
'''
label = Required(str)
channel_db = Optional("ChannelDatabase")
channel_db = Required("ChannelDatabase")

def __repr__(self):
return str(self)
Expand All @@ -83,9 +83,9 @@ class PhysicalChannel(Channel):

# Required reverse connections
logical_channel = Optional("LogicalChannel")
quad_channel_I = Optional("PhysicalQuadratureChannel", reverse="I_channel")
quad_channel_Q = Optional("PhysicalQuadratureChannel", reverse="Q_channel")
marker_channel = Optional("PhysicalMarkerChannel")
# quad_channel_I = Optional("PhysicalQuadratureChannel", reverse="I_channel")
# quad_channel_Q = Optional("PhysicalQuadratureChannel", reverse="Q_channel")
# marker_channel = Optional("PhysicalMarkerChannel")

class LogicalChannel(Channel):
'''
Expand All @@ -108,14 +108,14 @@ class PhysicalMarkerChannel(PhysicalChannel):
'''
gate_buffer = Required(float, default=0.0)
gate_min_width = Required(float, default=0.0)
phys_channel = Optional(PhysicalChannel)
# phys_channel = Optional(PhysicalChannel)

class PhysicalQuadratureChannel(PhysicalChannel):
'''
Something used to implement a standard qubit channel with two analog channels and a microwave gating channel.
'''
I_channel = Optional(PhysicalChannel)
Q_channel = Optional(PhysicalChannel)
# I_channel = Optional(PhysicalChannel)
# Q_channel = Optional(PhysicalChannel)
amp_factor = Required(float, default=1.0)
phase_skew = Required(float, default=0.0)
# marker_channel = Optional(PhysicalMarkerChannel)
Expand Down

0 comments on commit bb01929

Please sign in to comment.