In [None]:
import torch
from quairkit import Circuit
from quairkit.database import *
from quairkit.qinfo import *
import itertools
import time
import random
import math
import numpy as np
import matplotlib.pyplot as plt
import quairkit as qkit
from quairkit import Circuit, to_state
from quairkit.loss import *
from quairkit.database.hamiltonian import ising_hamiltonian
from quairkit.ansatz import *
from quairkit.operator import ParamOracle
import datetime

qkit.set_dtype('complex128')

def generate_single_pair(a):
    rho_bell = bell_state(2).density_matrix
    zero2 = torch.tensor([1, 0], dtype=torch.cdouble)
    one2 = torch.tensor([0, 1], dtype=torch.cdouble)
    basis00 = torch.kron(zero2, zero2)
    basis11 = torch.kron(one2, one2)
    psi_minus = (basis00 - basis11) / math.sqrt(2)
    rho_minus = psi_minus.unsqueeze(1) @ psi_minus.conj().unsqueeze(0)

    E0 = torch.tensor([[1, 0], [0, math.sqrt(1 - a)]], dtype=torch.cdouble)
    E1 = torch.tensor([[0, math.sqrt(a)], [0, 0]], dtype=torch.cdouble)
    K = [torch.kron(Ei, Ej) for Ei in (E0, E1) for Ej in (E0, E1)]
    rho_damp = sum(Kij @ rho_minus @ Kij.conj().T for Kij in K)
    
    ## case of distinguishing the noisy phi^+ and the noisy phi^-, corresponding to Figure S9
    # rho_damp_plus = sum(Kij @ rho_bell @ Kij.conj().T for Kij in K) 
    
    rho_damp_plus = rho_bell # case of distinguishing  phi^+ and the noisy phi^-, corresponding to Figure 7

    return rho_damp_plus, rho_damp

def generate_initial_6qubit_states(a):
    rho_bell, rho_damp = generate_single_pair(a)
    state_bell = torch.kron(torch.kron(rho_bell, rho_bell), rho_bell)
    state_damp = torch.kron(torch.kron(rho_damp, rho_damp), rho_damp)
    batch = torch.stack([state_bell, state_damp])
    inputs = to_state(batch, eps=None)
    
    return inputs


def prepare_new_pair_for_stage(outputs, a, stage_num):
    rho_bell, rho_damp = generate_single_pair(a)
    if hasattr(outputs, 'density_matrix'):
        density_matrices = outputs.density_matrix
    else:
        density_matrices = outputs
    batch_outputs = []
    
    for i in range(2):  
        if hasattr(outputs[i], 'density_matrix'):
            current_state = outputs[i].density_matrix
        else:
            current_state = density_matrices[i]
        
        state_obj = to_state(current_state, eps=None)
        traced_state = partial_trace(state_obj, 1, [16, 4])
        
        if hasattr(traced_state, 'density_matrix'):
            preserved_state = traced_state.density_matrix
        else:
            preserved_state = traced_state
        
        new_pair = rho_bell if i == 0 else rho_damp
    
        new_6qubit_state = torch.kron(preserved_state, new_pair)
        
        batch_outputs.append(new_6qubit_state)
    
    batch = torch.stack(batch_outputs)
    
    return to_state(batch, eps=None)

def create_2copy_circuit():
    cir = Circuit(6)
    cir.swap([1,2])
    cir.universal_two_qubits([0,1])
    cir.param_locc(universal2, 15, [[0,1], 2, 3], label='M1', support_batch=False)
    return cir

def create_stage1_circuit():
    cir = Circuit(6)
    cir.swap([1,2])
    cir.universal_two_qubits([0,1])
    cir.param_locc(universal2, 15, [0, 2, 3], label='M1', support_batch=False)
    return cir

def create_stage2_circuit():
    cir = Circuit(6)
    cir.param_locc(universal2, 15, [2, 4, 1], label='M21', support_batch=False)
    cir.param_locc(universal2, 15, [4, 5, 3], label='M22', support_batch=False)
    cir.swap([0, 4])
    cir.swap([2, 5])
    return cir

