In [71]:
from qiskit.circuit.random import random_circuit 
from qiskit.primitives import Sampler 
from qiskit import QuantumCircuit
from qiskit.quantum_info import Statevector, Operator 
import numpy as np

from qiskit.circuit.library.n_local.real_amplitudes import RealAmplitudes
from qiskit.circuit.library.n_local.efficient_su2 import EfficientSU2

import treelib 

In [35]:
def get_samples(circ):
    sampler = Sampler()
    circ_m = circ.measure_all(inplace=False)
    res = sampler.run(
        circ_m
    ).result()
    res = res.quasi_dists[0]
    proba = np.zeros(2**circ.num_qubits)
    for k,v in res.items():
        proba[k] = v 
    return proba

def get_amplitude(circ):
    return Statevector(circ).data.real

def add_h(circ, idx):
    qc = QuantumCircuit(circ.num_qubits)
    qc.append(circ, range(circ.num_qubits))
    qc.h(idx)
    return qc

In [36]:
nqbits = 2
circ = RealAmplitudes(nqbits, entanglement='full', reps=3, insert_barriers=False)
parameters = 4*np.pi*np.random.rand(circ.num_parameters)
circ = circ.bind_parameters(parameters)

In [63]:
sample = get_samples(circ)
amp = get_amplitude(circ)


In [64]:
def get_all_h_circ(circ):
    circuits = []
    
    for idx in range(circ.num_qubits):
        circuits.append(add_h(circ, idx)) 
    return circuits

In [65]:
hcircs = get_all_h_circ(circ)

In [66]:
h_samples = [get_samples(c) for c in hcircs]
h_amps = [get_amplitude(c) for c in hcircs]

In [67]:
amp = -amp 
h_amps = [-a for a in h_amps]

In [68]:
amp, h_amps

(array([-0.29522751, -0.16433723,  0.92260287,  0.1861127 ]),
 [array([-0.32496134, -0.09255341,  0.7839803 ,  0.52077719]),
  array([ 0.44362137,  0.01539758, -0.86113612, -0.24780552])])

In [69]:
1/np.sqrt(2)*(amp[0]+amp[1])

-0.3249613439302672

In [49]:
sample

array([0.08715928, 0.02700672, 0.85119605, 0.03463794])

In [47]:
h_samples

[array([0.10559988, 0.00856613, 0.61462511, 0.27120888]),
 array([1.96799920e-01, 2.37085621e-04, 7.41555417e-01, 6.14075774e-02])]

In [70]:
[np.sign(2*hs - sample) for hs in h_samples]

[array([ 1., -1.,  1.,  1.]), array([ 1., -1.,  1.,  1.])]

In [43]:
amp

array([ 0.29522751,  0.16433723, -0.92260287, -0.1861127 ])

In [None]:
def init_weight(samples, h0_samples):
    signs = np.sign(2*h0_samples - samples)
    N = len(signs)
    weights = np.zeros(N)
    for i in range(int(N/2)):
        b = 2*(i)+1
        weights[b] = signs[b]
    return weights  

In [72]:
def init_tree(N):
    trees = []
    for i in range(int(N/2)):
        tree = treelib.Tree()
        a,b = 2*(i), 2*(i)+1
        tree.create_node(a,a, data = 1)
        tree.create_node(b,b,parent=a, data=2*np.random.randint(1)-1)
        trees.append(tree)
    return trees 
N = 2**4
trees = init_tree(N)

In [73]:
trees[0]

<treelib.tree.Tree at 0x7fdead706610>

In [None]:
def link_weights(samples, h_samples):
    N = len(samples)
    ntree = N/2
    for iter in range(1, int(np.log2(N))):
        for iroot in range(0,int(ntree),2**iter):
            new_root = trees[iroot].root
            print(iroot, new_root)
            new_leaf = iroot + 2**(iter -1)
            trees[iroot].paste(new_root, trees[new_leaf])

In [74]:
def link_trees(trees):
    ntree = len(trees)
    for iter in range(1, int(np.log2(N))):
        for iroot in range(0,int(ntree),2**iter):
            new_root = trees[iroot].root
            print(iroot, new_root)
            new_leaf = iroot + 2**(iter -1)
            trees[iroot].paste(new_root, trees[new_leaf])
    
link_trees(trees)

0 0
2 4
4 8
6 12
0 0
4 8
0 0


In [78]:
trees[3].show()

6
└── 7



In [None]:
tree = trees[0]

In [None]:
tree.paths_to_leaves()

[[0, 1],
 [0, 2, 3],
 [0, 4, 5],
 [0, 4, 6, 7],
 [0, 8, 9],
 [0, 8, 10, 11],
 [0, 8, 12, 13],
 [0, 8, 12, 14, 15]]

In [None]:
for node in tree.all_nodes():
    print(node.data)

1
-1
1
-1
1
-1
1
-1
1
-1
1
-1
1
-1
1
-1


In [None]:
tree.get_node(0).data = -1

In [None]:
tree.get_node(0).data

-1