In [None]:
from qiskit import QuantumCircuit, transpile, assemble
from qiskit.quantum_info import Statevector, DensityMatrix, Pauli, random_clifford, random_statevector, partial_trace
from qiskit_aer import AerSimulator, Aer
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm, logm
import cvxpy as cp
import pickle
import itertools, functools, collections, warnings, copy, winsound, os
from IPython.display import clear_output

def tensor_prod(*tensors):
    if len(tensors) == 2:
        return np.kron(tensors[0], tensors[1])
    else:
        return np.kron(tensors[0], tensor_prod(*tensors[1:]))
    
def hermitian(matrix):
    return np.allclose(matrix, matrix.conj().T)

def trace_one(matrix):
    return np.isclose(np.trace(matrix), 1)

def positive_semi_definite(matrix, tol=1e-5):
    return np.all(np.linalg.eigvalsh(matrix) + tol >= 0)

def is_legal(matrix):
    return hermitian(matrix) and trace_one(matrix) and positive_semi_definite(matrix)

def check_legal(matrix, print_msg=True):
    errors, legal = [], True
    if not hermitian(matrix):
        errors.append('not hermitian')
    if not trace_one(matrix):
        errors.append('trace not equal to one')
    if not positive_semi_definite(matrix):
        errors.append('not positive semidefinite')
    legal = (len(errors) == 0)
    if not legal:
        msg = f'input is not legal: ' + '; '.join(errors)
    else: 
        msg = 'input is a legal density matrix'
    if print_msg:
        print(msg)
    return legal, msg
        
        
def generate_prob_lst(num_states):
    prob_lst = np.array([np.random.random() for _ in range(num_states)])
    prob_lst /= np.sum(prob_lst)
    return prob_lst

def get_rank(dm, tol=1e-10):
    return int(np.sum(np.linalg.eigvalsh(dm) > tol))

def get_fidelity(dm1, dm2, tol=1e-5):
    # assert is_legal(dm1) and is_legal(dm2), 'inputs are not legal density matrices'
    if not is_legal(dm1):
        warnings.warn("input dm1 is not a legal density matrix", UserWarning)
    if not is_legal(dm2):
        warnings.warn("input dm2 is not a legal density matrix", UserWarning)        
    try: 
        fidelity = (np.trace(sqrtm(sqrtm(dm1) @ dm2 @ sqrtm(dm1)))) ** 2
    except ValueError:
        print('fidelity cannot be computed for the given inputs')
    # assert np.abs(np.imag(fidelity)) < tol, 'fidelity is not real within tol'
    if np.abs(np.imag(fidelity)) > tol:
        warnings.warn(f"the fidelity has an imaginary part larger than tol: {np.abs(np.imag(fidelity))}")
    return fidelity.real

def get_purity(dm, tol=1e-5):
    if not is_legal(dm):
        warnings.warn("input is not a legal density matrix", UserWarning)
    try:
        purity = np.trace(dm @ dm)
    except ValueError:
        print('purity cannot be computed for the given inputs')
    # assert np.abs(np.imag(purity)) < tol, 'purity is not real within tol'
    if np.abs(np.imag(purity)) > tol:
        warnings.warn(f"the purity has an imaginary part larger than tol: {np.abs(np.imag(purity))}")
    return purity.real        

def generate_dm(num_qubits, num_states, state_lst=None, prob_lst=None, prime_prob=None):
    assert (prob_lst is None) or (prime_prob is None), 'cannot set prob_lst and prime_prob together'
    if state_lst is None:
        state_lst = [random_statevector(2**num_qubits) for _ in range(num_states)]
    if prime_prob is not None:
        prob_lst = np.array([prime_prob] + (generate_prob_lst(num_states - 1) * (1 - prime_prob)).tolist())
    elif prob_lst is None:
        prob_lst = generate_prob_lst(num_states)
    density_matrix = sum([DensityMatrix(state_lst[i]).data * prob_lst[i] for i in range(num_states)])
    return density_matrix

def single_sample(prob_list):
    assert np.isclose(sum(prob_list), 1), "probability does not sum up to 1"
    def alias_setup(probabilities):
        n = len(probabilities)
        prob = [0] * n
        alias = [0] * n
        scaled_prob = [p * n for p in probabilities]
        small = []
        large = []
        for i, sp in enumerate(scaled_prob):
            if sp < 1.0:
                small.append(i)
            else:
                large.append(i)
        while small and large:
            small_index = small.pop()
            large_index = large.pop()
            prob[small_index] = scaled_prob[small_index]
            alias[small_index] = large_index
            scaled_prob[large_index] = (scaled_prob[large_index] + scaled_prob[small_index]) - 1.0
            if scaled_prob[large_index] < 1.0:
                small.append(large_index)
            else:
                large.append(large_index)
        for i in large:
            prob[i] = 1.0
        for i in small:
            prob[i] = 1.0
        return prob, alias
    def alias_sample(prob, alias):
        n = len(prob)
        i = np.random.randint(0, n)
        r = np.random.random()
        if r < prob[i]:
            return i
        else:
            return alias[i]
    return alias_sample(*alias_setup(prob_list))