def create_middle_stage_circuit(stage_num):
    cir = Circuit(6)
    cir.param_locc(universal2, 15, [2, 4, 1], label=f'M{stage_num}1', support_batch=False)
    cir.param_locc(universal2, 15, [4, 5, 3], label=f'M{stage_num}2', support_batch=False)
    cir.swap([0, 4])
    cir.swap([2, 5])
    return cir

def create_final_stage_circuit(stage_num):
    cir = Circuit(6)
    cir.param_locc(universal2, 15, [2, 4, 1], label=f'M{stage_num}1', support_batch=False)
    cir.param_locc(universal2, 15, [[4,1], 5, 3], label=f'M{stage_num}2', support_batch=False)
    cir.swap([0, 4])
    cir.swap([2, 5])
    return cir

def create_all_circuits(n_copy):

    n_stages = n_copy - 1
    circuits = []
    
    if n_copy == 2:

        circuits.append(create_2copy_circuit())
    else:
        circuits.append(create_stage1_circuit())
        
        if n_stages == 2:
            circuits.append(create_final_stage_circuit(n_stages))
        else:
            circuits.append(create_stage2_circuit())
            for stage in range(3, n_stages):
                circuits.append(create_middle_stage_circuit(stage))
            circuits.append(create_final_stage_circuit(n_stages))
    
    return circuits


def loss_func_general(outputs, measure_qubits=[2, 3]):

    meas = Measure('zz')
    
    probs_batch, outputsh = meas(outputs, qubits_idx=measure_qubits, keep_state=True)
    prob_tensor = outputsh.probability
    
    shape = prob_tensor.shape
    batch_size = shape[0]
    num_branches = 1
    for i in range(1, len(shape) - 1):
        num_branches *= shape[i]
    
    prob_reshaped = prob_tensor.reshape(batch_size, num_branches, 4)
    probs_cond_reshaped = probs_batch.reshape(batch_size, num_branches, 4)
    branch_weights = prob_reshaped.sum(dim=2)
    
    marginalized_probs_list = []
    for b in range(batch_size):
        marginal_b = torch.zeros(4, dtype=torch.double)
        for j in range(4):
            weighted_sum = torch.sum(branch_weights[b] * probs_cond_reshaped[b, :, j])
            marginal_b[j] = weighted_sum
        marginalized_probs_list.append(marginal_b)
    
    marginalized_probs = torch.stack(marginalized_probs_list)
    
    normalized_probs_list = []
    for b in range(batch_size):
        norm_sum = marginalized_probs[b].sum()
        normalized_prob = marginalized_probs[b] / norm_sum
        normalized_probs_list.append(normalized_prob)

    normalized_probs = torch.stack(normalized_probs_list)
    
    p0 = normalized_probs[0].real
    p1 = normalized_probs[1].real
    
    F0_bell = p0[0] + p0[1]
    T0_bell = p0[2] + p0[3]
    eps = 1e-10
    p0_prob = F0_bell / (F0_bell + T0_bell + eps)

    F1_damped = p1[2] + p1[3]
    T1_damped = p1[0] + p1[1]
    p1_prob = F1_damped / (F1_damped + T1_damped + eps)
    
    loss = 0.5 * p0_prob + 0.5 * p1_prob
    
    return loss, outputs

def loss_func_all_stages(circuits, inputs, a, n_copy):
    n_stages = n_copy - 1
    current_output = inputs
    
    if n_copy == 2:
        # 2-copy
        current_output = circuits[0](current_output)
    else:
        # >3-copy
        # Stage 1
        current_output = circuits[0](current_output)
        
        if n_stages == 2:
            # Final Stage
            current_output = circuits[1](current_output)
        else:
            # >= 4-copy
            # Stage 2
            current_output = circuits[1](current_output)
            
            # Stage 3 ro Stage N-1
            for stage_idx in range(2, n_stages):
                current_output = prepare_new_pair_for_stage(current_output, a, stage_idx + 1)
                current_output = circuits[stage_idx](current_output)

    loss, final_outputs = loss_func_general(current_output, measure_qubits=[2, 3])
    
    return loss, final_outputs

