# Bootstrap Fitting

In [1]:
import sys
sys.path.insert(0, '../../../src/')

import numpy as np
import tensorflow as tf
import random
from tqdm.notebook import tqdm

from quantum_tools import  resample
from kraus_channels import KrausMap
from synthetic_data import generate_map_data, generate_spam_data, generate_spam_benchmark
from optimization import ModelQuantumMap, ModelSPAM, Logger, model_saver
from loss_functions import ProbabilityMSE, ProbabilityRValue, channel_fidelity_loss
from spam import SPAM, InitialState, POVMwQR as POVM, CorruptionMatrix
from utils import loader

#np.set_printoptions(threshold=sys.maxsize)
np.set_printoptions(precision=1)

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)





## Circuits

In [2]:
def fit_spam(inputs, 
             targets,
             num_iter = 3000,
             verbose = False):
    d = targets.shape[1]
    spam_model = SPAM(init = InitialState(d),
                    povm = CorruptionMatrix(d),
                    )

    spam_opt = ModelSPAM(spam_model, tf.keras.optimizers.Adam(learning_rate=0.01))
        
    spam_opt.pretrain(100, verbose=False)

    spam_opt.train(inputs = inputs,
                    targets = targets,
                    num_iter = num_iter,
                    verbose = verbose,
                )
    
    return spam_model
    

def fit_model(inputs, 
              targets,
              channel, 
              spam_model,
              num_iter = 3000,
              verbose=False):
    d = targets.shape[1]
    model = ModelQuantumMap(channel = KrausMap(d = d, 
                                        rank = d**2,
                                        spam = spam_model,
                                        ),
                    loss_function = ProbabilityMSE(),
                    optimizer = tf.optimizers.Adam(learning_rate=0.01),
                    logger = Logger(loss_function_list = [ProbabilityRValue(), channel_fidelity_loss], sample_freq=100),
                )
    model.train(inputs = inputs, 
                targets = targets,
                inputs_val = [inputs, None],
                targets_val = [targets, [channel]],
                N=500,
                num_iter = num_iter,
                verbose = verbose,)
    
    return model

## Free Fermion

### Generate Benchmarks

In [3]:
n = 4
d = 2**n

[channel_FF_list, spectra_FF_list, csr_FF_list] =  loader("data/FF_synthetic_benchmark.pkl")

### Generate Synthetic Data and Fit

In [4]:
np.random.seed(42)
random.seed(42)
tf.random.set_seed(42)
bs_samples = 10

for i, channel in tqdm(list(enumerate(channel_FF_list))):
    if i > 8:
        model_list = []

        spam_target = generate_spam_benchmark(n=4, c1=0.95, c2=0.95, type="CM")
        inputs_spam, targets_spam = generate_spam_data(spam_target, shots=12000)

        inputs_map, targets_map = generate_map_data(channel, 
                                                    spam_target = spam_target,
                                                    N_map=5000-6**n, 
                                                    shots = 12000)


        for bs in tqdm(range(bs_samples)):
            targets_spam_bs = resample(targets_spam, 12000)
            targets_map_bs = resample(targets_map, 12000)
            
            spam_model = fit_spam(inputs_spam,
                                targets_spam_bs,
                                num_iter = 3000,
                                verbose = False)
                

            model = fit_model(inputs_map,
                            targets_map_bs,
                            channel,
                            spam_model,
                            num_iter = 3000,
                            verbose = False
                            )

            model_list.append(model)

        model_saver(model_list, f"models/FF_bootstrap_{i}.model")


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

[0.9972453382065255, -0.9895525462890887]
[0.9972559505373514, -0.9900624070568305]
[0.9972929323795786, -0.989088689883409]
[0.997265329307438, -0.9896801060403454]
[0.9972989942058017, -0.9901588922030107]
[0.9972584772284604, -0.9899666552745838]
[0.9972883448184904, -0.990237312355385]
[0.9972641934610764, -0.9897628269669766]
[0.9972509803841384, -0.9899946625367713]
[0.9972807597990212, -0.9894718786980997]


## Chaotic Haar Random

### Generate Benchmarks

In [7]:
n = 4
d = 2**n

[channel_chaotic_list, spectra_chaotic_list, csr_chaotic_list] =  loader("data/chaotic_synthetic_benchmark.pkl")

### Generate Synthetic Data and Fit

In [8]:
np.random.seed(42)
random.seed(42)
tf.random.set_seed(42)
bs_samples = 10

for i, channel in tqdm(list(enumerate(channel_chaotic_list))):
    model_list = []

    spam_target = generate_spam_benchmark(n=4, c1=0.95, c2=0.95, type="CM")
    inputs_spam, targets_spam = generate_spam_data(spam_target, shots=12000)

    inputs_map, targets_map = generate_map_data(channel, 
                                                spam_target = spam_target,
                                                N_map=5000-6**n, 
                                                shots = 12000)


    for bs in tqdm(range(bs_samples)):
        targets_spam_bs = resample(targets_spam, 12000)
        targets_map_bs = resample(targets_map, 12000)
        
        spam_model = fit_spam(inputs_spam,
                            targets_spam_bs,
                            num_iter = 3000,
                            verbose = False)

        model = fit_model(inputs_map,
                        targets_map_bs,
                        channel,
                        spam_model,
                        num_iter = 3000,
                        verbose = False
                        )

        model_list.append(model)

    model_saver(model_list, f"models/chaotic_bootstrap_{i}.model")


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

