In [None]:
#| default_exp sampler.direct

In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
from qsample.sampler.base import Sampler, err_probs_tomatrix
import numpy as np

In [None]:
#| export
class DirectSampler(Sampler):
    
    def __init__(self, protocol, simulator, err_model=None, err_probs=None):
        super().__init__(protocol, simulator, err_probs=err_probs, err_model=err_model)
        self.err_probs = err_probs_tomatrix(err_probs, self.err_model.groups)
    
    def stats(self, tree_idx=None):
        if tree_idx:
            p_L = self.trees[tree_idx].direct_rate
            v_L = self.trees[tree_idx].direct_variance
        else:
            p_L, v_L = [], []
            for tree in self.trees.values():
                p_L.append(tree.direct_rate)
                v_L.append(tree.direct_variance)
        return p_L, np.sqrt(v_L)
    
    def optimize(self, tree_node, circuit, grp_probs):
        locgrps = self.protocol_groups[circuit.id]
        flocs = self.err_model.choose_p(locgrps,grp_probs)
        subset = tuple(len(locs) for locs in flocs.values())
        return {'subset': subset, 'flocs': flocs}

In [None]:
from qsample.examples import ghz3
from qsample.noise import E1, E2
from qsample.sim.stabilizer import StabilizerSimulator as CHP
from qsample.callbacks import *
from qsample.sampler.auto import SamplerSwitch

e = E1

# sample_range = np.logspace(-2,0,3)

err_probs = {
    # "q1": 0,
    # "q": sample_range
    "q": 0.1
}

sam = DirectSampler(ghz3, CHP, err_probs=err_probs, err_model=e)

sam.run(50000, callbacks=[SamplerSwitch(period=1000)])

p_phy=1.00E-01:   0%|          | 0/50000 [00:00<?, ?it/s]

2.6397567675619875e-05 8.83635849516693e-06 0.003724008139945334 False
1.704484798547869e-05 6.771105452254702e-06 0.0035897920171034103 False
9.70677931938663e-06 5.244447473287157e-06 0.0012155498688917188 False
8.779527824996818e-06 4.215707665621288e-06 0.0012935985534637817 False
6.838783419605407e-06 3.641153511963312e-06 0.0013600003290819007 False
6.01482857837642e-06 3.1788255078853225e-06 0.0016703433178606453 False
4.96459192865616e-06 2.8854963395405257e-06 0.0017299028138866301 False
4.336386144810397e-06 2.640885832570148e-06 0.0013996299143217428 False
3.7096392492719742e-06 2.444173277082957e-06 0.001383935580533291 False
3.3097148442467728e-06 2.2716769164659404e-06 0.0011584527525297927 False
3.1119764575596894e-06 2.094197535543215e-06 0.001081938496391821 False
2.80048445225718e-06 1.95371780686471e-06 0.001077332756805327 False
2.5219774319700646e-06 1.7437833609518985e-06 0.0009547733469863484 False
2.5033122337583186e-06 1.6118162803700508e-06 0.00090274165707526

In [None]:
from ipywidgets import interact
import ipywidgets as widgets

def get_tree(i):
    print("p_phy=", list(sam.trees.keys())[i])
    return (list(sam.trees.values())[i]).draw()
interact(get_tree, i=widgets.IntSlider(min=0, max=len(sam.trees)-1, step=1, value=0));

interactive(children=(IntSlider(value=0, description='i', max=0), Output()), _dom_classes=('widget-interact',)…

In [None]:
#| slow
sam.save('./test')

In [None]:
#| slow
sam2 = Sampler.load('./test')
sam2.run(100)

p_phy=1.00E-01:   0%|          | 0/100 [00:00<?, ?it/s]