def train_all_stages(num_itr, lr, a, n_copy, seed=None):

    if seed is not None:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
    
    inputs = generate_initial_6qubit_states(a)
    
    circuits = create_all_circuits(n_copy)
    n_stages = len(circuits)
    
    optimizers = []
    schedulers = []
    for i, cir in enumerate(circuits):
        opt = torch.optim.Adam(cir.parameters(), lr=lr)
        optimizers.append(opt)
        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', factor=0.98, patience=12)
        schedulers.append(sched)
    
    best_loss = float('inf')
    best_states = None
    
    for itr in range(num_itr):
        for opt in optimizers:
            opt.zero_grad()
    
        loss, _ = loss_func_all_stages(circuits, inputs, a, n_copy)
        loss.backward()

        for opt in optimizers:
            opt.step()
        for sched in schedulers:
            sched.step(loss)
        
        lv = loss.item()
        if lv < best_loss:
            best_loss = lv
            best_states = {}
            for i, cir in enumerate(circuits):
                best_states[f'stage{i+1}'] = {k: v.cpu().clone() for k, v in cir.state_dict().items()}

        if itr % 600 == 0 or itr == num_itr - 1:
            print(f"    iter {itr:4d}/{num_itr} | loss = {lv:.8f} | best_loss = {best_loss:.8f}")
    if best_states:
        for i, cir in enumerate(circuits):
            cir.load_state_dict(best_states[f'stage{i+1}'])
    
    return best_loss, circuits