[0.9936199917776504, -0.9860866101741259]
[0.9936288340466856, -0.9868338079060498]
[0.9936277657343732, -0.9853946977711365]
[0.9936297590560923, -0.9862237097412877]
[0.9936903483682091, -0.9862750106809443]
[0.9935552532151936, -0.9861729919068993]
[0.9937024686093543, -0.9860160216441454]
[0.9936009906195824, -0.9855669281526243]
[0.9935979955359504, -0.9867304268818546]
[0.9936908306854039, -0.9855037512157829]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.9934327294403668, -0.984453758267463]
[0.9935015761344449, -0.9852636483515333]
[0.9935348297273354, -0.9853564761300662]
[0.9934765720660532, -0.9853481289556578]
[0.99346412284834, -0.9855078086425194]
[0.9935142445225605, -0.9859027842116544]
[0.9934738932678419, -0.9855495192320419]
[0.99355224855667, -0.9860409101889029]
[0.9934576990654285, -0.9852399616981451]
[0.9935108573121362, -0.9853835143841787]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.9935846155438701, -0.9865358722906676]
[0.9935682085083164, -0.9859131302776905]
[0.9935550332436763, -0.986224970937649]
[0.9935803853264262, -0.9865408321419462]
[0.9935478893237463, -0.985774474723154]
[0.9934994188089303, -0.9865531959208677]
[0.9935381101769087, -0.9852039946262473]
[0.9935433999240536, -0.9869270184006546]
[0.9935483639685502, -0.9862782681109832]
[0.9935097536247324, -0.9857519818843182]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.9935928955637826, -0.9854332367151167]
[0.9936421463848943, -0.9851942562718363]
[0.9935939699218971, -0.9854232015315184]
[0.9935079931234553, -0.985764238830421]
[0.9935305405300213, -0.9857628684270895]
[0.9935889737995248, -0.9851397945006134]
[0.9935787761857082, -0.9853063286095012]
[0.9934960319888891, -0.9848718755388286]
[0.9936480295388842, -0.9845687761701798]
[0.9935824291123245, -0.9858652931429116]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.993480160614772, -0.9851900297042487]
[0.99342616187697, -0.9853016071965265]
[0.9935215711961177, -0.9851566428727492]
[0.993564822925072, -0.9852523460731192]
[0.9935359809548557, -0.9859669003362739]
[0.9935911155391844, -0.9857902180495536]
[0.9934689020930434, -0.9854040548852452]
[0.9935470834673676, -0.9852806350415095]
[0.9935637292933099, -0.9848228380618124]
[0.9935721865486432, -0.9848902244340478]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.9935899958656051, -0.9849264151925575]
[0.9936410391631829, -0.9853392896133066]
[0.9935507183191014, -0.9849747326131457]
[0.9936999889248057, -0.9849468312927092]
[0.9936461954801478, -0.9849194100011502]
[0.9937012428885673, -0.9849157296306805]
[0.9935990904916631, -0.9844323883824944]
[0.9935609567564547, -0.9851278361579032]
[0.9936299096104658, -0.9853903399451891]
[0.9937119316573292, -0.9853268999339804]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.9936253004958568, -0.9846929081644596]
[0.9936130353969075, -0.9858073892689863]
[0.9935510220846632, -0.9857882690352331]
[0.9935934296048643, -0.9855900016791115]
[0.9936614659632351, -0.985423205382599]
[0.9935842741144435, -0.9858742455024997]
[0.9936000518768823, -0.985037904725082]
[0.9935809778812461, -0.9852699834017721]
[0.9935908143634103, -0.9857492379798888]
[0.9935709765747106, -0.9850858093895303]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.9935979522406784, -0.9858116302708468]
[0.9935782542578038, -0.9859559781112794]
[0.9936309008490084, -0.9853496345167394]
[0.993573115899992, -0.9853467875723465]
[0.9936144145159345, -0.985340904230346]
[0.9935784012573334, -0.98525746268026]
[0.993591967591361, -0.9856664233030485]
[0.9935804545345016, -0.9855343269467817]
[0.9935960052802247, -0.984981358284573]
[0.9936282740910379, -0.9861238310652074]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.9935452685940963, -0.984959743370851]
[0.9935294647117412, -0.9859135608019839]
[0.9936405944821257, -0.985575117281238]
[0.9935796292942948, -0.9857953916789826]
[0.99350731062461, -0.9850063415469745]
[0.9935555725511646, -0.9855280626265062]
[0.9935408837171561, -0.9856289707836488]
[0.9935675149587735, -0.985528121106157]
[0.9935114884341419, -0.9852445721215055]
[0.9935183009235826, -0.9860177784344236]


  0%|          | 0/10 [00:00<?, ?it/s]

[0.9935283921419641, -0.9860645864214027]
[0.9935050634365632, -0.9858294954562774]
[0.9934909127451294, -0.9855355583168887]
[0.9935814526978526, -0.985433092655006]
[0.9934043025375839, -0.985346455034373]
[0.9935079253245271, -0.9859516850733513]
[0.9934882016976522, -0.9857821501222743]
[0.9933701353149337, -0.9853485499335699]
[0.9934977884254286, -0.9850966339681981]
[0.9935475406503004, -0.9853545284206426]