def sample_from_dict(d, n_samples):
    def alias_setup(probs):
        n = len(probs)
        alias = np.zeros(n, dtype=int)
        prob = np.zeros(n, dtype=np.float64)
        scaled_probs = np.array(probs) * n
        small = []
        large = []
        for i, sp in enumerate(scaled_probs):
            if sp < 1.0:
                small.append(i)
            else:
                large.append(i)
        while small and large:
            small_idx = small.pop()
            large_idx = large.pop()

            prob[small_idx] = scaled_probs[small_idx]
            alias[small_idx] = large_idx

            scaled_probs[large_idx] = scaled_probs[large_idx] + scaled_probs[small_idx] - 1.0

            if scaled_probs[large_idx] < 1.0:
                small.append(large_idx)
            else:
                large.append(large_idx)
        while large:
            large_idx = large.pop()
            prob[large_idx] = 1.0
        while small:
            small_idx = small.pop()
            prob[small_idx] = 1.0
        return alias, prob
    def alias_draw(alias, prob):
        n = len(alias)
        i = np.random.randint(n)
        if np.random.rand() < prob[i]:
            return i
        else:
            return alias[i]
    keys = list(d.keys())
    probs = list(d.values())
    alias, prob = alias_setup(probs)
    samples = [keys[alias_draw(alias, prob)] for _ in range(n_samples)]
    return samples

def expand_to_tensor_product(array):
    index = np.argmax(array)
    n = int(np.log2(len(array)))
    binary_string = format(index, f'0{n}b')
    tensor_product = []
    for bit in binary_string:
        if bit == '0':
            tensor_product.append(np.array([1, 0]))
        else:
            tensor_product.append(np.array([0, 1]))
    return tensor_product

def int_to_bin_list(n, length):
    bin_list = np.zeros(length)
    bin_list[n] = 1
    return bin_list

def split_and_calculate_mean(values, group_size):
    groups = [values[i:i + group_size] for i in range(0, len(values), group_size)]
    means = [np.sum(group, axis=0) / len(group) for group in groups]
    return means

def generate_random_01_strings(num_strings, length):
    characters = ['0', '1']
    generated_strings = []
    assert num_strings < 2 ** length, 'too much strings to generate'
    for _ in range(num_strings):
        while True:
            random_string = ''.join(np.random.choice(characters) for _ in range(length))
            if random_string != '0' * length and random_string not in generated_strings:
                generated_strings.append(random_string)
                break
    return generated_strings

def generate_uv_Pauli_matrix(u_vec, v_vec):
    uv_map = {
        '00':'I', '01':'Z', '10':'X', '11':'Y'
    }
    Pauli_string = ''
    for u, v in zip(u_vec, v_vec):
        Pauli_char = uv_map.get(str(u) + str(v))
        if Pauli_char is None:
            raise ValueError('u or v list contains elements neither 0 or 1')
        else:
            Pauli_string += Pauli_char
    return Pauli_string

def generate_all_binary(length):
    all_strings = itertools.product('01', repeat=length)
    result = [''.join(s) for s in all_strings if '1' in s]
    return result

def generate_Pauli_expectations(dm, obsv):
    return np.trace(dm @ Pauli(obsv).to_matrix()).real

def get_trace_norm(dm):
    return np.sum(np.linalg.svd(dm, compute_uv=False))

In [None]:
single_states = {'0': np.array([[1], [0]]), '1': np.array([[0], [1]])}

class State():
    def __init__(self, strings):
        if len(strings) == 1:
            self.state = single_states[strings]
        else:
            singles = [single_states[s] for s in strings]
            self.state = tensor_prod(*singles)
    def to_vector(self):
        return self.state