if __name__ == '__main__':

    N_COPY = 4  # N-copy, >= 2 
    NUM_ITR = 600
    print(f"\n>>> Running {N_COPY}-copy version <<<\n")

  # ==========================================
    seed_nums_a00 = 1    
    seed_nums_a01 = 10   
    seed_nums_a02 = 10    
    seed_nums_a03 = 10   
    seed_nums_a04 = 10   
    seed_nums_a05 = 10   
    seed_nums_a06 = 10    
    seed_nums_a07 = 10    
    seed_nums_a08 = 10  
    seed_nums_a09 = 10    
    seed_nums_a10 = 10   

    seed_nums_list = [seed_nums_a00, seed_nums_a01, seed_nums_a02, seed_nums_a03, seed_nums_a04, seed_nums_a05, seed_nums_a06, seed_nums_a07, seed_nums_a08, seed_nums_a09, seed_nums_a10]



    seed_start_a00 = 10    
    seed_start_a01 = 50  
    seed_start_a02 = 100   
    seed_start_a03 = 150   
    seed_start_a04 = 200  
    seed_start_a05 = 250  
    seed_start_a06 = 300  
    seed_start_a07 = 350  
    seed_start_a08 = 400  
    seed_start_a09 = 450 
    seed_start_a10 = 500  
    
    seed_start_list = [seed_start_a00, seed_start_a01, seed_start_a02, seed_start_a03, seed_start_a04, seed_start_a05, seed_start_a06, seed_start_a07, seed_start_a08, seed_start_a09, seed_start_a10]



    seed_interval_a00 = 50    
    seed_interval_a01 = 50  
    seed_interval_a02 = 50  
    seed_interval_a03 = 50  
    seed_interval_a04 = 50 
    seed_interval_a05 = 50
    seed_interval_a06 = 50  
    seed_interval_a07 = 50   
    seed_interval_a08 = 50 
    seed_interval_a09 = 50   
    seed_interval_a10 = 50 
    
    seed_interval_list = [seed_interval_a00, seed_interval_a01, seed_interval_a02, seed_interval_a03, seed_interval_a04, seed_interval_a05, seed_interval_a06, seed_interval_a07, seed_interval_a08, seed_interval_a09, seed_interval_a10]

    # ==========================================
    

    START_POINT = 0

    a_vals_full = np.linspace(0.0, 1.0, 11)
    
    a_vals = a_vals_full[START_POINT:]
    results = []
    best_seeds_used = []  
    detailed_results = []  
    
    print(f"Starting from a = {a_vals[0]:.1f} (START_POINT = {START_POINT})")
    print(f"Will train for a values: {[f'{a:.1f}' for a in a_vals]}")
    
  
    lr_a00 = 0.15   # a=0.0
    lr_a01 = 0.15   # a=0.1
    lr_a02 = 0.15   # a=0.2
    lr_a03 = 0.15   # a=0.3
    lr_a04 = 0.15   # a=0.4
    lr_a05 = 0.15   # a=0.5
    lr_a06 = 0.15   # a=0.6
    lr_a07 = 0.15   # a=0.7
    lr_a08 = 0.15   # a=0.8
    lr_a09 = 0.15   # a=0.9
    lr_a10 = 0.15   # a=1.0
    
    lr_list = [lr_a00, lr_a01, lr_a02, lr_a03, lr_a04, lr_a05, lr_a06, lr_a07, lr_a08, lr_a09, lr_a10]

    for idx, a in enumerate(a_vals):
        print(f"\n{'='*60}")
        print(f"Processing a = {a:.2f}")
        print(f"{'='*60}")
        
        actual_idx = idx + START_POINT
        
        lr = lr_list[actual_idx]
        num_seeds = seed_nums_list[actual_idx]
        seed_start = seed_start_list[actual_idx]
        seed_interval = seed_interval_list[actual_idx]
        
        best_loss_for_a = float('inf')
        best_seed_for_a = None
        best_sp_for_a = 0
        seed_results = []
        seed_list = []
        
        print(f"Trying {num_seeds} seeds for a={a:.2f}, lr={lr}")
        print(f"Seed range: {seed_start} to {seed_start + (num_seeds-1)*seed_interval} with interval {seed_interval}")
        
        for seed_idx in range(num_seeds):
            current_seed = seed_start + seed_idx * seed_interval
            seed_list.append(current_seed)
            
            print(f"\n  Seed {seed_idx+1}/{num_seeds}: seed={current_seed}")
            
            best_loss, _ = train_all_stages(NUM_ITR, lr, a, N_COPY, seed=current_seed)
            
            sp = 1 - best_loss
            seed_results.append(sp)
            
            print(f"  -> Result: success_prob={sp:.6f}, loss={best_loss:.6f}")
            
            if best_loss < best_loss_for_a:
                best_loss_for_a = best_loss
                best_seed_for_a = current_seed
                best_sp_for_a = sp
        
        results.append(best_sp_for_a)
        best_seeds_used.append(best_seed_for_a)
        detailed_results.append({
            'a': a,
            'best_sp': best_sp_for_a,
            'best_seed': best_seed_for_a,
            'all_seeds': seed_list.copy(),
            'all_results': seed_results.copy(),
            'lr': lr,
            'num_seeds': num_seeds,
            'seed_start': seed_start,
            'seed_interval': seed_interval
        })
        
        print(f"\nBest result for a={a:.2f}:")
        print(f"  Best seed: {best_seed_for_a}")
        print(f"  Best success probability: {best_sp_for_a:.6f}")
        print(f"  Best loss: {best_loss_for_a:.6f}")
        print(f"  All results: {[f'{sp:.4f}' for sp in seed_results]}")

    print(f"\n{'='*60}")
    print("FINAL RESULTS SUMMARY")
    print(f"{'='*60}")
    for idx, (a, sp, seed) in enumerate(zip(a_vals, results, best_seeds_used)):
        print(f"a={a:.2f} | success_prob={sp:.6f} | best_seed={seed}")

    plt.figure(figsize=(6, 4))
    plt.plot(a_vals, results, marker='o')
    plt.xlabel(r'$\gamma$')
    plt.ylabel('Success Probability (1 - loss)')
    plt.title(f'Success Probability vs AD channel parameter $\\gamma$ ({N_COPY}-copy)')
    plt.grid(True)
    plt.show()
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    if START_POINT > 0:
        filename = f"results_{N_COPY}copy_6qubit_multiseed_from{START_POINT}_{timestamp}.txt"
        data_filename = f"data_{N_COPY}copy_6qubit_multiseed_from{START_POINT}_{timestamp}.txt"
        detail_filename = f"detailed_{N_COPY}copy_6qubit_multiseed_from{START_POINT}_{timestamp}.txt"
    else:
        filename = f"results_{N_COPY}copy_6qubit_multiseed_{timestamp}.txt"
        data_filename = f"data_{N_COPY}copy_6qubit_multiseed_{timestamp}.txt"
        detail_filename = f"detailed_{N_COPY}copy_6qubit_multiseed_{timestamp}.txt"
    
    with open(filename, 'w') as f:

        f.write(f"{N_COPY}-copy 6-qubit state discrimination results (Multi-Seed)\n")
        f.write(f"Date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Number of iterations: {NUM_ITR}\n")
        f.write(f"Start point: {START_POINT} (starting from a={a_vals[0]:.1f})\n")
        f.write("\nSeed configuration for each point:\n")
        for idx, a in enumerate(a_vals):
            actual_idx = idx + START_POINT
            f.write(f"  a={a:.2f}: {seed_nums_list[actual_idx]} seeds, start={seed_start_list[actual_idx]}, interval={seed_interval_list[actual_idx]}\n")
        f.write("="*50 + "\n")
        f.write("a\tSuccess Probability\tBest Seed\tLR\tSeeds Tried\n")
        f.write("-"*50 + "\n")
        

        for idx, (a, sp, best_seed) in enumerate(zip(a_vals, results, best_seeds_used)):
            actual_idx = idx + START_POINT
            f.write(f"{a:.2f}\t{sp:.6f}\t{best_seed}\t{lr_list[actual_idx]}\t{seed_nums_list[actual_idx]}\n")
        

        f.write("-"*50 + "\n")
        f.write(f"Average success probability: {np.mean(results):.6f}\n")
        f.write(f"Maximum success probability: {np.max(results):.6f} at a={a_vals[np.argmax(results)]:.2f}\n")
        f.write(f"Minimum success probability: {np.min(results):.6f} at a={a_vals[np.argmin(results)]:.2f}\n")
        f.write(f"Total seeds used: {sum([d['num_seeds'] for d in detailed_results])}\n")
    
    print(f"\nResults saved to {filename}")
    

    with open(detail_filename, 'w') as f:
        f.write(f"{N_COPY}-copy 6-qubit State Discrimination Detailed Results\n")
        f.write(f"Date: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write("="*80 + "\n")
        
        for detail in detailed_results:
            f.write(f"\na = {detail['a']:.2f} (lr = {detail['lr']}, num_seeds = {detail['num_seeds']}):\n")
            f.write(f"  Seed configuration: start={detail['seed_start']}, interval={detail['seed_interval']}\n")
            f.write(f"  Best result: {detail['best_sp']:.6f} (seed {detail['best_seed']})\n")
            f.write(f"  All seeds: {detail['all_seeds']}\n")
            f.write(f"  All results: {[f'{sp:.6f}' for sp in detail['all_results']]}\n")
            if len(detail['all_results']) > 1:
                f.write(f"  Improvement over worst: {detail['best_sp'] - min(detail['all_results']):.6f}\n")
                f.write(f"  Standard deviation: {np.std(detail['all_results']):.6f}\n")
    
    print(f"Detailed results saved to {detail_filename}")
    
    np.savetxt(data_filename, np.column_stack([a_vals, results, best_seeds_used]), 
               delimiter='\t', header='a\tsuccess_probability\tbest_seed', comments='')
    print(f"Data saved to {data_filename}")