# Dynamic method comparison for state distillation - 5 copies of S state 


In [125]:
import time
import random
import math

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


import quairkit as qkit
from quairkit import Circuit
from quairkit import to_state
from quairkit.database import *
from quairkit.loss import *
from quairkit.qinfo import *
from quairkit.database.hamiltonian import ising_hamiltonian
from quairkit.ansatz import *
from quairkit.operator import ParamOracle

qkit.set_dtype('complex128')

## Universal DLOCCNet

In [126]:
def dynloccnetcir(n):
    cir = Circuit(2*n)
    cir.universal_qudits(qubits_idx=[0,2])
    cir.universal_qudits(qubits_idx=[1,3])

    
    return cir

In [None]:
def loss_func4(cir1, cir2, cir3, cir4, target_state, noisy_state):
    
    input_state1 = torch.kron(noisy_state,noisy_state)
    state1 = cir1(to_state(input_state1,eps=None))
    measure_state = Measure('z'* 2)
    _, m_state = measure_state(state1, qubits_idx=list(range(2,4)),keep_state=True,desired_result='0'*2)
    output_state1 = partial_trace(m_state,1,[4,4]).density_matrix
    
    input_state2 = torch.kron(output_state1,noisy_state)
    state2 = cir2(to_state(input_state2,eps=None))
    _, m_state2 = measure_state(state2, qubits_idx=list(range(2,4)),keep_state=True,desired_result='0'*2)
    output_state2 = partial_trace(m_state2,1,[4,4]).density_matrix
    
    input_state3 = torch.kron(output_state2,noisy_state)
    state3 = cir3(to_state(input_state3,eps=None))
    _, m_state3 = measure_state(state3, qubits_idx=list(range(2,4)),keep_state=True,desired_result='0'*2)
    output_state3 = partial_trace(m_state3,1,[4,4]).density_matrix
    
    input_state4 = torch.kron(output_state3,noisy_state)
    state4 = cir4(to_state(input_state4,eps=None))
    _, m_state4 = measure_state(state4, qubits_idx=list(range(2,4)),keep_state=True,desired_result='0'*2)
    output_state = partial_trace(m_state4,1,[4,4]).density_matrix
 
    f = state_fidelity(target_state,output_state).item()**2
    loss = 1-state_fidelity(target_state,output_state)**2
    
    return loss, output_state,f

In [None]:
def train_model_dyn4(num_itr, LR, n, target_state,noisy_state):
    
    loss_list, time_list = [], []
    
    cir1 = dynloccnetcir(n)
    cir2 = dynloccnetcir(n)
    cir3 = dynloccnetcir(n)
    cir4 = dynloccnetcir(n)
    
    opt_cir1 = torch.optim.Adam(lr=LR, params=cir1.parameters()) # cir is a Circuit type
    opt_cir2 = torch.optim.Adam(lr=LR, params=cir2.parameters()) # cir is a Circuit type
    opt_cir3 = torch.optim.Adam(lr=LR, params=cir3.parameters()) # cir is a Circuit type
    opt_cir4 = torch.optim.Adam(lr=LR, params=cir4.parameters()) # cir is a Circuit type
    scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir1, 'min', factor=0.5) # activate scheduler
    scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir2, 'min', factor=0.5) # activate scheduler
    scheduler3 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir3, 'min', factor=0.5) # activate scheduler
    scheduler4 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_cir4, 'min', factor=0.5) # activate scheduler
    print('Training:')
    
    for itr in range(num_itr):
        start_time = time.time()
        opt_cir1.zero_grad()
        opt_cir2.zero_grad()
        opt_cir3.zero_grad()
        opt_cir4.zero_grad()
        
        loss, output_state3,_ = loss_func4(cir1, cir2, cir3, cir4, target_state, noisy_state) # compute loss
        loss.backward()
        opt_cir1.step()
        opt_cir2.step()
        opt_cir3.step()
        opt_cir4.step()
        scheduler1.step(loss) # activate scheduler
        scheduler2.step(loss)
        scheduler3.step(loss)
        scheduler4.step(loss)
        
        loss = loss.item()
        loss_list.append(loss)
        time_list.append(time.time() - start_time)
        
        if itr % 500 == 0 or itr == num_itr - 1:
            print(
                f"iter: {itr}, loss: {loss:.8f}, lr: {scheduler1.get_last_lr()[0]:.2E}, avg_time: {np.mean(time_list):.4f}s"
            )
            time_list = []

    output_state = output_state3.detach()
    fid = state_fidelity(output_state,target_state).item()**2
    
    return fid