def generate_random_Pauli_strings(num_strings, num_qubits, pattern='balanced'):
    assert pattern in ['balanced', 'pro_I', 'pro_XYZ', 'uv_pair', 'only_XYZ'], 'please choose pattern from: balanced, pro_I, pro_XYZ, uv_pair, only_XYZ'
    generated_strings = []
    characters = ['X', 'Y', 'Z', 'I']
    if pattern == 'only_XYZ':
        assert 0 < num_strings <= 3 ** num_qubits, 'too much or too few strings to generate'
        return np.random.choice([''.join(obsv) for obsv in list(itertools.product(*['XYZ' for _ in range(num_qubits)]))], num_strings, replace=False)
    if pattern == 'balanced':
        assert 0 < num_strings <= 4 ** num_qubits - 1, 'too much or too few strings to generate'
        for _ in range(num_strings):
            while True:
                random_string = ''.join(np.random.choice(characters) for _ in range(num_qubits))
                if random_string != 'I' * num_qubits and random_string not in generated_strings:
                    generated_strings.append(random_string)
                    break
        return generated_strings
    if pattern == 'uv_pair':
        assert 0 < num_strings <= 4 ** num_qubits - 1, 'too much or too few strings to generate'
        uv_map = {'00':'I', '01':'Z', '10':'X', '11':'Y'}
        whole, remain = num_strings // 2 ** (num_qubits), num_strings % 2 ** (num_qubits)
        if whole > 0:
            v_lst = [format(i, f'0{num_qubits}b') for i in range(2 ** num_qubits)]
            u_lst = np.random.choice([format(i, f'0{num_qubits}b') for i in range(1, 2 ** num_qubits)], whole, replace=False).tolist()
            for u, v in list(itertools.product(u_lst, v_lst)):
                generated_strings.append(''.join([uv_map[u_char + v_char] for u_char, v_char in zip(u, v)]))
        if remain > 0:
            u_lst = ['0' * num_qubits]
            v_lst = np.random.choice([format(i, f'0{num_qubits}b') for i in range(1, 2 ** num_qubits)], remain, replace=False).tolist()
            for u, v in list(itertools.product(u_lst, v_lst)):
                generated_strings.append(''.join([uv_map[u_char + v_char] for u_char, v_char in zip(u, v)]))
        return generated_strings
    all_strings = generate_random_Pauli_strings(4 ** num_qubits - 1, num_qubits, pattern='balanced')
    grouped = dict()
    for i in range(num_qubits):
        grouped[i] = []
    for string in all_strings:
        grouped[string.count('I')].append(string)
    if pattern == 'pro_I':
        assert 0 < num_strings <= 4 ** num_qubits - 1, 'too much or too few strings to generate'
        i = num_qubits - 1
        while len(generated_strings) < num_strings:
            if num_strings - len(generated_strings) >= len(grouped[i]):
                generated_strings += grouped[i]
            else:
                generated_strings += np.random.choice(grouped[i], num_strings - len(generated_strings), replace=True).tolist()
            i -= 1
        return generated_strings
    if pattern == 'pro_XYZ':
        assert 0 < num_strings <= 4 ** num_qubits - 1, 'too much or too few strings to generate'
        i = 0
        while len(generated_strings) < num_strings:
            if num_strings - len(generated_strings) >= len(grouped[i]):
                generated_strings += grouped[i]
            else:
                generated_strings += np.random.choice(grouped[i], num_strings - len(generated_strings), replace=True).tolist()
            i += 1
        return generated_strings
    
def broadcast_string(s):
    def helper(s, index, current, results):
        if index == len(s):
            results.add(current)
            return
        # Option 1: Keep the current character
        helper(s, index + 1, current + s[index], results)
        # Option 2: Replace the current character with 'I'
        helper(s, index + 1, current + 'I', results)

    results = set()
    helper(s, 0, '', results)
    return results

def broadcast_all_strings(strings):
    all_broadcasts = set()
    for s in strings:
        all_broadcasts.update(broadcast_string(s))
    return all_broadcasts
    
def estimate_Pauli_expectations(dm, obsv, num_samples, simulation=False):
    num_samples = int(num_samples)
    if simulation: # simulate the process of sampling
        exp = np.real(np.trace(dm @ Pauli(obsv).to_matrix()))
        prob_p1 = (1 + exp) / 2
        prob_m1 = 1 - prob_p1
        samples = np.random.choice([+1, -1], size=num_samples, p=[prob_p1, prob_m1])
        return np.mean(samples)
    else: # use the approximate distribution instead
        exp = np.real(np.trace(dm @ Pauli(obsv).to_matrix()))
        num_samples_root = num_samples ** .5
        std_dev = (1 - exp ** 2) ** .5 / num_samples_root
        return np.random.normal(exp, std_dev)
    
def broadcast_Pauli_expectations(dm, num_qubits, obserables, num_samples):
    repetition = num_samples // len(obserables)
    original = set(obserables)
    broadcasted = broadcast_all_strings(obserables).difference(original)
    expectations = {key: [] for key in broadcasted.union(original)}
    remaining = copy.deepcopy(broadcasted)
    converter = {
        'X': np.array([[1, 1], [1, -1]]) / np.sqrt(2), 
        'Y': np.array([[1, 0], [0, 1j]], dtype=np.complex128) @ (np.array([[1, 1], [1, -1]]) / np.sqrt(2)), 
        'Z': np.array([[1, 0], [0, 1]])
    }
    all_states = [''.join(state) for state in list(itertools.product(*['01' for _ in range(num_qubits)]))]
    for obsv in original:
        overall_converter = tensor_prod(*[converter[s] for s in obsv])
        converted_dm = overall_converter.conj().T @ dm @ overall_converter
        distributions = {state: (State(state).to_vector().conj().T @ converted_dm @ State(state).to_vector())[0][0].real for state in all_states}
        samples = dict(collections.Counter(sample_from_dict(distributions, repetition)))
        probabilities = [samples[state] / repetition if state in samples.keys() else 0 for state in all_states]
        indices = range(len(obsv))
        parities = [(-1) ** (sum([1 for i in indices if state[i] == '1'])) for state in all_states]
        expectations[obsv] = sum([prob * parity for prob, parity in zip(probabilities, parities)])
        if len(remaining) > 0:
            for obs in broadcast_all_strings([obsv]).difference(obsv):
                if obs in remaining:
                    indices = [i for i in range(len(obs)) if not obs[i] == 'I']
                    parities = [(-1) ** (sum([1 for i in indices if state[i] == '1'])) for state in all_states]
                    expectations[obs] = sum([prob * parity for prob, parity in zip(probabilities, parities)])
                    remaining.remove(obs)
    return expectations
            
    
def get_partial_trace(dm, subsystems):
    return partial_trace(dm, subsystems).data

def get_von_neumann_entropy(dm, r=None):
    # if not np.allclose(dm, dm.conj().T):
    #     raise ValueError("The density matrix must be Hermitian.")
    if r is None:
        eigenvalues = np.linalg.eigvalsh(dm)
        eigenvalues = eigenvalues[eigenvalues > 0]
        entropy = - np.sum(eigenvalues * np.log(eigenvalues))
    else:
        eigenvalues = np.linalg.eigvalsh(dm)
        eigenvalues = np.partition(eigenvalues, -r)[-r:]
        entropy = - np.sum(eigenvalues * np.log(eigenvalues))
    return entropy

def get_TMI(dm, r=None):
    dm_s = get_partial_trace(dm, [1, 2])
    dm_m1 = get_partial_trace(dm, [0, 2])
    dm_m2 = get_partial_trace(dm, [0, 1])
    dm_m = get_partial_trace(dm, [0])
    dm_sm1 = get_partial_trace(dm, [2])
    dm_sm2 = get_partial_trace(dm, [1])
    i2_sm1 = get_von_neumann_entropy(dm_s, r) + get_von_neumann_entropy(dm_m1, r) - get_von_neumann_entropy(dm_sm1, r)
    i2_sm2 = get_von_neumann_entropy(dm_s, r) + get_von_neumann_entropy(dm_m2, r) - get_von_neumann_entropy(dm_sm2, r)
    i2_sm = get_von_neumann_entropy(dm_s, r) + get_von_neumann_entropy(dm_m, r) - get_von_neumann_entropy(dm, r)
    return i2_sm1 + i2_sm2 - i2_sm

def regularize(dm):
    dm /= np.trace(dm)
    dm = (dm + dm.conj().T) / 2
    return dm

        