## DEJMPS

In [129]:
def dejmps(n):
    cir = Circuit(2*n)
    cir.rx(param=np.pi/2,qubits_idx=[0,2]) 
    cir.rx(param=-np.pi/2,qubits_idx=[1,3])
    for i in range(2):
        cir.cnot([i,i+2])
    
    return cir

## Our protocol

In [None]:
def simcir1(n,p):
    cir = Circuit(2*n)

    cir.cnot([2,0])
    cir.cnot([3,1])
    cir.cnot([1,3])
    cir.ry(param=np.arccos(1-p)+np.pi,qubits_idx=[2,3]) 
    
    return cir

In [None]:
def simcir2(n):
    cir = Circuit(2*n)

    cir.cnot([2,0])
    cir.cnot([3,1])
    cir.ry(param=-np.pi/2,qubits_idx=[2,3]) 
    
    return cir

# Train

In [None]:
n = 2
NUM_ITR = 1500
LR = 0.1
target_state = bell_state(2).density_matrix

fid_s = []

fid_dynamic = []
fid_de = []
fid_iso1 = []
fid_sim = []

cir_de1 = dejmps(n)

cir_sim2 = simcir2(n)
cir_sim3 = simcir2(n)
cir_sim4 = simcir2(n)

for p1 in range(3,11):
    s1 = (p1/10) * bell_state(2).density_matrix + (1-p1/10) * zero_state(2).density_matrix
    fid1 = state_fidelity(s1,target_state).item()**2
    fid_s.append(fid1)

    # dynamic universal DLOCCNet
    fdyn4 = train_model_dyn4(NUM_ITR, LR, n, target_state, s1)
    fid_dynamic.append(fdyn4)

    # our protocol
    cir_sim1 = simcir1(n,p1/10)
    _, _, fsim4 = loss_func4(cir_sim1, cir_sim2, cir_sim3,cir_sim4, target_state, s1)
    # fsim4 = train_model_sim4(NUM_ITR, LR, n, target_state, s1)
    fid_sim.append(fsim4)

    # dynamic DEJPMS
    _, _, fde4 = loss_func4(cir_de1, cir_de1, cir_de1, cir_de1, target_state, s1)
    fid_de.append(fde4)

    print('dynamic universal dloccnet',fid_dynamic)
    print('our protocol',fid_sim)
    print('dynamic dejmps',fid_de)
    print('no distillation',fid_s)


Training:
iter: 0, loss: 0.95730093, lr: 1.00E-01, avg_time: 0.0079s
iter: 500, loss: 0.13909109, lr: 2.50E-02, avg_time: 0.0049s
iter: 1000, loss: 0.04230156, lr: 3.13E-03, avg_time: 0.0047s
iter: 1499, loss: 0.04230014, lr: 1.19E-08, avg_time: 0.0046s
dynamic [0.9576998511712534]
simplify [0.9746251484448544]
dejmps [0.9224533512234494]
no distillation [0.6500000004837357]
Training:
iter: 0, loss: 0.91947402, lr: 1.00E-01, avg_time: 0.0047s
iter: 500, loss: 0.01752289, lr: 5.00E-02, avg_time: 0.0045s
iter: 1000, loss: 0.01219715, lr: 1.25E-02, avg_time: 0.0045s
iter: 1499, loss: 0.00849160, lr: 6.25E-03, avg_time: 0.0046s
dynamic [0.9576998511712534, 0.9915084087841272]
simplify [0.9746251484448544, 0.9913294803661782]
dejmps [0.9224533512234494, 0.9673650359393162]
no distillation [0.6500000004837357, 0.7000000064444352]
Training:
iter: 0, loss: 0.85885657, lr: 1.00E-01, avg_time: 0.0046s
iter: 500, loss: 0.01664705, lr: 2.50E-02, avg_time: 0.0048s
iter: 1000, loss: 0.00542861, lr: 