In [None]:
class QuantumState():
    def __init__(self, num_qubits:int, num_shots:int|list):
        self._num_qubits = num_qubits
        self._num_shots = num_shots
        self._dm = None
        self._entangled = None
        self._params = dict()
        
    @property
    def dm(self):
        return self._dm
    
    @dm.setter
    def dm(self, new_dm):
        if not is_legal(new_dm):
            raise ValueError("density matrix is not physical")
        else:
            self._dm = new_dm
            
    @dm.deleter
    def dm(self):
        del self._dm
        
    @property
    def purity(self):
        return get_purity(self.dm)
    
    @property
    def TMI(self):
        return get_TMI(self.dm)

    @property
    def params(self):
        return self._params
    
    # @params.setter
    # def params(self, mode, **kwargs):
    #     assert mode in ['Classical Shadow', 'Compressed Sensing', 'QST'], 'allowed modes are: Classical Shadow, Compressed Sensing, QST'
    #     if mode == 'Classical Shadow':
    #         assert all([k in kwargs.keys() for k in ['target_func', 'sub_mode', 'batch_size']]), 'Classical Shadow method lacks necessary parameters: target_func, sub_mode, batch_size'
    #         self._params['Classical Shadow'] = kwargs
    #     if mode == 'Compressed Sensing':
    #         assert all([k in kwargs.keys() for k in ['target_func', 'num_bases', 'sub_mode']]), 'Compressed Sensing method lacks necessary parameters: target_func, num_bases, sub_mode'
    #         self._params['Compressed Sensing'] = kwargs
    
    def set_params(self, mode, **kwargs):
        assert mode in ['Classical Shadow', 'Compressed Sensing', 'Quantum State Tomography'], 'allowed modes are: Classical Shadow, Compressed Sensing, quantum_state_tomography'
        if mode == 'Classical Shadow':
            required_keys = ['target_func', 'sub_mode', 'batch_size']
            assert all(k in kwargs for k in required_keys), f'Classical Shadow method lacks necessary parameters: {", ".join(required_keys)}'
            self._params['Classical Shadow'] = kwargs
        elif mode == 'Compressed Sensing':
            required_keys = ['target_func', 'num_bases', 'sub_mode']
            assert all(k in kwargs for k in required_keys), f'Compressed Sensing method lacks necessary parameters: {", ".join(required_keys)}'
            self._params['Compressed Sensing'] = kwargs
        elif mode == 'Quantum State Tomography':
            required_keys = ['target_func']
            assert all(k in kwargs for k in required_keys), f'Quantum State Tomography method lacks necessary parameters: {", ".join(required_keys)}'
            self._params['Quantum State Tomography'] = kwargs
                    
    def compressed_sensing(self):
        assert 'Compressed Sensing' in self._params.keys(), 'parameters for Compressed Sensing have not been set'
        num_bases = self._params['Compressed Sensing']['num_bases']
        target_func = self._params['Compressed Sensing']['target_func']
        observables = generate_random_Pauli_strings(num_bases, self._num_qubits, pattern='only_XYZ')
        expectations = broadcast_Pauli_expectations(self.dm, self._num_qubits, observables, self._num_shots)
        
        def optimize(dim, expct, tol=1e-5):
            sigma = cp.Variable((dim, dim), complex=True)
            objective = cp.Minimize(cp.abs(5 * cp.norm(sigma, 'nuc') + 0 * cp.norm(sigma, 'fro') ** 2))
            constraints = [cp.trace(sigma) == 1]
            for o, e in expct.items():
                constraints.append(cp.abs(cp.trace(sigma @ Pauli(o).to_matrix()) - e) <= tol)
            problem = cp.Problem(objective, constraints)
            problem.solve()
            if problem.status not in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
                raise ValueError(f"Optimization failed with given observables")
            return sigma.value
        
        tolerance = 2 * (self._num_shots / num_bases) ** (-.5)
        sigma = optimize(2 ** self._num_qubits, expectations, tol = tolerance)
        sigma = regularize(sigma)
        return target_func(sigma)
    
    def quantum_state_tomography(self):
        assert 'Quantum State Tomography' in self._params.keys(), 'parameters for Quantum State Tomography have not been set'
        target_func = self._params['Quantum State Tomography']['target_func']
        all_bases = [''.join(obsv) for obsv in list(itertools.product(*['XYZ' for _ in range(self._num_qubits)]))]
        all_expectations = broadcast_Pauli_expectations(self.dm, self._num_qubits, all_bases, self._num_shots)
        dm = sum([Pauli(obsv).to_matrix() * expct for obsv, expct in all_expectations.items()]) / 2 ** self._num_qubits
        return target_func(dm)
    
    def random_evolve(self, sub_mode):
        if sub_mode == 'Clifford':
            self._U = random_clifford(self._num_qubits).to_matrix()
            self.dm = self._U @ self.dm @ np.conj(self._U).T
        elif sub_mode == 'Pauli':
            self._U = [random_clifford(1).to_matrix() for _ in range(self._num_qubits)]
            self.dm = tensor_prod(*self._U) @ self.dm @ np.conj(tensor_prod(*self._U)).T
    
    def single_shot_measure(self, sub_mode):
        prob_list = [self._dm[i, i] for i in range(2 ** self._num_qubits)]
        single_shot_state = int_to_bin_list(single_sample(prob_list), 2 ** self._num_qubits)
        del self._dm
        if sub_mode == 'Clifford':
            self._state = single_shot_state
        elif sub_mode == 'Pauli':
            self._state = expand_to_tensor_product(single_shot_state)
    
    def reconstruct_dm(self, sub_mode):
        dim = 2 ** self._num_qubits
        if sub_mode == 'Clifford':
            return (dim + 1) * (np.conj(self._U).T @ np.outer(self._state, self._state) @ self._U) - np.eye(dim)
        elif sub_mode == 'Pauli':
            return tensor_prod(*[3 * (np.conj(single_U).T @ np.outer(single_state, single_state) @ single_U) - np.eye(2) 
                                 for single_U, single_state in zip(self._U, self._state)])
    
    def classical_shadow(self):
        assert 'Classical Shadow' in self._params.keys(), 'parameters for Classical Shadow have not been set'
        sub_mode = self._params['Classical Shadow']['sub_mode']
        target_func = self._params['Classical Shadow']['target_func']
        batch_sizes = self._params['Classical Shadow']['batch_size']
        assert sub_mode in ['Clifford', 'Pauli'], 'sub_mode can only be set to Clifford or Pauli'
        real_value = target_func(self.dm) # only used to specify data type of the target function
        if isinstance(real_value, (float, np.float16, np.float32, np.float64)):
            results, result_type = [], 'scalar'
        elif isinstance(real_value, dict):
            results, result_type = {key: [] for key in real_value.keys()}, 'dict'
        dm_copy = self.dm.copy()        
        
        if isinstance(batch_sizes, (list, np.ndarray)):
            if isinstance(batch_sizes, np.ndarray):
                batch_sizes = batch_sizes.tolist()
                
            def get_gcd(a, b):
                while b:
                    a, b = b, a % b
                return a
            def get_gcd_multiple(*numbers):
                return functools.reduce(get_gcd, numbers)
            
            batch_gcd = get_gcd_multiple(*batch_sizes)
            reduced_sizes = np.array(batch_sizes) // batch_gcd
            long_results = copy.deepcopy(results)
            all_results = {bs: copy.deepcopy(results) for bs in batch_sizes}
            
            for _ in range(self._num_shots // batch_gcd):
                snapshots = []
                for _ in range(batch_sizes):
                    self._dm = dm_copy.copy()
                    self.random_evolve(sub_mode)
                    self.single_shot_measure(sub_mode)
                    snapshots.append(self.reconstruct_dm(sub_mode))
                mean_dm = np.mean(np.stack(snapshots), axis=0)
                sample = target_func(mean_dm)
                if result_type == 'scalar':
                    long_results.append(sample)
                elif result_type == 'dict':
                    long_results = {key: value + [sample[key]] for key, value in long_results.items()}
            if result_type == 'scalar':
                long_results /= np.sum(long_results)
            elif result_type == 'dict':
                long_results = {key: np.mean(value) for key, value in long_results.items()}
            
            def get_means(lst, n):
                means = []
                for i in range(0, len(lst), n):
                    chunk = lst[i : i + n]
                    chunk_mean = np.sum(chunk) / n
                    means.append(chunk_mean)
                return means
            
            for reduced_size in reduced_sizes:
                if result_type == 'scalar':
                    all_results[reduced_size * batch_gcd] = np.median(get_means(long_results, reduced_size))
                elif result_type == 'dict':
                    all_results[reduced_size * batch_gcd] = {key: np.median(get_means(long_results[key], reduced_size)) for key in long_results.keys()}
            return all_results
                                
        else: # one single batch_size
            for _ in range(self._num_shots // batch_sizes):
                snapshots = []
                for _ in range(batch_sizes):
                    self._dm = dm_copy.copy()
                    self.random_evolve(sub_mode)
                    self.single_shot_measure(sub_mode)
                    snapshots.append(self.reconstruct_dm(sub_mode))
                mean_dm = np.mean(np.stack(snapshots), axis=0)
                sample = target_func(mean_dm)
                if result_type == 'scalar':
                    results.append(sample)
                elif result_type == 'dict':
                    results = {key: value + [sample[key]] for key, value in results.items()}
            if result_type == 'scalar':
                results = np.median(results)
            elif result_type == 'dict':
                results = {key: np.median(value) for key, value in results.items()}
            return results
        
    def quantum_state_tomography_for_purity(self): # deprecated
        warnings.warn("this method has been deprecated", DeprecationWarning)
        max_shots = np.max(self._num_shots)
        max_repetition = max_shots // (4 ** self._num_qubits)
        all_observables = [''.join(obsv) for obsv in list(itertools.product(*['IXYZ' for _ in range(self._num_qubits)]))]
        
        def pm1_sample(expct, num_samples):
            prob_p1 = (1 + expct) / 2
            return [+1 if np.random.random() < prob_p1 else -1 for _ in range(num_samples)]
        
        all_samples = dict()
        for obsv in all_observables:
            all_samples[obsv] = pm1_sample(np.trace(self._dm @ Pauli(obsv).to_matrix()).real, max_repetition)
        purities = []
        for num_shots in self._num_shots:
            repetition = num_shots // (4 ** self._num_qubits)
            temp_samples = {k: v[:repetition] for k, v in all_samples.items()}
            estm_dm = np.sum(np.stack([np.mean(new_v) * Pauli(k).to_matrix() for k, new_v in temp_samples.items()]), axis=0) / (2 ** self._num_qubits)
            purities.append(np.trace(estm_dm @ estm_dm).real)
        return purities
    
    def classical_shadow_multi_batches(self): # deprecated
        warnings.warn("this method has been deprecated", DeprecationWarning)
        assert self._meas in ['Clifford', 'Pauli'], 'only Clifford and Pauli pattern have classical_shadow method'
        assert isinstance(self._batch_size, list) or isinstance(self._batch_size, np.ndarray), 'there must be more than one batch size'
        if not self._compute_purity:
            all_shadows = [{obs: [] for obs in self._observables} for _ in self._batch_size]
            dm_copy = self._dm
            snapshots = []
            for _ in range(self._num_shots):
                self._dm = dm_copy
                self.random_evolve()
                self.single_shot_measure()
                snapshots.append(self.reconstruct_dm())
            for index, size in enumerate(self._batch_size):
                snapshots = split_and_calculate_mean(snapshots, size)
                for k in self._observables:
                    samples = [np.trace(snapshot @ Pauli(k).to_matrix()).real for snapshot in snapshots]
                    all_shadows[index][k] = np.median(samples)
            return all_shadows
    
    def direct_sample(self): # deprecated
        warnings.warn("this method has been deprecated", DeprecationWarning)
        assert self._meas == 'direct', 'only direct pattern have direct_sample method'
        all_samples = {obs: [] for obs in self._observables}
        repetition = self._num_shots // len(self._observables)
        for k in all_samples.keys():
            expct = np.trace(Pauli(k).to_matrix() @ self._dm).real
            prob_p1 = (1 + expct) / 2
            prob_m1 = (1 - expct) / 2
            all_samples[k] = [+1 if np.random.random() < prob_p1 else -1 for _ in range(repetition)]
            all_samples[k] = np.mean(all_samples[k])
        return all_samples

In [None]:
# def parse(i_idx, j_idx):
#     os.chdir(r'C:\Users\Neville\Documents\Quantum Computing 3\Entanglement Detection\von Neumann Entropy')
#     name = str(i_idx) + '-' + str(j_idx) + '.txt'
#     path = r'output\resources01'
#     with open(os.path.join(path, name), 'r') as file:
#         data = file.read()
#     data = data.replace('(', '').replace(')', '').replace('j', 'j ')
#     complex_numbers = data.split()
#     complex_array = np.array(complex_numbers, dtype=np.complex128)
#     sample_dm = complex_array.reshape((8, 8))
#     return str(i_idx) + '-' + str(j_idx), sample_dm

# i, j, filtered_dms = 0, 0, dict()
# while i < 6:
#     if j > 150:
#         i += 1
#         j = 0
#     try: 
#         name, dm = parse(i, j)
#         if np.abs(get_TMI(dm)) > 1e-3:
#             filtered_dms[name] = dm
#         j += 1
#     except FileNotFoundError:
#         j += 1

with open(r'output/resources01.pkl', 'rb') as file:
    # Load the data from the file
    os.chdir(r'C:\Users\Neville\Documents\Quantum Computing 3\Entanglement Detection\von Neumann Entropy')
    filtered_dms = pickle.load(file)

In [None]:
nums_samples = np.logspace(4, 7, num=4, endpoint=True, base=10.0, dtype=None, axis=0)
nums_samples = [int(n) for n in nums_samples]
all_data_cs = dict()
for name, dm in filtered_dms.items():
    all_data_cs[name] = [[] for _ in range(len(nums_samples))]
    for i in range(len(nums_samples)):
        for _ in range(4):
            num_samples = nums_samples[i]
            qstate = QuantumState(3, num_samples)
            qstate.dm = dm
            qstate.set_params('Compressed Sensing', num_bases=14, sub_mode='', target_func=get_TMI)   
            try:
                cs_tmi = qstate.compressed_sensing()
            except:
                cs_tmi = np.nan
            all_data_cs[name][i].append(cs_tmi)
        all_data_cs[name][i] = np.array(all_data_cs[name][i])
with open (r'output/tmis02-cs.pkl', 'wb') as file:
    pickle.dump(all_data_cs, file)

In [None]:
nums_samples = np.logspace(4, 7, num=4, endpoint=True, base=10.0, dtype=None, axis=0)
nums_samples = [int(n) for n in nums_samples]
all_data_qst = dict()
for name, dm in filtered_dms.items():
    all_data_qst[name] = [[] for _ in range(len(nums_samples))]
    for i in range(len(nums_samples)):
        for _ in range(4):
            num_samples = nums_samples[i]
            qstate = QuantumState(3, num_samples)
            qstate.dm = dm
            qstate.set_params('Quantum State Tomography', target_func=get_TMI)  
            try:
                qst_tmi = qstate.quantum_state_tomography()
            except:
                qst_tmi = np.nan
            all_data_qst[name][i].append(qst_tmi)
        all_data_qst[name][i] = np.array(all_data_qst[name][i])
with open (r'output/tmis02-qst.pkl', 'wb') as file:
    pickle.dump(all_data_qst, file)

In [None]:
remaining_names = set(list(filtered_dms.keys()))
completed_names = set()
all_data_csp = dict()

In [None]:
nums_samples = np.logspace(4, 5, num=2, endpoint=True, base=10.0, dtype=None, axis=0)
nums_samples = [int(n) for n in nums_samples]
try:
    for name, dm in filtered_dms.items():
        if name in remaining_names:
            all_data_csp[name] = [[] for _ in range(len(nums_samples))]
            for i in range(len(nums_samples)):
                for _ in range(4):
                    num_samples = nums_samples[i]
                    qstate = QuantumState(3, num_samples)
                    qstate.dm = dm
                    qstate.set_params('Classical Shadow', target_func=get_TMI, sub_mode='Pauli', batch_size=int(5e3))  
                    # try:
                    qst_tmi = qstate.classical_shadow()
                    # except:
                    #     qst_tmi = np.nan
                    all_data_csp[name][i].append(qst_tmi)
                all_data_csp[name][i] = np.array(all_data_csp[name][i])
            remaining_names.remove(name)
            completed_names.add(name)
            print(f"name: {name}'s computation finished")
except:
    winsound.Beep(5000, 5000)
# with open (r'output/tmis02-qst.pkl', 'wb') as file:
#     pickle.dump(all_data_qst, file)

In [None]:
def get_avg_Pauli_expectations(dm, num_qubits):
    all_Pauli = [''.join(obsv) for obsv in list(itertools.product(*['XYZ' for _ in range(num_qubits)]))]
    # all_Pauli.remove(''.join(['I' for _ in range(num_qubits)]))
    all_expectations = [np.trace(dm @ Pauli(obsv).to_matrix()).real for obsv in all_Pauli]
    return np.mean(np.abs(all_expectations))

def plot(name, qst=None, cs=None, csp=None, show=True, save=False):
    dm = filtered_dms[name]
    true_tmi = get_TMI(dm)
    purity = get_purity(dm)
    avg_Pauli = get_avg_Pauli_expectations(dm, 3)
    plt.close()
    fig, ax = plt.subplots()
    xdata = np.logspace(4, 7, num=4, endpoint=True, base=10.0, dtype=None, axis=0)
    ax.plot(xdata, [true_tmi] * len(xdata), linestyle='--', c='black', label='theoretical')
    optm = np.inf
    if qst:
        xdata = np.logspace(4, 7, num=4, endpoint=True, base=10.0, dtype=None, axis=0)
        mean_data = [np.mean(series) for series in all_data_qst[name]]
        std_dev_data = [np.std(series) for series in all_data_qst[name]]
        ax.plot(xdata, mean_data, c='purple', linewidth=2, label='Quantum State Tomography')
        ax.errorbar(xdata, mean_data, yerr=std_dev_data, fmt='o', markersize=4, color='purple' , ecolor='purple', capsize=1)
        rel_errors = (np.array(mean_data) - true_tmi) / true_tmi
        best = rel_errors[np.argmin(np.abs(rel_errors))]
        if np.abs(best) < np.abs(optm):
            optm = best
    if cs:
        xdata = np.logspace(4, 7, num=4, endpoint=True, base=10.0, dtype=None, axis=0)
        mean_data = [np.mean(series) for series in all_data_cs[name]]
        std_dev_data = [np.std(series) for series in all_data_cs[name]]
        # ax.plot(xdata, mean_data, c='green', linewidth=2, label='Compressed Sensing')
        ax.errorbar(xdata, mean_data, yerr=std_dev_data, fmt='o', markersize=4, color='green', ecolor='green', capsize=1)
        # for i in range(5):
        #     temp_data = [series[i] for series in all_data_cs[name]]
        #     ax.scatter(xdata, temp_data)
        rel_errors = (np.array(mean_data) - true_tmi) / true_tmi
        best = rel_errors[np.argmin(np.abs(rel_errors))]
        if np.abs(best) < np.abs(optm):
            optm = best
    if csp:
        xdata = np.logspace(4, 5, num=2, endpoint=True, base=10.0, dtype=None, axis=0)
        mean_data = [np.mean(series) for series in all_data_csp[name]]
        std_dev_data = [np.std(series) for series in all_data_csp[name]]
        # ax.plot(xdata, mean_data, c='blue', linewidth=2, label='Classical Shadow Pauli')
        ax.errorbar(xdata, mean_data, yerr=std_dev_data, fmt='o', markersize=4, color='blue', ecolor='blue', capsize=1)
        rel_errors = (np.array(mean_data) - true_tmi) / true_tmi
        best = rel_errors[np.argmin(np.abs(rel_errors))]
        if np.abs(best) < np.abs(optm):
            optm = best
    ax.set_xscale('log')
    ax.set_ylabel(r'$TMI$')
    ax.set_xlabel(r'number of shots')
    ax.set_title(f'state {name}: purity={purity:.2f}, avg={avg_Pauli:.2f}, optm={optm*100:.2f}% \n TMI={true_tmi:4f} (each point repeated five times)')
    ax.legend(fontsize=8)
    ax.grid()
    if save:
        plt.savefig(os.path.join(r'figs\hybrid', name + '.png'), dpi=300)
    if show:
        plt.show()

In [None]:
# for name in filtered_dms:
#     plot(name, cs=True, qst=True, csp=False, show=False, save=True)
plot('0-0', cs=True, qst=True, csp=False, show=True, save=False)

In [None]:
purities = {'5-6-1': [], '5-6-2': [], '6-7-1': [], '6-7-2': [], '7-8-1': [], '7-8-2': [], '8-9-1': [], '8-9-2': [], '9-0-1': [],  '9-0-2': []}
for name, series in all_data_cs.items():
    dm = filtered_dms[name]
    true_tmi = get_TMI(dm)
    purity = get_purity(dm)
    series = series[3]
    rel_errs = (series - true_tmi) / true_tmi
    best = np.mean(rel_errs)
    if .5 < purity < .55:
        purities['5-6-1'].append(best)
    if .55 < purity < .6:
        purities['5-6-2'].append(best)        
    if .6 < purity < .65:
        purities['6-7-1'].append(best)  
    if .65 < purity < .7:
        purities['6-7-2'].append(best)  
    if .7 < purity < .75:
        purities['7-8-1'].append(best) 
    if .75 < purity < .8:
        purities['7-8-2'].append(best)         
    if .8 < purity < .85:
        purities['8-9-1'].append(best) 
    if .85 < purity < .9:
        purities['8-9-2'].append(best) 
    if .9 < purity < .95:
        purities['9-0-1'].append(best) 
    if .95 < purity < 1:
        purities['9-0-2'].append(best) 


In [None]:
purities

In [None]:
np.arange(5, 10, .5)

In [None]:
xdata = np.arange(5, 10, .5)
data = [np.abs(np.mean(np.abs(values))) for values in purities.values()]
plt.plot(xdata, data)
plt.scatter(xdata, data)